Keras Visualization of Model Built from Functional

2020-02-10 08:30发布

I wanted to ask if there was an easy way to visualize a Keras model built from the Functional API?

Right now, the best ways to debug at a high level a sequential model for me is:

model = Sequential()
model.add(...
...

print(model.summary())
SVG(model_to_dot(model).create(prog='dot', format='svg'))

However, I am having a hard time finding a good way to visualize the Keras API if we build a more complex, non-sequential model.

2条回答
地球回转人心会变
2楼-- · 2020-02-10 08:59

After a bit of googling and trial/error... Turns out you have to just convert the entire functional api model back into a "model format".

model = some_model()
output_layer = _build_output()
finalmodel = Model(inputs=model.input, outputs=finalmodel)

then, you can run finalmodel.summary(), or any of the plotting features for sequential modeling.

However, this requires I guess careful tracking of the model, which I admittedly did not do.

查看更多
▲ chillily
3楼-- · 2020-02-10 09:13

Yes there is, try checking the keras.utils which has a method plot_model() as explained on detail here. Seems that you already are familiar with keras.utils.vis_utils and the model_to_dot method, but this is another option. It's usage is something like:

from keras.utils import plot_model
plot_model(model, to_file='model.png')

To be honest, that is the best I have managed to find using Keras only. Using model.summary() as you did is also useful sometimes. I also wished there were some tool to enable for better visualization of one's models, perhaps even to be able to see the weights per layers as to decide on optimal network structures and initializations (if you know about one please tell :] ).


Probably the best option you currently have is to visualize things on Tensorboard, which you an include in Keras with the TensorBoard Callback. This enables you to visualize your training and the metrics of interest, as well as some info on activations of your layers,your biases and kernels, etc.. Basically you have to add this code to your program, before fitting your model:

from keras.callbacks import TensorBoard
#indicate folder to save, plus other options
tensorboard = TensorBoard(log_dir='./logs/run1', histogram_freq=1,
    write_graph=True, write_images=False)  

#save it in your callback list, where you can include other callbacks
callbacks_list = [tensorboard]
#then pass to fit as callback, remember to use validation_data also
regressor.fit(X, Y, callbacks=callbacks_list, epochs=64, 
    validation_data=(X_test, Y_test), shuffle=True)

You can then run Tensorboard (which runs locally on a webservice) with the following command on your terminal:

tensorboard --logdir=/logs/run1

This will then indicate you in which port to visualize your training. If you got different runs you can pass --logdir=/logs instead to be able to visualize them together for comparison. There are of course more options on the use of Tensorboard, so I suggest you check the included links if you are considering its use.

查看更多
登录 后发表回答