How does mask_zero in Keras Embedding layer work?

2019-02-24 07:23发布

I thought mask_zero=True will output 0's when the input value is 0, so the following layers could skip computation or something.

How does mask_zero works?

Example:

data_in = np.array([
  [1, 2, 0, 0]
])
data_in.shape
>>> (1, 4)

# model
x = Input(shape=(4,))
e = Embedding(5, 5, mask_zero=True)(x)

m = Model(inputs=x, outputs=e)
p = m.predict(data_in)
print(p.shape)
print(p)

The actual output is: (the numbers are random)

(1, 4, 5)
[[[ 0.02499047  0.04617121  0.01586803  0.0338897   0.009652  ]
  [ 0.04782704 -0.04035913 -0.0341589   0.03020919 -0.01157228]
  [ 0.00451764 -0.01433611  0.02606953  0.00328832  0.02650392]
  [ 0.00451764 -0.01433611  0.02606953  0.00328832  0.02650392]]]

However, I thought the output will be:

[[[ 0.02499047  0.04617121  0.01586803  0.0338897   0.009652  ]
  [ 0.04782704 -0.04035913 -0.0341589   0.03020919 -0.01157228]
  [ 0 0 0 0 0]
  [ 0 0 0 0 0]]]

1条回答
看我几分像从前
2楼-- · 2019-02-24 07:31

Actually, setting mask_zero=True for the Embedding layer does not result in returning a zero vector. Rather, the behavior of the Embedding layer would not change and it would return the embedding vector with index zero. You can confirm this by checking the Embedding layer weights (i.e. in the example you mentioned it would be m.layers[0].get_weights()). Instead, it would affect the behavior of the following layers such as RNN layers.

If you inspect the source code of Embedding layer you would see a method called compute_mask:

def compute_mask(self, inputs, mask=None):
    if not self.mask_zero:
        return None
    output_mask = K.not_equal(inputs, 0)
    return output_mask

This output mask will be passed, as the mask argument, to the following layers which support masking. This has been implemented in the __call__ method of base layer, Layer:

# Handle mask propagation.
previous_mask = _collect_previous_mask(inputs)
user_kwargs = copy.copy(kwargs)
if not is_all_none(previous_mask):
    # The previous layer generated a mask.
    if has_arg(self.call, 'mask'):
        if 'mask' not in kwargs:
            # If mask is explicitly passed to __call__,
            # we should override the default mask.
            kwargs['mask'] = previous_mask

And this makes the following layers to ignore (i.e. does not consider in their computations) this inputs steps. Here is a minimal example:

data_in = np.array([
  [1, 0, 2, 0]
])

x = Input(shape=(4,))
e = Embedding(5, 5, mask_zero=True)(x)
rnn = LSTM(3, return_sequences=True)(e)

m = Model(inputs=x, outputs=rnn)
m.predict(data_in)

array([[[-0.00084503, -0.00413611,  0.00049972],
        [-0.00084503, -0.00413611,  0.00049972],
        [-0.00144554, -0.00115775, -0.00293898],
        [-0.00144554, -0.00115775, -0.00293898]]], dtype=float32)

As you can see the outputs of the LSTM layer for the second and forth timesteps are the same as the output of first and third timesteps, respectively. This means that those timesteps have been masked.

Update: The mask will also be considered when computing the loss since the loss functions are internally augmented to support masking using weighted_masked_objective:

def weighted_masked_objective(fn):
    """Adds support for masking and sample-weighting to an objective function.
    It transforms an objective function `fn(y_true, y_pred)`
    into a sample-weighted, cost-masked objective function
    `fn(y_true, y_pred, weights, mask)`.
    # Arguments
        fn: The objective function to wrap,
            with signature `fn(y_true, y_pred)`.
    # Returns
        A function with signature `fn(y_true, y_pred, weights, mask)`.
    """

when compiling the model:

weighted_losses = [weighted_masked_objective(fn) for fn in loss_functions]

You can verify this using the following example:

data_in = np.array([[1, 2, 0, 0]])
data_out = np.arange(12).reshape(1,4,3)

x = Input(shape=(4,))
e = Embedding(5, 5, mask_zero=True)(x)
d = Dense(3)(e)

m = Model(inputs=x, outputs=d)
m.compile(loss='mse', optimizer='adam')
preds = m.predict(data_in)
loss = m.evaluate(data_in, data_out, verbose=0)
print(preds)
print('Computed Loss:', loss)

[[[ 0.009682    0.02505393 -0.00632722]
  [ 0.01756451  0.05928303  0.0153951 ]
  [-0.00146054 -0.02064196 -0.04356086]
  [-0.00146054 -0.02064196 -0.04356086]]]
Computed Loss: 9.041069030761719

# verify that only the first two outputs 
# have been considered in the computation of loss
print(np.square(preds[0,0:2] - data_out[0,0:2]).mean())

9.041070036475277
查看更多
登录 后发表回答