Spaces:
Runtime error
Runtime error
"""Implements Program-Aided Language Models. | |
As in https://arxiv.org/pdf/2211.10435.pdf. | |
""" | |
from __future__ import annotations | |
from typing import Any, Dict, List, Optional | |
from pydantic import BaseModel, Extra | |
from langchain.chains.base import Chain | |
from langchain.chains.llm import LLMChain | |
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT | |
from langchain.chains.pal.math_prompt import MATH_PROMPT | |
from langchain.prompts.base import BasePromptTemplate | |
from langchain.python import PythonREPL | |
from langchain.schema import BaseLanguageModel | |
class PALChain(Chain, BaseModel): | |
"""Implements Program-Aided Language Models.""" | |
llm: BaseLanguageModel | |
prompt: BasePromptTemplate | |
stop: str = "\n\n" | |
get_answer_expr: str = "print(solution())" | |
python_globals: Optional[Dict[str, Any]] = None | |
python_locals: Optional[Dict[str, Any]] = None | |
output_key: str = "result" #: :meta private: | |
return_intermediate_steps: bool = False | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
arbitrary_types_allowed = True | |
def input_keys(self) -> List[str]: | |
"""Return the singular input key. | |
:meta private: | |
""" | |
return self.prompt.input_variables | |
def output_keys(self) -> List[str]: | |
"""Return the singular output key. | |
:meta private: | |
""" | |
if not self.return_intermediate_steps: | |
return [self.output_key] | |
else: | |
return [self.output_key, "intermediate_steps"] | |
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: | |
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) | |
code = llm_chain.predict(stop=[self.stop], **inputs) | |
self.callback_manager.on_text( | |
code, color="green", end="\n", verbose=self.verbose | |
) | |
repl = PythonREPL(_globals=self.python_globals, _locals=self.python_locals) | |
res = repl.run(code + f"\n{self.get_answer_expr}") | |
output = {self.output_key: res.strip()} | |
if self.return_intermediate_steps: | |
output["intermediate_steps"] = code | |
return output | |
def from_math_prompt(cls, llm: BaseLanguageModel, **kwargs: Any) -> PALChain: | |
"""Load PAL from math prompt.""" | |
return cls( | |
llm=llm, | |
prompt=MATH_PROMPT, | |
stop="\n\n", | |
get_answer_expr="print(solution())", | |
**kwargs, | |
) | |
def from_colored_object_prompt( | |
cls, llm: BaseLanguageModel, **kwargs: Any | |
) -> PALChain: | |
"""Load PAL from colored object prompt.""" | |
return cls( | |
llm=llm, | |
prompt=COLORED_OBJECT_PROMPT, | |
stop="\n\n\n", | |
get_answer_expr="print(answer)", | |
**kwargs, | |
) | |
def _chain_type(self) -> str: | |
return "pal_chain" | |