"""Functional interface.
Core functions we use to build across the components.
Users can leverage these functions to customize their own components."""
from typing import (
Dict,
Any,
Callable,
Union,
List,
Tuple,
Optional,
Type,
get_type_hints,
get_origin,
get_args,
Set,
Sequence,
TypeVar,
)
import logging
import numpy as np
from enum import Enum
import re
import json
import yaml
import ast
import threading
from inspect import signature, Parameter
from dataclasses import fields, is_dataclass, MISSING, Field
log = logging.getLogger(__name__)
ExcludeType = Optional[Dict[str, List[str]]]
T_co = TypeVar("T_co", covariant=True)
########################################################################################
# For Dataclass base class and all schema related functions
########################################################################################
[docs]
def custom_asdict(
obj, *, dict_factory=dict, exclude: ExcludeType = None
) -> Dict[str, Any]:
"""Equivalent to asdict() from dataclasses module but with exclude fields.
Return the fields of a dataclass instance as a new dictionary mapping
field names to field values, while allowing certain fields to be excluded.
If given, 'dict_factory' will be used instead of built-in dict.
The function applies recursively to field values that are
dataclass instances. This will also look into built-in containers:
tuples, lists, and dicts.
"""
if not is_dataclass_instance(obj):
raise TypeError("custom_asdict() should be called on dataclass instances")
return _asdict_inner(obj, dict_factory, exclude or {})
def _asdict_inner(obj, dict_factory, exclude):
if is_dataclass_instance(obj):
result = []
for f in fields(obj):
if f.name in exclude.get(obj.__class__.__name__, []):
continue
value = _asdict_inner(getattr(obj, f.name), dict_factory, exclude)
result.append((f.name, value))
return dict_factory(result)
elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
return type(obj)(*[_asdict_inner(v, dict_factory, exclude) for v in obj])
elif isinstance(obj, (list, tuple)):
return type(obj)(_asdict_inner(v, dict_factory, exclude) for v in obj)
elif isinstance(obj, dict):
return type(obj)(
(
_asdict_inner(k, dict_factory, exclude),
_asdict_inner(v, dict_factory, exclude),
)
for k, v in obj.items()
)
else:
return obj
# return deepcopy(obj)
# def dataclass_obj_to_dict(
# obj: Any, exclude: ExcludeType = None, parent_key: str = ""
# ) -> Dict[str, Any]:
# r"""Convert a dataclass object to a dictionary With exclude fields.
# Equivalent to asdict() from dataclasses module but with exclude fields.
# Supports nested dataclasses, lists, and dictionaries.
# Allow exclude keys for each dataclass object.
# Example:
# .. code-block:: python
# from dataclasses import dataclass
# from typing import List
# @dataclass
# class TrecData:
# question: str
# label: int
# @dataclass
# class TrecDataList:
# data: List[TrecData]
# name: str
# trec_data = TrecData(question="What is the capital of France?", label=0)
# trec_data_list = TrecDataList(data=[trec_data], name="trec_data_list")
# dataclass_obj_to_dict(trec_data_list, exclude={"TrecData": ["label"], "TrecDataList": ["name"]})
# # Output:
# # {'data': [{'question': 'What is the capital of France?'}]}
# """
# if not is_dataclass_instance(obj):
# raise ValueError(
# f"dataclass_obj_to_dict() should be called with a dataclass instance."
# )
# if exclude is None:
# exclude = {}
# obj_class_name = obj.__class__.__name__
# current_exclude = exclude.get(obj_class_name, [])
# if hasattr(obj, "__dataclass_fields__"):
# return {
# key: dataclass_obj_to_dict(value, exclude, parent_key=key)
# for key, value in obj.__dict__.items()
# if key not in current_exclude
# }
# elif isinstance(obj, list):
# return [dataclass_obj_to_dict(item, exclude, parent_key) for item in obj]
# elif isinstance(obj, set):
# return {dataclass_obj_to_dict(item, exclude, parent_key) for item in obj}
# elif isinstance(obj, tuple):
# return (dataclass_obj_to_dict(item, exclude, parent_key) for item in obj)
# elif isinstance(obj, dict):
# return {
# key: dataclass_obj_to_dict(value, exclude, parent_key)
# for key, value in obj.items()
# }
# else:
# return deepcopy(obj)
[docs]
def validate_data(data: Dict[str, Any], fieldtypes: Dict[str, Any]) -> bool:
required_fields = {
name for name, type in fieldtypes.items() if _is_required_field(type)
}
return required_fields <= data.keys()
[docs]
def is_potential_dataclass(t):
"""Check if the type is directly a dataclass or potentially a wrapped dataclass like Optional."""
origin = get_origin(t)
if origin is Union:
# This checks if any of the arguments in a Union (which is what Optional is) is a dataclass
return any(is_dataclass(arg) for arg in get_args(t) if arg is not type(None))
return is_dataclass(t)
[docs]
def check_data_class_field_args_zero(cls):
"""Check if the field is a dataclass."""
return (
hasattr(cls, "__args__")
and len(cls.__args__) > 0
and cls.__args__[0]
and hasattr(cls.__args__[0], "__dataclass_fields__")
)
[docs]
def check_if_class_field_args_zero_exists(cls):
"""Check if the field is a dataclass."""
return hasattr(cls, "__args__") and len(cls.__args__) > 0 and cls.__args__[0]
[docs]
def check_data_class_field_args_one(cls):
"""Check if the field is a dataclass."""
return (
hasattr(cls, "__args__")
and len(cls.__args__) > 1
and cls.__args__[1]
and hasattr(cls.__args__[1], "__dataclass_fields__")
)
[docs]
def check_if_class_field_args_one_exists(cls):
"""Check if the field is a dataclass."""
return hasattr(cls, "__args__") and len(cls.__args__) > 1 and cls.__args__[1]
[docs]
def dataclass_obj_from_dict(cls: Type[object], data: Dict[str, object]) -> Any:
r"""Convert a dictionary to a dataclass object.
Supports nested dataclasses, lists, and dictionaries.
.. note::
If any required field is missing, it will raise an error.
Do not use the dict that has excluded required fields.
Example:
.. code-block:: python
from dataclasses import dataclass
from typing import List
@dataclass
class TrecData:
question: str
label: int
@dataclass
class TrecDataList:
data: List[TrecData]
name: str
trec_data_dict = {"data": [{"question": "What is the capital of France?", "label": 0}], "name": "trec_data_list"}
dataclass_obj_from_dict(TrecDataList, trec_data_dict)
# Output:
# TrecDataList(data=[TrecData(question='What is the capital of France?', label=0)], name='trec_data_list')
"""
log.debug(f"Dataclass: {cls}, Data: {data}")
if data is None:
return None
if is_dataclass(cls) or is_potential_dataclass(
cls
): # Optional[Address] will be false, and true for each check
log.debug(
f"{is_dataclass(cls)} of {cls}, {is_potential_dataclass(cls)} of {cls}"
)
# Ensure the data is a dictionary
if not isinstance(data, dict):
raise ValueError(
f"Expected data of type dict for {cls}, but got {type(data).__name__}"
)
cls_type = extract_dataclass_type(cls)
fieldtypes = {f.name: f.type for f in cls_type.__dataclass_fields__.values()}
restored_data = cls_type(
**{
key: dataclass_obj_from_dict(fieldtypes[key], value)
for key, value in data.items()
}
)
return restored_data
elif isinstance(data, (list, tuple)):
log.debug(f"List or Tuple: {cls}, {data}")
restored_data = []
for item in data:
if check_data_class_field_args_zero(cls):
# restore the value to its dataclass type
restored_data.append(dataclass_obj_from_dict(cls.__args__[0], item))
elif check_if_class_field_args_zero_exists(cls):
# Use the original data [Any]
restored_data.append(dataclass_obj_from_dict(cls.__args__[0], item))
else:
restored_data.append(item)
return restored_data
elif isinstance(data, set):
log.debug(f"Set: {cls}, {data}")
restored_data = set()
for item in data:
if check_data_class_field_args_zero(cls):
# restore the value to its dataclass type
restored_data.add(dataclass_obj_from_dict(cls.__args__[0], item))
elif check_if_class_field_args_zero_exists(cls):
# Use the original data [Any]
restored_data.add(dataclass_obj_from_dict(cls.__args__[0], item))
else:
# Use the original data [Any]
restored_data.add(item)
return restored_data
elif isinstance(data, dict):
log.debug(f"Dict: {cls}, {data}")
for key, value in data.items():
if check_data_class_field_args_one(cls):
# restore the value to its dataclass type
data[key] = dataclass_obj_from_dict(cls.__args__[1], value)
elif check_if_class_field_args_one_exists(cls):
# Use the original data [Any]
data[key] = dataclass_obj_from_dict(cls.__args__[1], value)
else:
# Use the original data [Any]
data[key] = value
return data
# else normal data like int, str, float, etc.
else:
log.debug(f"Not datclass, or list, or dict: {cls}, use the original data.")
return data
# Custom representer for OrderedDict
[docs]
def represent_ordereddict(dumper, data):
value = []
for item_key, item_value in data.items():
node_key = dumper.represent_data(item_key)
node_value = dumper.represent_data(item_value)
value.append((node_key, node_value))
return yaml.MappingNode("tag:yaml.org,2002:map", value)
[docs]
def from_dict_to_json(data: Dict[str, Any], sort_keys: bool = False) -> str:
r"""Convert a dictionary to a JSON string."""
try:
return json.dumps(data, indent=4, sort_keys=sort_keys)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to convert dict to JSON: {e}")
[docs]
def from_dict_to_yaml(data: Dict[str, Any], sort_keys: bool = False) -> str:
r"""Convert a dictionary to a YAML string."""
try:
return yaml.dump(data, default_flow_style=False, sort_keys=sort_keys)
except yaml.YAMLError as e:
raise ValueError(f"Failed to convert dict to YAML: {e}")
[docs]
def from_json_to_dict(json_str: str) -> Dict[str, Any]:
r"""Convert a JSON string to a dictionary."""
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to convert JSON to dict: {e}")
[docs]
def from_yaml_to_dict(yaml_str: str) -> Dict[str, Any]:
r"""Convert a YAML string to a dictionary."""
try:
return yaml.safe_load(yaml_str)
except yaml.YAMLError as e:
raise ValueError(f"Failed to convert YAML to dict: {e}")
[docs]
def is_dataclass_instance(obj):
return hasattr(obj, "__dataclass_fields__")
[docs]
def get_type_schema(
type_obj,
exclude: ExcludeType = None,
type_var_map: Optional[Dict] = None,
) -> str:
"""Retrieve the type name, handling complex and nested types."""
origin = get_origin(type_obj)
type_var_map = type_var_map or {}
# Replace type variables with their actual types to support Generic[T/To]
if hasattr(type_obj, "__origin__") and type_obj.__origin__ is not None:
type_obj = type_var_map.get(type_obj.__origin__, type_obj)
else:
type_obj = type_var_map.get(type_obj, type_obj)
if origin is Union:
# Handle Optional[Type] and other unions
args = get_args(type_obj)
types = [
get_type_schema(arg, exclude, type_var_map)
for arg in args
if arg is not type(None)
]
return (
f"Optional[{types[0]}]" if len(types) == 1 else f"Union[{', '.join(types)}]"
)
elif origin in {List, list}:
args = get_args(type_obj)
if args:
inner_type = get_type_schema(args[0], exclude, type_var_map)
return f"List[{inner_type}]"
else:
return "List"
elif origin in {Dict, dict}:
args = get_args(type_obj)
if args and len(args) >= 2:
key_type = get_type_schema(args[0], exclude, type_var_map)
value_type = get_type_schema(args[1], exclude, type_var_map)
return f"Dict[{key_type}, {value_type}]"
else:
return "Dict"
elif origin in {Set, set}:
args = get_args(type_obj)
return (
f"Set[{get_type_schema(args[0],exclude, type_var_map)}]" if args else "Set"
)
elif origin is Sequence:
args = get_args(type_obj)
return (
f"Sequence[{get_type_schema(args[0], exclude,type_var_map)}]"
if args
else "Sequence"
)
elif origin in {Tuple, tuple}:
args = get_args(type_obj)
if args:
return f"Tuple[{', '.join(get_type_schema(arg,exclude,type_var_map) for arg in args)}]"
return "Tuple"
elif is_dataclass(type_obj):
if issubclass(type_obj, Enum):
# Handle Enum dataclass types
enum_members = ", ".join([f"{e.name}={e.value}" for e in type_obj])
return f"Enum[{type_obj.__name__}({enum_members})]"
# Recursively handle nested dataclasses
output = str(get_dataclass_schema(type_obj, exclude, type_var_map))
return output
elif isinstance(type_obj, type) and issubclass(type_obj, Enum):
# Handle Enum types
enum_members = ", ".join([f"{e.name}={e.value}" for e in type_obj])
return f"Enum[{type_obj.__name__}({enum_members})]"
return type_obj.__name__ if hasattr(type_obj, "__name__") else str(type_obj)
[docs]
def get_enum_schema(enum_cls: Type[Enum]) -> Dict[str, object]:
return {
"type": "string",
"enum": [e.value for e in enum_cls],
"description": enum_cls.__doc__ if enum_cls.__doc__ else "",
}
[docs]
def get_dataclass_schema(
cls,
exclude: ExcludeType = None,
type_var_map: Optional[Dict] = None,
) -> Dict[str, Dict[str, object]]:
"""Generate a schema dictionary for a dataclass including nested structures.
1. Support customized dataclass with required_field function.
2. Support nested dataclasses, even with generics like List, Dict, etc.
3. Support metadata in the dataclass fields.
"""
if not is_dataclass(cls):
raise ValueError(
"Provided class is not a dataclass, please decorate your class with @dataclass"
)
type_var_map = type_var_map or {}
# TODO: Add support for having a description in the dataclass
schema = {
"type": cls.__name__,
"properties": {},
"required": [],
# "description": cls.__doc__ if cls.__doc__ else "",
}
# get the exclude list for the current class
current_exclude = exclude.get(cls.__name__, []) if exclude else []
# handle Combination of Enum and dataclass
if issubclass(cls, Enum):
schema["type"] = get_type_schema(cls, exclude, type_var_map)
return schema
for f in fields(cls):
if f.name in current_exclude:
continue
# prepare field schema, it weill be done recursively for nested dataclasses
field_type = type_var_map.get(f.type, f.type)
field_schema = {"type": get_type_schema(field_type, exclude, type_var_map)}
# check required field
is_required = _is_required_field(f)
if is_required:
schema["required"].append(f.name)
# add metadata to the field schema
if f.metadata:
field_schema.update(f.metadata)
# handle nested dataclasses and complex types
schema["properties"][f.name] = field_schema
return schema
def _is_required_field(f: Field) -> bool:
r"""Determine if the field of dataclass is required or optional.
Customized for required_field function."""
# Determine if the field is required or optional
# Using __name__ to check for function identity
if f.default is MISSING and (
f.default_factory is MISSING
or (
hasattr(f.default_factory, "__name__")
and f.default_factory.__name__ == "required_field"
)
):
return True
return False
[docs]
def convert_schema_to_signature(schema: Dict[str, Dict[str, Any]]) -> Dict[str, str]:
r"""Convert the value from get_data_class_schema to a string description."""
signature = {}
schema_to_use = schema.get("properties", {})
required_fields = schema.get("required", [])
for field_name, field_info in schema_to_use.items():
field_signature = field_info.get("desc", "")
# add type to the signature
if field_info["type"]:
field_signature += f" ({field_info['type']})"
# add required/optional to the signature
if field_name in required_fields:
field_signature += " (required)"
else:
field_signature += " (optional)"
signature[field_name] = field_signature
return signature
########################################################################################
# For FunctionTool component
# It uses get_type_schema and get_dataclass_schema to generate the schema of arguments.
########################################################################################
[docs]
def get_fun_schema(name: str, func: Callable[..., object]) -> Dict[str, object]:
r"""Get the schema of a function.
Support dataclass, Union and normal data types such as int, str, float, etc, list, dict, set.
Examples:
def example_function(x: int, y: str = "default") -> int:
return x
schema = get_fun_schema("example_function", example_function)
print(json.dumps(schema, indent=4))
# Output:
{
"type": "object",
"properties": {
"x": {
"type": "int"
},
"y": {
"type": "str",
"default": "default"
}
},
"required": [
"x"
]
}
"""
sig = signature(func)
schema = {"type": "object", "properties": {}, "required": []}
type_hints = get_type_hints(func)
for param_name, parameter in sig.parameters.items():
param_type = type_hints.get(param_name, "Any")
if parameter.default == Parameter.empty:
schema["required"].append(param_name)
schema["properties"][param_name] = {"type": get_type_schema(param_type)}
else:
schema["properties"][param_name] = {
"type": get_type_schema(param_type),
"default": parameter.default,
}
return schema
# For parse function call for FunctionTool component
[docs]
def evaluate_ast_node(node: ast.AST, context_map: Dict[str, Any] = None):
"""
Recursively evaluates an AST node and returns the corresponding Python object.
Args:
node (ast.AST): The AST node to evaluate. This node can represent various parts of Python expressions,
such as literals, identifiers, lists, dictionaries, and function calls.
context_map (Dict[str, Any]): A dictionary that maps variable names to their respective values and functions.
This context is used to resolve names and execute functions.
Returns:
Any: The result of evaluating the node. The type of the returned object depends on the nature of the node:
- Constants return their literal value.
- Names are looked up in the context_map.
- Lists and tuples return their contained values as a list or tuple.
- Dictionaries return a dictionary with keys and values evaluated.
- Function calls invoke the function with evaluated arguments and return its result.
Raises:
ValueError: If the node type is unsupported, a ValueError is raised indicating the inability to evaluate the node.
"""
if isinstance(node, ast.Constant):
return node.value
elif isinstance(node, ast.Dict):
return {
evaluate_ast_node(k): evaluate_ast_node(v)
for k, v in zip(node.keys, node.values)
}
elif isinstance(node, ast.List):
return [evaluate_ast_node(elem) for elem in node.elts]
elif isinstance(node, ast.Tuple):
return tuple(evaluate_ast_node(elem) for elem in node.elts)
elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
return -evaluate_ast_node(node.operand, context_map) # unary minus
elif isinstance(
node, ast.BinOp
): # support "multiply(2024-2017, 12)", the "2024-2017" is a "BinOp" node
left = evaluate_ast_node(node.left, context_map)
right = evaluate_ast_node(node.right, context_map)
if isinstance(node.op, ast.Add):
return left + right
elif isinstance(node.op, ast.Sub):
return left - right
elif isinstance(node.op, ast.Mult):
return left * right
elif isinstance(node.op, ast.Div):
return left / right
elif isinstance(node.op, ast.Mod):
return left % right
elif isinstance(node.op, ast.Pow):
return left**right
else:
log.error(f"Unsupported binary operator: {type(node.op)}")
raise ValueError(f"Unsupported binary operator: {type(node.op)}")
elif isinstance(node, ast.Name): # variable name
try:
output_fun = context_map[node.id]
return output_fun
# TODO: raise the error back to the caller so that the llm can get the error message
except KeyError as e:
log.error(f"Error: {e}, {node.id} does not exist in the context_map.")
raise ValueError(
f"Error: {e}, {node.id} does not exist in the context_map."
)
elif isinstance(node, ast.Attribute): # e.g. math.pi
value = evaluate_ast_node(node.value, context_map)
return getattr(value, node.attr)
elif isinstance(
node, ast.Call
): # another fun or class as argument and value, e.g. add( multiply(4,5), 3)
func = evaluate_ast_node(node.func, context_map)
args = [evaluate_ast_node(arg, context_map) for arg in node.args]
kwargs = {
kw.arg: evaluate_ast_node(kw.value, context_map) for kw in node.keywords
}
output = func(*args, **kwargs)
if hasattr(output, "raw_output"):
return output.raw_output
return output
else:
# directly evaluate the node
# print(f"Unsupported AST node type: {type(node)}")
# return eval(compile(ast.Expression(node), filename="<ast>", mode="eval"))
log.error(f"Unsupported AST node type: {type(node)}")
raise ValueError(f"Unsupported AST node type: {type(node)}")
[docs]
def parse_function_call_expr(
function_expr: str, context_map: Dict[str, Any] = None
) -> Tuple[str, List[Any], Dict[str, Any]]:
"""
Parse a string representing a function call into its components and ensure safe execution by only allowing function calls from a predefined context map.
Args:
function_expr (str): The string representing the function
context_map (Dict[str, Any]): A dictionary that maps variable names to their respective values and functions.
This context is used to resolve names and execute functions.
"""
function_expr = function_expr.strip()
# detect if it is missing the right parenthesis
# if function_expr[-1] != ")":
# # add the right parenthesis
# function_expr += ")"
# Parse the string into an AST
try:
function_expr = extract_function_expression(function_expr)
tree = ast.parse(function_expr, mode="eval")
if isinstance(tree.body, ast.Call):
# Extract the function name
func_name = (
tree.body.func.id if isinstance(tree.body.func, ast.Name) else None
)
# Prepare the list of arguments and keyword arguments
args = [evaluate_ast_node(arg, context_map) for arg in tree.body.args]
keywords = {
kw.arg: evaluate_ast_node(kw.value, context_map)
for kw in tree.body.keywords
}
return func_name, args, keywords
else:
log.error("Provided string is not a function call.")
raise ValueError("Provided string is not a function call.")
except Exception as e:
log.error(f"Error at parse_function_call_expr: {e}")
raise e
[docs]
def generate_function_call_expression_from_callable(
func: Callable[..., Any], *args: Any, **kwargs: Any
) -> str:
"""
Generate a function call expression string from a callable function and its arguments.
Args:
func (Callable[..., Any]): The callable function.
*args (Any): Positional arguments to be passed to the function.
**kwargs (Any): Keyword arguments to be passed to the function.
Returns:
str: The function call expression string.
"""
func_name = func.__name__
args_str = ", ".join(repr(arg) for arg in args)
kwargs_str = ", ".join(f"{k}={repr(v)}" for k, v in kwargs.items())
if args_str and kwargs_str:
full_args_str = f"{args_str}, {kwargs_str}"
else:
full_args_str = args_str or kwargs_str
return f"{func_name}({full_args_str})"
# Define a list of safe built-ins
SAFE_BUILTINS = {
"abs": abs,
"all": all,
"any": any,
"bin": bin,
"bool": bool,
"bytearray": bytearray,
"bytes": bytes,
"callable": callable,
"chr": chr,
"complex": complex,
"dict": dict,
"divmod": divmod,
"enumerate": enumerate,
"filter": filter,
"float": float,
"format": format,
"frozenset": frozenset,
"getattr": getattr,
"hasattr": hasattr,
"hash": hash,
"hex": hex,
"int": int,
"isinstance": isinstance,
"issubclass": issubclass,
"iter": iter,
"len": len,
"list": list,
"map": map,
"max": max,
"min": min,
"next": next,
"object": object,
"oct": oct,
"ord": ord,
"pow": pow,
"range": range,
"repr": repr,
"reversed": reversed,
"round": round,
"set": set,
"slice": slice,
"sorted": sorted,
"str": str,
"sum": sum,
"tuple": tuple,
"type": type,
"zip": zip,
}
[docs]
def sandbox_exec(
code: str, context: Optional[Dict[str, object]] = None, timeout: int = 5
) -> Dict:
r"""Execute code in a sandboxed environment with a timeout.
1. Works similar to eval(), but with timeout and context similar to parse_function_call_expr.
2. With more flexibility as you can write additional function in the code compared with simply the function call.
Args:
code (str): The code to execute. Has to be output=... or similar so that the result can be captured.
context (Dict[str, Any]): The context to use for the execution.
timeout (int): The execution timeout in seconds.
"""
result = {"output": None, "error": None}
context = {**context, **SAFE_BUILTINS} if context else SAFE_BUILTINS
try:
compiled_code = compile(code, "<string>", "exec")
# Result dictionary to store execution results
# Define a target function for the thread
def target():
try:
# Execute the code
exec(compiled_code, context, result)
except Exception as e:
result["error"] = e
# Create a thread to execute the code
thread = threading.Thread(target=target)
thread.start()
thread.join(timeout)
# Check if the thread is still alive (timed out)
if thread.is_alive():
result["error"] = TimeoutError("Execution timed out")
raise TimeoutError("Execution timed out")
except Exception as e:
print(f"Errpr at sandbox_exec: {e}")
raise e
return result
########################################################################################
# For ** component
########################################################################################
[docs]
def compose_model_kwargs(default_model_kwargs: Dict, model_kwargs: Dict) -> Dict:
r"""Add new arguments or overwrite the default arguments with the new arguments.
Example:
model_kwargs = {"temperature": 0.5, "model": "gpt-3.5-turbo"}
self.model_kwargs = {"model": "gpt-3.5"}
combine_kwargs(model_kwargs) => {"temperature": 0.5, "model": "gpt-3.5-turbo"}
"""
pass_model_kwargs = default_model_kwargs.copy()
if model_kwargs:
pass_model_kwargs.update(model_kwargs)
return pass_model_kwargs
########################################################################################
# For Tokenizer component
########################################################################################
VECTOR_TYPE = Union[List[float], np.ndarray]
[docs]
def is_normalized(v: VECTOR_TYPE, tol=1e-4) -> bool:
if isinstance(v, list):
v = np.array(v)
# Compute the norm of the vector (assuming v is 1D)
norm = np.linalg.norm(v)
# Check if the norm is approximately 1
return np.abs(norm - 1) < tol
[docs]
def normalize_np_array(v: np.ndarray) -> np.ndarray:
# Compute the norm of the vector (assuming v is 1D)
norm = np.linalg.norm(v)
# Normalize the vector
normalized_v = v / norm
# Return the normalized vector
return normalized_v
[docs]
def normalize_vector(v: VECTOR_TYPE) -> List[float]:
if isinstance(v, list):
v = np.array(v)
# Compute the norm of the vector (assuming v is 1D)
norm = np.linalg.norm(v)
# Normalize the vector
normalized_v = v / norm
# Return the normalized vector as a list
return normalized_v.tolist()
[docs]
def get_top_k_indices_scores(
scores: Union[List[float], np.ndarray], top_k: int
) -> Tuple[List[int], List[float]]:
if isinstance(scores, list):
scores_np = np.array(scores)
else:
scores_np = scores
top_k_indices = np.argsort(scores_np)[-top_k:][::-1]
top_k_scores = scores_np[top_k_indices]
return top_k_indices.tolist(), top_k_scores.tolist()
[docs]
def generate_readable_key_for_function(fn: Callable) -> str:
module_name = fn.__module__
function_name = fn.__name__
return f"{module_name}.{function_name}"
########################################################################################
# For Parser components
########################################################################################
[docs]
def fix_json_missing_commas(json_str: str) -> str:
# Example: adding missing commas, only after double quotes
# Regular expression to find missing commas
regex = r'(?<=[}\]"\'\d])(\s+)(?=[\{"\[])'
# Add commas where missing
fixed_json_str = re.sub(regex, r",\1", json_str)
return fixed_json_str
[docs]
def fix_json_escaped_single_quotes(json_str: str) -> str:
# First, replace improperly escaped single quotes inside strings
# json_str = re.sub(r"(?<!\\)\'", '"', json_str)
# Fix escaped single quotes
json_str = json_str.replace(r"\'", "'")
return json_str
[docs]
def parse_yaml_str_to_obj(yaml_str: str) -> Dict[str, Any]:
r"""Parse a YAML string to a Python object.
yaml_str: has to be a valid YAML string.
"""
yaml_str = yaml_str.strip()
try:
import yaml
yaml_obj = yaml.safe_load(yaml_str)
return yaml_obj
except yaml.YAMLError as e:
raise ValueError(
f"Got invalid YAML object. Error: {e}. Got YAML string: {yaml_str}"
)
except NameError as exc:
raise ImportError("Please pip install PyYAML.") from exc
[docs]
def parse_json_str_to_obj(json_str: str) -> Union[Dict[str, Any], List[Any]]:
r"""Parse a varietry of json format string to Python object.
json_str: has to be a valid JSON string. Either {} or [].
"""
json_str = json_str.strip()
# 1st attemp with json.loads
try:
json_obj = json.loads(json_str)
return json_obj
except json.JSONDecodeError as e:
log.info(
f"Got invalid JSON object with json.loads. Error: {e}. Got JSON string: {json_str}"
)
# 2nd attemp after fixing the json string
try:
log.info("Trying to fix potential missing commas...")
json_str = fix_json_missing_commas(json_str)
log.info("Trying to fix scaped single quotes...")
json_str = fix_json_escaped_single_quotes(json_str)
log.info(f"Fixed JSON string: {json_str}")
json_obj = json.loads(json_str)
return json_obj
except json.JSONDecodeError:
# 3rd attemp using yaml
try:
# NOTE: parsing again with pyyaml
# pyyaml is less strict, and allows for trailing commas
# right now we rely on this since guidance program generates
# trailing commas
log.info("Parsing JSON string with PyYAML...")
json_obj = yaml.safe_load(json_str)
return json_obj
except yaml.YAMLError as e:
raise ValueError(
f"Got invalid JSON object with yaml.safe_load. Error: {e}. Got JSON string: {json_str}"
)
########################################################################################
# For sampling
########################################################################################
[docs]
def random_sample(
dataset: Sequence[T_co],
num_shots: int,
replace: Optional[bool] = False,
weights: Optional[List[float]] = None,
delta: float = 1e-5, # to avoid zero division
) -> List[T_co]:
r"""
Randomly sample num_shots from the dataset. If replace is True, sample with replacement.
"""
dataset_size = len(dataset)
if dataset_size == 0:
return []
if not replace and num_shots > dataset_size:
log.debug(
f"num_shots {num_shots} is larger than the dataset size {dataset_size}"
)
num_shots = dataset_size
if weights is not None:
weights = np.array(weights)
# Add a small delta to all weights to avoid zero probabilities
weights = weights + delta
if weights.sum() == 0:
raise ValueError("Sum of weights cannot be zero.")
# Normalize weights to sum to 1
weights = weights / weights.sum()
indices = np.random.choice(len(dataset), size=num_shots, replace=replace, p=weights)
return [dataset[i] for i in indices]