"""Base class for Autograd Components that can be called and backpropagated through."""fromtypingimportTYPE_CHECKINGimportuuidifTYPE_CHECKING:fromadalflow.core.generatorimportBackwardEnginefromadalflow.optim.parameterimportParameterfromadalflow.core.componentimportComponent# TODO: make it a subclass of GradComponent
[docs]classLossComponent(Component):__doc__="""A base class to define a loss component. Loss component is to compute the textual gradients/feedback for each of its predecessors using another LLM as the backward engine. Each precessor should have basic information that is passed to its next component to inform its type such as retriever or generator and its role description. Compared with `Component`, `GradComponent` defines three important interfaces: - `forward`: the forward pass of the function, returns a `Parameter` object that can be traced and backpropagated. - `backward`: the backward pass of the function, updates the gradients/prediction score backpropagated from a "loss" parameter. - `set_backward_engine`: set the backward engine(a form of generator) to the component, which is used to backpropagate the gradients. The __call__ method will check if the component is in training mode, and call the `forward` method to return a `Parameter` object if it is in training mode, otherwise, it will call the `call` method to return the output such as "GeneratorOutput", "RetrieverOutput", etc. """backward_engine:"BackwardEngine"_component_type="loss"id=None_disable_backward_engine:booldef__init__(self,*args,**kwargs):super().__init__()super().__setattr__("backward_engine",None)super().__setattr__("id",str(uuid.uuid4()))super().__setattr__("_disable_backward_engine",False)def__call__(self,*args,**kwargs):returnself.forward(*args,**kwargs)
[docs]defset_backward_engine(self,backward_engine:"BackwardEngine",*args,**kwargs):raiseNotImplementedError("set_backward_engine method is not implemented")
[docs]defdisable_backward_engine(self):r"""Does not run gradients generation, but still with backward to gain module-context"""self._disable_backward_engine=True
[docs]defforward(self,*args,**kwargs)->"Parameter":r"""Default just wraps the call method."""raiseNotImplementedError("forward method is not implemented")
[docs]defbackward(self,*args,**kwargs):raiseNotImplementedError("backward method is not implemented")