Stop and Restart Training on VGG-16

2019-03-01 08:16发布

I am using pre-trained VGG-16 model for image classification. I am adding custom last layer as the number of my classification classes are 10. I am training the model for 200 epochs.

My question is: is there any way if I randomly stop (by closing python window) the training at some epoch, let's say epoch no. 50 and resume from there? I have read about saving and reloading model but my understanding is that works for our custom models only instead of pre-trained models like VGG-16.

2条回答
可以哭但决不认输i
2楼-- · 2019-03-01 09:16

Here is a customised version of ModelCheckpoint that I use to resume training from a given epoch, gist. It will save the epoch and other logs to a corresponding JSON file, it will also check whether to resume the training or not when starting. You need to call get_last_epoch and set initial_epoch in model.fit in order to resume from that epoch.

import json

class StatefulCheckpoint(ModelCheckpoint):
  """Save extra checkpoint data to resume training."""
  def __init__(self, weight_file, state_file=None, **kwargs):
    """Save the state (epoch etc.) along side weights."""
    super().__init__(weight_file, **kwargs)
    self.state_f = state_file
    self.state = dict()
    if self.state_f:
      # Load the last state if any
      try:
        with open(self.state_f, 'r') as f:
          self.state = json.load(f)
        self.best = self.state['best']
      except Exception as e: # pylint: disable=broad-except
        print("Skipping last state:", e)

  def on_train_begin(self, logs=None):
    prefix = "Resuming" if self.state else "Starting"
    print("{} training...".format(prefix))

  def on_epoch_end(self, epoch, logs=None):
    """Saves training state as well as weights."""
    super().on_epoch_end(epoch, logs)
    if self.state_f:
      state = {'epoch': epoch+1, 'best': self.best}
      state.update(logs)
      state.update(self.params)
      with open(self.state_f, 'w') as f:
        json.dump(state, f)

  def get_last_epoch(self, initial_epoch=0):
    """Return last saved epoch if any, or return default argument."""
    return self.state.get('epoch', initial_epoch)
查看更多
\"骚年 ilove
3楼-- · 2019-03-01 09:18

You can use ModelCheckpoint callback to save your model regularly. To use it, pass a callbacks parameter to the fit method:

from keras.callbacks import ModelCheckpoint
checkpointer = ModelCheckpoint(filepath='model-{epoch:02d}.hdf5', ...)
model.fit(..., callbacks=[checkpointer])

Then, later you can load the last saved model. For more customization of this callback take a look at the documentation.

查看更多
登录 后发表回答