How do I turn a Pytorch Dataloader into a numpy ar

2019-05-21 04:06发布

问题:

I am new to Pytorch. I have been trying to learn how to view my input images before I begin training on my CNN. I am having a very hard time changing the images into a form that can be used with matplotlib.

So far I have tried this:

from multiprocessing import freeze_support

import torch
from torch import nn
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader, Sampler
from torchvision import datasets
from torchvision.transforms import transforms
from torch.optim import Adam

import matplotlib.pyplot as plt
import numpy as np
import PIL

num_classes = 5
batch_size = 100
num_of_workers = 5

DATA_PATH_TRAIN = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\train'
DATA_PATH_TEST = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\test'

trans = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize(32),
    transforms.CenterCrop(32),
    transforms.ToPImage(),
    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
    ])

train_dataset = datasets.ImageFolder(root=DATA_PATH_TRAIN, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    print(npimg)
    plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))

def main():
    # get some random training images
    dataiter = iter(train_loader)
    images, labels = dataiter.next()

    # show images
    imshow(images)
    # print labels
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

if __name__ == "__main__":
    main()

However, this throws and error:

  [[0.27058825 0.18431371 0.31764707 ... 0.18823528 0.3882353
    0.27450982]
   [0.23137254 0.11372548 0.24313724 ... 0.16862744 0.14117646
    0.40784314]
   [0.25490198 0.19607842 0.30588236 ... 0.27450982 0.25882354
    0.34509805]
   ...
   [0.2784314  0.21960783 0.2352941  ... 0.5803922  0.46666667
    0.25882354]
   [0.26666668 0.16862744 0.23137254 ... 0.2901961  0.29803923
    0.2509804 ]
   [0.30980393 0.39607844 0.28627452 ... 0.1490196  0.10588235
    0.19607842]]

  [[0.2352941  0.06274509 0.15686274 ... 0.09411764 0.3019608
    0.19215685]
   [0.22745097 0.07843137 0.12549019 ... 0.07843137 0.10588235
    0.3019608 ]
   [0.20392156 0.13333333 0.1607843  ... 0.16862744 0.2117647
    0.22745097]
   ...
   [0.18039215 0.16862744 0.1490196  ... 0.45882353 0.36078432
    0.16470587]
   [0.1607843  0.10588235 0.14117646 ... 0.2117647  0.18039215
    0.10980392]
   [0.18039215 0.3019608  0.2117647  ... 0.11372548 0.06274509
    0.04705882]]]


 ...


 [[[0.8980392  0.8784314  0.8509804  ... 0.627451   0.627451
    0.627451  ]
   [0.8509804  0.8235294  0.7921569  ... 0.54901963 0.5568628
    0.56078434]
   [0.7921569  0.7529412  0.7176471  ... 0.47058824 0.48235294
    0.49411765]
   ...
   [0.3764706  0.38431373 0.3764706  ... 0.4509804  0.43137255
    0.39607844]
   [0.38431373 0.39607844 0.3882353  ... 0.4509804  0.43137255
    0.39607844]
   [0.3882353  0.4        0.39607844 ... 0.44313726 0.42352942
    0.39215687]]

  [[0.9254902  0.90588236 0.88235295 ... 0.60784316 0.6
    0.5921569 ]
   [0.88235295 0.85490197 0.8235294  ... 0.5411765  0.5372549
    0.53333336]
   [0.8235294  0.7882353  0.75686276 ... 0.47058824 0.47058824
    0.47058824]
   ...
   [0.50980395 0.5176471  0.5137255  ... 0.58431375 0.5647059
    0.53333336]
   [0.5137255  0.53333336 0.5254902  ... 0.58431375 0.5686275
    0.53333336]
   [0.5176471  0.53333336 0.5294118  ... 0.5764706  0.56078434
    0.5294118 ]]

  [[0.95686275 0.9372549  0.90588236 ... 0.18823528 0.19999999
    0.20784312]
   [0.9098039  0.8784314  0.8352941  ... 0.1607843  0.17254901
    0.18039215]
   [0.84313726 0.7921569  0.7490196  ... 0.1372549  0.14509803
    0.15294117]
   ...
   [0.03921568 0.05490196 0.05098039 ... 0.11764705 0.09411764
    0.02745098]
   [0.04705882 0.07843137 0.06666666 ... 0.12156862 0.10196078
    0.03529412]
   [0.05098039 0.0745098  0.07843137 ... 0.12549019 0.10196078
    0.04705882]]]


 [[[0.30588236 0.28627452 0.24313724 ... 0.2901961  0.26666668
    0.21568626]
   [0.8156863  0.6666667  0.5921569  ... 0.18039215 0.23921567
    0.21568626]
   [0.9019608  0.83137256 0.85490197 ... 0.21960783 0.36862746
    0.23921567]
   ...
   [0.7058824  0.83137256 0.85490197 ... 0.2627451  0.24313724
    0.20784312]
   [0.7137255  0.84313726 0.84705883 ... 0.26666668 0.29803923
    0.21568626]
   [0.7254902  0.8235294  0.8392157  ... 0.2509804  0.27058825
    0.2352941 ]]

  [[0.24705881 0.22745097 0.19215685 ... 0.2784314  0.25490198
    0.19607842]
   [0.59607846 0.37254903 0.29803923 ... 0.16470587 0.22745097
    0.20392156]
   [0.5921569  0.4509804  0.49803922 ... 0.20784312 0.3764706
    0.2352941 ]
   ...
   [0.42352942 0.4627451  0.42352942 ... 0.23921567 0.23137254
    0.19999999]
   [0.45882353 0.5176471  0.35686275 ... 0.23921567 0.26666668
    0.19607842]
   [0.41568628 0.44313726 0.34901962 ... 0.21960783 0.23921567
    0.21568626]]

  [[0.23137254 0.20784312 0.1490196  ... 0.30588236 0.28627452
    0.19607842]
   [0.61960787 0.3764706  0.26666668 ... 0.16470587 0.24313724
    0.21568626]
   [0.57254905 0.43137255 0.48235294 ... 0.2235294  0.40392157
    0.25882354]
   ...
   [0.4        0.42352942 0.37254903 ... 0.25490198 0.24705881
    0.21568626]
   [0.43137255 0.4509804  0.29411766 ... 0.25882354 0.28235295
    0.20392156]
   [0.38431373 0.3529412  0.25490198 ... 0.2352941  0.25490198
    0.23137254]]]


 [[[0.06274509 0.09019607 0.11372548 ... 0.5803922  0.5176471
    0.59607846]
   [0.09411764 0.14509803 0.1372549  ... 0.5294118  0.49803922
    0.5058824 ]
   [0.04705882 0.09411764 0.10196078 ... 0.45882353 0.42352942
    0.38431373]
   ...
   [0.15294117 0.12941176 0.1607843  ... 0.85882354 0.8509804
    0.80784315]
   [0.14509803 0.10588235 0.1607843  ... 0.8666667  0.85882354
    0.8       ]
   [0.1490196  0.10588235 0.16470587 ... 0.827451   0.8156863
    0.7921569 ]]

  [[0.06666666 0.12156862 0.17647058 ... 0.59607846 0.5529412
    0.6039216 ]
   [0.07058823 0.10588235 0.11764705 ... 0.56078434 0.5254902
    0.5372549 ]
   [0.03921568 0.0745098  0.09803921 ... 0.48235294 0.4392157
    0.4117647 ]
   ...
   [0.2117647  0.14509803 0.2784314  ... 0.43137255 0.3529412
    0.34117648]
   [0.2235294  0.11372548 0.2509804  ... 0.4509804  0.39607844
    0.2509804 ]
   [0.25490198 0.12156862 0.24705881 ... 0.38039216 0.36078432
    0.3254902 ]]

  [[0.05490196 0.09803921 0.12549019 ... 0.46666667 0.38039216
    0.45490196]
   [0.06274509 0.09803921 0.10196078 ... 0.44705883 0.41568628
    0.3882353 ]
   [0.03921568 0.06666666 0.0862745  ... 0.3764706  0.33333334
    0.28235295]
   ...
   [0.12156862 0.14509803 0.16862744 ... 0.15686274 0.0745098
    0.09411764]
   [0.10588235 0.11372548 0.16862744 ... 0.25882354 0.18431371
    0.05490196]
   [0.12156862 0.11372548 0.17254901 ... 0.2352941  0.17254901
    0.14117646]]]]
