I am trying to train a CNN with two input branches. And these two branches (b1, b2) are to be merged into a densely connected layer of 256 neurons with dropout rate of 0.25. This is what I have so far:
batch_size, epochs = 32, 3
ksize = 2
l2_lambda = 0.0001
### My first model(b1)
b1 = Sequential()
b1.add(Conv1D(128*2, kernel_size=ksize,
activation='relu',
input_shape=( xtest.shape[1], xtest.shape[2]),
kernel_regularizer=keras.regularizers.l2(l2_lambda)))
b1.add(Conv1D(128*2, kernel_size=ksize, activation='relu',kernel_regularizer=keras.regularizers.l2(l2_lambda)))
b1.add(MaxPooling1D(pool_size=ksize))
b1.add(Dropout(0.2))
b1.add(Conv1D(128*2, kernel_size=ksize, activation='relu',kernel_regularizer=keras.regularizers.l2(l2_lambda)))
b1.add(MaxPooling1D(pool_size=ksize))
b1.add(Dropout(0.2))
b1.add(Flatten())
###My second model (b2)
b2 = Sequential()
b2.add(Dense(64, input_shape = (5000,), activation='relu',kernel_regularizer=keras.regularizers.l2(l2_lambda)))
b2.add(Dropout(0.1))
##Merging the two models
model = Sequential()
model.add(concatenate([b1, b2],axis = -1))
model.add(Dense(256, activation='relu', kernel_initializer='normal',kernel_regularizer=keras.regularizers.l2(l2_lambda)))
model.add(Dropout(0.25))
model.add(Dense(num_classes, activation='softmax'))
But when I concatenate it gives me the following error:
I first tried using the following command:
model.add(Merge([b1, b2], mode = 'concat'))
But I got the error that 'ImportError: cannot import name 'Merge''. I am using keras 2.2.2 and python 3.6.
You need to use the functional API to achieve what you are looking for. You can either use
Concatenate
layer or its equivalent functional APIconcatenate
:Note that I have only converted the last part of your model to functional form. You can do the same thing for the other two models
b1
andb2
(actually, it seems that the architecture you are trying to define is one single model that consists of two branches that are merged together). At the end, usemodel.summary()
to see and recheck the architecture of the model.