|
import copy |
|
import random |
|
|
|
import dsp |
|
import dspy |
|
|
|
from dspy.predict.parameter import Parameter |
|
from dspy.predict.predict import Predict |
|
from dspy.primitives.prediction import Prediction |
|
from dspy.signatures.field import InputField, OutputField |
|
from dspy.signatures.signature import infer_prefix |
|
|
|
from langchain_core.pydantic_v1 import Extra |
|
from langchain_core.runnables import Runnable |
|
|
|
|
|
class Template2Signature(dspy.Signature): |
|
"""You are a processor for prompts. I will give you a prompt template (Python f-string) for an arbitrary task for other LMs. |
|
Your job is to prepare three modular pieces: (i) any essential task instructions or guidelines, (ii) a list of variable names for inputs, (iv) the variable name for output.""" |
|
|
|
template = dspy.InputField(format=lambda x: f"```\n\n{x.strip()}\n\n```\n\nLet's now prepare three modular pieces.") |
|
essential_instructions = dspy.OutputField() |
|
input_keys = dspy.OutputField(desc='comma-separated list of valid variable names') |
|
output_key = dspy.OutputField(desc='a valid variable name') |
|
|
|
|
|
class ShallowCopyOnly: |
|
def __init__(self, obj): self.obj = obj |
|
def __getattr__(self, item): return getattr(self.obj, item) |
|
def __deepcopy__(self, memo): return ShallowCopyOnly(copy.copy(self.obj)) |
|
|
|
|
|
class LangChainPredict(Predict, Runnable): |
|
class Config: extra = Extra.allow |
|
|
|
def __init__(self, prompt, llm, **config): |
|
Runnable.__init__(self) |
|
Parameter.__init__(self) |
|
|
|
self.langchain_llm = ShallowCopyOnly(llm) |
|
|
|
try: langchain_template = '\n'.join([msg.prompt.template for msg in prompt.messages]) |
|
except AttributeError: langchain_template = prompt.template |
|
|
|
self.stage = random.randbytes(8).hex() |
|
self.signature, self.output_field_key = self._build_signature(langchain_template) |
|
self.config = config |
|
self.reset() |
|
|
|
def reset(self): |
|
self.lm = None |
|
self.traces = [] |
|
self.train = [] |
|
self.demos = [] |
|
|
|
def dump_state(self): |
|
state_keys = ["lm", "traces", "train", "demos"] |
|
return {k: getattr(self, k) for k in state_keys} |
|
|
|
def load_state(self, state): |
|
for name, value in state.items(): |
|
setattr(self, name, value) |
|
|
|
self.demos = [dspy.Example(**x) for x in self.demos] |
|
|
|
def __call__(self, *arg, **kwargs): |
|
if len(arg) > 0: kwargs = {**arg[0], **kwargs} |
|
return self.forward(**kwargs) |
|
|
|
def _build_signature(self, template): |
|
gpt4T = dspy.OpenAI(model='gpt-4-1106-preview', max_tokens=4000, model_type='chat') |
|
|
|
with dspy.context(lm=gpt4T): parts = dspy.Predict(Template2Signature)(template=template) |
|
|
|
inputs = {k.strip(): InputField() for k in parts.input_keys.split(',')} |
|
outputs = {k.strip(): OutputField() for k in parts.output_key.split(',')} |
|
|
|
for k, v in inputs.items(): |
|
v.finalize(k, infer_prefix(k)) |
|
|
|
for k, v in outputs.items(): |
|
output_field_key = k |
|
v.finalize(k, infer_prefix(k)) |
|
|
|
return dsp.Template(parts.essential_instructions, **inputs, **outputs), output_field_key |
|
|
|
def forward(self, **kwargs): |
|
|
|
signature = kwargs.pop("signature", self.signature) |
|
demos = kwargs.pop("demos", self.demos) |
|
config = dict(**self.config, **kwargs.pop("config", {})) |
|
|
|
prompt = signature(dsp.Example(demos=demos, **kwargs)) |
|
output = self.langchain_llm.invoke(prompt, **config) |
|
|
|
try: content = output.content |
|
except AttributeError: content = output |
|
|
|
pred = Prediction.from_completions([{self.output_field_key: content}], signature=signature) |
|
|
|
|
|
|
|
|
|
dspy.settings.langchain_history.append((prompt, pred)) |
|
|
|
if dsp.settings.trace is not None: |
|
trace = dsp.settings.trace |
|
trace.append((self, {**kwargs}, pred)) |
|
|
|
return output |
|
|
|
def invoke(self, d, *args, **kwargs): |
|
|
|
return self.forward(**d) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LangChainModule(dspy.Module): |
|
def __init__(self, lcel): |
|
super().__init__() |
|
|
|
modules = [] |
|
for name, node in lcel.get_graph().nodes.items(): |
|
if isinstance(node.data, LangChainPredict): modules.append(node.data) |
|
|
|
self.modules = modules |
|
self.chain = lcel |
|
|
|
def forward(self, **kwargs): |
|
output_keys = ['output', self.modules[-1].output_field_key] |
|
output = self.chain.invoke(dict(**kwargs)) |
|
|
|
try: output = output.content |
|
except Exception: pass |
|
|
|
return dspy.Prediction({k: output for k in output_keys}) |
|
|
|
def invoke(self, d, *args, **kwargs): |
|
return self.forward(**d).output |
|
|
|
|