Tensorflow Train on incomplete batch

2019-06-09 15:59发布

问题:

I'm trying to do training with batches in tensorflow. This works a little since I can do the first epoch in batches. I currently have 2 problems with my code.
1. After the first epoch has finished the second epoch immediatly goes to the except tf.errors.OutOfRangeError and the next epoch doesn't restart the batch from the top. How can I do another epoch where it gives batches again?
2. I print the batchnr and I notice that the last batch of the epoch prints print(batchnr) but doesn't print print(End batchnr) and goes to the except and does not get trained. This is because the amount of rows left in the queue is less than the size of the batch size I guess. How can I still train that last part batch?

My train method and pipeline method

def input_pipeline(file, batch_size, num_epochs=None):
  filename_queue = tf.train.string_input_producer([file], num_epochs=num_epochs, shuffle=True)
  example, label = read_from_csv(filename_queue)
  min_after_dequeue = 10000
  capacity = min_after_dequeue + 3 * 2
  example_batch, label_batch = tf.train.shuffle_batch(
      [example, label], batch_size=batch_size, capacity=capacity,
      min_after_dequeue=min_after_dequeue)
  return example_batch, label_batch

def train():
    examples, labels = input_pipeline(training_data_file, batch_size, 1)
    saver = tf.train.Saver()
    prediction = neural_network_model(p_inputdata)
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=p_known_labels))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)

    init = tf.group(tf.initialize_all_variables(),
                    tf.initialize_local_variables())
    with tf.Session() as sess:
        sess.run(init) # initialize all variables **in** the session

        correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(p_known_labels, 1))
        accuracy = tf.reduce_mean(tf.cast(correct, 'float'))

        latest_cost_of_batch = None
        for e in range(epochs):
            epoch = e + 1
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            try:
                batchnr = 1
                while not coord.should_stop():
                    print(batchnr)
                    batch_data, batch_labels = sess.run([examples, labels])
                    batch_labels_output = get_output_values(batch_labels)
                    print("End", batchnr)
                    batchnr += 1

                    _, latest_cost_of_batch = sess.run([optimizer,cost], feed_dict={
                        p_inputdata: batch_data,
                        p_known_labels: batch_labels_output
                    })

            except tf.errors.OutOfRangeError:
                print('Done training, epoch reached')
                if (epoch) % print_each_x_number_of_epochs == 0 or epoch == 0:
                    print('Epoch', epoch, 'completed out of', epochs, "---", 'Loss', latest_cost_of_batch)
                if epoch % save_each_x_number_of_epochs == 0:
                    saver.save(sess, checkpoint_label)
            finally:
                coord.request_stop()
        coord.join(threads)

        print("Trained for ", epochs,"epochs. Saving variables...")
        saver.save(sess, checkpoint_label)
        print("Variables saved. Training finished.")
    end = time.time()
    seconds = end - start
    print("Total runtime:", str(datetime.timedelta(seconds=seconds)))

Debug Console

Start training
1
End 1
2
End 2
....
213
End 213
214
Done training, epoch reached
Epoch 1 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 2 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 3 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 4 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 5 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 6 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 7 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 8 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 9 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 10 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 11 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 12 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 13 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 14 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 15 completed out of 15 --- Loss 4.43414
Trained for  15 epochs. Saving variables...
Variables saved. Training finished.
Accuracy 0.935310311615 % after 15 epochs of training
Total runtime: 0:00:21.395917

EDIT
I changed the code based on the answer by Nicolas( I went with the multiple epochs in the string_input_producer). Now I have for training the following code:

def train():
    """Trains the neural network  
    """
    examples, labels = input_pipeline(training_data_file, batch_size, epochs)
    start = time.time()
    saver = tf.train.Saver()
    prediction = neural_network_model(p_inputdata)
    first_no_loss = True
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=p_known_labels))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)

    init = tf.group(tf.initialize_all_variables(),
                    tf.initialize_local_variables())
    with tf.Session() as sess:
        sess.run(init) # initialize all variables **in** the session
        correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(p_known_labels, 1))
        accuracy = tf.reduce_mean(tf.cast(correct, 'float'))

        print("Start training")
        latest_cost_of_batch = None

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        epoch_op = "input_producer/limit_epochs/epochs:0"
        try:
            batchnr = 1
            epochs_var = 0
            while not coord.should_stop():
                if (batchnr) % print_each_x_number_of_batches == 0:
                    print('Batch', batchnr, 'completed of epoch', epochs_var, "---", 'Loss', latest_cost_of_batch)

                if  batchnr > 3194:
                    print("GETTING BATCH", batchnr)
                epochs_var, batch_data, batch_labels = sess.run([epoch_op, examples, labels])
                batch_labels_output = get_output_values(batch_labels)
                if  batchnr > 3194:
                    print("GOT BATCH", batchnr)
                batchnr += 1
                _, latest_cost_of_batch = sess.run([optimizer,cost], feed_dict={
                    p_inputdata: batch_data,
                    p_known_labels: batch_labels_output
                })

        except tf.errors.OutOfRangeError:
            print('Done training, epoch reached')
        finally:
            coord.request_stop()

        coord.join(threads)

        print("Trained for ", epochs,"epochs. Saving variables...")
        saver.save(sess, checkpoint_label)
        print("Variables saved. Training finished.")
        labels, values, output = get_training_or_testdata(training_data_file)
        print('Accuracy', accuracy.eval(feed_dict={p_inputdata: values, p_known_labels: output}) * 100, '% after', epochs, 'epochs of training')
    end = time.time()
    seconds = end - start
    print("Total runtime:", str(datetime.timedelta(seconds=seconds)))

