I'd like to know the possible ways to implement batch normalization layers with synchronizing batch statistics when training with multi-GPU.
Caffe Maybe there are some variants of caffe that could do, like link. But for BN layer, my understanding is that it still synchronizes only the outputs of layers, not the means and vars. Maybe MPI can synchronizes means and vars but I think MPI is a little difficult to implemnt.
Torch I've seen some comments here and here, which show the running_mean and running_var can be synchronized but I think batch mean and batch var can not or are difficult to synchronize.
Tensorflow Normally, it is the same as caffe and torch. The implementation of BN refers this. I know tensorflow can distribute an operation to any device specified by tf.device()
. But the computation of means and vars is in the middle of BN layer, so if I gather the means and vars in cpu, my code will be like this:
cpu_gather = []
label_batches = []
for i in range(num_gpu):
with tf.device('/gpu:%d' % i):
with tf.variable_scope('block1', reuse=i > 0):
image_batch, label_batch = cifar_input.build_input('cifar10', train_data_path, batch_size, 'train')
label_batches.append(label_batch)
x = _conv('weights', image_batch, 3, 3, 16, _stride_arr(1))
block1_gather.append(x)
with tf.device('/cpu:0'):
print block1_gather[0].get_shape()
x1 = tf.concat(block1_gather, 0)
# print x1.get_shape()
mean, variance = tf.nn.moments(x1, [0, 1, 2], name='moments')
for i in range(num_gpu):
with tf.device('/gpu:%d' % i):
with tf.variable_scope('block2', reuse=i > 0):
shape = cpu_gather[i].get_shape().as_list()
assert len(shape) in [2, 4]
n_out = shape[-1]
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, True, True)
x = tf.nn.batch_normalization(
cpu_gather[i], mean, variance, beta, gamma, 0.00001)
x = _relu(x)
That is just for one BN layer. For gathering statistics in cpu, I have to break the code. If I have more than 100 BN layers, that will be cumbersome.
I am not expert in those libraries so maybe there are some misunderstanding, feel free to point out my errors.
I do not care much about training speed. I am doing image segmentation which consumes much GPU memory and BN needs a reasonable batch size (e.g. larger than 16) for stable statistics. So using multi-GPU is inevitable. In my opinion, tensorflow might be the best choice but I can't resolve the breaking code problem. Solution with other libraries will be welcome too.