Tensorflow Grid LSTM RNN TypeError

2019-04-12 22:00发布

问题:

I'm trying to build a LSTM RNN that handles 3D data in Tensorflow. From this paper, Grid LSTM RNN's can be n-dimensional. The idea for my network is a have a 3D volume [depth, x, y] and the network should be [depth, x, y, n_hidden] where n_hidden is the number of LSTM cell recursive calls. The idea is that each pixel gets its own "string" of LSTM recursive calls.

The output should be [depth, x, y, n_classes]. I'm doing a binary segmentation -- think foreground and background, so the number of classes is just 2.

# Network Parameters
n_depth = 5
n_input_x = 200 # MNIST data input (img shape: 28*28)
n_input_y = 200
n_hidden = 128 # hidden layer num of features
n_classes = 2

# tf Graph input
x = tf.placeholder("float", [None, n_depth, n_input_x, n_input_y])
y = tf.placeholder("float", [None, n_depth, n_input_x, n_input_y, n_classes])

# Define weights
weights = {}
biases = {}

# Initialize weights
for i in xrange(n_depth * n_input_x * n_input_y):
    weights[i] = tf.Variable(tf.random_normal([n_hidden, n_classes]))
    biases[i] = tf.Variable(tf.random_normal([n_classes]))

def RNN(x, weights, biases):

    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, n_input_y, n_input_x)
    # Permuting batch_size and n_input_y
    x = tf.reshape(x, [-1, n_input_y, n_depth * n_input_x])
    x = tf.transpose(x, [1, 0, 2])
    # Reshaping to (n_input_y*batch_size, n_input_x)

    x =  tf.reshape(x, [-1, n_input_x * n_depth])

    # Split to get a list of 'n_input_y' tensors of shape (batch_size, n_hidden)
    # This input shape is required by `rnn` function
    x = tf.split(0, n_depth * n_input_x * n_input_y, x)

    # Define a lstm cell with tensorflow
    lstm_cell = grid_rnn_cell.GridRNNCell(n_hidden, input_dims=[n_depth, n_input_x, n_input_y])
    # lstm_cell = rnn_cell.MultiRNNCell([lstm_cell] * 12, state_is_tuple=True)
    # lstm_cell = rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=0.8)
    outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)
    # Linear activation, using rnn inner loop last output
    # pdb.set_trace()

    output = []
    for i in xrange(n_depth * n_input_x * n_input_y):
        #I'll need to do some sort of reshape here on outputs[i]
        output.append(tf.matmul(outputs[i], weights[i]) + biases[i])

    return output


pred = RNN(x, weights, biases)
pred = tf.transpose(tf.pack(pred),[1,0,2])
pred = tf.reshape(pred, [-1, n_depth, n_input_x, n_input_y, n_classes])
# pdb.set_trace()
temp_pred = tf.reshape(pred, [-1, n_classes])
n_input_y = tf.reshape(y, [-1, n_classes])

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(temp_pred, n_input_y))

Currently I'm getting the error: TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'

It occurs after the RNN intialization: outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

x of course is of type float32

I am unable to tell what type GridRNNCell returns, any helpe here? This could be the issue. Should I be defining more arguments to this? input_dims makes sense, but what should output_dims be?

Is this a bug in the contrib code?

GridRNNCell is located in contrib/grid_rnn/python/ops/grid_rnn_cell.py

回答1:

which version of Grid LSTM cells are you using?

If you are using https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/ops/rnn_cell.py

I think you can try to initialize 'feature_size' and 'frequency_skip'. Also, I think there may exists another bug. Feed a dynamic shape into this version may cause a TypeError



回答2:

I was unsure on some of the implementation decisions of the code, so I decided to roll my own. One thing to keep in mind is that this is an implementation of just the cell. It is up to you to build the actual machinery that handles the locations and interactions of the h and m vectors and isn't as simple as passing in your data and expecting it to traverse the dimensions properly.

So for example, if you are working in two dimensions, start with the top left block, take the incoming x and y vectors, concat them together, then use your cell to compute the output (which includes outgoing vectors for both x and y); and it is up to you to store the output for later use in neighboring blocks. Pass those outputs individually to each corresponding dimension, and in each of those neighboring blocks, concat the incoming vectors (again, for each dimension) and compute the output for the neighboring blocks. To do this, you'll need two for-loops, one for each dimension.

Perhaps the version in contrib will work for this, but a couple problems I have with it (I could be wrong here, but as far as I can tell): 1) The vectors are handled using concat and slice rather than with tuples. This will likely result in slower performance. 2) It looks like the input is projected at each step, which doesn't sit well with me. In the paper they only project into the network for incoming blocks along the edge of the grid and not throughout.

If you look at the code, it is actually very simple. Perhaps reading the paper and making adjustments to the code as needed, or rolling your own are your best bet. And remember that the cell is only good for performing the recurrence at each step, and not for managing the incoming and outgoing h and m vectors.



回答3:

Yes, dynamic shape was the cause. There is a PR to fix this: https://github.com/tensorflow/tensorflow/pull/4631

@jstaker7: Thank you for trying it out. Re. problem 1, the above PR uses tuples for states and outputs, hopefully it can address the performance issue. GridRNNCell was created some while ago, at that time all the LSTMCells in Tensorflow was using concat/slice instead of tuple.

Re. problem 2, GridRNNCell will not project the input if you pass None. A dimension can be both input and recurrent, and when there is no input (inputs = None), it will use the recurrent tensors for computation. We can also use 2 input dimensions, by instantiate the GridRNNCell directly.

Of course writing a generic class for all cases makes the code looks a bit convoluted, and I think that it needs better documentation.

Anyway, it will be great if you could share your improvements, or any idea you might have to make it clearer/more useful. It is the nature of an open-source project anyway.