Traceback (most recent call last):
  File "image_loader.py", line 51, in <module>
    main()
  File "image_loader.py", line 46, in main
    imshow(images)
  File "image_loader.py", line 38, in imshow
    plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))
  File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 598, in transpose
    return _wrapfunc(a, 'transpose', axes)
  File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 51, in _wrapfunc
    return getattr(obj, method)(*args, **kwds)
ValueError: repeated axis in transpose

I tried to print out the arrays to get the dimensions but I do not know what to make of this. It is very confusing.

Here is my direct question: How do I view the input images before training using the tensors in my DataLoader object?

回答1:

First of all, dataloader output 4 dimensional tensor - [batch, channel, height, width]. Matplotlib and other image processing libraries often requires [height, width, channel]. You are right about using the transpose, just not in the right way.

There will be a lot of images in your images so first you need to pick one (or write a for loop to save all of them). This will be simply images[i], typically I use i=0.

Then, your transpose should convert a now [channel, height, width] tensor to a [height, width, channel] one. To do this, use np.transpose(image.numpy(), (1, 2, 0)), very much like yours.

Putting them together, you should have

plt.imshow(np.transpose(images[0].numpy(), (1, 2, 0)))

Sometimes you need to call .detach() (detach this part from the computational graph) and .cpu() (transfer data from GPU to CPU) depending on the use case, that will be

plt.imshow(np.transpose(images[0].cpu().detach().numpy(), (1, 2, 0)))