Why CIFAR-10 images are not displayed properly usi

2019-03-30 01:10发布

问题:

From the training set I took a image('img') of size (3,32,32). I have used plt.imshow(img.T). The image is not clear. Now changes I have to make to image('img') to make it more clearly visible. Thanks.

回答1:

Following prints 5X5 grid of random Cifar10 images. It isn't blurry, though not perfect either. Any suggestions welcome.

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from six.moves import cPickle 

f = open('data/cifar10/cifar-10-batches-py/data_batch_1', 'rb')
datadict = cPickle.load(f,encoding='latin1')
f.close()
X = datadict["data"] 
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("uint8")
Y = np.array(Y)

#Visualizing CIFAR 10
fig, axes1 = plt.subplots(5,5,figsize=(3,3))
for j in range(5):
    for k in range(5):
        i = np.random.choice(range(len(X)))
        axes1[j][k].set_axis_off()
        axes1[j][k].imshow(X[i:i+1][0])


回答2:

The image is blurry due to interpolation. To prevent blurring in matplotlib, call imshow with keyword interpolation='nearest':

plt.imshow(img.T, interpolation='nearest')

Also, it appears that your x and y axes are being swapped when you use the transpose so you may want to display like this instead:

plt.imshow(np.transpose(img, (1, 2, 0)), interpolation='nearest')


回答3:

I have used the following code to show all CIFAR data as one big image. The code show the image, but if you want to save it and not be blurtry i sugest using plt.savefig(fname, format='png', dpi=1000)

import numpy as np
import matplotlib.pyplot as plt

