EureCA / dspy /predict /retry.py
tonneli's picture
Delete history
f5776d3
import copy
import dspy
import dsp
from .predict import Predict
class Retry(Predict):
def __init__(self, module):
super().__init__(module.signature)
self.module = module
self.original_signature = module.signature.signature
self.original_forward = module.forward
self.new_signature = self._create_new_signature(self.original_signature)
def _create_new_signature(self, original_signature):
extended_signature = {}
input_fields = original_signature.input_fields()
output_fields = original_signature.output_fields()
modified_output_fields = {}
for key, value in output_fields.items():
modified_output_fields[f"past_{key}"] = dspy.InputField(
prefix="Past " + value.prefix,
desc="past output with errors",
format=value.format,
)
extended_signature.update(input_fields)
extended_signature.update(modified_output_fields)
extended_signature["feedback"] = dspy.InputField(
prefix="Instructions:",
desc="Some instructions you must satisfy",
format=str,
)
extended_signature.update(output_fields)
return extended_signature
def forward(self, *args, **kwargs):
for key, value in kwargs["past_outputs"].items():
past_key = f"past_{key}"
if past_key in self.new_signature:
kwargs[past_key] = value
del kwargs["past_outputs"]
kwargs["new_signature"] = self.new_signature
return self.original_forward(**kwargs)
def __call__(self, **kwargs):
cached_kwargs = copy.deepcopy(kwargs)
kwargs["_trace"] = False
kwargs.setdefault("demos", self.demos if self.demos is not None else [])
# perform backtracking
if dspy.settings.backtrack_to == self:
for key, value in dspy.settings.backtrack_to_args.items():
kwargs.setdefault(key, value)
pred = self.forward(**kwargs)
else:
pred = self.module(**kwargs)
# now pop multiple reserved keys
# NOTE(shangyin) past_outputs seems not useful to include in demos,
# therefore dropped
for key in ["_trace", "demos", "signature", "config", "lm", "past_outputs"]:
kwargs.pop(key, None)
if dsp.settings.trace is not None:
trace = dsp.settings.trace
trace.append((self, {**kwargs}, pred))
return pred