Source code for optim.sampler

"""The sampler here is designed to sample examples in few-shots ICL.

It differs from PyTorch's Sampler at torch.utils.data.sampler, which is used to sample data for training.

Our sampler directly impact the few-shot examples and can lead to different performance in the few-shot ICL.
"""

import random
from dataclasses import dataclass
import logging

from typing import (
    List,
    Sequence,
    Optional,
    Callable,
    Any,
    Dict,
    TypeVar,
    Generic,
    Union,
)
import math


T_co = TypeVar("T_co", covariant=True)
log = logging.getLogger(__name__)


[docs] @dataclass class Sample(Generic[T_co]): r"""Output data structure for each sampled data in the sequence.""" index: int # the initial index of the sample in the dataset data: T_co # the data of the sample
[docs] def to_dict(self) -> Dict: return {"index": self.index, "data": self.data}
[docs] class Sampler(Generic[T_co]): dataset: Sequence[object] = None def __init__(self, *args, **kwargs) -> None: pass
[docs] def set_dataset(self, dataset: Sequence[T_co]): r"""Set the dataset for the sampler""" self.dataset = dataset
[docs] def random_replace(self, *args, **kwargs): r"""Randomly replace some samples You can have two arguments, e.g., shots and samples, or shots, samples, and replace. """ pass
def __call__(self, *args: Any, **kwds: Any): return self.call(*args, **kwds)
[docs] def call(self, *args, **kwargs) -> List[Sample[T_co]]: r"""Abstract method to do the main sampling""" raise NotImplementedError( f"call method is not implemented in {type(self).__name__}" )
[docs] class RandomSampler(Sampler, Generic[T_co]): r""" Simple random sampler to sample from the dataset. """ dataset: Union[tuple, list] def __init__( self, dataset: Optional[Sequence[T_co]] = None, default_num_shots: Optional[int] = None, ): super().__init__() self.set_dataset(dataset) self.default_num_shots = default_num_shots self._id_to_index = ( {item.id: i for i, item in enumerate(dataset)} if dataset is not None else {} ) # to exclude samples in augmented demos
[docs] def set_dataset(self, dataset: Sequence[T_co]): # Sample will keep the index of the sample in the dataset self.dataset = ( [Sample[T_co](index=i, data=x) for i, x in enumerate(dataset)] if dataset is not None else None )
[docs] def random_replace( self, shots: int, samples: List[Sample[T_co]], replace: Optional[bool] = False, ) -> List[Sample[T_co]]: r""" Randomly replace num of shots in the samples. If replace is True, it will skip duplicate checks """ assert shots <= len( samples ), f"num_shots {shots} is larger than the number of samples {len(samples)}" samples = samples.copy() indices_to_replace = random.sample(range(len(samples)), shots) existing_indexces = {sample.index for sample in samples} if replace: # this can potentially result in duplicates in the samples for i in indices_to_replace: samples[i] = random.choice(self.dataset) return samples else: # exclude the indices in the samples from the choice choice_indexces = list( set(range(len(self.dataset))) - set(existing_indexces) ) # now sample shots from the choice_indices candidates_indices = random.sample(choice_indexces, shots) for i, j in zip(indices_to_replace, candidates_indices): samples[i] = self.dataset[j] return samples
[docs] def random_sample( self, shots: int, replace: Optional[bool] = False ) -> List[Sample]: r""" Randomly sample num of shots from the dataset. If replace is True, sample with replacement, meaning the same sample can be sampled multiple times. """ if replace: return [random.choice(self.dataset) for _ in range(shots)] return random.sample(self.dataset, shots)
[docs] def call( self, num_shots: Optional[int] = None, replace: Optional[bool] = False ) -> List[Sample]: if num_shots is None: num_shots = self.default_num_shots if num_shots is None: raise ValueError("num_shots is not set") return self.random_sample(num_shots, replace)
# TODO: this is only for classification tasks, will need to be further tested
[docs] class ClassSampler(Sampler, Generic[T_co]): r"""Sample from the dataset based on the class labels. T_co can be any type of data, e.g., dict, list, etc. with get_data_key_fun to extract the class label. Example: Initialize ``` dataset = [{"coarse_label": i} for i in range(10)] sampler = ClassSampler[Dict](dataset, num_classes=6, get_data_key_fun=lambda x: x["coarse_label"]) ``` """ def __init__( self, dataset: Sequence[T_co], num_classes: int, get_data_key_fun: Callable, default_num_shots: Optional[int] = None, ): super().__init__() self.dataset: List[Sample[T_co]] = [ Sample[T_co](index=i, data=x) for i, x in enumerate(dataset) ] self.num_classes = num_classes if get_data_key_fun is None: raise ValueError("get_data_key_fun must be provided") self.get_data_key_fun = get_data_key_fun self.class_indexces: List[List] = [[] for _ in range(num_classes)] for i, data in enumerate(dataset): self.class_indexces[self.get_data_key_fun(data)].append(i) self.default_num_shots = default_num_shots def _sample_one_class( self, num_samples: int, class_index: int, replace: Optional[bool] = False ) -> List[Sample[T_co]]: r""" Sample num_samples from the class with class_index""" if replace: # TODO: can allow different sample weights to be passed to each class based on the errors sampled_indexes = random.choices( self.class_indexces[class_index], k=num_samples ) else: sampled_indexes = random.sample( self.class_indexces[class_index], num_samples ) samples = [self.dataset[i] for i in sampled_indexes] return samples
[docs] def random_replace( self, shots: int, samples: List[Sample], replace: Optional[bool] = False, weights_per_class: Optional[List[float]] = None, ) -> Sequence[Sample[T_co]]: r""" Randomly select num shots from the samples and replace it with another sample has the same class index """ assert shots <= len( samples ), f"num_shots {shots} is larger than the number of samples {len(samples)}" samples = samples.copy() existing_indexces_by_class: Dict[Any, List[int]] = {} for i, sample in enumerate(samples): key = self.get_data_key_fun(sample.data) if key not in existing_indexces_by_class: existing_indexces_by_class[key] = [] existing_indexces_by_class[key].append(sample.index) # select num shots in samples to replace, class with higher accuracy will be less weight to be replaced if weights_per_class is None: replace_sample_indexes = random.sample(range(len(samples)), shots) else: weights = [ weights_per_class[self.get_data_key_fun(sample.data)] for sample in samples ] replace_sample_indexes = random.choices( range(len(samples)), k=shots, weights=weights ) replace_indexces_by_class: Dict[Any, List[int]] = {} for i in replace_sample_indexes: key = self.get_data_key_fun(samples[i].data) if key not in replace_indexces_by_class: replace_indexces_by_class[key] = [] replace_indexces_by_class[key].append(i) # sample for each class and exclude the existing samples replace_class_labels = list(replace_indexces_by_class.keys()) for class_label in replace_class_labels: num_sample_per_class = len(replace_indexces_by_class[class_label]) choice_indexces = list( set(self.class_indexces[class_label]) - set(existing_indexces_by_class[class_label]) ) if replace: sampled_indexes = random.choices( self.class_indexces[class_label], k=num_sample_per_class ) else: sampled_indexes = random.sample(choice_indexces, num_sample_per_class) for i, j in zip(replace_indexces_by_class[class_label], sampled_indexes): samples[i] = self.dataset[j] return samples
[docs] def random_sample( self, num_shots: int, replace: Optional[bool] = False, ) -> List[Sample[T_co]]: r""" Randomly sample num_shots from the dataset. If replace is True, sample with replacement. """ samples = [] samples_per_class = math.ceil(num_shots / self.num_classes) for class_index in range(self.num_classes): samples.extend( self._sample_one_class(samples_per_class, class_index, replace) ) if len(samples) > num_shots: # randomly sample from the class balance the samples = random.sample(samples, num_shots) return samples
[docs] def call( self, num_shots: int, replace: Optional[bool] = False, # weights: Optional[List] = None, ) -> List[Sample[T_co]]: r""" Sample num_shots from the dataset. If replace is True, sample with replacement. """ if num_shots is None: num_shots = self.default_num_shots if num_shots is None: raise ValueError("num_shots is not set") return self.random_sample(num_shots, replace)
if __name__ == "__main__": # test sample with type dict from typing import Dict dataset = [{"coarse_label": i} for i in range(10)] samples = [Sample[Dict](index=i, data=x) for i, x in enumerate(dataset)] print(samples)