What is supported by broadcasting in tensorflow? H

2019-07-19 02:22发布

问题:

I raised an issue in github at: https://github.com/tensorflow/tensorflow/issues/14924. Here is the details.

This is OK:

import tensorflow as tf
sess = tf.InteractiveSession()
xx = tf.constant(1, shape=[32,1,4,4,1], dtype=tf.float32)
yy = tf.constant(1, shape=[1,32,1,4,4], dtype=tf.float32)
zz = xx * yy
sess.run([zz])

However:

x2 = tf.constant(1, shape=[10,32,1,4,4,1])
y2 = tf.constant(1, shape=[10,1,32,1,4,4])
z2 = x2 * y2
sess.run(z2)

Gives an error:

UnimplementedError (see above for traceback): Broadcast between [10,32,1,4,4,1] and [10,1,32,1,4,4] is not supported yet. [[Node: mul_1 = Mul[T=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](Const_2, Const_3)]]

Log:

---------------------------------------------------------------------------
UnimplementedError                        Traceback (most recent call last)
<ipython-input-2-eef82717f8d8> in <module>()
      2 y2 = tf.constant(1, shape=[10,1,32,1,4,4])
      3 z2 = x2 * y2
----> 4 sess.run(z2)

/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    887     try:
    888       result = self._run(None, fetches, feed_dict, options_ptr,
--> 889                          run_metadata_ptr)
    890       if run_metadata:
    891         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1118     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1119       results = self._do_run(handle, final_targets, final_fetches,
-> 1120                              feed_dict_tensor, options, run_metadata)
   1121     else:
   1122       results = []

/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1315     if handle is None:
   1316       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1317                            options, run_metadata)
   1318     else:
   1319       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)
   1334         except KeyError:
   1335           pass
-> 1336       raise type(e)(node_def, op, message)
   1337 
   1338   def _extend_graph(self):

UnimplementedError: Broadcast between [10,32,1,4,4,1] and [10,1,32,1,4,4] is not supported yet.
     [[Node: mul_1 = Mul[T=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](Const_2, Const_3)]]

Caused by op u'mul_1', defined at:
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/runpy.py", line 174, in _run_module_as_main
    "__main__", fname, loader, pkg_name)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/runpy.py", line 72, in _run_code
    exec code in run_globals
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/__main__.py", line 3, in <module>
    app.launch_new_instance()
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/kernelapp.py", line 474, in start
    ioloop.IOLoop.instance().start()
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tornado/ioloop.py", line 887, in start
    handler_func(fd_obj, events)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 276, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 228, in dispatch_shell
    handler(stream, idents, msg)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 390, in execute_request
    user_expressions, allow_stdin)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/zmqshell.py", line 501, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2717, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2821, in run_ast_nodes
    if self.run_code(code, result):
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2881, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-eef82717f8d8>", line 3, in <module>
    z2 = x2 * y2
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/ops/math_ops.py", line 894, in binary_op_wrapper
    return func(x, y, name=name)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/ops/math_ops.py", line 1117, in _mul_dispatch
    return gen_math_ops._mul(x, y, name=name)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 2726, in _mul
    "Mul", x=x, y=y, name=name)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
  File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

UnimplementedError (see above for traceback): Broadcast between [10,32,1,4,4,1] and [10,1,32,1,4,4] is not supported yet.
     [[Node: mul_1 = Mul[T=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](Const_2, Const_3)]]

An update:

I assume the reason is related to how the dimensions are matching, instead of the total number of dimensions, or the number of mis-match. Because the following script runs OK, where x3 has the 2nd to last dimension changes from 4 to 1, adding one more places of mismatch.

x3 = tf.constant(1, shape=[10,32,1,4,1,1])
y3 = tf.constant(1, shape=[10,1,32,1,4,4])
z3 = x3 * y3
sess.run(z3)

回答1:

As you may have already observed, at the moment Tensorflow has restricted the number of dimensions mismatch which it will correct to broadcast.

For that purpose, I have written my own broadcasting function which will broadcast the variable number of tensors to one common shape. However note that this function will not work if the shape of the tensor is not defined or contains None in its shape.

def broadcast_tensors(tensors):
    shapes = [t.get_shape().as_list() for t in tensors]
    max_rank = max([len(s) for s in shapes])
    # Rank equalize all the tensors
    for index in range(len(shapes)):
        shape = shapes[index]
        if len(shape) == max_rank:
            continue

        tensor = tensors[index]
        for _ in range(max_rank - len(shape)):
            shape.insert(0, 1)
            tensor = tf.expand_dims(tensor, axis = 0)
        tensors[index] = tensor

    # Ensure if broadcasting is possible
    from collections import Counter
    broadcast_shape = []
    for index in range(max_rank):
        dimensions = [s[index] for s in shapes]
        repeats = Counter(dimensions)
        if len(repeats) > 2 or (len(repeats) == 2 and \
                          1 not in list(repeats.keys())):
            raise Exception("Broadcasting not possible")
        broadcast_shape.append(max(repeats.keys()))

    # Broadcast the tensors
    for axis, dimension in enumerate(broadcast_shape):
        tensors = [tf.concat([t] * dimension, axis = axis) \
                    if t.get_shape()[axis] == 1 else t for t in tensors]

    return tensors

Output:

x = tf.constant(1, shape = [10, 32, 1, 4, 4, 1])
y = tf.constant(1, shape =     [1, 32, 1, 4, 1])
z = tf.constant(1, shape =        [32, 4, 1, 1])
x, y, z = broadcast_tensors([x, y, z])
print(x.get_shape(), y.get_shape(), z.get_shape())
# (10, 32, 32, 4, 4, 1) (10, 32, 32, 4, 4, 1) (10, 32, 32, 4, 4, 1)

x = tf.constant(1, shape = [10, 32, 1, 4, 4, 1])
y = tf.constant(1, shape =     [1, 32, 3, 4, 2])
z = tf.constant(1, shape =        [32, 3, 1, 3])
x, y, z = broadcast_tensors([x, y, z])
# Exception: Broadcasting not possible