Content Overview
- Shared layers
- Extract and reuse nodes in the graph of layers
- Extend the API using custom layers
- When to use the Functional API
- Functional API strengths
- Functional API weaknesses
- Mix-and-match API styles
Shared layers
Another good use for the functional API are models that use shared layers. Shared layers are layer instances that are reused multiple times in the same model — they learn features that correspond to multiple paths in the graph-of-layers.
Shared layers are often used to encode inputs from similar spaces (say, two different pieces of text that feature similar vocabulary). They enable sharing of information across these different inputs, and they make it possible to train such a model on less data. If a given word is seen in one of the inputs, that will benefit the processing of all inputs that pass through the shared layer.
To share a layer in the functional API, call the same layer instance multiple times. For instance, here’s an Embedding
layer shared across two different text inputs:
# Embedding for 1000 unique words mapped to 128-dimensional vectors
shared_embedding = layers.Embedding(1000, 128)
# Variable-length sequence of integers
text_input_a = keras.Input(shape=(None,), dtype="int32")
# Variable-length sequence of integers
text_input_b = keras.Input(shape=(None,), dtype="int32")
# Reuse the same layer to encode both inputs
encoded_input_a = shared_embedding(text_input_a)
encoded_input_b = shared_embedding(text_input_b)
Extract and reuse nodes in the graph of layers
Because the graph of layers you are manipulating is a static data structure, it can be accessed and inspected. And this is how you are able to plot functional models as images.
This also means that you can access the activations of intermediate layers (“nodes” in the graph) and reuse them elsewhere — which is very useful for something like feature extraction.
Let’s look at an example. This is a VGG19 model with weights pretrained on ImageNet:
vgg19 = keras.applications.VGG19()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels.h5
574710816/574710816 [==============================] - 4s 0us/step
And these are the intermediate activations of the model, obtained by querying the graph data structure:
features_list = [layer.output for layer in vgg19.layers]
Use these features to create a new feature-extraction model that returns the values of the intermediate layer activations:
feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)
img = np.random.random((1, 224, 224, 3)).astype("float32")
extracted_features = feat_extraction_model(img)
This comes in handy for tasks like neural style transfer, among other things.
Extend the API using custom layers
keras
includes a wide range of built-in layers, for example:
- Convolutional layers:
Conv1D
,Conv2D
,Conv3D
,Conv2DTranspose
- Pooling layers:
MaxPooling1D
,MaxPooling2D
,MaxPooling3D
,AveragePooling1D
- RNN layers:
GRU
,LSTM
,ConvLSTM2D
BatchNormalization
,Dropout
,Embedding
, etc.
But if you don’t find what you need, it’s easy to extend the API by creating your own layers. All layers subclass the Layer
class and implement:
call
method, that specifies the computation done by the layer.build
method, that creates the weights of the layer (this is just a style convention since you can create weights in__init__
, as well).
To learn more about creating layers from scratch, read custom layers and models guide.
The following is a basic implementation of keras.layers.Dense
:
class CustomDense(layers.Layer):
def __init__(self, units=32):
super().__init__()
self.units = units
def build(self, input_shape):
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,), initializer="random_normal", trainable=True
)
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
inputs = keras.Input((4,))
outputs = CustomDense(10)(inputs)
model = keras.Model(inputs, outputs)
For serialization support in your custom layer, define a get_config
method that returns the constructor arguments of the layer instance:
@keras.saving.register_keras_serializable()
class CustomDense(layers.Layer):
def __init__(self, units=32):
super().__init__()
self.units = units
def build(self, input_shape):
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,), initializer="random_normal", trainable=True
)
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
def get_config(self):
return {"units": self.units}
inputs = keras.Input((4,))
outputs = CustomDense(10)(inputs)
model = keras.Model(inputs, outputs)
config = model.get_config()
new_model = keras.Model.from_config(config)
Optionally, implement the class method from_config(cls, config)
which is used when recreating a layer instance given its config dictionary. The default implementation of from_config
is:
def from_config(cls, config):
return cls(**config)
When to use the functional API
Should you use the Keras functional API to create a new model, or just subclass the Model
class directly? In general, the functional API is higher-level, easier and safer, and has a number of features that subclassed models do not support.
However, model subclassing provides greater flexibility when building models that are not easily expressible as directed acyclic graphs of layers. For example, you could not implement a Tree-RNN with the functional API and would have to subclass Model
directly.
For an in-depth look at the differences between the functional API and model subclassing, read What are Symbolic and Imperative APIs in TensorFlow 2.0?.
Functional API strengths:
The following properties are also true for Sequential models (which are also data structures), but are not true for subclassed models (which are Python bytecode, not data structures).
Less verbose
There is no super().__init__(...)
, no def call(self, ...):
, etc.
Compare:
inputs = keras.Input(shape=(32,))
x = layers.Dense(64, activation='relu')(inputs)
outputs = layers.Dense(10)(x)
With the subclassed version:
class MLP(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense_1 = layers.Dense(64, activation='relu')
self.dense_2 = layers.Dense(10)
def call(self, inputs):
x = self.dense_1(inputs)
return self.dense_2(x)
# Instantiate the model.
mlp = MLP()
# Necessary to create the model's state.
# The model doesn't have a state until it's called at least once.
_ = mlp(tf.zeros((1, 32)))
Model validation while defining its connectivity graph
In the functional API, the input specification (shape and dtype) is created in advance (using Input
). Every time you call a layer, the layer checks that the specification passed to it matches its assumptions, and it will raise a helpful error message if not.
This guarantees that any model you can build with the functional API will run. All debugging — other than convergence-related debugging — happens statically during the model construction and not at execution time. This is similar to type checking in a compiler.
A functional model is plottable and inspectable
You can plot the model as a graph, and you can easily access intermediate nodes in this graph. For example, to extract and reuse the activations of intermediate layers (as seen in a previous example):
features_list = [layer.output for layer in vgg19.layers]
feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)
A functional model can be serialized or cloned
Because a functional model is a data structure rather than a piece of code, it is safely serializable and can be saved as a single file that allows you to recreate the exact same model without having access to any of the original code. See the serialization & saving guide.
To serialize a subclassed model, it is necessary for the implementer to specify a get_config()
and from_config()
method at the model level.
Functional API weakness:
It does not support dynamic architectures
The functional API treats models as DAGs of layers. This is true for most deep learning architectures, but not all — for example, recursive networks or Tree RNNs do not follow this assumption and cannot be implemented in the functional API.
Mix-and-match API styles
Choosing between the functional API or Model subclassing isn’t a binary decision that restricts you into one category of models. All models in the keras
API can interact with each other, whether they’re Sequential
models, functional models, or subclassed models that are written from scratch.
You can always use a functional model or Sequential
model as part of a subclassed model or layer:
units = 32
timesteps = 10
input_dim = 5
# Define a Functional model
inputs = keras.Input((None, units))
x = layers.GlobalAveragePooling1D()(inputs)
outputs = layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
@keras.saving.register_keras_serializable()
class CustomRNN(layers.Layer):
def __init__(self):
super().__init__()
self.units = units
self.projection_1 = layers.Dense(units=units, activation="tanh")
self.projection_2 = layers.Dense(units=units, activation="tanh")
# Our previously-defined Functional model
self.classifier = model
def call(self, inputs):
outputs = []
state = tf.zeros(shape=(inputs.shape[0], self.units))
for t in range(inputs.shape[1]):
x = inputs[:, t, :]
h = self.projection_1(x)
y = h + self.projection_2(state)
state = y
outputs.append(y)
features = tf.stack(outputs, axis=1)
print(features.shape)
return self.classifier(features)
rnn_model = CustomRNN()
_ = rnn_model(tf.zeros((1, timesteps, input_dim)))
(1, 10, 32)
You can use any subclassed layer or model in the functional API as long as it implements a call
method that follows one of the following patterns:
call(self, inputs, **kwargs)
— Whereinputs
is a tensor or a nested structure of tensors (e.g. a list of tensors), and where**kwargs
are non-tensor arguments (non-inputs).call(self, inputs, training=None, **kwargs)
— Wheretraining
is a boolean indicating whether the layer should behave in training mode and inference mode.call(self, inputs, mask=None, **kwargs)
— Wheremask
is a boolean mask tensor (useful for RNNs, for instance).call(self, inputs, training=None, mask=None, **kwargs)
— Of course, you can have both masking and training-specific behavior at the same time.
Additionally, if you implement the get_config
method on your custom Layer or model, the functional models you create will still be serializable and cloneable.
Here’s a quick example of a custom RNN, written from scratch, being used in a functional model:
units = 32
timesteps = 10
input_dim = 5
batch_size = 16
@keras.saving.register_keras_serializable()
class CustomRNN(layers.Layer):
def __init__(self):
super().__init__()
self.units = units
self.projection_1 = layers.Dense(units=units, activation="tanh")
self.projection_2 = layers.Dense(units=units, activation="tanh")
self.classifier = layers.Dense(1)
def call(self, inputs):
outputs = []
state = tf.zeros(shape=(inputs.shape[0], self.units))
for t in range(inputs.shape[1]):
x = inputs[:, t, :]
h = self.projection_1(x)
y = h + self.projection_2(state)
state = y
outputs.append(y)
features = tf.stack(outputs, axis=1)
return self.classifier(features)
# Note that you specify a static batch size for the inputs with the `batch_shape`
# arg, because the inner computation of `CustomRNN` requires a static batch size
# (when you create the `state` zeros tensor).
inputs = keras.Input(batch_shape=(batch_size, timesteps, input_dim))
x = layers.Conv1D(32, 3)(inputs)
outputs = CustomRNN()(x)
model = keras.Model(inputs, outputs)
rnn_model = CustomRNN()
_ = rnn_model(tf.zeros((1, 10, 5)))
Originally published on the TensorFlow website, this article appears here under a new headline and is licensed under CC BY 4.0. Code samples shared under the Apache 2.0 License.