Source code for utils.data

"""Default Dataset, DataLoader similar to `utils.data` in PyTorch.

You can also use those provided by PyTorch or huggingface/datasets."""

from typing import Union, Tuple, List, Sequence, TypeVar, Generic
import numpy as np
import random

T_co = TypeVar("T_co", covariant=True)


# TODO: consider directly use torch.utils.data in the future
[docs] class Dataset(Generic[T_co]): r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. Subclasses could also optionally implement :meth:`__getitems__`, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """ def __getitem__(self, index) -> T_co: raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") def __len__(self) -> int: raise NotImplementedError("Subclasses of Dataset should implement __len__.")
[docs] class Subset(Dataset[T_co]): r""" Subset of a dataset at specified indices. Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """ dataset: Dataset[T_co] indices: Sequence[int] def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: self.dataset = dataset self.indices = indices def __getitem__(self, idx): if isinstance(idx, list): return self.dataset[[self.indices[i] for i in idx]] return self.dataset[self.indices[idx]] def __getitems__(self, indices: List[int]) -> List[T_co]: # add batched sampling support when parent dataset supports it. # see torch.utils.data._utils.fetch._MapDatasetFetcher if callable(getattr(self.dataset, "__getitems__", None)): return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined] else: return [self.dataset[self.indices[idx]] for idx in indices] def __len__(self): return len(self.indices)
[docs] class DataLoader: __doc__ = r"""A simplified version of PyTorch DataLoader. The biggest difference is not to handle tensors, but to handle any type of data.""" def __init__(self, dataset, batch_size: int = 4, shuffle: bool = True): self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle self.indices = np.arange(len(dataset)) # if self.shuffle: # np.random.shuffle(self.indices) self.current_index = 0 self.max_steps = self.__len__() self.step_index = 0
[docs] def set_max_steps(self, max_steps: int): self.max_steps = max_steps
def __iter__(self): if self.shuffle: np.random.shuffle(self.indices) self.current_index = 0 return self def __len__(self): return (len(self.dataset) + self.batch_size - 1) // self.batch_size def __next__(self) -> Union[np.ndarray, Tuple]: # if self.current_index >= len(self.dataset): # raise StopIteration if self.current_index >= len(self.dataset): if self.shuffle: np.random.shuffle(self.indices) # Reshuffle for the new epoch self.current_index = 0 if self.step_index < self.max_steps: pass else: raise StopIteration # raise StopIteration batch_indices = self.indices[ self.current_index : self.current_index + self.batch_size ] batch_data = [self.dataset[int(i)] for i in batch_indices] if isinstance(batch_data[0], tuple): batch_data = tuple(zip(*batch_data)) else: batch_data = np.array(batch_data) self.current_index += self.batch_size self.step_index += 1 return batch_data
[docs] def subset_dataset(dataset, num_samples: int): r"""This function will be useful for testing and debugging purposes.""" num_samples = min(num_samples, len(dataset)) random_subset_indices = random.sample(range(len(dataset)), num_samples) return Subset(dataset, random_subset_indices)