How train_on_batch() is different from fit()? What are the cases when we should use train_on_batch()?
问题:
回答1:
I believe you mean to compare train_on_batch
with fit
(and variations like fit_generator
), since train
is not a commonly available API function for Keras.
For this question, it's a simple answer from the primary author:
With fit_generator, you can use a generator for the validation data as well. In general I would recommend using fit_generator, but using train_on_batch works fine too. These methods only exist as for the sake of convenience in different use cases, there is no "correct" method.
train_on_batch
allows you to expressly update weights based on a collection of samples you provide, without regard to any fixed batch size. You would use this in cases when that is what you want: to train on an explicit collection of samples. You could use that approach to maintain your own iteration over multiple batches of a traditional training set but allowing fit
or fit_generator
to iterate batches for you is likely simpler.
One case when it might be nice to use train_on_batch
is for updating a pre-trained model on a single new batch of samples. Suppose you've already trained and deployed a model, and sometime later you've received a new set of training samples previously never used. You could use train_on_batch
to directly update the existing model only on those samples. Other methods can do this too, but it is rather explicit to use train_on_batch
for this case.
Apart from special cases like this (either where you have some pedagogical reason to maintain your own cursor across different training batches, or else for some type of semi-online training update on a special batch), it is probably better to just always use fit
(for data that fits in memory) or fit_generator
(for streaming batches of data as a generator).
回答2:
train_on_batch()
gives you greater control of the state of the LSTM, for example, when using a stateful LSTM and controlling calls to model.reset_states()
is needed. You may have multi-series data and need to reset the state after each series, which you can do with train_on_batch()
, but if you used .fit()
then the network would be trained on all the series of data without resetting the state. There's no right or wrong, it depends on what data you're using, and how you want the network to behave.