Source code for datasets.gsm8k

import random
import os
from typing import Literal
import tqdm

from adalflow.utils.lazy_import import safe_import, OptionalPackages


from adalflow.utils.data import Dataset
from adalflow.utils.file_io import save_json, load_json
from adalflow.datasets.utils import prepare_dataset_path
from adalflow.core.base_data_class import DataClass
from adalflow.datasets.types import GSM8KData
from adalflow.utils import printc


[docs] class GSM8K(Dataset): __doc__ = r""" Use huggingface datasets to load GSM8K dataset. official_train: 7473 official_test: 1319 Our train split: 3736/2 Our val split: 3736/2 Our test split: 1319 You can use size to limit the number of examples to load. Example: .. code-block:: python dataset = GSM8K(split="train", size=10) print(f"example: {dataset[0]}") The output will be: .. code-block:: GSM8KData(id='8fc791e6-ea1d-472c-a882-d00d0600d423', question="The result from the 40-item Statistics exam Marion and Ella took already came out. Ella got 4 incorrect answers while Marion got 6 more than half the score of Ella. What is Marion's score?", answer='24', gold_reasoning="Ella's score is 40 items - 4 items = <<40-4=36>>36 items. Half of Ella's score is 36 items / 2 = <<36/2=18>>18 items. So, Marion's score is 18 items + 6 items = <<18+6=24>>24 items.", reasoning=None) """ def __init__( self, root: str = None, split: Literal["train", "val", "test"] = "train", size: int = None, **kwargs, ) -> None: if split not in ["train", "val", "test"]: raise ValueError("Split must be one of 'train', 'val', 'test'") self.root = root self.task_name = "gsm8k" data_path = prepare_dataset_path(self.root, self.task_name) # download and save split_csv_path = os.path.join(data_path, f"{split}.json") print(f"split_csv_path: {split_csv_path}") self._check_or_download_dataset(split_csv_path, split) # load from csv self.data = [] self.data = load_json(split_csv_path) if size is not None: self.data = self.data[:size] # convert to dataclass self.data = [GSM8KData.from_dict(d) for d in self.data] def _check_or_download_dataset( self, data_path: str = None, split: str = "train", ): r"""It will download data from huggingface datasets and split it and save it into three csv files. Args: data_path (str): The path to save the data. In particular with split name appended. split (str): The dataset split, supports ``"train"`` (default), ``"val"`` and ``"test"``. Decides which split to return. only_hard_examples (bool): If True, only hard examples will be downloaded. keep_details (str): If "all", all details will be kept. If "dev_titles", only dev titles will be kept. """ if data_path is None: raise ValueError("data_path must be specified") if os.path.exists(data_path): return safe_import( OptionalPackages.DATASETS.value[0], OptionalPackages.DATASETS.value[1] ) from datasets import load_dataset # use huggingface cache gsm8k_dataset = load_dataset("gsm8k", "main", cache_dir=self.root) hf_official_train = gsm8k_dataset["train"] hf_official_test = gsm8k_dataset["test"] official_train = [] official_test = [] for example in tqdm.tqdm(hf_official_train): question = example["question"] answer = example["answer"].strip().split() assert answer[-2] == "####" gold_reasoning = " ".join(answer[:-2]) answer = str(int(answer[-1].replace(",", ""))) official_train.append( dict(question=question, gold_reasoning=gold_reasoning, answer=answer) ) for example in tqdm.tqdm(hf_official_test): question = example["question"] answer = example["answer"].strip().split() assert answer[-2] == "####" gold_reasoning = " ".join(answer[:-2]) answer = str(int(answer[-1].replace(",", ""))) official_test.append( dict(question=question, gold_reasoning=gold_reasoning, answer=answer) ) rng = random.Random(0) rng.shuffle(official_train) # 7473 train rng = random.Random(0) rng.shuffle(official_test) # 1319 test printc(f"official_train: {len(official_train)}") printc(f"official_test: {len(official_test)}") train_set = official_train[: len(official_train) * 50 // 100] val_set = official_train[len(official_train) * 50 // 100 :] data_path_dir = os.path.dirname(data_path) for split, examples in zip( ["train", "val", "test"], [train_set, val_set, official_test], ): target_path = os.path.join(data_path_dir, f"{split}.json") save_json(examples, f=target_path) if split == "train": return train_set elif split == "val": return val_set else: return official_test def __getitem__(self, index) -> DataClass: return self.data[index] def __len__(self): return len(self.data)
if __name__ == "__main__": dataset = GSM8K(split="train", size=10) print(f"len: {len(dataset)}") print(f"dataset[0]: {dataset[0]}")