keras model subclassing examples

2020-08-23 04:49发布

问题:

starting Keras 2.2.0, the 3rd API of model definition is released: Model subclassing.

According to the FAQ:

However, in subclassed models, the model's topology is defined as Python code (rather than as a static graph of layers). That means the model's topology cannot be inspected or serialized. As a result, the following methods and attributes are not available for subclassed models:

model.inputs and model.outputs. model.to_yaml() and model.to_json() model.get_config() and model.save().

The only option to save the trained model for inference is to use model.save_weights method. However, I have not had luck in loading the model back for inference. Encountered error messages include:

This model has never been called, thus its weights have not yet been created, so no summary can be displayed. Build the model first (e.g. by calling it on some test data). You are trying to load a weight file containing 4 layers into a model with 0 layers. NotImplementedError

Can anyone give a full toy example for creating a subclassed keras model, train, and save_weights, then load it back for inference?

回答1:

You need to call the tf.keras.Model.build method before you try to save a subclassed model weights. An alternative to this would be calling tf.keras.Model.fit or tf.keras.Model.fit.call on some inputs before you try to save your model weights. This same applies to load weights into a newly created instance of your subclassed model. you need to call one of the above-mentioned methods before you try to load your weights. Here is an example showing both saving and loading weights for a subclassed model

import tensorflow as tf

print('TensorFlow', tf.__version__)

class ResidualBlock(tf.keras.Model):
    def __init__(self, block_type=None, n_filters=None):
        super(ResidualBlock, self).__init__()
        self.n_filters = n_filters
        if block_type == 'identity':
            self.strides = 1
        elif block_type == 'conv':
            self.strides = 2
            self.conv_shorcut = tf.keras.layers.Conv2D(filters=self.n_filters, 
                               kernel_size=1, 
                               padding='same',
                               strides=self.strides,
                               kernel_initializer='he_normal')
            self.bn_shortcut = tf.keras.layers.BatchNormalization(momentum=0.9)

        self.conv_1 = tf.keras.layers.Conv2D(filters=self.n_filters, 
                               kernel_size=3, 
                               padding='same',
                               strides=self.strides,
                               kernel_initializer='he_normal')
        self.bn_1 = tf.keras.layers.BatchNormalization(momentum=0.9)
        self.relu_1 = tf.keras.layers.ReLU()

        self.conv_2 = tf.keras.layers.Conv2D(filters=self.n_filters, 
                               kernel_size=3, 
                               padding='same', 
                               kernel_initializer='he_normal')
        self.bn_2 = tf.keras.layers.BatchNormalization(momentum=0.9)
        self.relu_2 = tf.keras.layers.ReLU()

    def call(self, x, training=False):
        shortcut = x
        if self.strides == 2:
            shortcut = self.conv_shorcut(x)
            shortcut = self.bn_shortcut(shortcut)
        y = self.conv_1(x)
        y = self.bn_1(y)
        y = self.relu_1(y)
        y = self.conv_2(y)
        y = self.bn_2(y)
        y = tf.add(shortcut, y)
        y = self.relu_2(y)
        return y

class ResNet34(tf.keras.Model):
    def __init__(self, include_top=True, n_classes=1000):
        super(ResNet34, self).__init__()

        self.n_classes = n_classes
        self.include_top = include_top
        self.conv_1 = tf.keras.layers.Conv2D(filters=64, 
                                               kernel_size=7, 
                                               padding='same', 
                                               strides=2, 
                                               kernel_initializer='he_normal')
        self.bn_1 = tf.keras.layers.BatchNormalization(momentum=0.9)
        self.relu_1 = tf.keras.layers.ReLU()
        self.maxpool = tf.keras.layers.MaxPool2D(3, 2, padding='same')
        self.residual_blocks = tf.keras.Sequential()
        for n_filters, reps, downscale in zip([64, 128, 256, 512], 
                                              [3, 4, 6, 3], 
                                              [False, True, True, True]):
            for i in range(reps):
                if i == 0 and downscale:
                    self.residual_blocks.add(ResidualBlock(block_type='conv', 
                                                              n_filters=n_filters))
                else:
                    self.residual_blocks.add(ResidualBlock(block_type='identity', 
                                                              n_filters=n_filters))
        self.GAP = tf.keras.layers.GlobalAveragePooling2D()
        self.fc = tf.keras.layers.Dense(units=self.n_classes)

    def call(self, x, training=False):
        y = self.conv_1(x)
        y = self.bn_1(y)
        y = self.relu_1(y)
        y = self.maxpool(y)
        y = self.residual_blocks(y)
        if self.include_top:
            y = self.GAP(y)
            y = self.fc(y)
        return y

## saving weights
model = ResNet34()
model.build((1, 224, 224, 3))
model.summary()
model.save_weights('model_weights.h5')

## loading saved weights
model_new = ResNet34()
model_new.build((1, 224, 224, 3))
model_new.load_weights('model_weights.h5')