And my output is like this

Start training
Batch 100 completed of epoch 15 --- Loss 4.79351
Batch 200 completed of epoch 15 --- Loss 4.57468
Batch 300 completed of epoch 15 --- Loss 4.51134
Batch 400 completed of epoch 15 --- Loss 4.65865
Batch 500 completed of epoch 15 --- Loss 4.55456
Batch 600 completed of epoch 15 --- Loss 4.63549
Batch 700 completed of epoch 15 --- Loss 4.53037
Batch 800 completed of epoch 15 --- Loss 4.49263
Batch 900 completed of epoch 15 --- Loss 4.37
Batch 1000 completed of epoch 15 --- Loss 4.42719
Batch 1100 completed of epoch 15 --- Loss 4.4518
Batch 1200 completed of epoch 15 --- Loss 4.41053
Batch 1300 completed of epoch 15 --- Loss 4.43508
Batch 1400 completed of epoch 15 --- Loss 4.32173
Batch 1500 completed of epoch 15 --- Loss 4.36624
Batch 1600 completed of epoch 15 --- Loss 4.44027
Batch 1700 completed of epoch 15 --- Loss 4.37201
Batch 1800 completed of epoch 15 --- Loss 4.24956
Batch 1900 completed of epoch 15 --- Loss 4.40256
Batch 2000 completed of epoch 15 --- Loss 4.18391
Batch 2100 completed of epoch 15 --- Loss 4.30156
Batch 2200 completed of epoch 15 --- Loss 4.38423
Batch 2300 completed of epoch 15 --- Loss 4.23823
Batch 2400 completed of epoch 15 --- Loss 4.17783
Batch 2500 completed of epoch 15 --- Loss 4.31024
Batch 2600 completed of epoch 15 --- Loss 4.26312
Batch 2700 completed of epoch 15 --- Loss 4.26143
Batch 2800 completed of epoch 15 --- Loss 4.16691
Batch 2900 completed of epoch 15 --- Loss 4.48624
Batch 3000 completed of epoch 15 --- Loss 4.1347
Batch 3100 completed of epoch 15 --- Loss 4.20801
GETTING BATCH 3195
GOT BATCH 3195
GETTING BATCH 3196
GOT BATCH 3196
GETTING BATCH 3197
Done training, epoch reached
Trained for  15 epochs. Saving variables...
Variables saved. Training finished.
Accuracy 2.69019026309 % after 15 epochs of training
Total runtime: 0:03:07.577149

The things that I noticed is that still the last batch doesn't get trained(GOT BATCH 3197 doesn't get printed) and second that the way to get the current epoch isn't correct. It is always 15. Another SO question answer explained why the way I do it now is not the way to go but it doesn't explain a proper way to get the current epoch. Any clues?

回答1:


EDIT: you might want to have a look at this answer at it gives an example of the new API.

Here is an explanation of what you got.

  • The first time you go through the for e in range(epochs) loop, it dequeues everything from your data queue (until the data queue throws tf.errors.OutOfRangeError).

    This error is thrown when there is no more filenames in the filename queue. Which happens after reading the file only once, and this because you called examples, labels = input_pipeline(training_data_file, batch_size, 1).

    If, for example, you had called examples, labels = input_pipeline(training_data_file, batch_size, 3), you would have gone 3 times though the files before moving to e=1.

  • Then when you move to e>0, the filename queue has kept in memory that you already dequeued all the file names and as there is no more enqueue operation it throws the tf.errors.OutOfRangeError directly.

    See the string doc:

    Note: if num_epochs is not None, this function creates local counter epochs. Use local_variables_initializer() to initialize local variables.

What can you do?

  1. You move the session context manager in the for e in range(epochs) loop:

    init_queue = tf.variables_initializer(tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope='input_producer'))`
    with tf.Session() as sess:
        sess.run(init)
    for e in range(EPOCHS):
        with tf.Session() as sess:
            sess.run(init_queue) # initialize all local variables **in** the the input_producer scope
            epoch = e + 1
    

    It would mean that you reinitialize all your local variables in the input_producer scope, so you would need to be careful about what they are. You could also save your model and load it again at each step, or

  2. You rely on the num_epochs argument to run the right number of epochs and remove your for e in range(EPOCHS) loop. Instead of printing information at the end of each epoch you could print information every 100 or 1000 training steps (my favourite solution). If you really want to print information at the end of each epoch, you could try to access the hidden epochs variable, eval its value and print the information whenever there is an 'epochs' change (I wouldn't recommend this option).

For example:

    batchnr = 0
    tmp_batchnr = 0
    while not coord.should_stop():
            if batchnr != tmp_batchnr:
                print(....)
                batchnr = tmp_batchnr
            epochs_var, _, _ = sess.run([epochs_var, examples, labels])
            print("End", batchnr)
            batchnr += 1

Hope it helps!

REMARKS ON THE EDITED QUESTION:

Looking at the emphasised in this quote from the answer you referred to, it looks to me that you have no way of knowing from which epoch the dequeue belongs to.

When tf.start_queue_runners() is executed, all the epochs are enqueued together (in multiple stages if capacity is less than number of filenames). The local variable epochs:0 is used by tf.train.string_input_producer to maintain the epoch that is being enqueued. Once epochs:0 reaches num_epochs, it remains constant and no matter how many threads are dequeuing from the queue, it does not change.