TensorFlow: loss jumps up after restoring RNN net

2019-02-20 04:01发布

问题:

Environment info

  • Operating System: Windows 7 64-bit
  • Tensorflow installed from pre-built pip (no CUDA): 1.0.1
  • Python 3.5.2 64-bit

Problem

I have problems with restoring my net (RNN character base language model). Below is a simplified version with the same problem.

When I run it the first time, I get, for example, this.

    ...
    step 160: loss = 1.956 (perplexity = 7.069016620211226)
    step 180: loss = 1.837 (perplexity = 6.274748642468816)
    step 200: loss = 1.825 (perplexity = 6.202084762557817)

But on the second run, after restoring parameters, I get this.

    step 220: loss = 2.346 (perplexity = 10.446611983898903)
    step 240: loss = 2.346 (perplexity = 10.446709120339545)
    ...

All the tf variables seem to be correctly restored, including the state, which will be fed to RNN. Data position is also restored (from 'step').

I also made a similar program for MNIST recognition model, and this one works fine: the losses before and after the restoring are continuous.

Are there any other parameters or states that should be saved and restored?

    import argparse
    import os
    import tensorflow as tf
    import numpy as np
    import math

    B = 20  # batch size
    H = 200 # size of hidden layer of neurons
    T = 25  # number of time steps to unroll the RNN for
    data_file = 'ptb.train.txt' # any plain text file will do
    checkpoint_dir = "tmp"

    #----------------
    # prepare data
    #----------------
    data = open(data_file, 'r').read()
    chars = list(set(data))
    data_size, vocab_size = len(data), len(chars)
    print('data has {0} characters, {1} unique.'.format(data_size, vocab_size))
    char_to_ix = { ch:i for i,ch in enumerate(chars) }
    ix_to_char = { i:ch for i,ch in enumerate(chars) }

    input_index_raw = np.array([char_to_ix[ch] for ch in data])
    input_index_raw = input_index_raw[0:len(input_index_raw) // T * T]
    input_index_raw_shift = np.append(input_index_raw[1:], input_index_raw[0])
    input_all = input_index_raw.reshape([-1, T])
    target_all = input_index_raw_shift.reshape([-1, T])
    num_packed_data = len(input_all)

    #----------------
    # build model
    #----------------
    class Model(object):
      def __init__(self):
        self.input_ph = tf.placeholder(tf.int32, [None, T], name="input_ph")
        self.target_ph = tf.placeholder(tf.int32, [None, T], name="target_ph")
        embedding = tf.get_variable("embedding", [vocab_size, H], initializer=tf.random_normal_initializer(), dtype=tf.float32)
        # input_ph is B x T.
        # input_embedded is B x T x H.
        input_embedded = tf.nn.embedding_lookup(embedding, self.input_ph)

        cell = tf.contrib.rnn.BasicRNNCell(H)

        self.state_ph = tf.placeholder(tf.float32, (None, cell.state_size), name="state_ph")

        # Make state variable so that it will be saved by the saver.
        self.state = tf.get_variable("state", (B, cell.state_size), initializer=tf.zeros_initializer(), trainable=False, dtype=tf.float32)

        # Construct initial_state according to whether restoring or not.
        self.isRestore = tf.placeholder(tf.bool, shape=(), name="isRestore")
        zero_state = cell.zero_state(B, dtype=tf.float32)
        self.initial_state = tf.cond(self.isRestore, lambda: self.state, lambda: zero_state)

        # input_embedded : B x T x H
        # output: B x T x H
        # state : B x cell.state_size
        output, state_ = tf.nn.dynamic_rnn(cell, input_embedded, initial_state=self.state_ph)
        self.final_state = tf.assign(self.state, state_)

        # reshape to (B * T) x H.
        output_flat = tf.reshape(output, [-1, H])

        # Convert hidden layer's output to vector of logits for each vocabulary.
        softmax_w = tf.get_variable("softmax_w", [H, vocab_size], dtype=tf.float32)
        softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=tf.float32)
        logits = tf.matmul(output_flat, softmax_w) + softmax_b

        # cross_entropy is a vector of length B * T
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.reshape(self.target_ph, [-1]), logits=logits)
        self.loss = tf.reduce_mean(cross_entropy)

        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
        self.global_step = tf.get_variable("global_step", (), initializer=tf.zeros_initializer(), trainable=False, dtype=tf.int32)
        self.training_op = optimizer.minimize(cross_entropy, global_step=self.global_step)

      def train_batch(self, sess, input_batch, target_batch, initial_state):
        final_state_, _, final_loss = sess.run([self.final_state, self.training_op, self.loss], feed_dict={self.input_ph: input_batch, self.target_ph: target_batch, self.state_ph: initial_state})
        return final_state_, final_loss

    # main
    with tf.Session() as sess:
      if not tf.gfile.Exists(checkpoint_dir):
        tf.gfile.MakeDirs(checkpoint_dir)

      batch_stride = num_packed_data // B

      # make model
      model = Model()
      saver = tf.train.Saver()

      # always initialize
      init = tf.global_variables_initializer()
      init.run()

      # restore if necessary
      isRestore = False
      ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
      if ckpt:
        isRestore = True
        last_model = ckpt.model_checkpoint_path
        print("Loading " + last_model)
        saver.restore(sess, last_model)

      # set initial step
      step = tf.train.global_step(sess, model.global_step) + 1
      print("start step = {0}".format(step))

      # fetch initial state
      state =  sess.run(model.initial_state, feed_dict={model.isRestore: isRestore})
      print("Initial state: {0}".format(state))

      while True:
        # prepare batch data
        idx = [(step + x * batch_stride) % num_packed_data for x in range(0, B)]
        input_batch = input_all[idx]
        target_batch = target_all[idx]

        state, last_loss = model.train_batch(sess, input_batch, target_batch, state)

        if step % 20 == 0:
          print('step {0}: loss = {1:.3f} (perplexity = {2})'.format(step, last_loss, math.exp(last_loss)))

        if step % 200 == 0:
          saved_file = saver.save(sess, os.path.join(checkpoint_dir, "model.ckpt"), global_step=step)
          print("Saved to " + saved_file)
          print("Last state: {0}".format(model.state.eval()))
          break;

        step = step + 1

回答1:

The problem is solved. It had nothing to do with RNN nor TensorFlow.

I changed

chars = list(set(data))

to

chars = sorted(set(data))

and now it works.

This is because python uses a random hash function to build the set, and every time python restarted, 'chars' had a different ordering.