"""Ready to use trainer for LLM task pipeline"""
from typing import Literal, Optional, List, Dict, Any, Tuple, TYPE_CHECKING
import os
import logging
from tqdm import tqdm
import random
import numpy as np
import uuid
import time
from adalflow.core.component import Component
from adalflow.optim.optimizer import Optimizer, DemoOptimizer, TextOptimizer
if TYPE_CHECKING:
from adalflow.optim.parameter import Parameter
from adalflow.optim.types import (
PromptData,
TrainerResult,
ParameterType,
TrainerStepResult,
)
from adalflow.eval.base import EvaluationResult
from adalflow.optim.trainer.adal import AdalComponent
from adalflow.optim.text_grad.ops import sum_ops
from adalflow.utils import save_json, load_json
from adalflow.utils.cache import hash_text_sha1
from adalflow.utils.data import DataLoader
from adalflow.optim.types import TrainerValidateStats
log = logging.getLogger(__name__)
[docs]
class Trainer(Component):
__doc__ = r"""Ready to use trainer for LLM task pipeline to optimize all types of parameters.
Training set: can be used for passing initial proposed prompt or for few-shot sampling.
Validation set: Will be used to select the final prompt or samples.
Test set: Will be used to evaluate the final prompt or samples.
Args:
adaltask: AdalComponent: AdalComponent instance
strategy: Literal["random", "constrained"]: Strategy to use for the optimizer
max_steps: int: Maximum number of steps to run the optimizer
num_workers: int: Number of workers to use for parallel processing
ckpt_path: str: Path to save the checkpoint files, default to ~/.adalflow/ckpt.
batch_val_score_threshold: Optional[float]: Threshold for skipping a batch
max_error_samples: Optional[int]: Maximum number of error samples to keep
max_correct_samples: Optional[int]: Maximum number of correct samples to keep
max_proposals_per_step: int: Maximum number of proposals to generate per step
train_loader: Any: DataLoader instance for training
train_dataset: Any: Training dataset
val_dataset: Any: Validation dataset
test_dataset: Any: Test dataset
few_shots_config: Optional[FewShotConfig]: Few shot configuration
save_traces: bool: Save traces for for synthetic data generation or debugging
debug: bool: Debug mode to run the trainer in debug mode. If debug is True, for text debug, the graph will be under /ckpt/YourAdalComponentName/debug_text_grads for prompt parameter,
and for demo debug, the graph will be under /ckpt/YourAdalComponentName/debug_demos for demo parameters.
Note:
When you are in the debug mode, you can use get_logger api to show more detailed log on your own.
Example:
from adalflow.utils import get_logger
get_logger(level="DEBUG")
"""
adaltask: AdalComponent # task pipeline
train_batch_size: Optional[int] = 4
train_loader: Any
val_dataset = None
test_dataset = None
strategy: Literal["random", "constrained"]
optimization_order: Literal["sequential", "mix"] = (
"sequential" # zero-shot first, bootstrap second
)
max_steps: int
optimizer: Optimizer = None
ckpt_path: Optional[str] = None
ckpt_file: Optional[str] = None
num_workers: int = 4
max_proposals_per_step: int = 5
# moving batch for speed up the training
batch_val_score_threshold: Optional[float] = (
1.0 # when acc_score >= this threshold, skip this batch
)
max_error_samples: Optional[int] = 8
max_correct_samples: Optional[int] = 8
debug: bool = False
def __init__(
self,
adaltask: AdalComponent,
optimization_order: Literal["sequential", "mix"] = "sequential",
strategy: Literal["random", "constrained"] = "constrained", # search strategy
max_steps: int = 1000,
train_batch_size: Optional[int] = 4,
num_workers: int = 4,
ckpt_path: str = None,
batch_val_score_threshold: Optional[float] = 1.0,
max_error_samples: Optional[int] = 4,
max_correct_samples: Optional[int] = 4,
max_proposals_per_step: int = 5,
train_loader: Optional[Any] = None,
train_dataset: Optional[Any] = None,
val_dataset: Optional[Any] = None,
test_dataset: Optional[Any] = None,
# For demo optimizer
raw_shots: Optional[int] = None,
bootstrap_shots: Optional[int] = None,
weighted_sampling: bool = False, # if weighted sampling when do few-shot demos
exclude_input_fields_from_bootstrap_demos: bool = False,
debug: bool = False,
save_traces: bool = False, # save traces in the few-shto demos
*args,
**kwargs,
) -> None:
super().__init__()
if not isinstance(adaltask, AdalComponent):
raise ValueError("Task should be an instance of AdalComponent")
if strategy not in ["random", "constrained"]:
raise ValueError("Strategy should be either random or constrained")
self.optimization_order = optimization_order
self.strategy = strategy
self.max_steps = max_steps
self.ckpt_path = ckpt_path
self.adaltask = adaltask
self.num_workers = num_workers
self.train_loader = train_loader
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.batch_val_score_threshold = batch_val_score_threshold
self.max_error_samples = max_error_samples
self.max_correct_samples = max_correct_samples
self.max_proposals_per_step = max_proposals_per_step
self._subset_effect_count = {"pass": 0, "fail": 0}
self._fullset_effect_count = {"pass": 0, "fail": 0}
self._valset_effect_count = {"pass": 0, "fail": 0}
self._effective_measure = {
"subset": self._subset_effect_count,
"fullset": self._fullset_effect_count,
"valset": self._valset_effect_count,
}
self._raw_shots = raw_shots
self._bootstrap_shots = bootstrap_shots
self.demo_optimizers: List[DemoOptimizer] = []
self.text_optimizers: List[TextOptimizer] = []
self.save_traces = save_traces
self.train_batch_size = train_batch_size
self.weighted_sampling = weighted_sampling
self.debug = debug
self.exclude_input_fields_from_bootstrap_demos = (
exclude_input_fields_from_bootstrap_demos
)
# TODO: need to support checkpoint resume too!
[docs]
def diagnose(self, dataset: Any, split: str = "train"):
"""Run an evaluation on the trainset to track all error response, and its raw response using AdaplComponent's default configure_callbacks
Args:
dataset: Any: Dataset to evaluate
split: str: Split name, default to train and it is also used as set the directory name for saving the logs
Example:
.. code-block:: python
trainset, valset, testset = load_datasets(max_samples=10)
adaltask = TGDWithEvalFnLoss(
task_model_config=llama3_model,
backward_engine_model_config=llama3_model,
optimizer_model_config=llama3_model,
)
trainer = Trainer(adaltask=adaltask)
diagnose = trainer.diagnose(dataset=trainset)
print(diagnose)
"""
# 1. track all intermediate outputs
if not self.ckpt_path:
trainer_state = self.gather_trainer_states()
self.prep_ckpt_file_path(trainer_state)
save_path = os.path.join(self.ckpt_path, f"diagnose_{split}")
print(f"Save diagnose to {save_path}")
log_paths = self.adaltask.configure_callbacks(save_dir=save_path)
# 2. evaluate
acc = self.adaltask.validation_step(dataset, 0, self.num_workers)
acc_score = acc.avg_score
acc_per_item_scores = acc.per_item_scores
# 3. load all completion from the log paths
from adalflow.utils.file_io import load_jsonl, write_list_to_jsonl, save_json
sorted_indices = sorted(
range(len(acc_per_item_scores)), key=lambda i: acc_per_item_scores[i]
)
try:
sorted_ids = [dataset[i].id for i in sorted_indices]
except AttributeError:
raise ValueError(
"dataset should have an attribute id for tracking the samples"
)
print(f"sorted_indices: {sorted_indices}")
sorted_scores = [acc_per_item_scores[i] for i in sorted_indices]
print(f"sorted_scores: {sorted_scores}")
sorted_dataset = [dataset[i] for i in sorted_indices]
# reorder the samples based on the score
for log_path in log_paths:
file_name = os.path.basename(log_path)
print(f"Loading log file: {file_name}")
logs = load_jsonl(log_path)
try:
logs_dict = {log["output"]["id"]: log for log in logs}
except KeyError:
raise ValueError(
"Log file should have an output key with an id for tracking the samples. Ensure you have passed the data id to the Generator."
)
sorted_logs = [logs_dict[id] for id in sorted_ids]
for log, score in zip(sorted_logs, sorted_scores):
log["score"] = score
write_list_to_jsonl(log_path, sorted_logs)
log_dir = os.path.dirname(log_path)
diagnose_filename = file_name.replace(".jsonl", "_diagnose.json")
diagnose_file = os.path.join(log_dir, diagnose_filename)
diagnose_items = []
for i, log in enumerate(sorted_logs):
if log["score"] < 0.5:
diagnose_item = {
"id": log["output"]["id"] if "id" in log["output"] else None,
"score": log["score"],
"prompt_kwargs": log["prompt_kwargs"],
"raw_response": log["output"]["raw_response"],
"answer": log["output"]["data"],
"dataset_item": sorted_dataset[i],
"error": log["output"]["error"],
"time_stamp": log["time_stamp"],
}
diagnose_items.append(diagnose_item)
save_json(diagnose_items, diagnose_file)
# save the stats
stats = {
"total_samples": len(sorted_logs),
"total_error_samples": len(diagnose_items),
"avg_score": acc_score,
}
save_json(stats, os.path.join(log_dir, "stats.json"))
print(f"Total error samples: {len(diagnose_items)}")
print(f"Saved diagnose to {diagnose_file}")
return acc_score, acc_per_item_scores, log_paths
[docs]
def debug_report(
self,
text_grad_debug_path: Optional[str] = None,
few_shot_demo_debug_path: Optional[str] = None,
):
import colorama
from colorama import Fore
# Initialize colorama
colorama.init(autoreset=True)
print(Fore.CYAN + "\n================== DEBUG REPORT ==================\n")
if text_grad_debug_path:
print(Fore.GREEN + f"✔ Text grad debug path: {text_grad_debug_path}")
else:
print(Fore.RED + "✘ Text grad debugging was not run.")
if few_shot_demo_debug_path:
print(
Fore.GREEN + f"✔ Few shot demo debug path: {few_shot_demo_debug_path}"
)
else:
print(Fore.RED + "✘ Few shot demo debugging was not run.")
print(Fore.GREEN + "\n✔ The debug has run successfully!")
print(
Fore.YELLOW
+ "You can visualize the complete computation graph at the paths shown above."
)
print(Fore.CYAN + "\n===================================================\n")
[docs]
def fit(
self,
*,
adaltask: Optional[AdalComponent] = None,
train_loader: Optional[Any] = None,
train_dataset: Optional[Any] = None,
val_dataset: Optional[Any] = None,
test_dataset: Optional[Any] = None,
debug: bool = False,
save_traces: bool = False,
raw_shots: Optional[int] = None,
bootstrap_shots: Optional[int] = None,
resume_from_ckpt: Optional[
str
] = None, # TODO: have a more comprehensive ckpt loading in the future
):
r"""
train_loader: An iterable or collection of iterables specifying training samples.
"""
start_time = time.time()
debug = debug or self.debug
# check task
adaltask = adaltask or self.adaltask
self.adaltask = adaltask
if not isinstance(adaltask, AdalComponent):
raise ValueError(
f"Task should be an instance of AdalComponent. Got {adaltask}"
)
raw_shots = raw_shots or self._raw_shots
bootstrap_shots = bootstrap_shots or self._bootstrap_shots
print(f"raw_shots: {raw_shots}, bootstrap_shots: {bootstrap_shots}")
self.save_traces = save_traces or self.save_traces
train_loader = train_loader or self.train_loader
train_dataset = train_dataset or self.train_dataset
if not train_loader and train_dataset:
batch_size = self.train_batch_size
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
val_dataset = val_dataset or self.val_dataset
test_dataset = test_dataset or self.test_dataset
# check train_loader and val_dataset and test_dataset, reject tuple
if train_loader:
exam_batch = next(iter(train_loader))
if isinstance(exam_batch, tuple):
raise ValueError(
"train_loader should return not be tuple, please use dict or a dataclass or with DataClass"
)
if val_dataset:
if isinstance(val_dataset, tuple):
raise ValueError(
"val_dataset should not be tuple, please use dict or a dataclass or with DataClass"
)
if test_dataset:
if isinstance(test_dataset, tuple):
raise ValueError(
"test_dataset should not be tuple, please use dict or a dataclass or with DataClass"
)
if train_dataset:
if isinstance(train_dataset, tuple):
raise ValueError(
"train_dataset should not be tuple, please use dict or a dataclass or with DataClass"
)
# prepare optimizers
self.optimizers: List[Optimizer] = self.adaltask.configure_optimizers()
self.text_optimizers = [
opt for opt in self.optimizers if isinstance(opt, TextOptimizer)
]
self.demo_optimizers = [
opt for opt in self.optimizers if isinstance(opt, DemoOptimizer)
]
# config optimizers
if len(self._get_trainable_demo_params()) > 0:
for opt in self.demo_optimizers:
opt.config_shots(raw_shots=raw_shots, bootstrap_shots=bootstrap_shots)
opt.use_weighted_sampling(weighted=self.weighted_sampling)
opt.exclude_input_fields_from_bootstrap_demos = (
self.exclude_input_fields_from_bootstrap_demos
)
self.adaltask.configure_teacher_generator()
print("Configured demo optimizers")
else:
print("No trainable demo params to optimize")
self.demo_optimizers = []
if len(self._get_trainable_text_params()) > 0:
if self.adaltask.backward_engine is None:
self.adaltask.configure_backward_engine()
else:
print("No trainable text params to optimize")
self.text_optimizers = []
if len(self.demo_optimizers) == 0 and len(self.text_optimizers) == 0:
print("No trainable parameters to optimize")
return None
trainer_results = None
starting_step = 0
if resume_from_ckpt:
self.ckpt_file = resume_from_ckpt
dict_data = load_json(self.ckpt_file)
trainer_results: TrainerResult = TrainerResult.from_dict(dict_data)
# restore the prompts to the adaltask
val_scores = []
test_scores = []
for step in trainer_results.step_results:
if step.val_score:
val_scores.append(step.val_score)
if step.test_score:
test_scores.append(step.test_score)
result_from_step = 0
if test_scores:
result_from_step = test_scores.index(max(test_scores))
elif val_scores:
result_from_step = val_scores.index(max(val_scores))
prompts: List[PromptData] = trainer_results.step_results[
result_from_step
].prompt
print(f"Restoring prompts: {prompts[0]}")
self.adaltask._set_param_values(prompts)
starting_step = len(trainer_results.steps) - 1
if debug:
print("Debugging mode")
text_grad_debug_path, few_shot_demo_debug_path = None, None
if len(self.text_optimizers) > 0:
text_grad_debug_path = self._fit_text_grads_one_step_for_debug(
train_loader
)
if len(self.demo_optimizers) > 0:
few_shot_demo_debug_path = self._fit_demos_one_step_for_debug(
train_loader, train_dataset, val_dataset, test_dataset
)
self.debug_report(text_grad_debug_path, few_shot_demo_debug_path)
return
########Run text_optimizers and demo optimizers in sequential order ########
if (
self.optimization_order == "mix"
and len(self.demo_optimizers) > 0
and len(self.text_optimizers) > 0
):
if self.strategy == "random":
self._fit_text_grad_demo_mix_random(
train_loader,
train_dataset,
val_dataset,
test_dataset,
trainer_results,
starting_step=starting_step,
)
elif self.strategy == "constrained":
self._fit_text_grad_demo_mix_constrained(
train_loader,
train_dataset,
val_dataset,
test_dataset,
trainer_results,
starting_step=starting_step,
)
else:
raise ValueError(f"Strategy {self.strategy} not supported")
else: # sequential, text first and demo second
if len(self.text_optimizers) > 0:
if self.strategy == "random":
trainer_results = self._fit_text_grad_random(
train_loader,
val_dataset,
test_dataset,
trainer_results,
starting_step=starting_step,
)
starting_step += self.max_steps
elif self.strategy == "constrained":
trainer_results = self._fit_text_grad_constraint(
train_loader,
val_dataset,
test_dataset,
trainer_results=trainer_results,
starting_step=starting_step,
)
starting_step += self.max_steps
else:
raise ValueError(f"Strategy {self.strategy} not supported")
if len(self.demo_optimizers) > 0:
self.adaltask.configure_teacher_generator() # attemp to use the newest teacher as
self._fit_demos_random(
train_loader,
train_dataset,
val_dataset,
test_dataset,
trainer_results=trainer_results,
starting_step=starting_step,
)
end_time = time.time()
print(f"Training time: {end_time - start_time}s")
print(f"ckpt_file: {self.ckpt_file}")
@staticmethod
def _estimate_num_epochs(train_loader: Any, max_steps: int):
num_samples = len(train_loader)
return max_steps // num_samples + 1
[docs]
def initial_validation(self, val_dataset: Any, test_dataset: Any):
val_output = self.adaltask.validation_step(val_dataset, 0, self.num_workers)
val_score = val_output.avg_score
test_score = None
if test_dataset is not None:
test_output = self.adaltask.validation_step(
test_dataset, 0, self.num_workers
)
test_score = test_output.avg_score
trainer_results = TrainerResult(
steps=[], val_scores=[], test_scores=[], step_results=[], prompts=[]
)
trainer_results.val_scores.append(val_score)
trainer_results.test_scores.append(test_score)
prompts = self.adaltask._get_param_values()
trainer_results.prompts.append(prompts)
trainer_results.steps.append(0)
print(f"Initial validation score: {val_score}")
print(f"Initial test score: {test_score}")
return trainer_results
[docs]
def gather_trainer_states(self):
trainer_state = {}
trainer_state["strategy"] = self.strategy
trainer_state["demo_optimizers"] = self._get_trainable_demo_params()
trainer_state["text_optimizers"] = self._get_trainable_text_params()
trainer_state["max_steps"] = self.max_steps
trainer_state["num_workers"] = self.num_workers
trainer_state["raw_shots"] = self._raw_shots
trainer_state["bootstrap_shots"] = self._bootstrap_shots
trainer_state["weighted_sampling"] = self.weighted_sampling
trainer_state["exclude_input_fields_from_bootstrap_demos"] = (
self.exclude_input_fields_from_bootstrap_demos
)
trainer_state["batch_size"] = (
self.train_loader.batch_size if self.train_loader else None
)
trainer_state["train_size"] = (
len(self.train_loader.dataset) if self.train_loader else None
)
trainer_state["val_size"] = len(self.val_dataset) if self.val_dataset else None
trainer_state["test_size"] = (
len(self.test_dataset) if self.test_dataset else None
)
trainer_state["task_class"] = self.adaltask.__class__.__name__
from adalflow.utils.serialization import serialize
hash_key = hash_text_sha1(serialize(trainer_state))[0:5]
trainer_state["hash_key"] = hash_key
trainer_state["task_state_dict"] = self.adaltask.to_dict()
# restore_state = AdalComponent.from_dict(
# trainer_state["task_state_dict"]
# ) # tODO: add a test for adalcomponent
# print(
# f"restore_state: {str(restore_state.to_dict()) == str(self.adaltask.to_dict())}"
# )
# print(f"task_state_dict: {trainer_state['task_state_dict']}")
return trainer_state
[docs]
def prep_ckpt_file_path(self, trainer_state: Dict[str, Any] = None):
r"""Prepare the checkpoint root path: ~/.adalflow/ckpt/task_name/.
It also generates a unique checkpoint file name based on the strategy, max_steps, and a unique hash key.
For multiple runs but with the same adalcomponent + trainer setup, the run number will be incremented.
"""
if self.ckpt_file:
return
from adalflow.utils.global_config import get_adalflow_default_root_path
if self.ckpt_path is None:
default_root_path = get_adalflow_default_root_path()
self.ckpt_path = os.path.join(
default_root_path, "ckpt", self.adaltask.__class__.__name__
)
print(f"Checkpoint path: {self.ckpt_path}")
os.makedirs(self.ckpt_path, exist_ok=True)
# list all existing checkpoints with the same file name prefix
hash_key = (
trainer_state["hash_key"]
if trainer_state and "hash_key" in trainer_state
else str(uuid.uuid4())
)
file_name_prefix = f"{self.strategy}_max_steps_{self.max_steps}_{hash_key}"
ckpt_files = [
f for f in os.listdir(self.ckpt_path) if f.startswith(file_name_prefix)
]
run: int = 1
if ckpt_files:
# Sort files based on last modification time
ckpt_files.sort(
key=lambda x: os.path.getmtime(os.path.join(self.ckpt_path, x)),
reverse=True,
)
latest_ckpt_file = ckpt_files[0]
# get the run number
run = int(latest_ckpt_file.split("_run_")[-1].split(".json")[0]) + 1
else:
latest_ckpt_file = None
self.ckpt_file = os.path.join(
self.ckpt_path, f"{file_name_prefix}_run_{run}.json"
)
def _pre_fit(self, val_dataset: Any, test_dataset: Any) -> TrainerResult:
# validate first (separate into another function where we can even save the outputs so that we can highlight error predictions)
trainer_state = self.gather_trainer_states()
trainer_results: TrainerResult = self.initial_validation(
val_dataset, test_dataset
)
self._add_history_text_optimizers(trainer_results.val_scores[-1])
trainer_results.trainer_state = trainer_state
self.prep_ckpt_file_path(trainer_state)
return trainer_results
# end of validation
def _fit_demos_one_step_for_debug(
self, train_loader, train_dataset: Any, val_dataset: Any, test_dataset: Any
) -> str:
# get_logger(level="DEBUG")
print("Fitting using Random Demo Optimizer")
self.prep_ckpt_file_path()
debug_path = os.path.join(self.ckpt_path, "debug_demos")
os.makedirs(debug_path, exist_ok=True)
print(f"save to {debug_path}")
self.adaltask.train()
self.adaltask.trace()
self._set_demo_optimizers_dataset(train_dataset)
# test teacher mode
self.adaltask.use_teacher()
train_loader.batch_size = 2
pred_teacher = set() # id of the teacher predictions
batch = next(iter(train_loader))
y_preds: List[Parameter] = self.adaltask.train_step(batch, 0, self.num_workers)
if len(y_preds) != 2:
raise ValueError("Expected 2 y_preds")
nodes: List[Parameter] = y_preds[0].trace_graph(y_preds[0])[0]
demo_params = [p for p in nodes if p.param_type == ParameterType.DEMOS]
if len(demo_params) == 0:
raise ValueError("No demo params found")
if len(demo_params[0]._traces) != 2:
raise ValueError(
f"Expected 2 traces, got {len(demo_params[0]._traces)}, traces: {demo_params[0]._traces}"
)
print(f"Teacher y_preds: {y_preds[0].to_dict()}")
y_preds_outputs = [p.full_response for p in y_preds]
batch_eval: EvaluationResult = self.adaltask.evaluate_samples(
batch, y_preds_outputs
)
batch_acc = batch_eval.avg_score
batch_per_item_scores = batch_eval.per_item_scores
print(
f"Validation accuracy: {batch_acc}, per item scores: {batch_per_item_scores}"
)
# test loss
losses: List[Parameter] = self.adaltask.loss_step(
batch, y_preds, 0, self.num_workers
)
print(f"Losses: {losses[0].to_dict()}")
self._demo_optimizers_add_scores(
[sample.id for sample in batch], batch_per_item_scores, is_teacher=True
)
losses[0].backward()
losses[1].backward()
pred_teacher.add(batch[0].id)
pred_teacher.add(batch[1].id)
graph_path = os.path.join(debug_path, "graph")
print(f"Graph saved to {graph_path}")
# check the score
for key, val in demo_params[0]._traces.items():
print(f"param: {key}, val: {val}")
score = val.score
if score is None:
raise ValueError("Score is None")
print(f"param: {key}, score: {score}")
print(f"Loss after backward: {losses[0].to_dict()}")
# tracking the bootstrap so we wont repeat the same samples
for batch_idx, batch in enumerate(train_loader):
print(f"Training step: {batch_idx}")
if batch_idx > 0:
break
# eval_student_mode
self.adaltask.use_teacher(False)
y_preds_student = self.adaltask.train_step(
batch, batch_idx, self.num_workers
)
losses_student: List[Parameter] = self.adaltask.loss_step( # noqa F841
batch, y_preds_student, batch_idx, self.num_workers
)
self._demo_optimizers_add_scores(
[sample.id for sample in batch], batch_per_item_scores, is_teacher=False
)
# for loss in losses_student:
# loss.backward()
y_preds_outputs = [p.full_response for p in y_preds_student]
eval_result = self.adaltask.evaluate_samples(batch, y_preds_outputs)
print(f"Eval result: {eval_result.avg_score}")
eval_score_per_item = eval_result.per_item_scores
# bootstrap
batch_for_teacher = []
losses_teacher = []
for i, (sample, item_score) in enumerate(zip(batch, eval_score_per_item)):
# use teacher
if sample.id in pred_teacher:
continue
# if item_score < 0.5:
batch_for_teacher.append(sample)
pred_teacher.add(sample.id)
# run teacher, use teachers's output instead of the initial output (bootstrap)
if len(batch_for_teacher) > 0:
print(f"Using teacher for {len(batch_for_teacher)} samples")
self.adaltask.use_teacher()
y_preds_teacher = self.adaltask.train_step(
batch_for_teacher, batch_idx, self.num_workers
)
losses_teacher: List[Parameter] = self.adaltask.loss_step( # noqa F841
batch_for_teacher, y_preds_teacher, batch_idx, self.num_workers
)
self._demo_optimizers_add_scores(
[sample.id for sample in batch_for_teacher],
eval_score_per_item,
is_teacher=True,
)
# propose
self._demo_optimizers_propose()
graph_path = os.path.join(debug_path, "student_graph")
losses_student[0].draw_graph(filepath=graph_path)
# test step
self._demo_optimizers_step()
for opt in self.demo_optimizers:
if opt.proposing:
raise ValueError("Optimizer is still proposing")
# check demo params
opt_params = []
for opt in self.demo_optimizers:
opt_params.extend(opt.params)
print(f"Opt params: {opt_params}")
for name, param in self.adaltask.named_parameters():
if param.param_type == ParameterType.DEMOS:
print(f"Demo param: {name}, value: {param.data}, param: {param}")
if param.data is None:
raise ValueError("Demo param data is None")
if len(param._traces) == 0:
raise ValueError(f"No traces found, param_id: {param.id}")
if len(param._previous_demos) > 0:
raise ValueError(
f"Previous demos should be empty, param: {param.id}"
)
if len(param._demos) == 0:
raise ValueError(f"No demos found, param: {param}")
return debug_path
def _fit_text_grads_one_step_for_debug(self, train_loader: Any) -> str:
print("Debugging fitting one step with batch size 2 for text optimizer")
self.prep_ckpt_file_path()
debug_path = os.path.join(self.ckpt_path, "debug_text_grads")
os.makedirs(debug_path, exist_ok=True)
print(f"save to {debug_path}")
train_loader.batch_size = 2
train_loader.shuffle = True
self.adaltask.train() # this will turn everything to train mode
correct_loss = None
failed_loss = None
print("Finding one successful and one failed loss")
for batch in train_loader:
y_preds = self.adaltask.train_step(batch, 0, self.num_workers)
losses = self.adaltask.loss_step(batch, y_preds, 0, self.num_workers)
for loss in losses:
if loss.data > 0.5:
correct_loss = loss
else:
failed_loss = loss
if correct_loss is not None and failed_loss is not None:
print("Found correct and failed loss")
break
total_loss = sum_ops([correct_loss, failed_loss])
total_loss.backward()
# test optimizer
self._propose_text_optimizers()
total_loss.draw_graph(filepath=debug_path)
return debug_path
def _set_demo_optimizers_dataset(self, train_dataset: Any):
# init the dataset
for opt in self.demo_optimizers:
opt.set_dataset(train_dataset)
def _demo_optimizers_propose(self):
for opt in self.demo_optimizers:
opt.propose()
def _demo_optimizers_add_scores(
self, ids: List[str], scores: List[float], is_teacher: bool = True
):
for opt in self.demo_optimizers:
opt.add_scores(ids, scores, is_teacher)
def _demo_optimizers_revert(self):
for opt in self.demo_optimizers:
opt.revert()
def _demo_optimizers_step(self):
for opt in self.demo_optimizers:
opt.step()
def _init_demo_optimizers(self):
# init the dataset
for opt in self.demo_optimizers:
opt.init_shots()
def _get_trainable_demo_params(self):
params = []
for opt in self.demo_optimizers:
params.extend([p for p in opt.params if p.requires_opt])
return params
def _zero_grad_text_optimizers(self):
for text_optimizer in self.text_optimizers:
text_optimizer.zero_grad()
def _propose_text_optimizers(self):
for text_optimizer in self.text_optimizers:
text_optimizer.propose()
def _get_trainable_text_params(self):
params = []
for opt in self.text_optimizers:
params.extend([p for p in opt.params if p.requires_opt])
return params
def _step_text_optimizers(self):
for text_optimizer in self.text_optimizers:
text_optimizer.step()
def _add_history_text_optimizers(self, val_score: float):
if not isinstance(val_score, float):
raise ValueError(
f"val_score should be a float, got {type(val_score)}, {val_score}"
)
for text_optimizer in self.text_optimizers:
text_optimizer.add_score_to_params(round(val_score, 4))
def _revert_text_optimizers(self):
for text_optimizer in self.text_optimizers:
text_optimizer.revert()
def _check_optimizer_proposal(self):
r"""Return True if all optimizers have proposed a new prompt"""
for text_optimizer in self.text_optimizers:
if not text_optimizer.proposing:
return False
return True
# TODO: mix training teacher should keep updated with the new prompt
def _fit_text_grad_demo_mix_constrained(
self,
train_loader: Any,
train_dataset: Any,
val_dataset: Any,
test_dataset: Any,
trainer_results: TrainerResult = None,
starting_step: int = 0,
):
from adalflow.optim.parameter import Parameter
log.info("Fitting using Textual Gradient Descent")
trainer_results = (
self._pre_fit(val_dataset, test_dataset)
if trainer_results is None
else trainer_results
)
print(f"save to {self.ckpt_file}")
if train_dataset is None:
raise ValueError("train_dataset is required")
self.adaltask.train()
self._zero_grad_text_optimizers()
self._set_demo_optimizers_dataset(train_dataset)
num_epochs = self._estimate_num_epochs(train_loader, self.max_steps)
total_steps = starting_step
teacher_losses_cache: Dict[str, Parameter] = {}
all_samples, all_losses, all_y_preds = [], [], []
for epoch in tqdm(range(num_epochs), desc="Epoch"):
for steps, batch in enumerate((pbar := tqdm(train_loader, position=0))):
total_steps += 1
if total_steps > self.max_steps + starting_step:
print("Reached max steps")
break
self._zero_grad_text_optimizers()
pbar.set_description(f"Training Step: {total_steps}")
self.adaltask.train() # this will turn everything to train mode
self.adaltask.trace() # NOTE: this needs to be turned on?
self.adaltask.use_teacher(False)
y_preds = self.adaltask.train_step(batch, steps, self.num_workers)
losses = self.adaltask.loss_step(
batch, y_preds, steps, self.num_workers
)
# moving batch
all_samples.extend(batch)
all_losses.extend(losses)
# extract the non-parameter y_preds
all_y_preds.extend(
[y.full_response for y in y_preds if isinstance(y, Parameter)]
)
# for loss in losses:
# loss.backward_engine_disabled = (
# True # temporary disable the backward engine
# )
# loss.backward()
# handle the demo
print(f"batch: {batch}")
self._demo_optimizers_add_scores(
[sample.id for sample in batch],
[float(loss.data) for loss in losses],
is_teacher=False,
)
# Trace the teacher run
self.adaltask.use_teacher(True)
self.adaltask.train()
self.adaltask.trace()
# filter by id
batch_for_teacher = []
for sample in batch:
if sample.id not in teacher_losses_cache:
batch_for_teacher.append(sample)
y_preds_teacher = self.adaltask.train_step(
batch_for_teacher, total_steps, self.num_workers
)
losses_teacher: List[Parameter] = self.adaltask.loss_step(
batch_for_teacher, y_preds_teacher, total_steps, self.num_workers
)
self._demo_optimizers_add_scores(
[sample.id for sample in batch_for_teacher],
[float(loss.data) for loss in losses_teacher],
is_teacher=True,
)
for idx, (sample, loss) in enumerate(
zip(batch_for_teacher, losses_teacher)
):
teacher_losses_cache[sample.id] = loss
all_samples, all_losses, all_y_preds = (
self._text_grad_constraint_propose_step(
steps=steps,
all_samples=all_samples,
all_losses=all_losses,
all_y_preds=all_y_preds,
include_demo_optimizers=True,
)
)
if not self._check_optimizer_proposal():
print(
"No proposal can improve the subset and full set, go to next step"
)
self._add_one_step_in_trainer_results(
trainer_results,
trainer_results.val_scores[-1],
trainer_results.test_scores[-1],
trainer_results.prompts[-1],
total_steps,
)
continue
# set the batch size to the size of the validation set
last_val_score = trainer_results.val_scores[-1]
val_output = self.adaltask.validation_step(
val_dataset,
total_steps,
self.num_workers,
minimum_score=last_val_score,
)
val_score = val_output.avg_score
self._add_history_text_optimizers(val_score)
if val_score > last_val_score:
print(f"Optimizer step: {val_score} > {last_val_score}")
# self.optimizer.step()
self._step_text_optimizers()
self._demo_optimizers_step()
# test the model
test_score = None
if test_dataset is not None:
test_output = self.adaltask.validation_step(
test_dataset, total_steps, self.num_workers
)
test_score = test_output.avg_score
new_prompts = self.adaltask._get_param_values()
self._add_one_step_in_trainer_results(
trainer_results,
val_score,
test_score,
new_prompts,
total_steps,
)
all_samples, all_losses, all_y_preds = [], [], []
else:
print(f"Optimizer revert: {val_score} <= {last_val_score}")
# self.optimizer.revert()
self._revert_text_optimizers()
self._demo_optimizers_revert()
# save the score, no change
self._add_one_step_in_trainer_results(
trainer_results,
last_val_score,
trainer_results.test_scores[-1],
trainer_results.prompts[-1],
total_steps,
attempted_val_score=val_score,
)
print(f"Saving checkpoint to {self.ckpt_file}")
save_json(trainer_results.to_dict(), self.ckpt_file)
save_json(trainer_results.to_dict(), self.ckpt_file) # checkpoint
def _fit_text_grad_demo_mix_random(
self,
train_loader: Any,
train_dataset: Any,
val_dataset: Any,
test_dataset: Any,
train_results: TrainerResult = None,
starting_step: int = 0,
):
log.info("Fitting using Textual Gradient Descent")
trainer_results = (
self._pre_fit(val_dataset, test_dataset)
if train_results is None
else train_results
)
print(f"save to {self.ckpt_file}")
if train_dataset is None:
raise ValueError("train_dataset is required")
self.adaltask.train()
self._zero_grad_text_optimizers()
self._set_demo_optimizers_dataset(train_dataset)
num_epochs = self._estimate_num_epochs(train_loader, self.max_steps)
total_steps = starting_step
teacher_losses_cache: Dict[str, Parameter] = {}
for epoch in tqdm(range(num_epochs), desc="Epoch"):
for steps, batch in enumerate((pbar := tqdm(train_loader, position=0))):
total_steps += 1
if total_steps > self.max_steps + starting_step:
print("Reached max steps")
break
self._zero_grad_text_optimizers()
pbar.set_description(f"Training Step: {total_steps}")
self.adaltask.train() # this will turn everything to train mode
self.adaltask.trace() # NOTE: this needs to be turned on?
self.adaltask.use_teacher(False)
y_preds = self.adaltask.train_step(batch, steps, self.num_workers)
losses = self.adaltask.loss_step(
batch, y_preds, steps, self.num_workers
)
total_loss = sum_ops(losses)
print("Loss backward...")
total_loss.backward()
# for loss in losses:
# loss.backward_engine_disabled = (
# True # temporary disable the backward engine
# )
# loss.backward()
# handle the demo
self._demo_optimizers_add_scores(
[sample.id for sample in batch],
[float(loss.data) for loss in losses],
is_teacher=False,
)
# Trace the teacher run
self.adaltask.use_teacher(True)
self.adaltask.train()
self.adaltask.trace()
# filter by id
batch_for_teacher = []
for sample in batch:
if sample.id not in teacher_losses_cache:
batch_for_teacher.append(sample)
y_preds_teacher = self.adaltask.train_step(
batch_for_teacher, total_steps, self.num_workers
)
losses_teacher: List[Parameter] = self.adaltask.loss_step(
batch_for_teacher, y_preds_teacher, total_steps, self.num_workers
)
self._demo_optimizers_add_scores(
[sample.id for sample in batch_for_teacher],
[float(loss.data) for loss in losses_teacher],
is_teacher=True,
)
# for loss in losses_teacher:
# loss.backward_engine_disabled = (
# True # temporary disable the backward engine
# )
# loss.backward()
# save the teacher predictions, if Generator is in cache mode, it will also avoid re-running the teacher
for idx, (sample, loss) in enumerate(
zip(batch_for_teacher, losses_teacher)
):
teacher_losses_cache[sample.id] = loss
print("Optimizer propose...")
self._propose_text_optimizers()
self._demo_optimizers_propose()
new_prompts = self.adaltask._get_param_values()
print("New prompts: ", new_prompts)
# set the batch size to the size of the validation set
last_val_score = trainer_results.val_scores[-1]
val_output = self.adaltask.validation_step(
val_dataset,
total_steps,
self.num_workers,
minimum_score=last_val_score,
)
val_score = val_output.avg_score
self._add_history_text_optimizers(val_score)
if val_score > last_val_score:
print(f"Optimizer step: {val_score} > {last_val_score}")
# self.optimizer.step()
self._step_text_optimizers()
self._demo_optimizers_step()
# test the model
test_output = self.adaltask.validation_step(
test_dataset, total_steps, self.num_workers
)
test_score = test_output.avg_score
self._add_one_step_in_trainer_results(
trainer_results,
val_score,
test_score,
new_prompts,
total_steps,
)
else:
print(f"Optimizer revert: {val_score} <= {last_val_score}")
# self.optimizer.revert()
self._revert_text_optimizers()
self._demo_optimizers_revert()
# save the score, no change
self._add_one_step_in_trainer_results(
trainer_results,
last_val_score,
trainer_results.test_scores[-1],
trainer_results.prompts[-1],
total_steps,
attempted_val_score=val_score,
)
print(f"Saving checkpoint to {self.ckpt_file}")
save_json(trainer_results.to_dict(), self.ckpt_file)
save_json(trainer_results.to_dict(), self.ckpt_file) # checkpoint
def _fit_demos_random(
self,
train_loader,
train_dataset: Any,
val_dataset: Any,
test_dataset: Any,
trainer_results: TrainerResult,
starting_step: int,
):
log.info("Fitting using Random Demo Optimizer")
# self.adaltask.train()
trainer_results = (
self._pre_fit(val_dataset, test_dataset)
if trainer_results is None
else trainer_results
)
print(f"save to {self.ckpt_file}")
print(f"Starting step: {starting_step}")
self.adaltask.train()
self.adaltask.trace()
self._set_demo_optimizers_dataset(train_dataset)
# total_steps = 0
train_loader.set_max_steps(self.max_steps)
teacher_losses_cache: Dict[str, Parameter] = {}
pbar = tqdm(
zip(range(self.max_steps), train_loader), total=self.max_steps, desc="Step"
)
for step, batch in pbar:
step = step + starting_step + 1
print(f"Training Step: {step}")
pbar.set_description(f"Training Step: {step}")
# Trace the run in the demos
self.adaltask.train()
self.adaltask.trace()
self.adaltask.use_teacher(False)
y_preds = self.adaltask.train_step(batch, step, self.num_workers)
losses: List[Parameter] = self.adaltask.loss_step(
batch, y_preds, step, self.num_workers
)
self._demo_optimizers_add_scores(
[sample.id for sample in batch],
[float(loss.data) for loss in losses],
is_teacher=False,
)
for loss in losses:
loss.backward_engine_disabled = (
True # temporary disable the backward engine
)
loss.backward() # TODO: ensure no gradients in the backward, disable backward engine
# Trace the teacher run
self.adaltask.use_teacher(True)
self.adaltask.train()
self.adaltask.trace()
# filter by id
batch_for_teacher = []
for sample in batch:
if sample.id not in teacher_losses_cache:
batch_for_teacher.append(sample)
y_preds_teacher = self.adaltask.train_step(
batch_for_teacher, step, self.num_workers
)
losses_teacher: List[Parameter] = self.adaltask.loss_step(
batch_for_teacher, y_preds_teacher, step, self.num_workers
)
self._demo_optimizers_add_scores(
[sample.id for sample in batch_for_teacher],
[float(loss.data) for loss in losses_teacher],
is_teacher=True,
)
for loss in losses_teacher:
loss.backward_engine_disabled = (
True # temporary disable the backward engine
)
loss.backward()
# save the teacher predictions, if Generator is in cache mode, it will also avoid re-running the teacher
for idx, (sample, loss) in enumerate(
zip(batch_for_teacher, losses_teacher)
):
teacher_losses_cache[sample.id] = loss
# propose
self._demo_optimizers_propose()
new_prompts = self.adaltask._get_param_values()
print(f"New prompts: {new_prompts}")
# validate
if self.adaltask.validate_condition(step, total_steps=self.max_steps):
last_val_score = trainer_results.val_scores[-1]
val_output = self.adaltask.validation_step(
val_dataset,
step,
self.num_workers,
minimum_score=last_val_score,
)
val_score = val_output.avg_score
if val_score > last_val_score:
print(
f"Pass validation: {val_score} > {trainer_results.val_scores[-1]}"
)
self._demo_optimizers_step()
for opt in self.demo_optimizers:
if opt.proposing:
raise ValueError("Optimizer is still proposing")
# test the new prompts
test_score = None
if test_dataset is not None:
test_output = self.adaltask.validation_step(
test_dataset, step, self.num_workers
)
test_score = test_output.avg_score
self._add_one_step_in_trainer_results(
trainer_results,
val_score,
test_score=test_score,
prompts=new_prompts,
step=step,
attempted_val_score=val_score,
)
else:
print(f"Fail validation: {val_score} <= {last_val_score}, revert")
self._demo_optimizers_revert()
# ensure all demo optimizer are not proposing
for opt in self.demo_optimizers:
if opt.proposing:
raise ValueError("Optimizer is still proposing")
self._add_one_step_in_trainer_results(
trainer_results,
last_val_score,
test_score=trainer_results.test_scores[-1],
prompts=trainer_results.prompts[-1],
step=step,
attempted_val_score=val_score,
)
save_json(trainer_results.to_dict(), self.ckpt_file)
pbar.update(1)
self._compute_validate_stats(trainer_results)
save_json(trainer_results.to_dict(), self.ckpt_file)
if self.save_traces:
for i, demo_opt in enumerate(self.demo_optimizers):
for param in demo_opt.params:
teacher_traces = param._traces
student_traces = param._student_traces
trace_file = os.path.join(
self.ckpt_path,
f"opt_{i}_param_{param.name}_teacher_traces.json",
)
save_json(teacher_traces, trace_file)
trace_file = os.path.join(
self.ckpt_path,
f"opt_{i}_param_{param.name}_student_traces.json",
)
save_json(student_traces, trace_file)
# save demos
demo_file = os.path.join(
self.ckpt_path, f"opt_{i}_param_{param.name}_demos.json"
)
save_json(param._demos, demo_file)
print(f"Saved ckpt to {self.ckpt_file}")
return trainer_results
@staticmethod
def _compute_validate_stats(trainer_results: TrainerResult):
attempted_val_scores = [
(
step_result.attempted_val_score
if step_result.attempted_val_score is not None
else step_result.val_score
)
for step_result in trainer_results.step_results
]
array = np.array(attempted_val_scores)
mean = round(float(np.mean(array)), 4)
std = round(float(np.std(array)), 4)
max_score = round(float(np.max(array)), 4)
min_score = round(float(np.min(array)), 4)
trainer_results.validate_stats = TrainerValidateStats(
max_score=max_score,
min_score=min_score,
mean_of_score=mean,
std_of_score=std,
)
def _fit_text_grad_random(
self,
train_loader: Any,
val_dataset: Any,
test_dataset: Any,
trainer_results: TrainerResult = None,
starting_step: int = 0,
) -> TrainerResult:
log.info("Fitting using Textual Gradient Descent")
trainer_results = (
self._pre_fit(val_dataset, test_dataset)
if trainer_results is None
else trainer_results
)
print(f"save to {self.ckpt_file}")
self.adaltask.train()
# self.optimizer.zero_grad()
self._zero_grad_text_optimizers()
num_epochs = self._estimate_num_epochs(train_loader, self.max_steps)
total_steps = starting_step
for epoch in tqdm(range(num_epochs), desc="Epoch"):
for steps, batch in enumerate((pbar := tqdm(train_loader, position=0))):
total_steps += 1
if total_steps > self.max_steps + starting_step:
print("Reached max steps")
break
self._zero_grad_text_optimizers()
pbar.set_description(f"Training Step: {total_steps}")
self.adaltask.train() # this will turn everything to train mode
self.train()
y_preds = self.adaltask.train_step(batch, steps, self.num_workers)
losses = self.adaltask.loss_step(
batch, y_preds, steps, self.num_workers
)
total_loss = sum_ops(losses)
print("Loss backward...")
total_loss.backward()
print("Optimizer propose...")
self._propose_text_optimizers()
new_prompts = self.adaltask._get_param_values()
print("New prompts: ", new_prompts)
# set the batch size to the size of the validation set
last_val_score = trainer_results.val_scores[-1]
val_output = self.adaltask.validation_step(
val_dataset,
total_steps,
self.num_workers,
minimum_score=last_val_score,
)
val_score = val_output.avg_score
self._add_history_text_optimizers(val_score)
if val_score > last_val_score:
print(f"Optimizer step: {val_score} > {last_val_score}")
# self.optimizer.step()
self._step_text_optimizers()
# test the model
test_output = self.adaltask.validation_step(
test_dataset, total_steps, self.num_workers
)
test_score = test_output.avg_score
self._add_one_step_in_trainer_results(
trainer_results,
val_score,
test_score,
new_prompts,
total_steps,
)
else:
print(f"Optimizer revert: {val_score} <= {last_val_score}")
# self.optimizer.revert()
self._revert_text_optimizers()
# save the score, no change
self._add_one_step_in_trainer_results(
trainer_results,
last_val_score,
trainer_results.test_scores[-1],
trainer_results.prompts[-1],
total_steps,
attempted_val_score=val_score,
)
print(f"Saving checkpoint to {self.ckpt_file}")
save_json(trainer_results.to_dict(), self.ckpt_file)
save_json(trainer_results.to_dict(), self.ckpt_file) # checkpoint
return trainer_results
@staticmethod
def _add_one_step_in_trainer_results(
trainer_results: TrainerResult,
val_score: float,
test_score: float,
prompts: List[PromptData], # target prompts
step: int,
attempted_val_score: Optional[float] = None,
):
step_results = TrainerStepResult(
step=step,
val_score=val_score,
test_score=test_score,
prompt=prompts,
attempted_val_score=attempted_val_score,
)
trainer_results.step_results.append(step_results)
trainer_results.val_scores.append(val_score)
trainer_results.test_scores.append(test_score)
trainer_results.prompts.append(prompts)
trainer_results.steps.append(step)
def _downsample_move_batch(
self, all_samples, all_losses: List["Parameter"], all_y_preds, acc_score_list
):
"""Downsample the moving batch to a more balanced error and correct samples"""
from adalflow.optim.parameter import Parameter
if not all([score >= 0 and score <= 1 for score in acc_score_list]):
raise ValueError(
"acc_score_list should only contain values between 0 and 1"
)
for loss in all_losses:
if not isinstance(loss, Parameter):
raise ValueError("Loss should be a Parameter object")
max_moving_batch_size = 20
correct_indices = [i for i, score in enumerate(acc_score_list) if score > 0.5]
error_indices = [i for i, score in enumerate(acc_score_list) if score <= 0.5]
if (
len(error_indices) + len(correct_indices)
<= max_moving_batch_size
# and len(correct_indices) <= max_moving_batch_size
):
return all_samples, all_losses, all_y_preds, acc_score_list
# downsample from all samples
new_sample_indices = random.sample(
range(len(all_samples)), min(max_moving_batch_size, len(all_samples))
)
all_samples = [all_samples[i] for i in new_sample_indices]
all_losses = [all_losses[i] for i in new_sample_indices]
all_y_preds = [all_y_preds[i] for i in new_sample_indices]
acc_score_list = [acc_score_list[i] for i in new_sample_indices]
return all_samples, all_losses, all_y_preds, acc_score_list
def _moving_batch_sample(
self, acc_score_list: List[float]
) -> Tuple[float, List[int]]:
"""Sample from both correct and error samples according to max_error_samples and max_correct_samples"""
# ensure only 0 and 1 in the acc_score_list
import numpy as np
if not all([score in [0, 1] for score in acc_score_list]):
raise ValueError("acc_score_list should only contain 0 and 1")
correct_indices = [i for i, score in enumerate(acc_score_list) if score == 1]
error_indices = [i for i, score in enumerate(acc_score_list) if score == 0]
print(f"Moving batch correct size: {len(correct_indices)}")
print(f"Moving batch error size: {len(error_indices)}")
if len(error_indices) == 0:
raise ValueError("No error samples found")
sampled_error_indices = random.sample(
error_indices, min(self.max_error_samples, len(error_indices))
)
num_errors = len(sampled_error_indices)
# max allowed correct samples min(0.8 * num_errors, len(correct_indices), self.max_correct_samples)
max_num_correct_samples = int(2 * num_errors)
sampled_correct_indices = random.sample(
correct_indices,
min(
self.max_correct_samples,
max_num_correct_samples,
len(correct_indices),
),
)
print(f"Subset Error size: {len(sampled_error_indices)}")
print(f"Subset Correct size: {len(sampled_correct_indices)}")
subset = sampled_error_indices + sampled_correct_indices
# subset_samples = samples[sampled_error_indices + sampled_correct_indices]
subset_score = np.mean(np.array(acc_score_list)[subset])
print(f"Subset score: {subset_score}")
return subset_score, subset
def _track_effectiveness(
self, stage: Literal["subset", "fullset", "valset"], pass_: bool
):
if stage == "subset":
if pass_:
self._subset_effect_count["pass"] += 1
else:
self._subset_effect_count["fail"] += 1
elif stage == "fullset":
if pass_:
self._fullset_effect_count["pass"] += 1
else:
self._fullset_effect_count["fail"] += 1
elif stage == "valset":
if pass_:
self._valset_effect_count["pass"] += 1
else:
self._valset_effect_count["fail"] += 1
def _text_grad_constraint_propose_step(
self,
steps: int,
all_samples,
all_losses: List["Parameter"],
all_y_preds,
include_demo_optimizers: bool = False,
):
# comptute moving batch acc
from adalflow.optim.parameter import Parameter
for loss in all_losses:
if not isinstance(loss, Parameter):
raise ValueError("Loss should be a Parameter object")
self.adaltask.eval()
move_batch_eval = self.adaltask.evaluate_samples(all_samples, all_y_preds)
move_batch_score = move_batch_eval.avg_score
move_batch_acc_score_list = move_batch_eval.per_item_scores
if move_batch_score >= self.batch_val_score_threshold:
print(f"Skipping batch {steps} as acc: {move_batch_score}")
# reset the moving batch
all_samples, all_losses, all_y_preds = [], [], []
return all_samples, all_losses, all_y_preds
# downsample the moving batch
all_samples, all_losses, all_y_preds, move_batch_acc_score_list = (
self._downsample_move_batch(
all_samples, all_losses, all_y_preds, move_batch_acc_score_list
)
)
move_batch_score = np.mean(np.array(move_batch_acc_score_list))
print(f"Moving batch acc: {move_batch_score}")
# create a subset with a more balanced error and correct samples
subset_score, subset_indices = self._moving_batch_sample(
move_batch_acc_score_list
)
print(f"Subset batch acc: {subset_score}")
# compute the subset loss
subset_losses = [all_losses[i] for i in subset_indices]
subset_loss = sum_ops(subset_losses)
print("Subset loss backward...")
start_time = time.time()
subset_loss.backward()
print(f"Subset loss backward time: {time.time() - start_time}") # 12seconds
print("Optimizer propose...")
# mark the subset loss to be backpropagated
# TODO: make this a step
tdqm_loader = tqdm(range(self.max_proposals_per_step), desc="Proposing")
for i in tdqm_loader:
# print(f"Proposing step: {i}")
# self.optimizer.propose()
self._propose_text_optimizers() # new prompts
if include_demo_optimizers:
self._demo_optimizers_propose()
new_prompts = self.adaltask._get_param_values()
print("New prompts: ", new_prompts)
# valide the subset
subset_samples = [all_samples[i] for i in subset_indices]
# validate the subset
val_output = self.adaltask.validation_step(
subset_samples, steps, self.num_workers
)
# check subset validation score
val_score = val_output.avg_score
if val_score > subset_score:
print(f"Pass subset check: {val_score} > {subset_score}")
self._track_effectiveness("subset", True)
else:
print(
f"Fail subset check, try next proposal: {val_score} <= {subset_score}"
)
self._track_effectiveness("subset", False)
self._revert_text_optimizers()
if include_demo_optimizers:
self._demo_optimizers_revert()
continue
# validate the full set
move_batch_result = self.adaltask.validation_step(
all_samples, steps, self.num_workers
)
new_move_batch_score = move_batch_result.avg_score
if new_move_batch_score >= move_batch_score:
print(f"Pass full check: {new_move_batch_score} >= {move_batch_score}")
self._track_effectiveness("fullset", True)
break
else:
print(
f"Fail full check, try next proposal: {new_move_batch_score} < {move_batch_score}"
)
self._track_effectiveness("fullset", False)
self._revert_text_optimizers()
if include_demo_optimizers:
self._demo_optimizers_revert()
continue
print("Done with proposals")
self.adaltask.train()
return all_samples, all_losses, all_y_preds
# def _fit_bootstrap_few_shot_random(
# self,
# train_loader: Any,
# val_dataset: Any,
# test_dataset: Any,
# optimizers: List[DemoOptimizer],
# ):
# log.info("Fitting using Bootstrap Few Shot only")
# trainer_results = self._pre_fit(val_dataset, test_dataset)
# print(f"save to {self.ckpt_file}")
# self.adaltask.train() #
# num_epochs = self._estimate_num_epochs(train_loader, self.max_steps)
# total_steps = 0
# for optimizer in optimizers:
# optimizer.init()
# for epoch in tqdm(range(num_epochs), desc="Epoch"):
# for steps, batch in enumerate((pbar := tqdm(train_loader, position=0))):
# total_steps += 1
# if total_steps > self.max_steps:
# print("Reached max steps")
# break
# pbar.set_description(f"Training Step: {total_steps}")
# self.adaltask.train()
def _fit_text_grad_constraint(
self,
train_loader: Any,
val_dataset: Any,
test_dataset: Any,
trainer_results: TrainerResult = None,
starting_step: int = 0,
) -> TrainerResult:
from adalflow.optim.parameter import Parameter
log.info("Fitting using Textual Gradient Descent with constraints")
trainer_results = (
self._pre_fit(val_dataset, test_dataset)
if trainer_results is None
else trainer_results
)
print(f"save to {self.ckpt_file}")
self.adaltask.train()
self._zero_grad_text_optimizers()
num_epochs = self._estimate_num_epochs(train_loader, self.max_steps)
total_steps = starting_step
all_samples, all_losses, all_y_preds = [], [], []
for epoch in tqdm(range(num_epochs), desc="Epoch"):
for steps, batch in enumerate((pbar := tqdm(train_loader, position=0))):
total_steps += 1
if total_steps > self.max_steps + starting_step:
print("Reached max steps")
break
self._zero_grad_text_optimizers()
pbar.set_description(f"Training Step: {total_steps}")
self.adaltask.train() # this will turn everything to train mode
y_preds = self.adaltask.train_step(batch, steps, self.num_workers)
losses = self.adaltask.loss_step(
batch, y_preds, steps, self.num_workers
)
# moving batch
all_samples.extend(batch)
all_losses.extend(losses)
all_y_preds.extend(
[y.full_response for y in y_preds if isinstance(y, Parameter)]
)
all_samples, all_losses, all_y_preds = (
self._text_grad_constraint_propose_step(
steps=steps,
all_samples=all_samples,
all_losses=all_losses,
all_y_preds=all_y_preds,
)
)
# check optimizer stages to see if the proposal was accepted so far
if not self._check_optimizer_proposal():
print(
"No proposal can improve the subset and full set, go to next step"
)
self._add_one_step_in_trainer_results(
trainer_results,
trainer_results.val_scores[-1],
trainer_results.test_scores[-1],
trainer_results.prompts[-1],
total_steps,
)
continue
# prune the correct sample size if its too big, same with error samples
# run the tests as any other optimizer
if self.adaltask.validate_condition(steps, total_steps):
# set the batch size to the size of the validation set
last_val_score = trainer_results.val_scores[-1]
val_output = self.adaltask.validation_step(
val_dataset,
total_steps,
self.num_workers,
minimum_score=last_val_score,
)
val_score = val_output.avg_score
self._add_history_text_optimizers(val_score)
if val_score > last_val_score:
print(f"Optimizer step: {val_score} > {last_val_score}")
# self.optimizer.step()
self._step_text_optimizers()
# save the score
step_result = {
"val_score": val_score,
}
self._track_effectiveness("valset", True)
# test the model
if test_dataset is not None:
test_output = self.adaltask.validation_step(
test_dataset,
steps,
self.num_workers,
)
step_result["test_score"] = test_output.avg_score
else:
step_result["test_score"] = None
step_result["prompts"] = self.adaltask._get_param_values()
step_result["step"] = total_steps
self._add_one_step_in_trainer_results(
trainer_results,
**step_result,
)
all_samples, all_losses, all_y_preds = [], [], []
else:
print(f"Optimizer revert: {val_score} <= {last_val_score}")
self._revert_text_optimizers()
self._track_effectiveness("valset", False)
self._add_one_step_in_trainer_results(
trainer_results,
trainer_results.val_scores[-1],
trainer_results.test_scores[-1],
trainer_results.prompts[-1],
total_steps,
attempted_val_score=val_score,
)
trainer_results.effective_measure = self._effective_measure
save_json(trainer_results.to_dict(), self.ckpt_file)
save_json(trainer_results.to_dict(), self.ckpt_file)
return trainer_results