Pytorch transforms.RandomRotation() does not work

2020-04-21 01:00发布

Normally i was working on letter&digit recognition on my computer and I wanted to move my project to Colab but unfortunately there was an error (you can see the error below). after some debugging i found which line is giving me error.

transforms.RandomRotation(degrees=(90, -90))

below i wrote simple abstract code to show this error.This code does not work on colab but it works fine at my own computer environment.Problem might be about the different versions of pytorch library i have version 1.3.1 on my computer and colab uses version 1.4.0.

import torch
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt   
    transformOpt = transforms.Compose([
            transforms.RandomRotation(degrees=(90, -90)),
            transforms.ToTensor()
        ])

    train_set = datasets.MNIST(
        root='', train=True, transform=transformOpt, download=True)
    test_set = datasets.MNIST(
        root='', train=False, transform=transformOpt, download=True)


    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=100,
        shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=100,
        shuffle=False)

    images, labels = next(iter(train_loader))
    plt.imshow(images[0].view(28, 28), cmap="gray")
    plt.show()

The full error I got when I execute this sample code above on Google Colab.

TypeError                                 Traceback (most recent call last)

<ipython-input-1-8409db422154> in <module>()
     24     shuffle=False)
     25 
---> 26 images, labels = next(iter(train_loader))
     27 plt.imshow(images[0].view(28, 28), cmap="gray")
     28 plt.show()

10 frames

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py in __getitem__(self, index)
     95 
     96         if self.transform is not None:
---> 97             img = self.transform(img)
     98 
     99         if self.target_transform is not None:

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)
     68     def __call__(self, img):
     69         for t in self.transforms:
---> 70             img = t(img)
     71         return img
     72 

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)    1001         angle = self.get_params(self.degrees)    1002 
-> 1003         return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)    1004     1005     def
__repr__(self):

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in rotate(img, angle, resample, expand, center, fill)
    727         fill = tuple([fill] * 3)
    728 
--> 729     return img.rotate(angle, resample, expand, center, fillcolor=fill)
    730 
    731 

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in rotate(self, angle, resample, expand, center, translate, fillcolor)    2003         w, h = nw, nh    2004 
-> 2005         return self.transform((w, h), AFFINE, matrix, resample, fillcolor=fillcolor)    2006     2007     def save(self,    fp, format=None, **params):

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in transform(self, size, method, data, resample, fill, fillcolor)    2297             raise ValueError("missing method data")    2298 
-> 2299         im = new(self.mode, size, fillcolor)    2300         if method == MESH:    2301             # list of quads

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in new(mode, size, color)    2503         im.palette = ImagePalette.ImagePalette()    2504         color = im.palette.getcolor(color)
-> 2505     return im._new(core.fill(mode, size, color))    2506     2507 

TypeError: function takes exactly 1 argument (3 given)

1条回答
迷人小祖宗
2楼-- · 2020-04-21 01:26

You're absolutely correct. torchvision 0.5 has a bug in RandomRotation() in the fill argument probably due to incompatible Pillow version. This issue has now been fixed (PR#1760) and will be resolved in the next release.

Temporarily, you add fill=(0,) to RandomRotation transform to fix it.

transforms.RandomRotation(degrees=(90, -90), fill=(0,))
查看更多
登录 后发表回答