How to get misclassified files in TF-Slim's ev

2019-06-05 04:16发布

问题:

I'm using a script that comes with TF-Slim to validate my trained model. It works fine but I'd like to get a list of the misclassified files.

The script makes use of https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/evaluation.py but even there I cannot find any options for printing the misclassified files.

How can I achieve that?

回答1:

At a high level, you need to do 3 things:

1) Get your filename from the data loader. If you are using a tf-slim dataset from tfrecords, it is likely that the filenames are not stored in the tfrecord so you may be out of luck there. However if you are consuming image files directly from the filesystem with tf.WholeFileReader, then you can get the tensor of filenames where you form your batch:

def load_data():
    train_image_names = ... # list of filenames
    filename_queue = tf.train.string_input_producer(train_image_names)
    reader = tf.WholeFileReader()
    image_filename, image_file = reader.read(filename_queue)
    image = tf.image.decode_jpeg(image_file, channels=3)

    .... # load your labels from somewhere

    return image_filename, image, label


 # in your eval code
 image_fn, image, label = load_data()

 filenames, images, labels = tf.train.batch(
                                [image_fn, image, label],
                                batch_size=32,
                                num_threads=2,
                                capacity=100,
                                allow_smaller_final_batch=True)

2) Mask your filename tensor with your result after doing inference:

logits = my_network(images)
preds = tf.argmax(logits, 1)
mislabeled = tf.not_equal(preds, labels)
mislabeled_filenames = tf.boolean_mask(filenames, mislabeled)

3) Put all this into your eval_op:

eval_op = tf.Print(eval_op, [mislabeled_filenames])

slim.evaluation.evaluate_once(
                        .... # other options 
                        eval_op=eval_op,
                        .... # other options)

I don't have a setup to test this, unfortunately. Let me know if it works!



回答2:

shadow chris pointed me in the right direction so I share my solution to make it work with a TF-records dataset.

For better unstanding I relate my code to the flower example of TF-Slim.

1) Modify your dataset script to store a filename feature in the TF-records.

  keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
      'image/filename': tf.FixedLenFeature((), tf.string, default_value=''),
  }

  items_to_handlers = {
      'image': slim.tfexample_decoder.Image(),
      'label': slim.tfexample_decoder.Tensor('image/class/label'),
      'filename': slim.tfexample_decoder.Tensor('image/filename'),
  }

2) Add filename parameter to data util's image_to_tfexample function

It should then look like:

def image_to_tfexample(image_data, image_format, height, width, class_id, filename):
 return tf.train.Example(features=tf.train.Features(feature={
      'image/encoded': bytes_feature(image_data),
      'image/format': bytes_feature(image_format),
      'image/class/label': int64_feature(class_id),
      'image/height': int64_feature(height),
      'image/width': int64_feature(width),
      'image/filename': bytes_feature(filename)
  }))

3) Modify download and convert script to save filenames

Feed your TF record with the filename.

    example = dataset_utils.image_to_tfexample(
        image_data, 'jpg', height, width, class_id, filenames[i])

4) In your evaluation map misclassified imgs to filename

I'm refering to eval_image_classifier.py.

Retrieve filenames with tf.train.batch:

images, labels, filenames = tf.train.batch(
    [image, label, filename],
    batch_size=FLAGS.batch_size,
    num_threads=FLAGS.num_preprocessing_threads,
    capacity=5 * FLAGS.batch_size)

Get misclassified imgs and map them to filenames:

predictions = tf.argmax(logits, 1)
labels = tf.squeeze(labels)
mislabeled = tf.not_equal(predictions, labels)
mislabeled_filenames = tf.boolean_mask(filenames, mislabeled)

Print:

eval_op = tf.Print(eval_op, [mislabeled_filenames])

slim.evaluation.evaluate_once(
                        .... # other options 
                        eval_op=eval_op,
                        .... # other options)