HOw to fine tune niftynet pre trained model for cu

2019-08-18 03:10发布

I want to use niftynet pretrained segmentation model for segmenting custom data. I downloaded the pre trained weights and and modified model_dir path to downloaded one.

However when I run

python3 net_segment.py train -c /home/Container_data/config/promise12_demo_train_config.ini

I am getting the error below.

Caused by op 'save/Assign_17', defined at:
    File "net_segment.py", line 8, in <module>
      sys.exit(main())
    File "/home/NiftyNet/niftynet/__init__.py", line 142, in main
      app_driver.run(app_driver.app)
    File "/home/NiftyNet/niftynet/engine/application_driver.py", line 197, in run
      SESS_STARTED.send(application, iter_msg=None)
    File "/usr/local/lib/python3.5/dist-packages/blinker/base.py", line 267, in send
      for receiver in self.receivers_for(sender)]
    File "/usr/local/lib/python3.5/dist-packages/blinker/base.py", line 267, in <listcomp>
      for receiver in self.receivers_for(sender)]
    File "/home/NiftyNet/niftynet/engine/handler_model.py", line 109, in restore_model
      var_list=to_restore, save_relative_paths=True)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1102, in __init__
      self.build()
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1114, in build
      self._build(self._filename, build_save=True, build_restore=True)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 1151, in _build
      build_save=build_save, build_restore=build_restore)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 795, in _build_internal
      restore_sequentially, reshape)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 428, in _AddRestoreOps
      assign_ops.append(saveable.restore(saveable_tensors, shapes))
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py", line 119, in restore
      self.op.get_shape().is_fully_defined())
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/state_ops.py", line 221, in assign
      validate_shape=validate_shape)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_state_ops.py", line 61, in assign
      use_locking=use_locking, name=name)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
      op_def=op_def)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
      return func(*args, **kwargs)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3274, in create_op
      op_def=op_def)
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1770, in __init__
      self._traceback = tf_stack.extract_stack()
  InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
  Assign requires shapes of both tensors to match. lhs shape= [3,3,61,256] rhs shape= [3,3,3,61,9]
           [[node save/Assign_17 (defined at /home/NiftyNet/niftynet/engine/handler_model.py:109)  = Assign[T=DT_FLOAT, _class=["loc:@DenseVNet/conv/conv_/w"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](DenseVNet/conv/conv_/w, save/RestoreV2/_35)

https://github.com/tensorflow/models/issues/5390 Above link says to add

--initialize_last_layer = False
--last_layers_contain_logits_only = False

Can some one help me how to get rid of this error.

标签: niftynet
1条回答
ゆ 、 Hurt°
2楼-- · 2019-08-18 03:52

It seems you are having problems with your last layer. When you use a pretrained model on a new task you probably need to change your last layer to fit your new requirements.

In order to do that you should modify your config file by restoring all vars but last layer: vars_to_restore = ^((?!(last_layer_name)).)*$

and then set num_classes to suit your new segmentation problem.

You can check transfer learning docs here: https://niftynet.readthedocs.io/en/dev/transfer_learning.html

查看更多
登录 后发表回答