4-step Alternating RPN / Faster R-CNN Training? -

2020-07-26 15:50发布

问题:

Been going through the recently released tensorflow/models/../object_detection models, particularly faster r-cnn.

The paper mentions 4-step alternating training, where you would

  1. train the RPN, then freeze RPN layers,
  2. train RCNN, then freeze RCNN layers,
  3. train RPN, then freeze RPN layers
  4. train RCNN.

From what I gather, at stage 2 = RCNN, RPN is indeed frozen with:

if self._is_training:
    proposal_boxes = tf.stop_gradient(proposal_boxes) 

So train RPN + freeze RPN Layers, followed by RCNN training is covered, but where are the other 3 steps performed?

Am I missing something?

回答1:

Our implementation of Faster R-CNN in the TF Object Detection API follows the paper quite closely but differs in a few ways. And one of those differences is that we train the model end-to-end instead of the alternating training used in the paper.

The stop_gradient that you mention doesn't actually freeze the RPN --- what it does is it ignores the contribution of the gradient through the proposal coordinates, but continues to allow the gradient to pass through the RPN features.

Hope this helps!



回答2:

Because of the StackOverflow's ridiculous rule, I cannot add comment. I have to write it here as an "answer". But this is actually following the questions in @Jonathan Huang's response.

I am still confused about the stop gradient. If we stop the gradient of boxes, how could the RPN box accuracy be improved? In this case, it seems only the detection and the RPN objectness accuracies are improved, but the RPN boxes accuracy could never be improved.

Although the RPN loss is composed of box loss and objectness loss, disabling gradients for location may lead to the parameters in the layer estimating 4K coordinates from 256D tensor, for example, becoming constant. Then how RPN box locations be improved?

Could anyone provide some explanations? Thank you



回答3:

I am also looking into performing the 4-step alternate training mentioned in the paper. My understanding about the correct implementation should be: 1. Train the shared conv-layer + RPN, retrieve the region proposals 2. Train Fast RCNN with the region proposals as input (note: not Faster RCNN) 3. Initialize Faster RCNN with weights from the Fast RCNN in step 2, train RPN part only 4. Fix the shared conv-layer and PRN, only train the bottom network.

Step 2 require some amendment to the tf-faster rcnn implementation. For other steps, you should be able to fix the weight by setting the trainable flag to false in the network.py module

def _region_proposal(self, net_conv, is_training, initializer):
rpn = slim.conv2d(net_conv, cfg.RPN_CHANNELS, [3, 3], trainable=*setThisToFalse*, weights_initializer=initializer,
                    scope="rpn_conv/3x3")
self._act_summaries.append(rpn)
rpn_cls_score = slim.conv2d(rpn, self._num_anchors * 2, [1, 1], trainable=*setThisToFalse*,
                            weights_initializer=initializer,
                            padding='VALID', activation_fn=None, scope='rpn_cls_score')


回答4:

Multi-stage training might still be possible - The object detection API config file has provision to freeze some layers of the network while training. (freeze_variables parameter in train_config)

If you closely inspect the checkpoints that a model from the TF object detection API generates, these are the outer variable name scopes corresponding to the network architecture from the Faster RCNN paper -

Region proposal network:

  • 'FirstStageFeatureExtractor/InceptionResnetV2' (shared)
  • 'Conv/biases', 'Conv/weights'
  • 'FirstStageBoxPredictor

Detector:

  • 'FirstStageFeatureExtractor/InceptionResnetV2' (shared)
  • 'SecondStageFeatureExtractor/InceptionResnetV2'
  • 'SecondStageBoxPredictor

So what you could do is perform successive rounds of training while freezing the layers you don't want to update. Also note the classification and localization loss weights for the first and second stage can be set to zero in the config files if you don't want them to contribute to updates.

Hope this helps!