"""Functional data classes to support functional components like Generator, Retriever, and Assistant."""
from enum import Enum, auto
from typing import (
List,
Dict,
Any,
Optional,
Union,
Generic,
TypeVar,
Sequence,
Literal,
Callable,
Awaitable,
Type,
)
from collections import OrderedDict
from dataclasses import (
dataclass,
field,
InitVar,
)
from uuid import UUID
from datetime import datetime
import uuid
import logging
from adalflow.core.base_data_class import DataClass, required_field
from adalflow.core.tokenizer import Tokenizer
from adalflow.core.functional import (
is_normalized,
generate_function_call_expression_from_callable,
)
from adalflow.components.model_client import (
CohereAPIClient,
TransformersClient,
AnthropicAPIClient,
GroqAPIClient,
OpenAIClient,
GoogleGenAIClient,
)
logger = logging.getLogger(__name__)
T_co = TypeVar("T_co", covariant=True)
T = TypeVar("T") # invariant type
#######################################################################################
# Data modeling for ModelClient
######################################################################################
[docs]
class ModelType(Enum):
__doc__ = r"""The type of the model, including Embedder, LLM, Reranker.
It helps ModelClient identify the model type required to correctly call the model."""
EMBEDDER = auto()
LLM = auto()
RERANKER = auto() # ranking model
UNDEFINED = auto()
[docs]
@dataclass
class ModelClientType:
__doc__ = r"""A quick way to access all model clients in the ModelClient module.
From this:
.. code-block:: python
from adalflow.components.model_client import CohereAPIClient, TransformersClient, AnthropicAPIClient, GroqAPIClient, OpenAIClient
model_client = OpenAIClient()
To this:
.. code-block:: python
from adalflow.core.types import ModelClientType
model_client = ModelClientType.OPENAI()
"""
COHERE = CohereAPIClient
TRANSFORMERS = TransformersClient
ANTHROPIC = AnthropicAPIClient
GROQ = GroqAPIClient
OPENAI = OpenAIClient
GOOGLE_GENAI = GoogleGenAIClient
# TODO: define standard required outputs
[docs]
def get_model_args(model_type: ModelType) -> List[str]:
r"""Get the required keys in model_kwargs for a specific model type.
note:
If your model inference sdk uses different keys, you need to convert them to the standard keys here in their specifc ModelClient.
Args:
model_type (ModelType): The model type
Returns:
List[str]: The required keys in model_kwargs
"""
if model_type == ModelType.EMBEDDER:
return ["model"]
elif model_type == ModelType.LLM:
return ["model"]
elif model_type == ModelType.RERANKER:
return ["model", "top_k", "documents", "query"]
else:
return []
#######################################################################################
# Data modeling for Embedder component
######################################################################################
[docs]
@dataclass
class Embedding:
"""
Container for a single embedding.
In sync with api spec, same as openai/types/embedding.py
"""
embedding: List[float]
index: Optional[int] # match with the index of the input, in case some are missing
[docs]
@dataclass
class Usage:
"""
In sync with OpenAI embedding api spec, same as openai/types/create_embedding_response.py
"""
prompt_tokens: int
total_tokens: int
[docs]
@dataclass
class EmbedderOutput(DataClass):
__doc__ = r"""Container to hold the response from an Embedder model. Only Per-batch.
Data standard for Embedder model output to interact with other components.
Batch processing is often available, thus we need a list of Embedding objects.
"""
data: List[Embedding] = field(
default_factory=list, metadata={"desc": "List of embeddings"}
)
model: Optional[str] = field(default=None, metadata={"desc": "Model name"})
usage: Optional[Usage] = field(default=None, metadata={"desc": "Usage tracking"})
error: Optional[str] = field(default=None, metadata={"desc": "Error message"})
raw_response: Optional[Any] = field(
default=None, metadata={"desc": "Raw response"}
) # only used if error
input: Optional[List[str]] = field(default=None, metadata={"desc": "Input text"})
@property
def length(self) -> int:
return len(self.data) if self.data and isinstance(self.data, Sequence) else 0
@property
def embedding_dim(self) -> int:
r"""The dimension of the embedding, assuming all embeddings have the same dimension.
Returns:
int: The dimension of the embedding, -1 if no embedding is available
"""
return (
len(self.data[0].embedding) if self.data and self.data[0].embedding else -1
)
@property
def is_normalized(self) -> bool:
r"""Check if the embeddings are normalized to unit vectors.
Returns:
bool: True if the embeddings are normalized, False otherwise
"""
return (
is_normalized(self.data[0].embedding)
if self.data and self.data[0].embedding
else False
)
EmbedderInputType = Union[str, Sequence[str]]
EmbedderOutputType = EmbedderOutput
BatchEmbedderInputType = EmbedderInputType
BatchEmbedderOutputType = List[EmbedderOutputType]
#######################################################################################
# Data modeling for Generator component
######################################################################################
[docs]
@dataclass
class TokenLogProb:
r"""similar to openai.ChatCompletionTokenLogprob"""
token: str
logprob: float
[docs]
@dataclass
class CompletionUsage:
__doc__ = r"In sync with OpenAI completion usage api spec at openai/types/completion_usage.py"
completion_tokens: Optional[int] = field(
metadata={"desc": "Number of tokens in the generated completion"}, default=None
)
prompt_tokens: Optional[int] = field(
metadata={"desc": "Number of tokens in the prompt"}, default=None
)
total_tokens: Optional[int] = field(
metadata={
"desc": "Total number of tokens used in the request (prompt + completion)"
},
default=None,
)
[docs]
@dataclass
class GeneratorOutput(DataClass, Generic[T_co]):
__doc__ = r"""
The output data class for the Generator component.
We can not control its output 100%, so we use this to track the error_message and
allow the raw string output to be passed through.
(1) When model predict and output processors are both without error,
we have data as the final output, error as None.
(2) When either model predict or output processors have error,
we have data as None, error as the error message.
Raw_response will depends on the model predict.
"""
id: Optional[str] = field(
default=None, metadata={"desc": "The unique id of the output"}
)
data: T_co = field(
default=None,
metadata={"desc": "The final output data potentially after output parsers"},
)
error: Optional[str] = field(
default=None,
metadata={"desc": "Error message if any"},
)
usage: Optional[CompletionUsage] = field(
default=None, metadata={"desc": "Usage tracking"}
)
raw_response: Optional[str] = field(
default=None, metadata={"desc": "Raw string response from the model"}
) # parsed from model client response
metadata: Optional[Dict[str, object]] = field(
default=None, metadata={"desc": "Additional metadata"}
)
GeneratorOutputType = GeneratorOutput[object]
#######################################################################################
# Data modeling for Retriever component
######################################################################################
RetrieverQueryType = TypeVar("RetrieverQueryType", contravariant=True)
RetrieverStrQueryType = str
RetrieverQueriesType = Union[RetrieverQueryType, Sequence[RetrieverQueryType]]
RetrieverStrQueriesType = Union[str, Sequence[RetrieverStrQueryType]]
RetrieverDocumentType = TypeVar("RetrieverDocumentType", contravariant=True)
RetrieverStrDocumentType = str # for text retrieval
RetrieverDocumentsType = Sequence[RetrieverDocumentType]
[docs]
@dataclass
class RetrieverOutput(DataClass):
__doc__ = r"""Save the output of a single query in retrievers.
It is up to the subclass of Retriever to specify the type of query and document.
"""
doc_indices: List[int] = field(metadata={"desc": "List of document indices"})
doc_scores: Optional[List[float]] = field(
default=None, metadata={"desc": "List of document scores"}
)
query: Optional[RetrieverQueryType] = field(
default=None, metadata={"desc": "The query used to retrieve the documents"}
)
documents: Optional[List[RetrieverDocumentType]] = field(
default=None, metadata={"desc": "List of retrieved documents"}
)
RetrieverOutputType = List[RetrieverOutput] # so to support multiple queries at once
#######################################################################################
# Data modeling for function calls
######################################################################################
AsyncCallable = Callable[..., Awaitable[Any]]
[docs]
@dataclass
class FunctionDefinition(DataClass):
__doc__ = r"""The data modeling of a function definition, including the name, description, and parameters."""
func_name: str = field(metadata={"desc": "The name of the tool"})
func_desc: Optional[str] = field(
default=None, metadata={"desc": "The description of the tool"}
)
func_parameters: Dict[str, object] = field(
default_factory=dict, metadata={"desc": "The schema of the parameters"}
)
[docs]
def fn_schema_str(self, type: Literal["json", "yaml"] = "json") -> str:
r"""Get the function definition str to be used in the prompt.
You should also directly use :meth:`to_json` and :meth:`to_yaml` to get the schema in JSON or YAML format.
"""
if type == "json":
return self.to_json()
elif type == "yaml":
return self.to_yaml()
else:
raise ValueError(f"Unsupported type: {type}")
[docs]
@dataclass
class Function(DataClass):
__doc__ = r"""The data modeling of a function call, including the name and keyword arguments.
You can use the exclude in :meth:`to_json` and :meth:`to_yaml` to exclude the `thought` field if you do not want to use chain-of-thought pattern.
Example:
.. code-block:: python
# assume the function is added in a context_map
# context_map = {"add": add}
def add(a, b):
return a + b
# call function add with arguments 1 and 2
fun = Function(name="add", kwargs={"a": 1, "b": 2})
# evaluate the function
result = context_map[fun.name](**fun.kwargs)
# or call with positional arguments
fun = Function(name="add", args=[1, 2])
result = context_map[fun.name](*fun.args)
"""
thought: Optional[str] = field(
default=None, metadata={"desc": "Why the function is called"}
)
name: str = field(default="", metadata={"desc": "The name of the function"})
args: Optional[List[object]] = field(
default_factory=list,
metadata={"desc": "The positional arguments of the function"},
)
kwargs: Optional[Dict[str, object]] = field(
default_factory=dict,
metadata={"desc": "The keyword arguments of the function"},
)
_action_desc = """FuncName(<kwargs>) \
Valid function call expression. \
Example: "FuncName(a=1, b=2)" \
Follow the data type specified in the function parameters.\
e.g. for Type object with x,y properties, use "ObjectType(x=1, y=2)"""
[docs]
@dataclass
class FunctionExpression(DataClass):
__doc__ = r"""The data modeling of a function expression for a call, including the name and arguments.
Example:
.. code-block:: python
def add(a, b):
return a + b
# call function add with positional arguments 1 and 2
fun_expr = FunctionExpression(action="add(1, 2)")
# evaluate the expression
result = eval(fun_expr.action)
print(result)
# Output: 3
# call function add with keyword arguments
fun_expr = FunctionExpression(action="add(a=1, b=2)")
result = eval(fun_expr.action)
print(result)
# Output: 3
Why asking LLM to generate function expression (code snippet) for a function call?
- It is more efficient/compact to call a function.
- It is more flexible.
(1) for the full range of Python expressions, including arithmetic operations, nested function calls, and more.
(2) allow to pass variables as arguments.
- Ease of parsing using ``ast`` module.
The benefits are less failed function calls.
"""
thought: Optional[str] = field(
default=None, metadata={"desc": "Why the function is called"}
)
action: str = field(
default_factory=required_field,
metadata={"desc": _action_desc},
)
[docs]
@classmethod
def from_function(
cls,
func: Union[Callable[..., Any], AsyncCallable],
thought: Optional[str] = None,
*args,
**kwargs,
) -> "FunctionExpression":
r"""Create a FunctionExpression object from a function.
Args:
fun (Union[Callable[..., Any], AsyncCallable]): The function to be converted
Returns:
FunctionExpression: The FunctionExpression object
Usage:
1. Create a FunctionExpression object from a function call:
2. use :meth:`to_json` and :meth:`to_yaml` to get the schema in JSON or YAML format.
3. This will be used as an example in prompt showing LLM how to call the function.
Example:
.. code-block:: python
from adalflow.core.types import FunctionExpression
def add(a, b):
return a + b
# create an expression for the function call and using keyword arguments
fun_expr = FunctionExpression.from_function(
add, thought="Add two numbers", a=1, b=2
)
print(fun_expr)
# output
# FunctionExpression(thought='Add two numbers', action='add(a=1, b=2)')
"""
try:
action = generate_function_call_expression_from_callable(
func, *args, **kwargs
)
except Exception as e:
logger.error(f"Error generating function expression: {e}")
raise ValueError(f"Error generating function expression: {e}")
return cls(action=action, thought=thought)
[docs]
@dataclass
class FunctionOutput(DataClass):
__doc__ = (
r"""The output of a tool, which could be a function, a class, or a module."""
)
name: Optional[str] = field(
default=None, metadata={"desc": "The name of the function"}
)
input: Optional[Union[Function, FunctionExpression]] = field(
default=None, metadata={"desc": "The Function or FunctionExpression object"}
)
parsed_input: Optional[Function] = field(
default=None,
metadata={
"desc": "The parsed Function object if the input is FunctionExpression"
},
)
output: Optional[object] = field(
default=None, metadata={"desc": "The output of the function execution"}
)
error: Optional[str] = field(
default=None, metadata={"desc": "The error message if any"}
)
#######################################################################################
# Data modeling for agent component
######################################################################################
[docs]
@dataclass
class StepOutput(DataClass, Generic[T]):
__doc__ = r"""The output of a single step in the agent. Suits for serial planning agent such as React"""
step: int = field(
default=0, metadata={"desc": "The order of the step in the agent"}
)
# This action can be in Function, or Function Exptression, or just str
# it includes the thought and action already
# directly the output from planner LLMs
action: T = field(
default=None, metadata={"desc": "The action the agent takes at this step"}
)
function: Optional[Function] = field(
default=None, metadata={"desc": "The parsed function from the action"}
)
observation: Optional[str] = field(
default=None, metadata={"desc": "The execution result shown for this action"}
)
[docs]
@classmethod
def with_action_type(cls, action_type: Type[T]) -> Type["StepOutput[T]"]:
"""
Create a new StepOutput class with the specified action type.
Use this if you want to create schema for StepOutput with a specific action type.
Args:
action_type (Type[T]): The type to set for the action attribute.
Returns:
Type[StepOutput[T]]: A new subclass of StepOutput with the specified action type.
Example:
.. code-block:: python
from adalflow.core.types import StepOutput, FunctionExpression
StepOutputWithFunctionExpression = StepOutput.with_action_type(FunctionExpression)
"""
# Create a new type variable map
type_var_map = {T: action_type}
# Create a new subclass with the updated type
new_cls = type(cls.__name__, (cls,), {"__type_var_map__": type_var_map})
# Update the __annotations__ to reflect the new type of action
new_cls.__annotations__["action"] = action_type
return new_cls
#######################################################################################
# Data modeling for data processing pipleline such as Text splitting and Embedding
######################################################################################
[docs]
@dataclass
class Document(DataClass):
__doc__ = r"""A text container with optional metadata and vector representation.
It is the data structure to support functions like Retriever, DocumentSplitter, and used with LocalDB.
"""
text: str = field(metadata={"desc": "The main text"})
meta_data: Optional[Dict[str, Any]] = field(
default=None, metadata={"desc": "Metadata for the document"}
)
# can save data for filtering at retrieval time too
vector: List[float] = field(
default_factory=list,
metadata={"desc": "The vector representation of the document"},
)
# the vector representation of the document
id: Optional[str] = field(
default_factory=lambda: str(uuid.uuid4()), metadata={"desc": "Unique id"}
) # unique id of the document
order: Optional[int] = field(
default=None,
metadata={"desc": "Order of the chunked document in the original document"},
)
score: Optional[float] = field(
default=None,
metadata={"desc": "Score of the document, likely used in retrieval output"},
)
parent_doc_id: Optional[Union[str, UUID]] = field(
default=None, metadata={"desc": "id of the Document where the chunk is from"}
)
estimated_num_tokens: Optional[int] = field(
default=None,
metadata={
"desc": "Estimated number of tokens in the text, useful for cost estimation"
},
)
def __post_init__(self):
if self.estimated_num_tokens is None and self.text:
tokenizer = Tokenizer()
self.estimated_num_tokens = tokenizer.count_tokens(self.text)
[docs]
@classmethod
def from_dict(cls, doc: Dict):
"""Create a Document object from a dictionary.
Example:
.. code-block :: python
doc = Document.from_dict({
"id": "123",
"text": "Hello world",
"meta_data": {"title": "Greeting"}
})
"""
doc = doc.copy()
assert "meta_data" in doc, "meta_data is required"
assert "text" in doc, "text is required"
if "estimated_num_tokens" not in doc:
tokenizer = Tokenizer()
doc["estimated_num_tokens"] = tokenizer.count_tokens(doc["text"])
if "id" not in doc or not doc["id"]:
doc["id"] = uuid.uuid4()
return super().from_dict(doc)
def __repr__(self):
"""Custom repr method to truncate the text to 100 characters and vector to 10 floats."""
max_chars_to_show = 100
truncated_text = (
self.text[:max_chars_to_show] + "..."
if len(self.text) > max_chars_to_show
else self.text
)
truncated_vector = (
f"len: {len(self.vector)}" if len(self.vector) else self.vector
)
return (
f"Document(id={self.id}, text={truncated_text!r}, meta_data={self.meta_data}, "
f"vector={truncated_vector!r}, parent_doc_id={self.parent_doc_id}, order={self.order}, "
f"score={self.score})"
)
#######################################################################################
# Data modeling for dialog system
######################################################################################
[docs]
@dataclass
class UserQuery:
query_str: str
metadata: Optional[Dict[str, Any]] = None
[docs]
@dataclass
class AssistantResponse:
response_str: str
metadata: Optional[Dict[str, Any]] = None
# There could more other roles in a multi-party conversation. We might consider in the future.
[docs]
@dataclass
class DialogTurn(DataClass):
__doc__ = r"""A turn consists of a user query and the assistant response.
The dataformat is designed to fit into a relational database, where each turn is a row.
Use `session_id` to group the turns into a dialog session with the `order` field and
`user_query_timestamp` and `assistant_response_timestamp` to order the turns.
Args:
id (str): The unique id of the turn.
user_id (str, optional): The unique id of the user.
session_id (str, optional): The unique id of the dialog session.
order (int, optional): The order of the turn in the dialog session, starts from 0.
user_query (UserQuery, optional): The user query in the turn.
assistant_response (AssistantResponse, optional): The assistant response in the turn.
user_query_timestamp (datetime, optional): The timestamp of the user query.
assistant_response_timestamp (datetime, optional): The timestamp of the assistant response.
metadata (Dict[str, Any], optional): Additional metadata.
Examples:
- User: Hi, how are you?
- Assistant: I'm fine, thank you!
DialogTurn(id=uuid4(), user_query=UserQuery("Hi, how are you?"), assistant_response=AssistantResponse("I'm fine, thank you!"))
"""
id: str = field(
default_factory=lambda: str(uuid.uuid4()),
metadata={"desc": "The unique id of the turn"},
)
user_id: Optional[str] = field(
default=None, metadata={"desc": "The unique id of the user"}
)
conversation_id: Optional[str] = field(
default=None,
metadata={"desc": "The unique id of the conversation it belongs to"},
)
order: Optional[int] = field(
default=None,
metadata={"desc": "The order of the turn in the Dialog Session, starts from 0"},
)
user_query: Optional[UserQuery] = field(
default=None, metadata={"desc": "The user query in the turn"}
)
assistant_response: Optional[AssistantResponse] = field(
default=None, metadata={"desc": "The assistant response in the turn"}
)
user_query_timestamp: Optional[datetime] = field(
default_factory=datetime.now,
metadata={"desc": "The timestamp of the user query"},
)
assistant_response_timestamp: Optional[datetime] = field(
default_factory=datetime.now,
metadata={"desc": "The timestamp of the assistant response"},
)
metadata: Optional[Dict[str, Any]] = field(
default=None, metadata={"desc": "Additional metadata"}
)
vector: Optional[List[float]] = field(
default=None,
metadata={"desc": "The vector representation of the dialog turn"},
)
[docs]
def set_user_query(
self, user_query: UserQuery, user_query_timestamp: Optional[datetime] = None
):
self.user_query = user_query
if not user_query_timestamp:
user_query_timestamp = datetime.now()
self.user_query_timestamp = user_query_timestamp
[docs]
def set_assistant_response(
self,
assistant_response: AssistantResponse,
assistant_response_timestamp: Optional[datetime] = None,
):
self.assistant_response = assistant_response
if not assistant_response_timestamp:
assistant_response_timestamp = datetime.now()
self.assistant_response_timestamp = assistant_response_timestamp
# TODO: This part and the Memory class is still WIP, and will need more work in the future.
[docs]
@dataclass
class Conversation:
__doc__ = r"""A conversation manages the dialog turns in a whole conversation as a session.
This class is mainly used in-memory for the dialog system/app to manage active conversations.
You won't need this class for past conversations which have already been persisted in a database as a form of
record or history.
"""
id: str = field(
default_factory=lambda: str(uuid.uuid4()),
metadata={"desc": "The id of the conversation"},
) # the id of the conversation
name: Optional[str] = field(
default=None, metadata={"desc": "The name of the conversation"}
)
user_id: Optional[str] = field(
default=None, metadata={"desc": "The id of the user"}
)
dialog_turns: OrderedDict[int, DialogTurn] = field(
default_factory=OrderedDict, metadata={"desc": "The dialog turns"}
)
# int is the order of the turn, starts from 0
metadata: Optional[Dict[str, Any]] = field(
default=None, metadata={"desc": "Additional metadata"}
)
created_at: Optional[datetime] = field(
default_factory=datetime.now,
metadata={"desc": "The timestamp of the conversation creation"},
)
# InitVar type annotation is used for parameters that are used in __post_init__
# but not meant to be fields in the dataclass.
dialog_turns_input: InitVar[
Optional[Union[OrderedDict[int, DialogTurn], List[DialogTurn]]]
] = None
def __post_init__(
self,
dialog_turns_input: Optional[
Union[OrderedDict[int, DialogTurn], List[DialogTurn]]
] = None,
):
if dialog_turns_input:
if isinstance(dialog_turns_input, list):
# Assume the list is of DialogTurn objects and needs to be added to an OrderedDict
for order, dialog_turn in enumerate(dialog_turns_input):
self.append_dialog_turn(dialog_turn)
elif isinstance(dialog_turns_input, OrderedDict):
self.dialog_turns = dialog_turns_input
else:
raise ValueError(
"dialog_turns should be a list of DialogTurn or an OrderedDict"
)
[docs]
def get_next_order(self):
return len(self.dialog_turns)
[docs]
def append_dialog_turn(self, dialog_turn: DialogTurn):
next_order = self.get_next_order()
if dialog_turn.order is None:
dialog_turn.order = next_order
else:
assert dialog_turn.order == next_order, f"order should be {next_order}"
self.dialog_turns[next_order] = dialog_turn
[docs]
def get_dialog_turns(self) -> OrderedDict[int, DialogTurn]:
return self.dialog_turns
[docs]
def get_chat_history_str(self) -> str:
chat_history_str = ""
for order, dialog_turn in self.dialog_turns.items():
chat_history_str += f"User: {dialog_turn.user_query.query_str}\n"
chat_history_str += (
f"Assistant: {dialog_turn.assistant_response.response_str}\n"
)
return chat_history_str
[docs]
def delete_dialog_turn(self, order: int):
self.dialog_turns.pop(order)
[docs]
def update_dialog_turn(self, order: int, dialog_turn: DialogTurn):
self.dialog_turns[order] = dialog_turn