How do you load images into Pytorch DataLoader?

2019-03-21 11:39发布

问题:

The pytorch tutorial for data loading and processing is quite specific to one example, could someone help me with what the function should look like for a more generic simple loading of images?

Tutorial: http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

My Data:

I have the MINST dataset as jpg's in the following folder structure. (I know I can just use the dataset class, but this is purely to see how to load simple images into pytorch without csv's or complex features).

The folder name is the label and the images are 28x28 png's in greyscale, no transformations required.

data
    train
        0
            3.png
            5.png
            13.png
            23.png
            ...
        1
            3.png
            10.png
            11.png
            ...
        2
            4.png
            13.png
            ...
        3
            8.png
            ...
        4
            ...
        5
            ...
        6
            ...
        7
            ...
        8
            ...
        9
            ...

回答1:

If you're using mnist, there's already a preset in pytorch via torchvision.
You could do

import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd

transform = transforms.Compose(
[transforms.ToTensor(),
 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True,
                                    download=True, transform=transform)
mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16,
                                      shuffle=True, num_workers=2)

If you want to generalize to a directory of images (same imports as above), you could do

class mnistmTrainingDataset(torch.utils.data.Dataset):

    def __init__(self,text_file,root_dir,transform=transformMnistm):
        """
        Args:
            text_file(string): path to text file
            root_dir(string): directory with all train images
        """
        self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1))
        self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2))
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.name_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0])
        image = Image.open(img_name)
        image = self.transform(image)
        labels = self.label_frame.iloc[idx, 0]
        #labels = labels.reshape(-1, 2)
        sample = {'image': image, 'labels': labels}

        return sample


mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt',
                                   root_dir = 'Downloads/mnist_m/mnist_m_train')

mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2)

You can then iterate over it like:

for i_batch,sample_batched in enumerate(mnistmTrainLoader,0):
    print("training sample for mnist-m")
    print(i_batch,sample_batched['image'],sample_batched['labels'])

There are a bunch of ways to generalize pytorch for image dataset loading, the method that I know of is subclassing torch.utils.data.dataset



回答2:

Here's what I did for pytorch 4.1

def load_dataset():
    data_path = 'data/train/'
    train_dataset = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=64,
        num_workers=0,
        shuffle=True
    )
    return train_loader

for batch_idx, (data, target) in enumerate(load_dataset()):
    #train network