Source code for tracing.generator_state_logger
from typing import Dict, Any, List, Optional, TYPE_CHECKING
import os
import logging
from dataclasses import dataclass, field
from datetime import datetime
import json
if TYPE_CHECKING:
from adalflow.core.generator import Generator
from adalflow.core.base_data_class import DataClass
from adalflow.utils import serialize
log = logging.getLogger(__name__)
[docs]
@dataclass
class GeneratorStatesRecord(DataClass):
prompt_states: Dict[str, Any] = field(default_factory=dict)
time_stamp: str = field(default_factory=str)
def __eq__(self, other: Any):
if not isinstance(other, GeneratorStatesRecord):
return NotImplemented
return serialize(self.prompt_states) == serialize(other.prompt_states)
[docs]
class GeneratorStateLogger:
__doc__ = r"""Log the generator states especially the prompt states update history to a file.
Each generator should has its unique and identifiable name to be logged.
One file can log multiple generators' states.
We use _trace_map to store the states and track any changes and updates and save it to a file.
Args:
save_dir(str, optional): The directory to save the trace file. Default is "./traces/"
project_name(str, optional): The project name. Default is None.
filename(str, optional): The file path to save the trace. Default is "generator_state_trace.json"
"""
_generator_names: set = set()
# TODO: create a logger base class to avoid code duplication
def __init__(
self,
save_dir: Optional[str] = None,
project_name: Optional[str] = None,
filename: Optional[str] = None,
):
self.filepath = save_dir or "./traces/"
self.project_name = project_name
if project_name:
self.filepath = os.path.join(self.filepath, project_name)
# TODO: make this a generator state instead of just the prompt as right now
os.makedirs(self.filepath, exist_ok=True)
self.filename = filename or "generator_state_trace.json"
self.filepath = os.path.join(self.filepath, self.filename)
self._trace_map: Dict[str, List[GeneratorStatesRecord]] = (
{} # generator_name: [prompt_states]
)
# load previous records if the file exists
if os.path.exists(self.filepath):
self.load(self.filepath)
[docs]
def get_log_location(self) -> str:
return self.filepath
@property
def generator_names(self):
return self._generator_names
[docs]
def log_prompt(self, generator: "Generator", name: str):
r"""Log the prompt states of the generator with the given name."""
self._generator_names.add(name)
prompt_states: Dict = (
generator.prompt.to_dict()
) # TODO: log all states of the generator instead of just the prompt
try:
if name not in self._trace_map:
self._trace_map[name] = [
GeneratorStatesRecord(
prompt_states=prompt_states,
time_stamp=datetime.now().isoformat(),
)
]
self.save(self.filepath)
else:
# compare the last record with the new record
last_record = self._trace_map[name][-1]
new_prompt_record = GeneratorStatesRecord(
prompt_states=prompt_states, time_stamp=datetime.now().isoformat()
)
if last_record != new_prompt_record:
self._trace_map[name].append(new_prompt_record)
self.save(self.filepath)
except Exception as e:
raise Exception(f"Error logging prompt states for {name}") from e
[docs]
def save(self, filepath: str):
with open(filepath, "w") as f:
serialized_obj = serialize(self._trace_map)
f.write(serialized_obj)
[docs]
def load(self, filepath: str):
if os.stat(filepath).st_size == 0:
logging.info(f"File {filepath} is empty.")
return
with open(filepath, "r") as f:
content = f.read().strip()
if not content:
logging.info(f"File {filepath} is empty after stripping.")
return
self._trace_map = json.loads(content)
# convert each dict record to PromptRecord
for name, records in self._trace_map.items():
self._trace_map[name] = [
GeneratorStatesRecord.from_dict(record) for record in records
]