I have a 42 GB jsonl file. Every element of this file is a json object. I create training samples from every json object. But the number of training samples from every json object that I extract can vary between 0 to 5 samples. What is the best way to create a custom PyTorch dataset without reading the entire jsonl file in memory?
This is the dataset I am talking about - Google Natural Questions.
You have a couple of options.
- The simplest option, if having lots of small files is not a problem, is to preprocess each json object into a single file. Then you can just read each one depending on the index requested. E.g
class SingleFileDataset(Dataset):
def __init__(self, list_of_file_paths):
self.list_of_file_paths = list_of_file_paths
def __getitem__(self, index):
return np.load(self.list_of_file_paths[index]) # Or equivalent reading code for single file
- You can also split the data into a constant number of files, and then calculate, given the index, which file the sample resides in. Then you need to open that file into memory and read the appropriate index. This gives a trade-off between disk access and memory usage. Assume you have
n
samples, and we split the samples into c
files evenly during preprocessing. Now, to read the sample at index i
we would do
class SplitIntoFilesDataset(Dataset):
def __init__(self, list_of_file_paths, n_splits):
self.list_of_file_paths = list_of_file_paths
self.n_splits = n_splits
def __getitem__(self, index):
# index // n_splits is the relevant file, and
# index % len(self) is the index in in that file
file_to_load = self.list_of_file_paths[index // self.n_splits]
# Load file
file = np.load(file)
datapoint = file[index % len(self)]
Finally, you could use a HDF5 file that allows access to rows on disk. This is possibly the best solution if you have a lot of data, since the data will be close on disk. There's an implementation here which I have copy pasted below:
import h5py
import torch
import torch.utils.data as data
class H5Dataset(data.Dataset):
def __init__(self, file_path):
super(H5Dataset, self).__init__()
h5_file = h5py.File(file_path)
self.data = h5_file.get('data')
self.target = h5_file.get('label')
def __getitem__(self, index):
return (torch.from_numpy(self.data[index,:,:,:]).float(),
torch.from_numpy(self.target[index,:,:,:]).float())
def __len__(self):
return self.data.shape[0]