I'm training an LSTM network with Tensorflow in Python and wanted to switch to tf.contrib.cudnn_rnn.CudnnLSTM for faster training. What I did is replaced
cells = tf.nn.rnn_cell.LSTMCell(self.num_hidden)
initial_state = cells.zero_state(self.batch_size, tf.float32)
rnn_outputs, _ = tf.nn.dynamic_rnn(cells, my_inputs, initial_state = initial_state)
with
lstm = tf.contrib.cudnn_rnn.CudnnLSTM(1, self.num_hidden)
rnn_outputs, _ = lstm(my_inputs)
I'm experiencing significant training speedup (more than 10x times), but at the same time my performance metric goes down. AUC on a binary classification is 0.741 when using LSTMCell and 0.705 when using CudnnLSTM. I'm wondering if I'm doing something wrong or it's the difference in implementation between those two and it's that's the case how to get my performance back while keep using CudnnLSTM.
The training dataset has 15,337 sequences of varying length (up to few hundred elements) that are padded with zeros to be the same length in each batch. All the code is the same including the TF Dataset API pipeline and all evaluation metrics. I ran each version few times and in all cases it converges around those values.
Moreover, I have few datasets that can be plugged into exactly the same model and the problem persists on all of them.
In the tensorflow code for cudnn_rnn I found a sentence saying:
Cudnn LSTM and GRU are mathematically different from their tf counterparts.
But there's no explanation what those differences really are...
It seems
tf.contrib.cudnn_rnn.CudnnLSTM
are time-major, so those should be provided with sequence of shape(seq_len, batch_size, embedding_size)
instead of(batch_size, seq_len, embedding_size)
, so you would have to transpose it (I think, can't be sure when it comes to messy Tensorflow documentation, but you may want to test that. See links below if you wish to check it).More informations on the topic here (in there is another link pointing towards math differences), except one thing seems to be wrong: not only GRU is time-major, LSTM is as well (as pointed by this issue).
I would advise against using
tf.contrib
, as it's even messier (and will be, finally, left out of Tensorflow 2.0 releases) and stick tokeras
if possible (as it will be the main front-end of the upcoming Tensorflow 2.0) ortf.nn
, as it's gonna be a part oftf.Estimator
API (though it's far less readable IMO).... or consider using PyTorch to save yourself the hassle, where input shapes (and their meaning) are provided in the documentation at the very least.