How to give a constant input to keras

2019-05-02 05:47发布

问题:

My network has two time-series inputs. One of the input has a fixed vector repeating for every time step. Is there an elegant way to load this fixed vector into the model just once and use it for computation?

回答1:

You can create a static input using the tensor argument as described by jdehesa, however the tensor should be a Keras (not tensorflow) variable. You can create this as follows:

from keras.layers import Input
from keras import backend as K

constants = [1,2,3]
k_constants = K.variable(constants)
fixed_input = Input(tensor=k_constants)


回答2:

EDIT: Apparently the answer below does not work (nowadays anyway). See Creating constant value in Keras for a related answer.


Looking at the source (I haven't been able to find a reference in the docs), it looks like you can just use Input and pass it a constant Theano/TensorFlow tensor.

from keras.layers import Input
import tensorflow as tf

fixed_input = Input(tensor=tf.constant([1, 2, 3, 4]))

This will "wrap" the tensor (actually more like "extend" it with metadata) so you can use it with any Keras layer.



回答3:

Something to add: When you come to compile the model you need to give the constant input as an input otherwise the graph disconnects

#your input
inputs = Input(shape = (input_shape,))

# an array of ones
constants = [1] * input_shape

# make the array a variable
k_constants = K.variable(constants, name = "ones_variable") 

# make the variable a tensor
ones_tensor = Input(tensor=k_constants, name = "ones_tensor")

# do some layers
inputs = (Some_Layers())(inputs)

# get the complementary of the outputs
output = Subtract()([ones_tensor,inputs])

model = Model([inputs, ones_tensor],output)
model.complie(some_params)

when you train you can just feed in the data you have, you don't need the constant layer anymore.

I have found that no matter what you try it's usually easier to just use a custom layer and take advantage of the power of numpy:

class Complementry(Layer):

    def __init__(self, **kwargs):
        super(Complementry, self).__init__(**kwargs)

    def build(self, input_shape):
        super(Complementry, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        return 1-x  # here use MyArray + x