"""LocalDB to perform in-memory storage and data persistence(pickle or any filesystem) for data models like documents and dialogturn."""
from typing import List, Optional, Callable, Dict, Any, TypeVar, Generic, overload
import logging
import os
from dataclasses import field, dataclass
import pickle
from adalflow.core.component import Component
from adalflow.utils.registry import EntityMapping
from adalflow.utils.global_config import get_adalflow_default_root_path
log = logging.getLogger(__name__)
T = TypeVar("T") # Allow any type as items
U = TypeVar("U") # U will be the type after transformation
# TODO: localDB does not need to be a component
# TODO: DB clarity can be further improved
[docs]
@dataclass
class LocalDB(Generic[T], Component):
__doc__ = """LocalDB with in-memory CRUD operations, data transformation/processing pipelines, and persistence.
LocalDB is highly flexible.
1. It can store any type of data items in the `items` attribute.
2. You can register and apply multiple transformers, and save the transformed data in the `transformed_items` attribute.
This is highly useful to manage experiments with different data transformations.
3. You can save the state of the LocalDB to a pickle file and load it back later. All states are restored.
str(local_db.__dict__) == str(local_db_loaded.__dict__) should be True.
.. note::
The transformer should be of type Component. We made the effort in the library to make every component picklable.
CRUD operations:
1. Create a new db: ``db = LocalDB(name="my_db")``
2. load: Load the db with data. ``db.load([{"text": "hello world"}, {"text": "hello world2"}])``
3. extend: Extend the db with data. ``db.extend([{"text": "hello world3"}])``.
In default, the transformer is applied and the transformed data is extended.
4. add: Add a single item to the db. ``db.add({"text": "hello world4"})``.
In default, the transformer is applied and the transformed data is added.
Unless the transformed data keeps the same length as the original data, the insert operation does not mean insert after the last item.
5. delete: Remove items by index. ``db.delete([0])``.
6. reset: Remove all items. ``db.reset()``, including transformed_items and transformer_setups,and mapper_setups.
Data transformation:
1. Register a transformer first and apply it later
.. code-block:: python
db.register_transformer(transformer, key="test", map_fn=map_fn)
# load data
db.load([{"text": "hello world"}, {"text": "hello world2"}], apply_transformer=True)
# or load data first and apply transformer by key
db.load([{"text": "hello world"}, {"text": "hello world2"}], apply_transformer=False)
db.apply_transformer("test")
2. Add a version of transformed data to the db along with the transformer.
.. code-block:: python
db.transform(transformer, key="test", map_fn=map_fn)
Data persistence:
1. Save the state of the db to a pickle file.
.. code-block:: python
db.save_state("storage/local_item_db.pkl")
2. Load the state of the db from a pickle file.
.. code-block:: python
db2 = LocalDB.load_state("storage/local_item_db.pkl")
3. Check if the loaded and original db are the same.
.. code-block:: python
str(db.__dict__) == str(db2.__dict__) # expect True
Args:
items (List[T], optional): The original data items. Defaults to []. Can be any type such as Document, DialogTurn, dict, text, etc.
The only requirement is that they should be picklable/serializable.
transformed_items (Dict[str, List [U]], optional): Transformed data items by key. Defaults to {}.
Transformer must be of type Component.
transformer_setups (Dict[str, Component], optional): Transformer setup by key. Defaults to {}.
It is used to save the transformer setup for later use.
mapper_setups (Dict[str, Callable[[T], Any]], optional): Map function setup by key. Defaults to {}.
"""
name: Optional[str] = field(
default=None, metadata={"description": "Name of the DB"}
)
items: List[T] = field(
default_factory=list, metadata={"description": "The original data items"}
)
transformed_items: Dict[str, List[U]] = field(
default_factory=dict, metadata={"description": "Transformed data items by key"}
)
transformer_setups: Dict[str, Component] = field(
default_factory=dict, metadata={"description": "Transformer setup by key"}
)
mapper_setups: Dict[str, Callable[[T], Any]] = field(
default_factory=dict, metadata={"description": "Map function setup by key"}
)
index_path: Optional[str] = field(
default="index.faiss", metadata={"description": "Path to the index file"}
)
def __post_init__(self):
super().__init__()
@property
def length(self):
return len(self.items)
# TODO: combine this to fetch_transformed_items
def _get_transformer_name(self, transformer: Component) -> str:
name = f"{transformer.__class__.__name__}_"
for n, _ in transformer.named_components():
name += n + "_"
return name
@overload
def transform(self, key: str) -> str:
"""Apply the transformer by key to the data."""
...
@overload
def transform(
self,
transformer: Component,
key: Optional[str] = None,
map_fn: Optional[Callable[[T], Any]] = None,
) -> str:
"""Register and apply the transformer to the data."""
...
[docs]
def load(self, items: List[Any]):
"""Load the db with new items.
Args:
items (List[Any]): The items to load.
Examples:
.. code-block:: python
db = LocalDB()
db.load([{"text": "hello world"}, {"text": "hello world2"}])
"""
self.items = items
[docs]
def extend(
self,
items: List[Any],
apply_transformer: bool = True,
):
"""Extend the db with new items."""
self.items.extend(items)
if apply_transformer:
for key, transformer in self.transformer_setups.items():
# check if there was a map function registered
transformed_items = []
if key in self.mapper_setups:
map_fn = self.mapper_setups[key]
transformed_items = transformer([map_fn(doc) for doc in items])
else:
transformed_items = transformer(items)
self.transformed_items[key].extend(transformed_items)
[docs]
def delete(self, index: Optional[int] = None, remove_transformed: bool = True):
"""Remove items by index or pop the last item. Optionally remove the transformed data as well.
Assume the transformed item has the same index as the original item. Might not always be the case.
Args:
index (Optional[int], optional): The index to remove. Defaults to None.
remove_transformed (bool, optional): Whether to remove the transformed data as well. Defaults to True.
"""
if remove_transformed:
for key in self.transformed_items.keys():
self.transformed_items[key].pop(index)
self.items.pop(index)
[docs]
def add(
self, item: Any, index: Optional[int] = None, apply_transformer: bool = True
):
"""Add a single item by index or append to the end. Optionally apply the transformer.
.. note::
The item will be transformed using the registered transformer.
Only if the transformed data keeps the same length as the original data, the ``insert`` operation will work correctly.
Args:
item (Any): The item to add.
index (int, optional): The index to add the item at. Defaults to None.
When None, the item is appended to the end.
apply_transformer (bool, optional): Whether to apply the transformer to the item. Defaults to True.
"""
transformed_items: Dict[str, List] = {}
if apply_transformer:
for key, transformer in self.transformer_setups.items():
transformed_docs = []
map_fn = self.mapper_setups.get(key, None)
if map_fn is not None:
transformed_docs = transformer([map_fn(item)])
else:
transformed_docs = transformer([item])
transformed_items[key] = transformed_docs
if index is not None:
self.items.insert(index, item)
for key, transformed_docs in transformed_items.items():
for doc in transformed_docs:
self.transformed_items[key].insert(index, doc)
else:
self.items.append(item)
for key, transformed_docs in transformed_items.items():
self.transformed_items[key].extend(transformed_docs)
# TODO: rename it better to add the condition filter
[docs]
def fetch_items(self, condition: Callable[[T], bool]) -> List[T]:
"""Fetch items with a condition."""
return [item for item in self.items if condition(item)]
[docs]
def reset(self):
r"""Reset all attributes to empty."""
self.mapped_items = {}
self.transformer_setups = {}
self.mapper_setups = {}
self.items = []
[docs]
def save_state(self, filepath: str = None):
"""Save the current state (attributes) of the DB using pickle.
Note:
The transformer setups will be lost when pickling. As it might not be picklable.
"""
filepath = filepath or os.path.join(
get_adalflow_default_root_path,
(
"local_db/local_item_db.pkl"
if not self.name
else f"local_db/{self.name}.pkl"
),
)
self.index_path = filepath
file_dir = os.path.dirname(filepath)
if not os.path.exists(file_dir):
os.makedirs(file_dir, exist_ok=True)
with open(filepath, "wb") as file:
pickle.dump(self, file)
print(f"Saved the state of the DB to {filepath}")
[docs]
@classmethod
def load_state(cls, filepath: str = None) -> "LocalDB":
"""Load the state of the DB from a pickle file."""
filepath = filepath or os.path.join(
get_adalflow_default_root_path, "local_db/local_item_db.pkl"
)
if os.path.exists(filepath):
with open(filepath, "rb") as file:
return pickle.load(file)
def __getstate__(self):
"""Special handling of the components in pickling."""
state = self.__dict__.copy()
_transformer_files = {}
_transformer_type_names = {}
for key, transformer in self.transformer_setups.items():
_transformer_files[key] = transformer.to_dict()
_transformer_type_names[key] = transformer.__class__.__name__
state["transformer_setups"] = {}
state["_transformer_files"] = _transformer_files
state["_transformer_type_names"] = _transformer_type_names
return state
def __setstate__(self, state):
"""Restore state with special handling of the components."""
_transformer_files = state.pop("_transformer_files")
_transformer_type_names = state.pop("_transformer_type_names")
self.__dict__.update(state)
for key, transformer_file in _transformer_files.items():
class_type = (
EntityMapping.get(_transformer_type_names[key])
or globals()[_transformer_type_names[key]]
)
self.transformer_setups[key] = class_type.from_dict(transformer_file)