How to make use of the torch.utils.data.Dataset
and torch.utils.data.DataLoader
on your own data (not just the torchvision.datasets
)?
Is there a way to use the inbuilt DataLoaders
which they use on TorchVisionDatasets
to be used on any dataset?
Yes, that is possible. Just create the objects by yourself, e.g.
where
features
andtargets
are tensors.features
has to be 2-D, i.e. a matrix where each line represents one training sample, andtargets
may be 1-D or 2-D, depending on whether you are trying to predict a scalar or a vector.Hope that helps!
EDIT: response to @sarthak's question
Basically yes. If you create an object of type
TensorData
, then the constructor investigates whether the first dimensions of the feature tensor (which is actually calleddata_tensor
) and the target tensor (calledtarget_tensor
) have the same length:However, if you want to feed these data into a neural network subsequently, then you need to be careful. While convolution layers work on data like yours, (I think) all of the other types of layers expect the data to be given in matrix form. So, if you run into an issue like this, then an easy solution would be to convert your 4D-dataset (given as some kind of tensor, e.g.
FloatTensor
) into a matrix by using the methodview
. For your 5000xnxnx3 dataset, this would look like this:(The value
-1
tells PyTorch to figure out the length of the second dimension automatically.)In addition to user3693922's answer and the accepted answer, which respectively link the "quick" PyTorch documentation example to create custom dataloaders for custom datasets, and create a custom dataloader in the "simplest" case, there is a much more detailed dedicated official PyTorch tutorial on how to create a custom dataloader with the associated preprocessing: "writing custom datasets, dataloaders and transforms" official PyTorch tutorial
You can easily do this be extending the
data.Dataset
class. According to the API, all you have to do is implement two function:__getitem__
and__len__
.You can then wrap the dataset with the DataLoader as shown in the API and in @pho7 's answer.
I think the
ImageFolder
class is a reference. See code here.