I have a server with multiple GPUs and want to make full use of them during model inference inside a java app. By default tensorflow seizes all available GPUs, but uses only the first one.
I can think of three options to overcome this issue:
Restrict device visibility on process level, namely using
CUDA_VISIBLE_DEVICES
environment variable.That would require me to run several instances of the java app and distribute traffic among them. Not that tempting idea.
Launch several sessions inside a single application and try to assign one device to each of them via
ConfigProto
:public class DistributedPredictor { private Predictor[] nested; private int[] counters; // ... public DistributedPredictor(String modelPath, int numDevices, int numThreadsPerDevice) { nested = new Predictor[numDevices]; counters = new int[numDevices]; for (int i = 0; i < nested.length; i++) { nested[i] = new Predictor(modelPath, i, numDevices, numThreadsPerDevice); } } public Prediction predict(Data data) { int i = acquirePredictorIndex(); Prediction result = nested[i].predict(data); releasePredictorIndex(i); return result; } private synchronized int acquirePredictorIndex() { int i = argmin(counters); counters[i] += 1; return i; } private synchronized void releasePredictorIndex(int i) { counters[i] -= 1; } } public class Predictor { private Session session; public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) { GPUOptions gpuOptions = GPUOptions.newBuilder() .setVisibleDeviceList("" + deviceIdx) .setAllowGrowth(true) .build(); ConfigProto config = ConfigProto.newBuilder() .setGpuOptions(gpuOptions) .setInterOpParallelismThreads(numDevices * numThreadsPerDevice) .build(); byte[] graphDef = Files.readAllBytes(Paths.get(modelPath)); Graph graph = new Graph(); graph.importGraphDef(graphDef); this.session = new Session(graph, config.toByteArray()); } public Prediction predict(Data data) { // ... } }
This approach seems to work fine at a glance. However, sessions occasionally ignore
setVisibleDeviceList
option and all go for the first device causing Out-Of-Memory crash.Build the model in a multi-tower fashion in python using
tf.device()
specification. On java side, give differentPredictor
s different towers inside a shared session.Feels cumbersome and idiomatically wrong to me.
UPDATE: As @ash proposed, there's yet another option:
Assign an appropriate device to each operation of the existing graph by modifying its definition (
graphDef
).To get it done, one could adapt the code from Method 2:
public class Predictor { private Session session; public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) { byte[] graphDef = Files.readAllBytes(Paths.get(modelPath)); graphDef = setGraphDefDevice(graphDef, deviceIdx) Graph graph = new Graph(); graph.importGraphDef(graphDef); ConfigProto config = ConfigProto.newBuilder() .setAllowSoftPlacement(true) .build(); this.session = new Session(graph, config.toByteArray()); } private static byte[] setGraphDefDevice(byte[] graphDef, int deviceIdx) throws InvalidProtocolBufferException { String deviceString = String.format("/gpu:%d", deviceIdx); GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder(); for (int i = 0; i < builder.getNodeCount(); i++) { builder.getNodeBuilder(i).setDevice(deviceString); } return builder.build().toByteArray(); } public Prediction predict(Data data) { // ... } }
Just like other mentioned approaches, this one doesn't set me free from manually distributing data among devices. But at least it works stably and is comparably easy to implement. Overall, this looks like an (almost) normal technique.
Is there an elegant way to do such a basic thing with tensorflow java API? Any ideas would be appreciated.