Data loading with variable batch size?

2020-08-04 04:34发布

问题:

I am currently working on patch based super-resolution. Most of the papers divide an image into smaller patches and then use the patches as input to the models.I was able to create patches using custom dataloader. The code is given below:

import torch.utils.data as data
from torchvision.transforms import CenterCrop, ToTensor, Compose, ToPILImage, Resize, RandomHorizontalFlip, RandomVerticalFlip
from os import listdir
from os.path import join
from PIL import Image
import random
import os
import numpy as np
import torch

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg", ".bmp"])

class TrainDatasetFromFolder(data.Dataset):
    def __init__(self, dataset_dir, patch_size, is_gray, stride):
        super(TrainDatasetFromFolder, self).__init__()
        self.imageHrfilenames = []
        self.imageHrfilenames.extend(join(dataset_dir, x)
                                     for x in sorted(listdir(dataset_dir)) if is_image_file(x))
        self.is_gray = is_gray
        self.patchSize = patch_size
        self.stride = stride

    def _load_file(self, index):
        filename = self.imageHrfilenames[index]
        hr = Image.open(self.imageHrfilenames[index])
        downsizes = (1, 0.7, 0.45)
        downsize = 2
        w_ = int(hr.width * downsizes[downsize])
        h_ = int(hr.height * downsizes[downsize])
        aug = Compose([Resize([h_, w_], interpolation=Image.BICUBIC),
                       RandomHorizontalFlip(),
                       RandomVerticalFlip()])

        hr = aug(hr)
        rv = random.randint(0, 4)
        hr = hr.rotate(90*rv, expand=1)
        filename = os.path.splitext(os.path.split(filename)[-1])[0]
        return hr, filename

    def _patching(self, img):

        img = ToTensor()(img)
        LR_ = Compose([ToPILImage(), Resize(self.patchSize//2, interpolation=Image.BICUBIC), ToTensor()])

        HR_p, LR_p = [], []
        for i in range(0, img.shape[1] - self.patchSize, self.stride):
            for j in range(0, img.shape[2] - self.patchSize, self.stride):
                temp = img[:, i:i + self.patchSize, j:j + self.patchSize]
                HR_p += [temp]
                LR_p += [LR_(temp)]

        return torch.stack(LR_p),torch.stack(HR_p)

    def __getitem__(self, index):
        HR_, filename = self._load_file(index)
        LR_p, HR_p = self._patching(HR_)
        return LR_p, HR_p

    def __len__(self):
        return len(self.imageHrfilenames)

Suppose the batch size is 1, it takes an image and gives an output of size [x,3,patchsize,patchsize]. When batch size is 2, I will have two different outputs of size [x,3,patchsize,patchsize] (for example image 1 may give[50,3,patchsize,patchsize], image 2 may give[75,3,patchsize,patchsize] ). To handle this a custom collate function was required that stacks these two outputs along dimension 0. The collate function is given below:

def my_collate(batch):
    data = torch.cat([item[0] for item in batch],dim = 0)
    target = torch.cat([item[1] for item in batch],dim = 0)

    return [data, target]

This collate function concatenates along x (From the above example, I finally get [125,3,patchsize,pathsize]. For training purposes, I need to train the model using a minibatch size of say 25. Is there any method or any functions which I can use to directly get an output of size [25 , 3, patchsize, pathsize] directly from the dataloader using the necessary number of images as input to the Dataloader?

回答1:

The following code snippet works for your purpose.

First, we define a ToyDataset which takes in a list of tensors (tensors) of variable length in dimension 0. This is similar to the samples returned by your dataset.

import torch
from torch.utils.data import Dataset
from torch.utils.data.sampler import RandomSampler

class ToyDataset(Dataset):
    def __init__(self, tensors):
        self.tensors = tensors

    def __getitem__(self, index):
        return self.tensors[index]

    def __len__(self):
        return len(tensors)

Secondly, we define a custom data loader. The usual Pytorch dichotomy to create datasets and data loaders is roughly the following: There is an indexed dataset, to which you can pass an index and it returns the associated sample from the dataset. There is a sampler which yields an index, there are different strategies to draw indices which give rise to different samplers. The sampler is used by a batch_sampler to draw multiple indices at once (as many as specified by batch_size). There is a dataloader which combines sampler and dataset to let you iterate over a dataset, importantly the data loader also owns a function (collate_fn) which specifies how the multiple samples retrieved from the dataset using the indices from the batch_sampler should be combined. For your use case, the usual PyTorch dichotomy does not work well, because instead of drawing a fixed number of indices, we need to draw indices until the objects associated with the indices exceed the cumulative size we desire. This means we need immediate inspection of the objects and use this knowledge to decide whether to return a batch or keep drawing indices. This is what the custom data loader below does:

class CustomLoader(object):

    def __init__(self, dataset, my_bsz, drop_last=True):
        self.ds = dataset
        self.my_bsz = my_bsz
        self.drop_last = drop_last
        self.sampler = RandomSampler(dataset)

    def __iter__(self):
        batch = torch.Tensor()
        for idx in self.sampler:
            batch = torch.cat([batch, self.ds[idx]])
            while batch.size(0) >= self.my_bsz:
                if batch.size(0) == self.my_bsz:
                    yield batch
                    batch = torch.Tensor()
                else:
                    return_batch, batch = batch.split([self.my_bsz,batch.size(0)-self.my_bsz])
                    yield return_batch
        if batch.size(0) > 0 and not self.drop_last:
            yield batch

Here we iterate over the dataset, after drawing an index and loading the associated object, we concatenate it to the tensors we drew before (batch). We keep doing this until we reach the desired size, such that we can cut out and yield a batch. We retain the rows in batch, which we did not yield. Because it may be the case that a single instance exceeds the desired batch_size, we use a while loop.

You could modify this minimal CustomDataloader to add more features in the style of PyTorch's dataloader. There is also no need to use a RandomSampler to draw in indices, others would work equally well. It would also be possible to avoid repeated concats, in case your data is large by using for example a list and keeping track of the cumulative length of its tensors.

Here is an example, that demonstrates it works:

patch_size = 5
channels = 3
dim0sizes = torch.LongTensor(100).random_(1, 100)
data = torch.randn(size=(dim0sizes.sum(), channels, patch_size, patch_size))
tensors = torch.split(data, list(dim0sizes))

ds = ToyDataset(tensors)
dl = CustomLoader(ds, my_bsz=250, drop_last=False)
for i in dl:
    print(i.size(0))


回答2:

(Related, but not exactly in topic)

For batch size adaptation you can use the code as exemplified in this repo. It is implemented for a different purpose (maximize GPU memory usage), but it is not too hard to translate to your problem.

The code does batch adaptation and batch spoofing.