I raised an issue in github at: https://github.com/tensorflow/tensorflow/issues/14924. Here is the details.
This is OK:
import tensorflow as tf
sess = tf.InteractiveSession()
xx = tf.constant(1, shape=[32,1,4,4,1], dtype=tf.float32)
yy = tf.constant(1, shape=[1,32,1,4,4], dtype=tf.float32)
zz = xx * yy
sess.run([zz])
However:
x2 = tf.constant(1, shape=[10,32,1,4,4,1])
y2 = tf.constant(1, shape=[10,1,32,1,4,4])
z2 = x2 * y2
sess.run(z2)
Gives an error:
UnimplementedError (see above for traceback): Broadcast between [10,32,1,4,4,1] and [10,1,32,1,4,4] is not supported yet. [[Node: mul_1 = Mul[T=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](Const_2, Const_3)]]
Log:
---------------------------------------------------------------------------
UnimplementedError Traceback (most recent call last)
<ipython-input-2-eef82717f8d8> in <module>()
2 y2 = tf.constant(1, shape=[10,1,32,1,4,4])
3 z2 = x2 * y2
----> 4 sess.run(z2)
/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
887 try:
888 result = self._run(None, fetches, feed_dict, options_ptr,
--> 889 run_metadata_ptr)
890 if run_metadata:
891 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
1118 if final_fetches or final_targets or (handle and feed_dict_tensor):
1119 results = self._do_run(handle, final_targets, final_fetches,
-> 1120 feed_dict_tensor, options, run_metadata)
1121 else:
1122 results = []
/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1315 if handle is None:
1316 return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1317 options, run_metadata)
1318 else:
1319 return self._do_call(_prun_fn, self._session, handle, feeds, fetches)
/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)
1334 except KeyError:
1335 pass
-> 1336 raise type(e)(node_def, op, message)
1337
1338 def _extend_graph(self):
UnimplementedError: Broadcast between [10,32,1,4,4,1] and [10,1,32,1,4,4] is not supported yet.
[[Node: mul_1 = Mul[T=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](Const_2, Const_3)]]
Caused by op u'mul_1', defined at:
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/runpy.py", line 174, in _run_module_as_main
"__main__", fname, loader, pkg_name)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/runpy.py", line 72, in _run_code
exec code in run_globals
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/__main__.py", line 3, in <module>
app.launch_new_instance()
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/traitlets/config/application.py", line 658, in launch_instance
app.start()
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/kernelapp.py", line 474, in start
ioloop.IOLoop.instance().start()
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/zmq/eventloop/ioloop.py", line 177, in start
super(ZMQIOLoop, self).start()
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tornado/ioloop.py", line 887, in start
handler_func(fd_obj, events)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tornado/stack_context.py", line 275, in null_wrapper
return fn(*args, **kwargs)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
self._handle_recv()
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
self._run_callback(callback, msg)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
callback(*args, **kwargs)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tornado/stack_context.py", line 275, in null_wrapper
return fn(*args, **kwargs)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 276, in dispatcher
return self.dispatch_shell(stream, msg)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 228, in dispatch_shell
handler(stream, idents, msg)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 390, in execute_request
user_expressions, allow_stdin)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/ipkernel.py", line 196, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/ipykernel/zmqshell.py", line 501, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2717, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2821, in run_ast_nodes
if self.run_code(code, result):
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2881, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-eef82717f8d8>", line 3, in <module>
z2 = x2 * y2
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/ops/math_ops.py", line 894, in binary_op_wrapper
return func(x, y, name=name)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/ops/math_ops.py", line 1117, in _mul_dispatch
return gen_math_ops._mul(x, y, name=name)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 2726, in _mul
"Mul", x=x, y=y, name=name)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
op_def=op_def)
File "/home/jetadmin/anaconda2/envs/ygtf/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
UnimplementedError (see above for traceback): Broadcast between [10,32,1,4,4,1] and [10,1,32,1,4,4] is not supported yet.
[[Node: mul_1 = Mul[T=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](Const_2, Const_3)]]
An update:
I assume the reason is related to how the dimensions are matching, instead of the total number of dimensions, or the number of mis-match. Because the following script runs OK, where x3 has the 2nd to last dimension changes from 4 to 1, adding one more places of mismatch.
x3 = tf.constant(1, shape=[10,32,1,4,1,1])
y3 = tf.constant(1, shape=[10,1,32,1,4,4])
z3 = x3 * y3
sess.run(z3)
As you may have already observed, at the moment Tensorflow has restricted the number of dimensions mismatch which it will correct to broadcast.
For that purpose, I have written my own broadcasting function which will broadcast the variable number of tensors to one common shape. However note that this function will not work if the shape of the tensor is not defined or contains
None
in its shape.Output: