How do I implement transfer learning in NiftyNet?

2019-06-26 05:11发布

I'd like to perform some transfer learning using the NiftyNet stack, as my dataset of labeled images is rather small. In TensorFlow, this is possible--I can load a variety of pre-trained networks and directly work with their layers. To fine-tune the network, I could freeze training of the intermediate layers and only train the final layer, or I could just use the output of the intermediate layers as a feature vector to feed into another classifier.

How do I do this in NiftyNet? The only mention of "transfer learning" in the documentation or the source code is in reference to the model zoo, but for my task (image classification), there are no networks available in the zoo. The ResNet architecture seems to be implemented and available to use, but as far as I can tell, it's not trained on anything yet. In addition, it seems the only way I can train a network is by running net_classify train, using the various TRAIN configuration options in the config file, none of which have options for freezing networks. The various layers in niftynet.layer also do not seem to have options to enable them to be trained or not.

I suppose the questions I have are:

  1. Is it possible to port over a pre-trained TensorFlow network?
    • If I manually recreate the layer architecture in NiftyNet, is there a way to import the weights from a pre-trained TF network?
  2. How do I access the intermediate weights and layers of a model? (How can I get access to intermediate activation maps of the pre-trained models in NiftyNet? refers to the model zoo, where they can be obtained using net_download, but not to any arbitrary model)
  3. As an aside, it also seems that learning rate is a constant--to vary this over time, would I have to run the network for some number of iterations, change lr, then restart training from the last checkpoint?

1条回答
叛逆
2楼-- · 2019-06-26 05:31

[Edit]: Here are the docs for transfer learning with NiftyNet.

This feature is currently being worked on. See here for full details.

Intended capabilities include the following:

  • Command for printing all trainable variable names (with optional regular expression matching)
  • Ability to randomly initialize a subset of variables, this subset is created by regex name matching
  • Ability to restore (from an existing checkpoint) and continue updating a subset of the variables. If the optimization method is changed, deal with method-specific variables (e.g. momentum)
  • Being able to restore (from an existing checkpoint) and freeze trained weights for the rest of the variables
  • Saving all trainable variables after training
  • Add configuration parameters for finetuning, variable name regex, unit tests
  • A demo/tutorial
  • Preprocess the checkpoints for compatibility issues
  • Deal with batch norm and dropout layers (editing networks to remove batch norm variables)
查看更多
登录 后发表回答