The pytorch tutorial for data loading and processing is quite specific to one example, could someone help me with what the function should look like for a more generic simple loading of images?
Tutorial: http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
My Data:
I have the MINST dataset as jpg's in the following folder structure. (I know I can just use the dataset class, but this is purely to see how to load simple images into pytorch without csv's or complex features).
The folder name is the label and the images are 28x28 png's in greyscale, no transformations required.
data
train
0
3.png
5.png
13.png
23.png
...
1
3.png
10.png
11.png
...
2
4.png
13.png
...
3
8.png
...
4
...
5
...
6
...
7
...
8
...
9
...
If you're using mnist, there's already a preset in pytorch via torchvision.
You could do
import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16,
shuffle=True, num_workers=2)
If you want to generalize to a directory of images (same imports as above), you could do
class mnistmTrainingDataset(torch.utils.data.Dataset):
def __init__(self,text_file,root_dir,transform=transformMnistm):
"""
Args:
text_file(string): path to text file
root_dir(string): directory with all train images
"""
self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1))
self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2))
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.name_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0])
image = Image.open(img_name)
image = self.transform(image)
labels = self.label_frame.iloc[idx, 0]
#labels = labels.reshape(-1, 2)
sample = {'image': image, 'labels': labels}
return sample
mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt',
root_dir = 'Downloads/mnist_m/mnist_m_train')
mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2)
You can then iterate over it like:
for i_batch,sample_batched in enumerate(mnistmTrainLoader,0):
print("training sample for mnist-m")
print(i_batch,sample_batched['image'],sample_batched['labels'])
There are a bunch of ways to generalize pytorch for image dataset loading, the method that I know of is subclassing torch.utils.data.dataset
Here's what I did for pytorch 4.1
def load_dataset():
data_path = 'data/train/'
train_dataset = torchvision.datasets.ImageFolder(
root=data_path,
transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=64,
num_workers=0,
shuffle=True
)
return train_loader
for batch_idx, (data, target) in enumerate(load_dataset()):
#train network