How to create a custom PyTorch dataset when the or

2019-05-18 14:38发布

问题:

I have a 42 GB jsonl file. Every element of this file is a json object. I create training samples from every json object. But the number of training samples from every json object that I extract can vary between 0 to 5 samples. What is the best way to create a custom PyTorch dataset without reading the entire jsonl file in memory?

This is the dataset I am talking about - Google Natural Questions.

回答1:

You have a couple of options.

  1. The simplest option, if having lots of small files is not a problem, is to preprocess each json object into a single file. Then you can just read each one depending on the index requested. E.g
    class SingleFileDataset(Dataset):
        def __init__(self, list_of_file_paths):
            self.list_of_file_paths = list_of_file_paths

        def __getitem__(self, index):
            return np.load(self.list_of_file_paths[index]) # Or equivalent reading code for single file
  1. You can also split the data into a constant number of files, and then calculate, given the index, which file the sample resides in. Then you need to open that file into memory and read the appropriate index. This gives a trade-off between disk access and memory usage. Assume you have n samples, and we split the samples into c files evenly during preprocessing. Now, to read the sample at index i we would do
    class SplitIntoFilesDataset(Dataset):
        def __init__(self, list_of_file_paths, n_splits):
            self.list_of_file_paths = list_of_file_paths
            self.n_splits = n_splits

        def __getitem__(self, index):
            # index // n_splits is the relevant file, and 
            # index % len(self) is the index in in that file
            file_to_load = self.list_of_file_paths[index // self.n_splits]
            # Load file
            file = np.load(file)
            datapoint = file[index % len(self)]
  1. Finally, you could use a HDF5 file that allows access to rows on disk. This is possibly the best solution if you have a lot of data, since the data will be close on disk. There's an implementation here which I have copy pasted below:

    import h5py
    import torch
    import torch.utils.data as data
    class H5Dataset(data.Dataset):
    
        def __init__(self, file_path):
            super(H5Dataset, self).__init__()
            h5_file = h5py.File(file_path)
            self.data = h5_file.get('data')
            self.target = h5_file.get('label')
    
        def __getitem__(self, index):            
            return (torch.from_numpy(self.data[index,:,:,:]).float(),
                    torch.from_numpy(self.target[index,:,:,:]).float())
    
        def __len__(self):
            return self.data.shape[0]