How to make the generated data in remote worker sp

2019-08-06 06:24发布

问题:

I use the in-graph replication of tensorflow to do distributed training. For reducing communicaiton cost purpose, i need hold some generated data (such as the cell states in LSTM) in some remote worker in one training iteration to next iteration, but i found that i can not achieve it.

If i use the fetch operation of 'session.run' interface to retrieve the data generated in one remote worker, and feed the data to this remoter worker in the next training iteration, the unnecessary network costs are produced, as below codes show:

cluster = tf.train.ClusterSpec({"worker": ["remoteIP0:port", "remoteIP1:port"]})
...

for i in xrange(2):
  with tf.device("/woker:local/task:%d" % i):
    with tf.name_scope('%s_%d' % (TOWER_NAME, i)) as scope:
      # execute code for building the model replica and one taining
      # step.
      ...
      initial_state[i] = ...
      ...
      weight[i] = ...
      bias[i] = ...
      cost[i] = ...
      ...
      gradient[i] =
      final_state[i] = 
      ...
 grad = aggregate_func(gradient[0], gradient[1])
 optimizer = tf.train.GradientDescentOptimizer(lr)
 train_op = optimizer.apply_gradient(grad)

 ...
 with tf.Session("grpc://localhost:port") as session:
   ...
   for k in xrange(max_step):
     cost, finalstate, _ = session.run([cost, final_state, train_op],
                             {initial_state: finalstate})
   ...

The 'final_state[i]' generated in iteration k need be assigned to the 'initial_state[i]' in iteration k+1 for every remote worker, how can we do the assignment in remote worker machine without fetching to the master(grpc://localhost:port) machine and feeding again to the remote workers ?

回答1:

Variable objects and persistent tensors can replace feed_dict, as Yaroslav proposed. Thanks Yaroslav.