Source code for core.prompt_builder
"""Class prompt builder for LightRAG system prompt."""
from typing import Dict, Any, Optional, List, TypeVar
import logging
from functools import lru_cache
from jinja2 import Template, Environment, StrictUndefined, meta
from adalflow.core.component import Component
from adalflow.core.default_prompt_template import DEFAULT_LIGHTRAG_SYSTEM_PROMPT
from adalflow.optim.parameter import Parameter
logger = logging.getLogger(__name__)
T = TypeVar("T")
[docs]
class Prompt(Component):
__doc__ = r"""Renders a text string(prompt) from a Jinja2 template string.
In default, we use the :ref:`DEFAULT_LIGHTRAG_SYSTEM_PROMPT<core-default_prompt_template>` as the template.
Args:
template (str, optional): The Jinja2 template string. Defaults to DEFAULT_LIGHTRAG_SYSTEM_PROMPT.
preset_prompt_kwargs (Optional[Dict], optional): The preset prompt kwargs to fill in the variables in the prompt. Defaults to {}.
Examples:
>>> from core.prompt_builder import Prompt
>>> prompt = Prompt(prompt_kwargs={"task_desc_str": "You are a helpful assistant."})
>>> print(prompt)
>>> prompt.print_prompt_template()
>>> prompt.print_prompt(context_str="This is a context string.")
>>> prompt.call(context_str="This is a context string.")
When examples_str itself is another template with variables, You can use another Prompt to render it.
>>> EXAMPLES_TEMPLATE = r'''
>>> {% if examples %}
>>> {% for example in examples %}
>>> {{loop.index}}. {{example}}
>>> {% endfor %}
>>> {% endif %}
>>> '''
>>> examples_prompt = Prompt(template=EXAMPLES_TEMPLATE)
>>> examples_str = examples_prompt.call(examples=["Example 1", "Example 2"])
>>> # pass it to the main prompt
>>> prompt.print_prompt(examples_str=examples_str)
"""
def __init__(
self,
template: Optional[str] = None,
prompt_kwargs: Optional[Dict[str, Parameter]] = {},
):
super().__init__()
self.template = template or DEFAULT_LIGHTRAG_SYSTEM_PROMPT
self.__create_jinja2_template()
self.prompt_variables: List[str] = []
for var in self._find_template_variables(self.template):
self.prompt_variables.append(var)
logger.info(f"{__class__.__name__} has variables: {self.prompt_variables}")
self.prompt_kwargs = prompt_kwargs
def __create_jinja2_template(self):
r"""Create the Jinja2 template object."""
try:
self.jinja2_template: Template = get_jinja2_environment().from_string(
self.template
)
except Exception as e:
raise ValueError(f"Invalid Jinja2 template: {e}")
[docs]
def update_prompt_kwargs(self, **kwargs):
r"""Update the initial prompt kwargs after Prompt is initialized."""
self.prompt_kwargs.update(kwargs)
[docs]
def get_prompt_variables(self) -> List[str]:
r"""Get the prompt kwargs."""
return self.prompt_variables
[docs]
def is_key_in_template(self, key: str) -> bool:
r"""Check if the key exists in the template."""
return key in self.prompt_variables
def _find_template_variables(self, template_str: str):
"""Automatically find all the variables in the template."""
parsed_content = self.jinja2_template.environment.parse(template_str)
return meta.find_undeclared_variables(parsed_content)
[docs]
def compose_prompt_kwargs(self, **kwargs) -> Dict:
r"""Compose the final prompt kwargs by combining the initial and the provided kwargs at runtime."""
composed_kwargs = {key: None for key in self.prompt_variables}
if self.prompt_kwargs:
composed_kwargs.update(self.prompt_kwargs)
if kwargs:
for key, _ in kwargs.items():
if key not in composed_kwargs:
logger.debug(f"Key {key} does not exist in the prompt_kwargs.")
composed_kwargs.update(kwargs)
return composed_kwargs
[docs]
def print_prompt_template(self):
r"""Print the template string."""
print("Template:")
print("-------")
print(f"{self.template}")
print("-------")
[docs]
def print_prompt(self, **kwargs) -> str:
r"""Print the rendered prompt string using the preset_prompt_kwargs and the provided kwargs."""
try:
pass_kwargs = self.compose_prompt_kwargs(**kwargs)
pass_kwargs = _convert_prompt_kwargs_to_str(pass_kwargs)
logger.debug(f"Prompt kwargs: {pass_kwargs}")
prompt_str = self.jinja2_template.render(**pass_kwargs)
print("Prompt:\n______________________")
print(prompt_str)
return prompt_str
except Exception as e:
raise ValueError(f"Error rendering Jinja2 template: {e}")
[docs]
def call(self, **kwargs) -> str:
"""
Renders the prompt template with keyword arguments. Allow None values.
"""
try:
pass_kwargs = self.compose_prompt_kwargs(**kwargs)
pass_kwargs = _convert_prompt_kwargs_to_str(pass_kwargs)
prompt_str = self.jinja2_template.render(**pass_kwargs)
return prompt_str
except Exception as e:
raise ValueError(f"Error rendering Jinja2 template: {e}")
def _extra_repr(self) -> str:
s = f"template: {self.template}"
prompt_kwargs_str = _convert_prompt_kwargs_to_str(self.prompt_kwargs)
if prompt_kwargs_str:
s += f", prompt_kwargs: {prompt_kwargs_str}"
if self.prompt_variables:
s += f", prompt_variables: {self.prompt_variables}"
return s
[docs]
@classmethod
def from_dict(cls: type[T], data: Dict[str, Any]) -> T:
obj = super().from_dict(data)
# recreate the jinja2 template
obj.jinja2_template = get_jinja2_environment().from_string(obj.template)
return obj
[docs]
def to_dict(self) -> Dict[str, Any]:
"""
Get the dictionary representation of all the Prompt object's attributes, with sorting applied to
dictionary keys and list elements to ensure consistent ordering.
"""
exclude = ["jinja2_template"] # unserializable object
output = super().to_dict(exclude=exclude)
return output
def _convert_prompt_kwargs_to_str(prompt_kwargs: Dict) -> Dict[str, str]:
r"""Convert the prompt_kwargs to a dictionary with string values."""
prompt_kwargs_str: Dict[str, str] = {}
for key, p in prompt_kwargs.items():
if isinstance(p, Parameter):
prompt_kwargs_str[key] = p.data
else:
prompt_kwargs_str[key] = p
return prompt_kwargs_str
[docs]
@lru_cache(None)
def get_jinja2_environment():
r"""Helper function for Prompt component to get the Jinja2 environment with the default settings."""
try:
default_environment = Environment(
undefined=StrictUndefined,
trim_blocks=True,
keep_trailing_newline=True,
lstrip_blocks=True,
)
return default_environment
except Exception as e:
raise ValueError(f"Invalid Jinja2 environment: {e}")