I have a problem with which I've been struggling. It is related to tf.matmul()
and its absence of broadcasting.
I am aware of a similar issue on https://github.com/tensorflow/tensorflow/issues/216, but tf.batch_matmul()
doesn't look like a solution for my case.
I need to encode my input data as a 4D tensor:
X = tf.placeholder(tf.float32, shape=(None, None, None, 100))
The first dimension is the size of a batch, the second the number of entries in the batch.
You can imagine each entry as a composition of a number of objects (third dimension). Finally, each object is described by a vector of 100 float values.
Note that I used None for the second and third dimensions because the actual sizes may change in each batch. However, for simplicity, let's shape the tensor with actual numbers:
X = tf.placeholder(tf.float32, shape=(5, 10, 4, 100))
These are the steps of my computation:
compute a function of each vector of 100 float values (e.g., linear function)
W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1))
Y = tf.matmul(X, W)
problem: no broadcasting fortf.matmul()
and no success usingtf.batch_matmul()
expected shape of Y: (5, 10, 4, 50)applying average pooling for each entry of the batch (over the objects of each entry):
Y_avg = tf.reduce_mean(Y, 2)
expected shape of Y_avg: (5, 10, 50)
I expected that tf.matmul()
would have supported broadcasting. Then I found tf.batch_matmul()
, but still it looks like doesn't apply to my case (e.g., W needs to have 3 dimensions at least, not clear why).
BTW, above I used a simple linear function (the weights of which are stored in W). But in my model I have a deep network instead. So, the more general problem I have is automatically computing a function for each slice of a tensor. This is why I expected that tf.matmul()
would have had a broadcasting behavior (if so, maybe tf.batch_matmul()
wouldn't even be necessary).
Look forward to learning from you! Alessio
As the renamed title of the GitHub issue you linked suggests, you should use
tf.tensordot()
. It enables contraction of axes pairs between two tensors, in line with Numpy'stensordot()
. For your case:You could achieve that by reshaping
X
to shape[n, d]
, whered
is the dimensionality of one single "instance" of computation (100 in your example) andn
is the number of those instances in your multi-dimensional object (5*10*4=200
in your example). After reshaping, you can usetf.matmul
and then reshape back to the desired shape. The fact that the first three dimensions can vary makes that little tricky, but you can usetf.shape
to determine the actual shapes during run time. Finally, you can perform the second step of your computation, which should be a simpletf.reduce_mean
over the respective dimension. All in all, it would look like this: