-->

Parallel processes in distributed tensorflow

2020-07-17 06:02发布

问题:

I have neural network in tensorflow with trained parameters, it is "policy" for agent. Network is being updated in training loop in main tensorflow session in core program.

In the end of each training cycle I need to pass this network to few parallel processes ("workers"), which will use it for collecting samples from interactions of agent's policy with environment.

I need to do it in parallel, because simulating environment takes most of the time and runs only single-core. So, few parallel sampling processes are needed. I am struggling how to structure this in distributed tensorflow. What I considered so far:

  1. Create main session in core program, where global network is updated. Spawn processes with python multiprocessing and pass them network's global parameters (can I just pass network as an argument?). Then in each process separate Session is created, where network runs.

Minimal example for this is here (also inserted code below): https://gist.github.com/dd210/e1808efcc4362cab949ad0337ba600a9

Problem with this example that it sometimes hangs on sess.run in 2nd process. And sometimes it runs smoothly (!). So, there got to be some fundamental issue with this approach.. In my real code 2nd agent just always hangs on sess.run.

  1. Somehow use between-graph replication and Supervisor (for managing sessions) in order to create one master session (in core program) and use replicas of global network on the workers. Seems more accurate, but here I don't know how to structure it. Example of code from official tutorial is quite different.

Would be grateful for advice.

Code to p.1:

import time
import tensorflow as tf

from multiprocessing import Process

N_WORKERS = 2
SPEC = {'worker': ['127.0.0.1:12824', '127.0.0.1:12825']}

def run_worker(task):
    spec = tf.train.ClusterSpec(SPEC)
    server = tf.train.Server(spec, job_name='worker', task_index=task)
    sess = tf.Session(server.target)
    x = tf.Variable(0., dtype=tf.float32, name='x')
    sess.run(tf.global_variables_initializer())
    print 'result: ', sess.run(x)

def main(_):
    workers = []
    for i in xrange(2):
        p = Process(target=run_worker, args=(i,))
        p.start()
        workers.append(p)
        time.sleep(1)
    for w in workers: w.join()

if __name__ == '__main__':
    tf.app.run()