"""Functional data classes to support functional components like Generator, Retriever, and Assistant."""
from enum import Enum, auto
from typing import (
List,
Dict,
Any,
Optional,
Union,
Generic,
TypeVar,
Sequence,
Literal,
Callable,
Awaitable,
Generator,
AsyncGenerator,
AsyncIterator,
Coroutine,
Iterable,
Tuple,
)
from typing_extensions import TypeAlias
from collections import OrderedDict
from dataclasses import (
dataclass,
field,
InitVar,
)
from uuid import UUID
from datetime import datetime
import uuid
import logging
import json
from collections.abc import AsyncIterable
from adalflow.core.base_data_class import DataClass, required_field
from adalflow.core.tokenizer import Tokenizer
from adalflow.core.functional import (
is_normalized,
generate_function_call_expression_from_callable,
)
# Import OpenAI's ResponseStreamEvent for type alias
logger = logging.getLogger(__name__)
T_co = TypeVar("T_co", covariant=True)
T = TypeVar("T") # invariant type
#######################################################################################
# Data modeling for ModelClient
######################################################################################
[docs]
class ModelType(Enum):
__doc__ = r"""The type of the model, including Embedder, LLM, Reranker.
It helps ModelClient identify the model type required to correctly call the model."""
EMBEDDER = auto()
LLM = auto()
LLM_REASONING = auto() # use reasoning model compatible to openai.responses
RERANKER = auto() # ranking model
IMAGE_GENERATION = auto() # image generation models like DALL-E
UNDEFINED = auto()
[docs]
class ModelClientType:
__doc__ = r"""A quick way to access all model clients in the ModelClient module.
From this:
.. code-block:: python
from adalflow.components.model_client import CohereAPIClient, TransformersClient, AnthropicAPIClient, GroqAPIClient, OpenAIClient
model_client = OpenAIClient()
To this:
.. code-block:: python
from adalflow.core.types import ModelClientType
model_client = ModelClientType.OPENAI
"""
_clients_cache = {}
def __class_getattr__(cls, name):
"""Dynamically import and return model clients on attribute access."""
if name in cls._clients_cache:
return cls._clients_cache[name]
client_mapping = {
'COHERE': ('adalflow.components.model_client', 'CohereAPIClient'),
'TRANSFORMERS': ('adalflow.components.model_client', 'TransformersClient'),
'ANTHROPIC': ('adalflow.components.model_client', 'AnthropicAPIClient'),
'GROQ': ('adalflow.components.model_client', 'GroqAPIClient'),
'OPENAI': ('adalflow.components.model_client', 'OpenAIClient'),
'GOOGLE_GENAI': ('adalflow.components.model_client', 'GoogleGenAIClient'),
'OLLAMA': ('adalflow.components.model_client', 'OllamaClient'),
}
if name in client_mapping:
module_name, class_name = client_mapping[name]
import importlib
module = importlib.import_module(module_name)
client_class = getattr(module, class_name)
cls._clients_cache[name] = client_class
return client_class
raise AttributeError(f"'{cls.__name__}' object has no attribute '{name}'")
# TODO: define standard required outputs
[docs]
def get_model_args(model_type: ModelType) -> List[str]:
r"""Get the required keys in model_kwargs for a specific model type.
note:
If your model inference sdk uses different keys, you need to convert them to the standard keys here in their specifc ModelClient.
Args:
model_type (ModelType): The model type
Returns:
List[str]: The required keys in model_kwargs
"""
if model_type == ModelType.EMBEDDER:
return ["model"]
elif model_type == ModelType.LLM:
return ["model"]
elif model_type == ModelType.RERANKER:
return ["model", "top_k", "documents", "query"]
else:
return []
#######################################################################################
# Data modeling for Embedder component
######################################################################################
[docs]
@dataclass
class Embedding:
"""
Container for a single embedding.
In sync with api spec, same as openai/types/embedding.py
"""
embedding: List[float]
index: Optional[int] # match with the index of the input, in case some are missing
[docs]
@dataclass
class Usage:
"""
In sync with OpenAI embedding api spec, same as openai/types/create_embedding_response.py
"""
prompt_tokens: int
total_tokens: int
[docs]
@dataclass
class EmbedderOutput(DataClass):
__doc__ = r"""Container to hold the response from an Embedder datacomponent for a single batch of input.
Data standard for Embedder model output to interact with other components.
Batch processing is often available, thus we need a list of Embedding objects.
"""
data: List[Embedding] = field(
default_factory=list, metadata={"desc": "List of embeddings"}
)
model: Optional[str] = field(default=None, metadata={"desc": "Model name"})
usage: Optional[Usage] = field(default=None, metadata={"desc": "Usage tracking"})
error: Optional[str] = field(default=None, metadata={"desc": "Error message"})
raw_response: Optional[Any] = field(
default=None, metadata={"desc": "Raw response"}
) # only used if error
input: Optional[List[str]] = field(default=None, metadata={"desc": "Input text"})
@property
def length(self) -> int:
return len(self.data) if self.data and isinstance(self.data, Sequence) else 0
@property
def embedding_dim(self) -> int:
r"""The dimension of the embedding, assuming all embeddings have the same dimension.
Returns:
int: The dimension of the embedding, -1 if no embedding is available
"""
return (
len(self.data[0].embedding) if self.data and self.data[0].embedding else -1
)
@property
def is_normalized(self) -> bool:
r"""Check if the embeddings are normalized to unit vectors.
Returns:
bool: True if the embeddings are normalized, False otherwise
"""
return (
is_normalized(self.data[0].embedding)
if self.data and self.data[0].embedding
else False
)
EmbedderInputType = Union[str, Sequence[str]]
EmbedderOutputType = EmbedderOutput
BatchEmbedderInputType = EmbedderInputType
BatchEmbedderOutputType = List[EmbedderOutputType]
#######################################################################################
# Data modeling for Generator component
######################################################################################
[docs]
@dataclass
class TokenLogProb:
r"""similar to openai.ChatCompletionTokenLogprob"""
token: str
logprob: float
[docs]
@dataclass
class CompletionUsage:
__doc__ = r"In sync with OpenAI completion usage api spec at openai/types/completion_usage.py"
completion_tokens: Optional[int] = field(
metadata={"desc": "Number of tokens in the generated completion"}, default=None
)
prompt_tokens: Optional[int] = field(
metadata={"desc": "Number of tokens in the prompt"}, default=None
)
total_tokens: Optional[int] = field(
metadata={
"desc": "Total number of tokens used in the request (prompt + completion)"
},
default=None,
)
[docs]
@dataclass
class OutputTokensDetails:
__doc__ = r"Details about output tokens used in a response"
reasoning_tokens: Optional[int] = field(
metadata={"desc": "Number of tokens used for reasoning"}, default=0
)
[docs]
@dataclass
class ResponseUsage:
__doc__ = r"Usage information for a response, including token counts, in sync with OpenAI response usage api spec at openai/types/response_usage.py"
input_tokens: int = field(metadata={"desc": "Number of input tokens used"})
output_tokens: int = field(metadata={"desc": "Number of output tokens used"})
total_tokens: int = field(metadata={"desc": "Total number of tokens used"})
input_tokens_details: InputTokensDetails = field(
metadata={"desc": "Details about input tokens"},
default_factory=InputTokensDetails,
)
output_tokens_details: OutputTokensDetails = field(
metadata={"desc": "Details about output tokens"},
default_factory=OutputTokensDetails,
)
#######################################################################################
# Data modeling for Retriever component
######################################################################################
RetrieverQueryType = TypeVar("RetrieverQueryType", contravariant=True)
RetrieverStrQueryType = str
RetrieverQueriesType = Union[RetrieverQueryType, Sequence[RetrieverQueryType]]
RetrieverStrQueriesType = Union[str, Sequence[RetrieverStrQueryType]]
RetrieverDocumentType = TypeVar("RetrieverDocumentType", contravariant=True)
RetrieverStrDocumentType = str # for text retrieval
RetrieverDocumentsType = Sequence[RetrieverDocumentType]
[docs]
@dataclass
class RetrieverOutput(DataClass):
__doc__ = r"""Save the output of a single query in retrievers.
It is up to the subclass of Retriever to specify the type of query and document.
"""
id: str = field(default=None, metadata={"desc": "The unique id of the output"})
doc_indices: List[int] = field(
default=required_field, metadata={"desc": "List of document indices"}
)
doc_scores: List[float] = field(
default=None, metadata={"desc": "List of document scores"}
)
query: RetrieverQueryType = field(
default=None, metadata={"desc": "The query used to retrieve the documents"}
)
documents: List[RetrieverDocumentType] = field(
default=None, metadata={"desc": "List of retrieved documents"}
)
RetrieverOutputType = Union[
List[RetrieverOutput], RetrieverOutput
] # so to support multiple queries at once
#######################################################################################
# Data modeling for function calls
######################################################################################
AsyncCallable = Callable[..., Awaitable[Any]]
[docs]
@dataclass
class FunctionDefinition(DataClass):
__doc__ = r"""The data modeling of a function definition, including the name, description, and parameters."""
# class_instance: Optional[Any] = field(
# default=None,
# metadata={"desc": "The instance of the class this function belongs to"},
# )
# NOTE: for class method: cls_name + "_" + name
func_name: str = field(
metadata={"desc": "The name of the tool"}, default=required_field
)
func_desc: Optional[str] = field(
default=None, metadata={"desc": "The description of the tool"}
)
func_parameters: Dict[str, object] = field(
default_factory=dict, metadata={"desc": "The schema of the parameters"}
)
[docs]
def fn_schema_str(self, type: Literal["json", "yaml"] = "json") -> str:
r"""Get the function definition str to be used in the prompt.
You should also directly use :meth:`to_json` and :meth:`to_yaml` to get the schema in JSON or YAML format.
"""
if type == "json":
return self.to_json()
elif type == "yaml":
return self.to_yaml()
else:
raise ValueError(f"Unsupported type: {type}")
[docs]
@dataclass
class Function(DataClass):
__doc__ = r"""The data modeling of a function call, including the name and keyword arguments.
You can use the exclude in :meth:`to_json` and :meth:`to_yaml` to exclude the `thought` field if you do not want to use chain-of-thought pattern.
Example:
.. code-block:: python
# assume the function is added in a context_map
# context_map = {"add": add}
def add(a, b):
return a + b
# call function add with arguments 1 and 2
fun = Function(name="add", kwargs={"a": 1, "b": 2})
# evaluate the function
result = context_map[fun.name](**fun.kwargs)
# or call with positional arguments
fun = Function(name="add", args=[1, 2])
result = context_map[fun.name](*fun.args)
"""
id: Optional[str] = field(
default=None, metadata={"desc": "The id of the function call"}
)
thought: Optional[str] = field(
default=None, metadata={"desc": "Your reasoning for this step. Be short for simple queries. For complex queries, provide a clear chain of thought."}
) # if the model itself is a thinking model, disable thought field
name: str = field(default="", metadata={"desc": "The name of the function"})
args: Optional[List[object]] = field(
default_factory=list,
metadata={"desc": "The positional arguments of the function"},
)
kwargs: Optional[Dict[str, object]] = field(
default_factory=dict,
metadata={"desc": "The keyword arguments of the function"},
)
_is_answer_final: Optional[bool] = field(
default=None,
metadata={"desc": "Whether this current output is the final answer"},
)
_answer: Optional[Any] = field(
default=None,
metadata={"desc": "The final answer if this is the final output."},
)
[docs]
@classmethod
def from_function(
cls,
func: Union[Callable[..., Any], AsyncCallable],
thought: Optional[str] = None,
*args,
**kwargs,
) -> "Function":
r"""Create a Function object from a function.
Args:
fun (Union[Callable[..., Any], AsyncCallable]): The function to be converted
Returns:
Function: The Function object
Usage:
1. Create a Function object from a function call:
2. use :meth:`to_json` and :meth:`to_yaml` to get the schema in JSON or YAML format.
3. This will be used as an example in prompt showing LLM how to call the function.
Example:
.. code-block:: python
from adalflow.core.types import Function
def add(a, b):
return a + b
# create a function call object with positional arguments
fun = Function.from_function(add, thought="Add two numbers", 1, 2)
print(fun)
# output
# Function(thought='Add two numbers', name='add', args=[1, 2])
"""
return cls(
thought=thought,
name=func.__name__,
args=args,
kwargs=kwargs,
)
__output_fields__ = ["thought", "name", "kwargs", "_is_answer_final", "_answer"]
_action_desc = """FuncName(<kwargs>) \
Valid function call expression. \
Example: "FuncName(a=1, b=2)" \
Follow the data type specified in the function parameters.\
e.g. for Type object with x,y properties, use "ObjectType(x=1, y=2)"""
[docs]
@dataclass
class QueueSentinel:
"""Special sentinel object to mark the end of a stream when using asyncio.Queue for stream processing."""
type: Literal["queue_sentinel"] = "queue_sentinel"
"""Type discriminator for the sentinel."""
[docs]
@dataclass
class RawResponsesStreamEvent(DataClass):
"""Streaming event for storing the raw responses from the LLM. These are 'raw' events, i.e. they are directly passed through
from the LLM.
"""
input: Optional[Any] = None
"""The input to the LLM."""
data: Union[Any, None] = None
"""The raw responses streaming event from the LLM."""
type: Literal["raw_response_event"] = "raw_response_event"
"""The type of the event."""
error: Optional[str] = None
"""The error message if any."""
[docs]
@dataclass
class GeneratorOutput(DataClass, Generic[T_co]):
__doc__ = r"""
The output data class for the Generator component.
We can not control its output 100%, so we use this to track the error_message and
allow the raw string output to be passed through.
(1) When model predict and output processors are both without error,
we have data as the final output, error as None.
(2) When either model predict or output processors have error,
we have data as None, error as the error message.
Raw_response will depends on the model predict.
"""
id: Optional[str] = field(
default=None, metadata={"desc": "The unique id of the output"}
)
input: Optional[Any] = field(
default=None,
metadata={"desc": "The input to the generator"}, # should use it to save the prompt
)
data: T_co = field(
default=None,
metadata={"desc": "The final output data potentially after output parsers"},
) # for reasoning model, this is only the text content/answer (raw_response)
# extend to support thinking and tool use
thinking: Optional[str] = field(
default=None, metadata={"desc": "The thinking of the model"}
)
tool_use: Optional[Function] = field(
default=None, metadata={"desc": "The tool use of the model"}
)
images: Optional[Union[str, List[str]]] = field(
default=None, metadata={"desc": "Generated images (base64 or URLs) from image generation tools"}
)
error: Optional[str] = field(
default=None,
metadata={"desc": "Error message if any"},
)
usage: Optional[CompletionUsage] = field(
default=None, metadata={"desc": "Usage tracking"}
)
# The caller expects the raw_response to follow the OpenAI API documentation for streams
raw_response: Optional[Union[str, AsyncIterable[T_co], Iterable[T_co]]] = field(
default=None, metadata={"desc": "Raw string chunk generator from the model"}
) # parsed from model client response
api_response: Optional[Any] = field(
default=None, metadata={"desc": "Raw response from the api/model client"}
)
metadata: Optional[Dict[str, object]] = field(
default=None, metadata={"desc": "Additional metadata"}
)
[docs]
async def stream_events(self) -> AsyncIterator[T_co]:
"""
Stream raw events from the Generator's raw response which has the processed version of api_response.
If the raw_response has already been consumed, yield from the data field
Returns:
AsyncIterator[T_co]: An async iterator that yields events stored in raw_response
"""
count = 0
# Fallback to raw_response if event_queue didn't yield anything
if isinstance(self.raw_response, AsyncIterable):
async for event in self.raw_response:
count += 1
yield event
# if the stream is already consumed and there is final data then just return the final data
if count == 0 and self.data:
yield self.data
[docs]
def save_images(
self,
directory: str = ".",
prefix: str = "generated",
format: Literal["png", "jpg", "jpeg", "webp", "gif", "bmp"] = "png",
decode_base64: bool = True,
return_paths: bool = True
) -> Optional[List[str]]:
"""Save generated images to disk with automatic format conversion.
Args:
directory: Directory to save images to (default: current directory)
prefix: Filename prefix for saved images (default: "generated")
format: Image format to save as (png, jpg, jpeg, webp, gif, bmp)
decode_base64: Whether to decode base64 encoded images (default: True)
return_paths: Whether to return the saved file paths (default: True)
Returns:
If return_paths is True:
- List[str]: Paths to saved images (always returns a list, even for single image)
- None: If no images to save
Otherwise returns None
Examples:
>>> # Save single image as PNG (returns list with one element)
>>> response.save_images()
['generated_0.png']
>>> # Save multiple images as JPEG with custom prefix
>>> response.save_images(prefix="cat", format="jpg")
['cat_0.jpg', 'cat_1.jpg']
>>> # Save to specific directory
>>> response.save_images(directory="/tmp/images", format="webp")
['/tmp/images/generated_0.webp']
"""
if not self.images:
return None
import os
import base64
from pathlib import Path
# Create directory if it doesn't exist
Path(directory).mkdir(parents=True, exist_ok=True)
saved_paths = []
images_to_save = self.images if isinstance(self.images, list) else [self.images]
try:
# Try to import PIL for format conversion
from PIL import Image
import io
has_pil = True
except ImportError:
has_pil = False
if format.lower() not in ["png", "jpg", "jpeg"]:
raise ImportError(
f"PIL/Pillow is required for '{format}' format. "
"Install with: pip install Pillow"
)
for i, img_data in enumerate(images_to_save):
# Determine if this is base64 or a URL
is_base64 = False
if isinstance(img_data, str):
if img_data.startswith("data:"):
# Data URI format
is_base64 = True
# Extract base64 part from data URI
base64_data = img_data.split(",")[1] if "," in img_data else img_data
elif not img_data.startswith(("http://", "https://")):
# Assume it's raw base64 if not a URL
is_base64 = True
base64_data = img_data
# Construct filename
filename = f"{prefix}_{i}.{format}"
filepath = os.path.join(directory, filename)
if is_base64 and decode_base64:
# Decode base64 and save
img_bytes = base64.b64decode(base64_data)
if has_pil and format.lower() not in ["png"]:
# Use PIL to convert format
img = Image.open(io.BytesIO(img_bytes))
# Convert RGBA to RGB for JPEG
if format.lower() in ["jpg", "jpeg"] and img.mode == "RGBA":
rgb_img = Image.new("RGB", img.size, (255, 255, 255))
rgb_img.paste(img, mask=img.split()[3] if len(img.split()) == 4 else None)
img = rgb_img
# PIL expects 'JPEG' for jpg/jpeg formats
pil_format = "JPEG" if format.lower() in ["jpg", "jpeg"] else format.upper()
img.save(filepath, pil_format)
else:
# Save as-is (assuming PNG or no conversion needed)
with open(filepath, "wb") as f:
f.write(img_bytes)
else:
# For URLs or if not decoding, save the string as-is
with open(filepath + ".url", "w") as f:
f.write(img_data)
filepath = filepath + ".url"
saved_paths.append(filepath)
if return_paths:
return saved_paths # Always return a list
return None
GeneratorOutputType = GeneratorOutput[object]
[docs]
@dataclass
class FunctionExpression(DataClass):
__doc__ = r"""The data modeling of a function expression for a call, including the name and arguments.
Example:
.. code-block:: python
def add(a, b):
return a + b
# call function add with positional arguments 1 and 2
fun_expr = FunctionExpression(action="add(1, 2)")
# evaluate the expression
result = eval(fun_expr.action)
print(result)
# Output: 3
# call function add with keyword arguments
fun_expr = FunctionExpression(action="add(a=1, b=2)")
result = eval(fun_expr.action)
print(result)
# Output: 3
Why asking LLM to generate function expression (code snippet) for a function call?
- It is more efficient/compact to call a function.
- It is more flexible.
(1) for the full range of Python expressions, including arithmetic operations, nested function calls, and more.
(2) allow to pass variables as arguments.
- Ease of parsing using ``ast`` module.
The benefits are less failed function calls.
"""
# question: str = field(
# default=None, metadata={"desc": "The question to ask the LLM"}
# )
thought: str = field(default=None, metadata={"desc": "Why the function is called"})
action: str = field(
default_factory=required_field,
metadata={"desc": _action_desc},
)
[docs]
@classmethod
def from_function(
cls,
func: Union[Callable[..., Any], AsyncCallable],
thought: Optional[str] = None,
*args,
**kwargs,
) -> "FunctionExpression":
r"""Create a FunctionExpression object from a function.
Args:
fun (Union[Callable[..., Any], AsyncCallable]): The function to be converted
Returns:
FunctionExpression: The FunctionExpression object
Usage:
1. Create a FunctionExpression object from a function call:
2. use :meth:`to_json` and :meth:`to_yaml` to get the schema in JSON or YAML format.
3. This will be used as an example in prompt showing LLM how to call the function.
Example:
.. code-block:: python
from adalflow.core.types import FunctionExpression
def add(a, b):
return a + b
# create an expression for the function call and using keyword arguments
fun_expr = FunctionExpression.from_function(
add, thought="Add two numbers", a=1, b=2
)
print(fun_expr)
# output
# FunctionExpression(thought='Add two numbers', action='add(a=1, b=2)')
"""
try:
action = generate_function_call_expression_from_callable(
func, *args, **kwargs
)
except Exception as e:
logger.error(f"Error generating function expression: {e}")
raise ValueError(f"Error generating function expression: {e}")
return cls(action=action, thought=thought)
FunctionOutputValueType = Union[
Any,
Generator[Any, Any, Any],
AsyncGenerator[Any, Any],
Coroutine[Any, Any, Any],
]
[docs]
@dataclass
class FunctionOutput(DataClass):
__doc__ = (
r"""The output of a tool, which could be a function, a class, or a module."""
)
name: Optional[str] = field(
default=None, metadata={"desc": "The name of the function"}
)
input: Optional[Union[Function, FunctionExpression]] = field(
default=None, metadata={"desc": "The Function or FunctionExpression object"}
)
parsed_input: Optional[Function] = field(
default=None,
metadata={
"desc": "The parsed Function object if the input is FunctionExpression"
},
)
output: Optional[FunctionOutputValueType] = field(
default=None,
metadata={
"desc": "The output of the function execution - supports sync functions, sync generators, async functions, and async generators"
},
)
error: Optional[str] = field(
default=None, metadata={"desc": "The error message if any"}
)
#######################################################################################
# Data modeling for component tool
######################################################################################
#######################################################################################
# Data modeling for agent component
######################################################################################
[docs]
@dataclass
class StepOutput(DataClass, Generic[T]):
__doc__ = r"""The output of a single step in the agent. Suits for serial planning agent such as React"""
step: int = field(
default=0, metadata={"desc": "The order of the step in the agent"}
)
# This action can be in Function, or Function Exptression, or just str
# it includes the thought and action already
# directly the output from planner LLMs
planner_prompt: Optional[str] = field(
default=None, metadata={"desc": "The planner prompt for this step"}
)
action: T = field(
default=None, metadata={"desc": "The action the agent takes at this step"}
)
function: Optional[Function] = field(
default=None, metadata={"desc": "The parsed function from the action"}
)
observation: Optional[str] = field(
default=None, metadata={"desc": "The execution result shown for this action"}
)
ctx: Optional[Dict[str, Any]] = field(
default=None, metadata={"desc": "The context of the step"}
)
[docs]
def to_prompt_str(self) -> str:
output: Dict[str, Any] = {}
if self.action and isinstance(self.action, FunctionExpression):
if self.action.thought:
output["thought"] = self.action.thought
output["action"] = self.action.action if self.action else None
if self.observation:
output["observation"] = (
self.observation.to_dict()
if hasattr(self.observation, "to_dict")
else str(self.observation)
)
return json.dumps(output)
#######################################################################################
# Data modeling for data processing pipleline such as Text splitting and Embedding
######################################################################################
[docs]
@dataclass
class Document(DataClass):
__doc__ = r"""A text container with optional metadata and vector representation.
It is the data structure to support functions like Retriever, DocumentSplitter, and used with LocalDB.
"""
text: str = field(metadata={"desc": "The main text"})
meta_data: Optional[Dict[str, Any]] = field(
default=None, metadata={"desc": "Metadata for the document"}
)
# can save data for filtering at retrieval time too
vector: List[float] = field(
default_factory=list,
metadata={"desc": "The vector representation of the document"},
)
# the vector representation of the document
id: Optional[str] = field(
default_factory=lambda: str(uuid.uuid4()), metadata={"desc": "Unique id"}
) # unique id of the document
order: Optional[int] = field(
default=None,
metadata={"desc": "Order of the chunked document in the original document"},
)
score: Optional[float] = field(
default=None,
metadata={"desc": "Score of the document, likely used in retrieval output"},
)
parent_doc_id: Optional[Union[str, UUID]] = field(
default=None, metadata={"desc": "id of the Document where the chunk is from"}
)
estimated_num_tokens: Optional[int] = field(
default=None,
metadata={
"desc": "Estimated number of tokens in the text, useful for cost estimation"
},
)
def __post_init__(self):
if self.estimated_num_tokens is None and self.text:
tokenizer = Tokenizer()
self.estimated_num_tokens = tokenizer.count_tokens(self.text)
[docs]
@classmethod
def from_dict(cls, doc: Dict):
"""Create a Document object from a dictionary.
Example:
.. code-block :: python
doc = Document.from_dict({
"id": "123",
"text": "Hello world",
"meta_data": {"title": "Greeting"}
})
"""
doc = doc.copy()
assert "meta_data" in doc, "meta_data is required"
assert "text" in doc, "text is required"
if "estimated_num_tokens" not in doc:
tokenizer = Tokenizer()
doc["estimated_num_tokens"] = tokenizer.count_tokens(doc["text"])
if "id" not in doc or not doc["id"]:
doc["id"] = uuid.uuid4()
return super().from_dict(doc)
def __repr__(self):
"""Custom repr method to truncate the text to 100 characters and vector to 10 floats."""
max_chars_to_show = 100
truncated_text = (
self.text[:max_chars_to_show] + "..."
if len(self.text) > max_chars_to_show
else self.text
)
truncated_vector = (
f"len: {len(self.vector)}" if len(self.vector) else self.vector
)
return (
f"Document(id={self.id}, text={truncated_text!r}, meta_data={self.meta_data}, "
f"vector={truncated_vector!r}, parent_doc_id={self.parent_doc_id}, order={self.order}, "
f"score={self.score})"
)
#######################################################################################
# Data modeling for dialog system
######################################################################################
[docs]
@dataclass
class UserQuery:
query_str: str
metadata: Optional[Dict[str, Any]] = (
None # context or files can be used in the user queries
)
[docs]
@dataclass
class AssistantResponse:
response_str: str
metadata: Optional[Dict[str, Any]] = None # for agent, we have step history
# There could more other roles in a multi-party conversation. We might consider in the future.
[docs]
@dataclass
class DialogTurn(DataClass):
__doc__ = r"""A turn consists of a user query and the assistant response.
The dataformat is designed to fit into a relational database, where each turn is a row.
Use `session_id` to group the turns into a dialog session with the `order` field and
`user_query_timestamp` and `assistant_response_timestamp` to order the turns.
Args:
id (str): The unique id of the turn.
user_id (str, optional): The unique id of the user.
session_id (str, optional): The unique id of the dialog session.
order (int, optional): The order of the turn in the dialog session, starts from 0.
user_query (UserQuery, optional): The user query in the turn.
assistant_response (AssistantResponse, optional): The assistant response in the turn.
user_query_timestamp (datetime, optional): The timestamp of the user query.
assistant_response_timestamp (datetime, optional): The timestamp of the assistant response.
metadata (Dict[str, Any], optional): Additional metadata.
Examples:
- User: Hi, how are you?
- Assistant: Doing great!
DialogTurn(id=uuid4(), user_query=UserQuery("Hi, how are you?"), assistant_response=AssistantResponse("Doing great!"))
"""
id: str = field(
default_factory=lambda: str(uuid.uuid4()),
metadata={"desc": "The unique id of the turn"},
)
user_id: Optional[str] = field(
default=None, metadata={"desc": "The unique id of the user"}
)
conversation_id: Optional[str] = field(
default=None,
metadata={"desc": "The unique id of the conversation it belongs to"},
)
order: Optional[int] = field(
default=None,
metadata={"desc": "The order of the turn in the Dialog Session, starts from 0"},
)
user_query: Optional[UserQuery] = field(
default=None, metadata={"desc": "The user query in the turn"}
)
assistant_response: Optional[AssistantResponse] = field(
default=None, metadata={"desc": "The assistant response in the turn"}
)
user_query_timestamp: Optional[datetime] = field(
default_factory=datetime.now,
metadata={"desc": "The timestamp of the user query"},
)
assistant_response_timestamp: Optional[datetime] = field(
default_factory=datetime.now,
metadata={"desc": "The timestamp of the assistant response"},
)
metadata: Optional[Dict[str, Any]] = field(
default=None, metadata={"desc": "Additional metadata"}
)
vector: Optional[List[float]] = field(
default=None,
metadata={"desc": "The vector representation of the dialog turn"},
)
[docs]
def set_user_query(
self, user_query: UserQuery, user_query_timestamp: Optional[datetime] = None
):
self.user_query = user_query
if not user_query_timestamp:
user_query_timestamp = datetime.now()
self.user_query_timestamp = user_query_timestamp
[docs]
def set_assistant_response(
self,
assistant_response: AssistantResponse,
assistant_response_timestamp: Optional[datetime] = None,
):
self.assistant_response = assistant_response
if not assistant_response_timestamp:
assistant_response_timestamp = datetime.now()
self.assistant_response_timestamp = assistant_response_timestamp
# TODO: This part and the Memory class is still WIP, and will need more work in the future.
[docs]
@dataclass
class Conversation:
__doc__ = r"""A conversation manages the dialog turns in a whole conversation as a session.
This class is mainly used in-memory for the dialog system/app to manage active conversations.
You won't need this class for past conversations which have already been persisted in a database as a form of
record or history.
"""
id: str = field(
default_factory=lambda: str(uuid.uuid4()),
metadata={"desc": "The id of the conversation"},
) # the id of the conversation
name: Optional[str] = field(
default=None, metadata={"desc": "The name of the conversation"}
)
user_id: Optional[str] = field(
default=None, metadata={"desc": "The id of the user"}
)
dialog_turns: OrderedDict[int, DialogTurn] = field(
default_factory=OrderedDict, metadata={"desc": "The dialog turns"}
)
# int is the order of the turn, starts from 0
metadata: Optional[Dict[str, Any]] = field(
default=None, metadata={"desc": "Additional metadata"}
)
created_at: Optional[datetime] = field(
default_factory=datetime.now,
metadata={"desc": "The timestamp of the conversation creation"},
)
# InitVar type annotation is used for parameters that are used in __post_init__
# but not meant to be fields in the dataclass.
dialog_turns_input: InitVar[
Optional[Union[OrderedDict[int, DialogTurn], List[DialogTurn]]]
] = None
def __post_init__(
self,
dialog_turns_input: Optional[
Union[OrderedDict[int, DialogTurn], List[DialogTurn]]
] = None,
):
if dialog_turns_input:
if isinstance(dialog_turns_input, list):
# Assume the list is of DialogTurn objects and needs to be added to an OrderedDict
for order, dialog_turn in enumerate(dialog_turns_input):
self.append_dialog_turn(dialog_turn)
elif isinstance(dialog_turns_input, OrderedDict):
self.dialog_turns = dialog_turns_input
else:
raise ValueError(
"dialog_turns should be a list of DialogTurn or an OrderedDict"
)
[docs]
def get_next_order(self):
return len(self.dialog_turns)
[docs]
def append_dialog_turn(self, dialog_turn: DialogTurn):
next_order = self.get_next_order()
if dialog_turn.order is None:
dialog_turn.order = next_order
else:
assert dialog_turn.order == next_order, f"order should be {next_order}"
self.dialog_turns[next_order] = dialog_turn
[docs]
def get_dialog_turns(self) -> OrderedDict[int, DialogTurn]:
return self.dialog_turns
[docs]
def get_chat_history_str(self) -> str:
chat_history_str = ""
for order, dialog_turn in self.dialog_turns.items():
chat_history_str += f"User: {dialog_turn.user_query.query_str}\n"
chat_history_str += (
f"Assistant: {dialog_turn.assistant_response.response_str}\n"
)
return chat_history_str
[docs]
def delete_dialog_turn(self, order: int):
self.dialog_turns.pop(order)
[docs]
def update_dialog_turn(self, order: int, dialog_turn: DialogTurn):
self.dialog_turns[order] = dialog_turn
##############################
# Agent runner events
##############################
import asyncio
[docs]
@dataclass
class RunItem(DataClass):
"""
Base class for streaming execution events in the Runner system.
RunItems represent discrete events that occur during the execution of an Agent
through the Runner. These items are used for streaming real-time updates about
the execution progress, allowing consumers to monitor and react to different
phases of agent execution.
Attributes:
id: Unique identifier for tracking this specific event instance
type: String identifier for the event type (used for event filtering/routing)
data: Optional generic data payload (deprecated, prefer specific fields in subclasses)
timestamp: When this event was created (for debugging and monitoring)
Usage:
This is an abstract base class. Use specific subclasses for different event types.
Example:
```python
# Don't instantiate directly - use subclasses
tool_call_event = ToolCallRunItem(function=my_function)
```
"""
id: str = field(
default_factory=lambda: str(uuid.uuid4()),
metadata={"desc": "Unique identifier for this run item"},
)
type: str = field(
default="base",
metadata={
"desc": "Type of run item - used for event identification and routing"
},
)
data: Optional[Any] = field(
default=None,
metadata={
"desc": "Generic data payload (deprecated - use specific fields in subclasses)"
},
)
error: Optional[str] = field(
default=None,
metadata={"desc": "Error message if an error occurred"},
)
timestamp: datetime = field(
default_factory=datetime.now,
metadata={"desc": "Timestamp when this event was created"},
)
[docs]
@dataclass
class FunctionRequest(DataClass):
"""
Event emitted when the Agent is about to execute a function/tool call.
"""
id: str = field(
default_factory=lambda: str(uuid.uuid4())
) # tool call id for this request
tool_name: str = field(
default=None, metadata={"desc": "Name of the tool to be called"}
)
tool: Optional[Function] = field(
default=None,
metadata={"desc": "Function object containing the tool call to be executed"},
)
# send this to the frontend user to display the details of the confirmation
confirmation_details: Optional[Any] = field(
default=None,
metadata={"desc": "Confirmation details for the tool call"},
)
[docs]
@dataclass
class StepRunItem(RunItem):
"""
Event emitted when a complete execution step has finished.
A "step" represents one complete cycle of: planning → tool selection → tool execution.
This event marks the completion of that cycle and contains the full step information
including the action taken and the observation (result).
Attributes:
data: Complete StepOutput containing step number, action, and observation
Event Flow Position:
ToolOutputRunItem → **StepRunItem** → (next step or completion)
Usage:
```python
# Track step completion
async for event in runner.astream(prompt_kwargs).stream_events():
if isinstance(event, RunItemStreamEvent) and event.name == "step_completed":
step_item = event.item
print(f"Completed step {step_item.data.step}")
```
"""
type: str = field(default="step", metadata={"desc": "Type of run item"})
data: Optional[StepOutput] = field(
default=None,
metadata={
"desc": "Complete step execution result including action and observation"
},
)
"""
Used to wrap the final response from the runner which holds key information about the execution such
as the answer, step history, and error.
"""
[docs]
@dataclass
class RunnerResult:
step_history: List[StepOutput] = field(
metadata={"desc": "The step history of the execution"},
default_factory=list,
)
answer: Optional[str] = field(
metadata={"desc": "The answer to the user's query"}, default=None
)
error: Optional[str] = field(
metadata={"desc": "The error message if the code execution failed"},
default=None,
)
ctx: Optional[Dict] = field(
metadata={"desc": "The context of the execution"},
default=None,
)
[docs]
@dataclass
class FinalOutputItem(RunItem):
"""
Event emitted when the entire Runner execution has completed.
This event signals the end of the execution sequence and contains the final
processed result. It's emitted regardless of whether execution completed
successfully or with an error.
Attributes:
data: The final RunnerResponse containing the complete execution result
Event Flow Position:
Final step → **FinalOutputItem** (execution complete)
Usage:
```python
# Get final results
async for event in runner.astream(prompt_kwargs).stream_events():
if isinstance(event, RunItemStreamEvent) and event.name == "runner_finished":
final_item = event.item
if final_item.data.error:
print(f"Execution failed: {final_item.data.error}")
else:
print(f"Final answer: {final_item.data.answer}")
```
"""
type: str = field(default="final_output", metadata={"desc": "Type of run item"})
data: Optional[RunnerResult] = field(
default=None,
metadata={"desc": "Final processed output from the runner execution"},
)
[docs]
@dataclass
class RunItemStreamEvent(DataClass):
"""
Wrapper for streaming RunItem events during Runner execution.
This class wraps RunItem instances with event metadata to create a streaming
event system. Each event has a name that indicates what type of execution
event occurred, and contains the associated RunItem with the event data.
The streaming system allows consumers to react to different phases of agent
execution in real-time, such as when tools are called, when steps complete,
or when execution finishes.
Attributes:
name: The specific event type that occurred (see event name literals)
item: The RunItem containing the event-specific data
type: Always "run_item_stream_event" for type discrimination
Event Types:
- "agent.final_output": Final output from the agent (FinalOutputItem)
- "agent.tool_permission_request": Tool permission request before execution
- "agent.tool_call_start": Agent is about to execute a tool (ToolCallRunItem)
- "agent.tool_call_activity": Agent is about to execute a tool (ToolCallActivityRunItem)
- "agent.tool_call_complete": Tool execution completed (ToolOutputRunItem)
- "agent.step_complete": Full execution step finished (StepRunItem)
- "agent.execution_complete": Entire execution completed (FinalOutputItem)
"""
name: Literal[
# Core agent execution events
"agent.tool_call_start", # Function/tool about to be executed
"agent.tool_call_activity", # Function/tool intermediate activity and progress updates
"agent.tool_call_complete", # Function/tool execution completed
"agent.step_complete", # Complete execution step finished
"agent.final_output", # Final processed output available
"agent.execution_complete", # Entire Runner execution completed
"agent.tool_permission_request", # Tool permission request before execution
] = field(
metadata={
"desc": "The name identifying the specific type of execution event that occurred"
}
)
"""The name identifying the specific type of execution event that occurred."""
# TODO: convert this to data to be consistent with other events
item: RunItem = field(
metadata={
"desc": "The RunItem instance containing the event-specific data and context"
}
)
"""The RunItem instance containing the event-specific data and context."""
type: Literal["run_item_stream_event"] = field(
default="run_item_stream_event",
metadata={"desc": "Type discriminator for the streaming event system"},
)
"""Type discriminator for the streaming event system."""
StreamEvent: TypeAlias = Union[RawResponsesStreamEvent, RunItemStreamEvent]
[docs]
@dataclass
class QueueCompleteSentinel:
"""Sentinel to indicate queue completion."""
pass
[docs]
class EventEncoder(json.JSONEncoder):
[docs]
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
elif hasattr(obj, "__dict__"):
return obj.__dict__
elif hasattr(obj, "__str__"):
return str(obj)
else:
return super().default(obj)
[docs]
@dataclass
class RunnerStreamingResult:
"""
Container for runner streaming results that provides access to the event queue
and allows users to consume streaming events.
"""
_event_queue: asyncio.Queue = field(default_factory=asyncio.Queue)
_run_task: Optional[asyncio.Task] = field(default=None)
_exception: Optional[Exception] = field(default=None)
answer: Optional[Any] = field(default=None)
step_history: List[Any] = field(default_factory=list)
_is_complete: bool = field(default=False)
@property
def is_complete(self) -> bool:
"""Check if the workflow execution is complete."""
return self._is_complete
[docs]
def set_exception(self, exc: Any) -> None:
"""Set an exception, ensuring it's a proper exception object."""
if exc is None:
self._exception = None
elif isinstance(exc, BaseException):
self._exception = exc
else:
# Convert non-exception to a proper exception
self._exception = RuntimeError(f"Non-exception error: {str(exc)}")
[docs]
def put_nowait(self, item: StreamEvent):
# only RawResponsesStreamEvent and RunItemStreamEvent can be put into the queue
if not isinstance(
item, (RawResponsesStreamEvent, RunItemStreamEvent, QueueCompleteSentinel)
):
raise ValueError(
"Only RawResponsesStreamEvent and RunItemStreamEvent can be put into the queue"
)
self._event_queue.put_nowait(item)
[docs]
async def stream_events(self) -> AsyncIterator[StreamEvent]:
"""
Stream events from the runner execution.w
Returns:
AsyncIterator[StreamEvent]: An async iterator that yields stream events
Example:
```python
result = runner.astream(prompt_kwargs)
async for event in result.stream_events():
if isinstance(event, RawResponsesStreamEvent):
print(f"Raw event: {event.data}")
elif isinstance(event, RunItemStreamEvent):
print(f"Run item: {event.name} - {event.item}")
```
"""
while True:
if self._exception:
# Ensure we're raising a proper exception
if isinstance(self._exception, BaseException):
raise self._exception
else:
# Convert non-exception to a proper exception
raise RuntimeError(str(self._exception))
try:
# Wait for an event from the queue
event = await self._event_queue.get()
# Check for completion sentinel or special completion events
if isinstance(event, QueueCompleteSentinel):
self._event_queue.task_done()
break
else:
# always yield event
yield event
# mark the task as done
self._event_queue.task_done()
# if the event is a RunItemStreamEvent and the name is agent.execution_complete then additionally break the loop
if (
isinstance(event, RunItemStreamEvent)
and event.name == "agent.execution_complete"
):
break
except asyncio.CancelledError:
# Clean up and re-raise to allow proper cancellation
self._is_complete = True
raise
except Exception as e:
# Store unexpected exceptions
self.set_exception(e)
raise
[docs]
async def stream_to_json(
self, file_name: str = "agent_events_stream.json"
) -> AsyncIterator[StreamEvent]:
"""
Stream events to a JSON file in real-time while also yielding them.
This method writes events to a JSON file as they arrive, giving a live
streaming effect. The JSON file is updated incrementally.
Args:
file_name: The output file name for saving events
Yields:
StreamEvent: Each event as it arrives
Example:
```python
result = runner.astream(prompt_kwargs)
async for event in result.stream_to_json("live_events.json"):
# Process event while it's also being written to file
print(f"Event: {event}")
```
"""
event_count = 0
# Open file and write the opening bracket
with open(file_name, "w") as f:
f.write("[\n")
first_event = True
try:
async for event in self.stream_events():
event_count += 1
# Prepare event data
try:
if hasattr(event, "to_dict"):
event_data = event.to_dict()
else:
event_data = str(event)
except Exception as e:
# If serialization fails, use a fallback representation
event_data = f"<Error serializing event: {str(e)}>"
event_dict = {
"event_number": event_count,
"timestamp": datetime.now().isoformat(),
"event_type": type(event).__name__,
"event_data": event_data,
}
# Append to file in streaming fashion
try:
with open(file_name, "r+") as f:
# Seek to end of file
f.seek(0, 2)
if first_event:
# For first event, we're right after "[\n"
first_event = False
else:
# For subsequent events, go back to overwrite the previous "\n]"
f.seek(f.tell() - 2)
f.write(",\n")
# Write the event
json.dump(event_dict, f, indent=2, cls=EventEncoder)
# Write closing bracket
f.write("\n]")
except (IOError, OSError) as e:
# Log file write error but continue streaming
logger.warning(f"Failed to write event to {file_name}: {e}")
# Yield the event so caller can process it
yield event
except asyncio.CancelledError:
# Properly handle cancellation
raise
except Exception as e:
# Log unexpected errors
logger.error(f"Error in stream_to_json: {e}")
raise
print(f"\nStreamed {event_count} events to {file_name}")
[docs]
def stream_to_json_sync(self, file_name: str = "agent_events_stream.json"):
"""
Synchronous wrapper for stream_to_json that returns an iterator.
This allows users to use the streaming JSON functionality in a sync context.
Args:
file_name: The output file name for saving events
Returns:
Iterator of events
Example:
```python
result = runner.astream(prompt_kwargs)
for event in result.stream_to_json_sync("live_events.json"):
print(f"Event: {event}")
```
"""
import asyncio
async def _collect_events():
events = []
async for event in self.stream_to_json(file_name):
events.append(event)
return events
try:
# Get the current event loop if running
# loop = asyncio.get_running_loop()
# If we're already in an async context, we need to run in a thread
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, _collect_events())
return future.result()
except RuntimeError:
# No event loop running, we can use asyncio.run directly
return asyncio.run(_collect_events())
[docs]
def cancel(self):
"""Cancel the running task."""
if self._run_task and not self._run_task.done():
self._run_task.cancel()
[docs]
async def wait_for_completion(self):
"""Wait for the runner task to complete."""
if self._run_task:
await self._run_task