I'm trying to apply a mask (binary, only one channel) to an RGB image (3 channels, normalized to [0, 1]). My current solution is, that I split the RGB image into it's channels, multiply it with the mask and concatenate these channels again:
with tf.variable_scope('apply_mask') as scope:
# Output mask is in range [-1, 1], bring to range [0, 1] first
zero_one_mask = (output_mask + 1) / 2
# Apply mask to all channels.
channels = tf.split(3, 3, output_img)
channels = [tf.mul(c, zero_one_mask) for c in channels]
output_img = tf.concat(3, channels)
However, this seems pretty inefficient, especially since, to my understanding, none of these computations are done in-place. Is there a more efficient way for doing this?
The
tf.mul()
operator supports numpy-style broadcasting, which would allow you to simplify and optimize the code slightly.Let's say that
zero_one_mask
is anm x n
tensor, andoutput_img
is ab x m x n x 3
(whereb
is the batch size - I'm inferring this from the fact that you splitoutput_img
on dimension 3)*. You can usetf.expand_dims()
to makezero_one_mask
broadcastable tochannels
, by reshaping it to be anm x n x 1
tensor:(* This would work equally if
output_img
were a 4-Db x m x n x c
(for any number of channelsc
) or 3-Dm x n x c
tensor, due to the way broadcasting works.)