def reshape_and_print(self, cifar_data):
    # number of images in rows and columns
    rows = cols = np.sqrt(cifar_data.shape[0]).astype(np.int32)
    # Image hight and width. Divide by 3 because of 3 color channels
    imh = imw = np.sqrt(cifar_data.shape[1] // 3).astype(np.int32)
    # reshape to number of images X color channels X image size
    # transpose to color channels X number of images X image size
    timg = cifar_data.reshape(rows * cols, 3, imh * imh).transpose(1, 0, 2)
    # reshape to color channels X rows X cols X image hight X image with
    # swap axis to color channels X rows X image hight X cols X image with
    timg = timg.reshape(3, rows, cols, imh, imw).swapaxes(2, 3)
    # reshape to color channels X combined image hight X combined image with
    # transpose to combined image hight X combined image with X color channels
    timg = timg.reshape(3, rows * imh, cols * imw).transpose(1, 2, 0)

    plt.imshow(timg)
    plt.show()

I made a quick data helper class that i used for a small test project, I hope is can be useful:

import gzip
import pickle
import numpy as np
import matplotlib.pyplot as plt


class DataSet(object):

    def __init__(self, seed=42, setsize=10000):
        self.seed = seed
        # set the seed for reproducability
        np.random.seed(seed)
        # load the data
        train_set, test_set = self.load_data()
        # self.split_data(train_set, valid_set, test_set)
        self.split_data(train_set, test_set, setsize)

    def split_data(self, data_set, test_set, split_size):
        permutation = np.random.permutation(data_set.shape[0])
        self.train = data_set[permutation[:split_size]]
        self.valid = data_set[permutation[split_size:split_size * 2]]
        self.test = test_set[:split_size]

    def reshape_for_print(self, data):
        raise NotImplemented

    def load_data(self):
        raise NotImplemented

    def show_all_imgs(self, data):
        raise NotImplemented


class CIFAR(DataSet):

    def load_data(self):
        # try to load data
        with open('./data/cifar-100-python/train', 'rb') as f:
            data = pickle.load(f, encoding='latin1')
        train_set = data['data'].astype(np.float32) / 255.0

        with open('./data/cifar-100-python/test', 'rb') as f:
            data = pickle.load(f, encoding='latin1')
        test_set = data['data'].astype(np.float32) / 255.0

        return train_set, test_set

    def reshape_for_print(self, data):
        gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
        imh = imw = np.sqrt(data.shape[1] // 3).astype(np.int32)
        timg = data.reshape(gh * gw, 3, imh * imh).transpose(1, 0, 2)
        timg = timg.reshape(3, gh, gw, imh, imw).swapaxes(2, 3)
        timg = timg.reshape(3, gh * imh, gw * imw).transpose(1, 2, 0)
        return timg

    def show_all_imgs(self, data):
        timg = self.reshape_for_print(data)
        plt.imshow(timg)
        plt.show()


class MNIST(DataSet):

    def load_data(self):
        # try to load data
        with gzip.open('./data/mnist.pkl.gz', 'rb') as f:
            train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
        return train_set[0], test_set[0]

    def reshape_for_print(self, data):
        gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
        imh = imw = np.sqrt(data.shape[1]).astype(np.int32)
        timg = data.reshape(gh, gw, imh, imw).swapaxes(1, 2)
        timg = timg.reshape(gh * imh, gw * imw)
        return timg

    def show_all_imgs(self, data):
        timg = self.reshape_for_print(data)
        plt.imshow(timg, cmap=plt.cm.gray)
        plt.show()


回答4:

try using

import matplotlib.pyplot as plt
from scipy.misc import toimage
plt.imshow(toimage(img))

I am not 100% sure of how the code works, but I think that because the images are stored in floating point numpy arrays, the imshow() function has a difficult time mapping them to the right colors. By typecasting them to image using toimage() you convert them into proper image format that imshow() expects, i.e not an array but an image encoded as .png or .jpg.

This code works for me every time I want to display images in python.



回答5:

I made a function to plot the RGB image from a row in the CIFAR10 dataset.The image will be blurry at best since the original size of the image is very small (32px X 32px).

def unpickle(file):
    with open(file, 'rb') as fo:
        dict1 = pickle.load(fo, encoding='bytes')
    return dict1

pd_tr = pd.DataFrame()
tr_y = pd.DataFrame()

for i in range(1,6):
    data = unpickle('data/data_batch_' + str(i))
    pd_tr = pd_tr.append(pd.DataFrame(data[b'data']))
    tr_y = tr_y.append(pd.DataFrame(data[b'labels']))
    pd_tr['labels'] = tr_y

tr_x = np.asarray(pd_tr.iloc[:, :3072])
tr_y = np.asarray(pd_tr['labels'])
ts_x = np.asarray(unpickle('data/test_batch')[b'data'])
ts_y = np.asarray(unpickle('data/test_batch')[b'labels'])    
labels = unpickle('data/batches.meta')[b'label_names']

def plot_CIFAR(ind):
    arr = tr_x[ind]
    sc_dpi = 157.35
    R = arr[0:1024].reshape(32,32)/255.0
    G = arr[1024:2048].reshape(32,32)/255.0
    B = arr[2048:].reshape(32,32)/255.0

    img = np.dstack((R,G,B))
    title = re.sub('[!@#$b]', '', str(labels[tr_y[ind]]))
    fig = plt.figure(figsize=(3,3))
    ax = fig.add_subplot(111)
    ax.imshow(img,interpolation='bicubic')
    ax.set_title('Category = '+ title,fontsize =15)

plot_CIFAR(4)


回答6:

This file reads the cifar10 dataset and plots individual images using matplotlib.

import _pickle as pickle
import argparse
import numpy as np
import os
import matplotlib.pyplot as plt

cifar10 = "./cifar-10-batches-py/"

parser = argparse.ArgumentParser("Plot training images in cifar10 dataset")
parser.add_argument("-i", "--image", type=int, default=0, 
                    help="Index of the image in cifar10. In range [0, 49999]")
args = parser.parse_args()


def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def cifar10_plot(data, meta, im_idx=0):
    im = data[b'data'][im_idx, :]

    im_r = im[0:1024].reshape(32, 32)
    im_g = im[1024:2048].reshape(32, 32)
    im_b = im[2048:].reshape(32, 32)

    img = np.dstack((im_r, im_g, im_b))

    print("shape: ", img.shape)
    print("label: ", data[b'labels'][im_idx])
    print("category:", meta[b'label_names'][data[b'labels'][im_idx]])         

    plt.imshow(img) 
    plt.show()


def main():
    batch = (args.image // 10000) + 1
    idx = args.image - (batch-1)*10000

    data = unpickle(os.path.join(cifar10, "data_batch_" + str(batch)))
    meta = unpickle(os.path.join(cifar10, "batches.meta"))

    cifar10_plot(data, meta, im_idx=idx)


if __name__ == "__main__":
    main()


回答7:

Add 0.5:

plt.imshow(np.transpose(img, (1, 2, 0)) + 0.5)