Spaces:
Runtime error
Runtime error
"""Chain that makes API calls and summarizes the responses to answer a question.""" | |
from __future__ import annotations | |
from typing import Any, Dict, List, Optional | |
from pydantic import BaseModel, Field, root_validator | |
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT | |
from langchain.chains.base import Chain | |
from langchain.chains.llm import LLMChain | |
from langchain.prompts import BasePromptTemplate | |
from langchain.requests import RequestsWrapper | |
from langchain.schema import BaseLanguageModel | |
class APIChain(Chain, BaseModel): | |
"""Chain that makes API calls and summarizes the responses to answer a question.""" | |
api_request_chain: LLMChain | |
api_answer_chain: LLMChain | |
requests_wrapper: RequestsWrapper = Field(exclude=True) | |
api_docs: str | |
question_key: str = "question" #: :meta private: | |
output_key: str = "output" #: :meta private: | |
def input_keys(self) -> List[str]: | |
"""Expect input key. | |
:meta private: | |
""" | |
return [self.question_key] | |
def output_keys(self) -> List[str]: | |
"""Expect output key. | |
:meta private: | |
""" | |
return [self.output_key] | |
def validate_api_request_prompt(cls, values: Dict) -> Dict: | |
"""Check that api request prompt expects the right variables.""" | |
input_vars = values["api_request_chain"].prompt.input_variables | |
expected_vars = {"question", "api_docs"} | |
if set(input_vars) != expected_vars: | |
raise ValueError( | |
f"Input variables should be {expected_vars}, got {input_vars}" | |
) | |
return values | |
def validate_api_answer_prompt(cls, values: Dict) -> Dict: | |
"""Check that api answer prompt expects the right variables.""" | |
input_vars = values["api_answer_chain"].prompt.input_variables | |
expected_vars = {"question", "api_docs", "api_url", "api_response"} | |
if set(input_vars) != expected_vars: | |
raise ValueError( | |
f"Input variables should be {expected_vars}, got {input_vars}" | |
) | |
return values | |
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: | |
question = inputs[self.question_key] | |
api_url = self.api_request_chain.predict( | |
question=question, api_docs=self.api_docs | |
) | |
self.callback_manager.on_text( | |
api_url, color="green", end="\n", verbose=self.verbose | |
) | |
api_response = self.requests_wrapper.get(api_url) | |
self.callback_manager.on_text( | |
api_response, color="yellow", end="\n", verbose=self.verbose | |
) | |
answer = self.api_answer_chain.predict( | |
question=question, | |
api_docs=self.api_docs, | |
api_url=api_url, | |
api_response=api_response, | |
) | |
return {self.output_key: answer} | |
def from_llm_and_api_docs( | |
cls, | |
llm: BaseLanguageModel, | |
api_docs: str, | |
headers: Optional[dict] = None, | |
api_url_prompt: BasePromptTemplate = API_URL_PROMPT, | |
api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT, | |
**kwargs: Any, | |
) -> APIChain: | |
"""Load chain from just an LLM and the api docs.""" | |
get_request_chain = LLMChain(llm=llm, prompt=api_url_prompt) | |
requests_wrapper = RequestsWrapper(headers=headers) | |
get_answer_chain = LLMChain(llm=llm, prompt=api_response_prompt) | |
return cls( | |
api_request_chain=get_request_chain, | |
api_answer_chain=get_answer_chain, | |
requests_wrapper=requests_wrapper, | |
api_docs=api_docs, | |
**kwargs, | |
) | |
def _chain_type(self) -> str: | |
return "api_chain" | |