Tensorflow and Multiprocessing: Passing Sessions

2019-01-22 04:37发布

I have recently been working on a project that uses a neural network for virtual robot control. I used tensorflow to code it up and it runs smoothly. So far, I used sequential simulations to evaluate how good the neural network is, however, I want to run several simulations in parallel to reduce the amount of time it takes to get data.

To do this I am importing python's multiprocessing package. Initially I was passing the sess variable (sess=tf.Session()) to a function that would run the simulation. However, once I get to any statement that uses this sess variable, the process quits without a warning. After searching around for a bit I found these two posts: Tensorflow: Passing a session to a python multiprocess and Running multiple tensorflow sessions concurrently

While they are highly related I haven't been able to figure out how to make it work. I tried creating a session for each individual process and assigning the weights of the neural net to its trainable parameters without success. I've also tried saving the session into a file and then loading it within a process, but no luck there either.

Has someone been able to pass a session (or clones of sessions) to several processes?

Thanks.

1条回答
成全新的幸福
2楼-- · 2019-01-22 05:14

I use keras as a wrapper with tensorflow as a backed, but the same general principal should apply.

If you try something like this:

import keras
from functools import partial
from multiprocessing import Pool

def ModelFunc(i,SomeData):
    YourModel = Here
    return(ModelScore)

pool = Pool(processes = 4)
for i,Score in enumerate(pool.imap(partial(ModelFunc,SomeData),range(4))):
    print(Score)

It will fail. However, if you try something like this:

from functools import partial
from multiprocessing import Pool

def ModelFunc(i,SomeData):
    import keras
    YourModel = Here
    return(ModelScore)

pool = Pool(processes = 4)
for i,Score in enumerate(pool.imap(partial(ModelFunc,SomeData),range(4))):
    print(Score)

It should work. Try calling tensorflow separately for each process.

查看更多
登录 后发表回答