Updating configuration protobuffer in tensorflow o

2019-08-19 06:41发布

In their respective config files, by default the faster-RCNNs have only random horizontal flips enabled and the SSDs have random horizontal flips followed by SSD random crop enabled. I want to add more augmentation options. I wrote the following snippet to do so.

import tensorflow as tf
from object_detection.protos import pipeline_pb2, preprocessor_pb2
from google.protobuf import text_format

def get_configs_from_pipeline_file(pipeline_config_path):

  pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  with tf.gfile.GFile(pipeline_config_path, 'r') as f:
    proto_str = f.read()
    text_format.Merge(proto_str, pipeline_config)
  return pipeline_config


config = get_configs_from_pipeline_file(
    'ssd_inception_v2_coco.config')

VERTICALFLIP = preprocessor_pb2.PreprocessingStep()
VERTICALFLIP.random_vertical_flip.SetInParent()

config.train_config.data_augmentation_options.extend([VERTICALFLIP])
print(config.train_config.data_augmentation_options)

The above snippet appends random vertical flip after SSD random crops as the next augmentation step in ssd_inception_v2_coco.config, but I want to add it in between random horizontal flip and SSD random crops. Because,I think the augmentation functions are applied in this order while training(correct me if I am wrong).

UPDATE : A workaround I found is to delete SSD random crop and reinsert it at the end.

config = get_configs_from_pipeline_file(
    'ssd_inception_v2_coco.config')

del config.train_config.data_augmentation_options[-1]

VERTICALFLIP = preprocessor_pb2.PreprocessingStep()
VERTICALFLIP.random_vertical_flip.SetInParent()

SSD_RANDOM_CROP = preprocessor_pb2.PreprocessingStep()
SSD_RANDOM_CROP.ssd_random_crop.SetInParent()

config.train_config.data_augmentation_options.extend([VERTICALFLIP])
config.train_config.data_augmentation_options.extend([SSD_RANDOM_CROP])

I there a direct way to do it like we are able to insert into a python list at a specific index like:

a=[10,20,30]
a.insert(-1,100)

I am new to working with google protocol buffers

0条回答
登录 后发表回答