Using bvlc_googlenet as pretrained model in digits

2019-01-29 04:34发布

问题:

digits 4.0 0.14.0-rc.3 /Ubuntu (aws)

training a 5 class GoogLenet model with about 800 training samples in each class. I was trying to use the bvlc_imagent as pre-trained model. These are the steps I took:

  1. downloaded imagenet from http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel and placed it in /home/ubuntu/models

  2. 2.

a. Pasted the "train_val.prototxt" from here https://github.com/BVLC/caffe/blob/master/models/bvlc_reference_caffenet/train_val.prototxt into the custom network tab and

b. '#' commented out the "source" and "backend" lines (since it was complaning about them)

  1. In the pre-trained models text box pasted the path to the '.caffemodel'. in my case: "/home/ubuntu/models/bvlc_googlenet.caffemodel"

I get this error:

ERROR: Cannot copy param 0 weights from layer 'loss1/classifier'; shape mismatch. Source param shape is 1 1 1000 1024 (1024000); target param shape is 6 1024 (6144). To learn this layer's parameters from scratch rather than copying from a saved net, rename the layer.

I have pasted various train_val.prototext from github issues etc and no luck unfortunately,

I am not sure why this is getting so complicated, in older versions of digits, we could just enter the path to the folder and it was working great for transfer learning.

Could someone help?

回答1:

Rename the layer from "loss1/classifier" to "loss1/classifier_retrain".

When fine-tuning a model, here's what Caffe does:

# pseudo-code
for layer in new_model:
  if layer.name in old_model:
    new_model.layer.weights = old_model.layer.weights

You're getting an error because the weights for "loss1/classifier" were for a 1000-class classification problem (1000x1024), and you're trying to copy them into a layer for a 6-class classification problem (6x1024). When you rename the layer, Caffe doesn't try to copy the weights for that layer and you get randomly initialized weights - which is what you want.

Also, I suggest you use this network description which is already set up as an all-in-one network description for GoogLeNet. It will save you some trouble.

https://github.com/NVIDIA/DIGITS/blob/digits-4.0/digits/standard-networks/caffe/googlenet.prototxt