"""Default Dataset, DataLoader similar to `utils.data` in PyTorch.You can also use those provided by PyTorch or huggingface/datasets."""fromtypingimportUnion,Tuple,List,Sequence,TypeVar,GenericimportnumpyasnpimportrandomT_co=TypeVar("T_co",covariant=True)# TODO: consider directly use torch.utils.data in the future
[docs]classDataset(Generic[T_co]):r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. Subclasses could also optionally implement :meth:`__getitems__`, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """def__getitem__(self,index)->T_co:raiseNotImplementedError("Subclasses of Dataset should implement __getitem__.")def__len__(self)->int:raiseNotImplementedError("Subclasses of Dataset should implement __len__.")
[docs]classSubset(Dataset[T_co]):r""" Subset of a dataset at specified indices. Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """dataset:Dataset[T_co]indices:Sequence[int]def__init__(self,dataset:Dataset[T_co],indices:Sequence[int])->None:self.dataset=datasetself.indices=indicesdef__getitem__(self,idx):ifisinstance(idx,list):returnself.dataset[[self.indices[i]foriinidx]]returnself.dataset[self.indices[idx]]def__getitems__(self,indices:List[int])->List[T_co]:# add batched sampling support when parent dataset supports it.# see torch.utils.data._utils.fetch._MapDatasetFetcherifcallable(getattr(self.dataset,"__getitems__",None)):returnself.dataset.__getitems__([self.indices[idx]foridxinindices])# type: ignore[attr-defined]else:return[self.dataset[self.indices[idx]]foridxinindices]def__len__(self):returnlen(self.indices)
[docs]classDataLoader:__doc__=r"""A simplified version of PyTorch DataLoader. The biggest difference is not to handle tensors, but to handle any type of data."""def__init__(self,dataset,batch_size:int=4,shuffle:bool=True,seed:int=42):self.dataset=datasetself.batch_size=batch_sizeself.shuffle=shuffleself.seed=seedself.indices=np.arange(len(dataset))# if self.shuffle:# np.random.shuffle(self.indices)self.current_index=0self.max_steps=self.__len__()self.step_index=0
def__iter__(self):ifself.shuffle:ifself.seedisnotNone:np.random.seed(self.seed)# Use the provided seednp.random.shuffle(self.indices)self.current_index=0returnselfdef__len__(self):return(len(self.dataset)+self.batch_size-1)//self.batch_sizedef__next__(self)->Union[np.ndarray,Tuple]:# if self.current_index >= len(self.dataset):# raise StopIterationifself.current_index>=len(self.dataset):ifself.shuffle:ifself.seedisnotNone:np.random.seed(self.seed)# Use the same seed for reshufflenp.random.shuffle(self.indices)# Reshuffle for the new epochself.current_index=0ifself.step_index<self.max_steps:passelse:raiseStopIteration# raise StopIterationbatch_indices=self.indices[self.current_index:self.current_index+self.batch_size]batch_data=[self.dataset[int(i)]foriinbatch_indices]ifisinstance(batch_data[0],tuple):batch_data=tuple(zip(*batch_data))else:batch_data=np.array(batch_data)self.current_index+=self.batch_sizeself.step_index+=1returnbatch_data
[docs]defsubset_dataset(dataset,num_samples:int):r"""This function will be useful for testing and debugging purposes."""num_samples=min(num_samples,len(dataset))random_subset_indices=random.sample(range(len(dataset)),num_samples)returnSubset(dataset,random_subset_indices)