Spaces:
Runtime error
Runtime error
"""Chain that makes API calls and summarizes the responses to answer a question.""" | |
from __future__ import annotations | |
from typing import Any, Dict, List, Optional, Sequence, Tuple | |
from urllib.parse import urlparse | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.prompts import BasePromptTemplate | |
from langchain_core.pydantic_v1 import Field, root_validator | |
from langchain.callbacks.manager import ( | |
AsyncCallbackManagerForChainRun, | |
CallbackManagerForChainRun, | |
) | |
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT | |
from langchain.chains.base import Chain | |
from langchain.chains.llm import LLMChain | |
from langchain.utilities.requests import TextRequestsWrapper | |
def _extract_scheme_and_domain(url: str) -> Tuple[str, str]: | |
"""Extract the scheme + domain from a given URL. | |
Args: | |
url (str): The input URL. | |
Returns: | |
return a 2-tuple of scheme and domain | |
""" | |
parsed_uri = urlparse(url) | |
return parsed_uri.scheme, parsed_uri.netloc | |
def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool: | |
"""Check if a URL is in the allowed domains. | |
Args: | |
url (str): The input URL. | |
limit_to_domains (Sequence[str]): The allowed domains. | |
Returns: | |
bool: True if the URL is in the allowed domains, False otherwise. | |
""" | |
scheme, domain = _extract_scheme_and_domain(url) | |
for allowed_domain in limit_to_domains: | |
allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain) | |
if scheme == allowed_scheme and domain == allowed_domain: | |
return True | |
return False | |
class APIChain(Chain): | |
"""Chain that makes API calls and summarizes the responses to answer a question. | |
*Security Note*: This API chain uses the requests toolkit | |
to make GET, POST, PATCH, PUT, and DELETE requests to an API. | |
Exercise care in who is allowed to use this chain. If exposing | |
to end users, consider that users will be able to make arbitrary | |
requests on behalf of the server hosting the code. For example, | |
users could ask the server to make a request to a private API | |
that is only accessible from the server. | |
Control access to who can submit issue requests using this toolkit and | |
what network access it has. | |
See https://python.langchain.com/docs/security for more information. | |
""" | |
api_request_chain: LLMChain | |
api_answer_chain: LLMChain | |
requests_wrapper: TextRequestsWrapper = Field(exclude=True) | |
api_docs: str | |
question_key: str = "question" #: :meta private: | |
output_key: str = "output" #: :meta private: | |
limit_to_domains: Optional[Sequence[str]] | |
"""Use to limit the domains that can be accessed by the API chain. | |
* For example, to limit to just the domain `https://www.example.com`, set | |
`limit_to_domains=["https://www.example.com"]`. | |
* The default value is an empty tuple, which means that no domains are | |
allowed by default. By design this will raise an error on instantiation. | |
* Use a None if you want to allow all domains by default -- this is not | |
recommended for security reasons, as it would allow malicious users to | |
make requests to arbitrary URLS including internal APIs accessible from | |
the server. | |
""" | |
def input_keys(self) -> List[str]: | |
"""Expect input key. | |
:meta private: | |
""" | |
return [self.question_key] | |
def output_keys(self) -> List[str]: | |
"""Expect output key. | |
:meta private: | |
""" | |
return [self.output_key] | |
def validate_api_request_prompt(cls, values: Dict) -> Dict: | |
"""Check that api request prompt expects the right variables.""" | |
input_vars = values["api_request_chain"].prompt.input_variables | |
expected_vars = {"question", "api_docs"} | |
if set(input_vars) != expected_vars: | |
raise ValueError( | |
f"Input variables should be {expected_vars}, got {input_vars}" | |
) | |
return values | |
def validate_limit_to_domains(cls, values: Dict) -> Dict: | |
"""Check that allowed domains are valid.""" | |
if "limit_to_domains" not in values: | |
raise ValueError( | |
"You must specify a list of domains to limit access using " | |
"`limit_to_domains`" | |
) | |
if not values["limit_to_domains"] and values["limit_to_domains"] is not None: | |
raise ValueError( | |
"Please provide a list of domains to limit access using " | |
"`limit_to_domains`." | |
) | |
return values | |
def validate_api_answer_prompt(cls, values: Dict) -> Dict: | |
"""Check that api answer prompt expects the right variables.""" | |
input_vars = values["api_answer_chain"].prompt.input_variables | |
expected_vars = {"question", "api_docs", "api_url", "api_response"} | |
if set(input_vars) != expected_vars: | |
raise ValueError( | |
f"Input variables should be {expected_vars}, got {input_vars}" | |
) | |
return values | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, str]: | |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
question = inputs[self.question_key] | |
api_url = self.api_request_chain.predict( | |
question=question, | |
api_docs=self.api_docs, | |
callbacks=_run_manager.get_child(), | |
) | |
_run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose) | |
api_url = api_url.strip() | |
if self.limit_to_domains and not _check_in_allowed_domain( | |
api_url, self.limit_to_domains | |
): | |
raise ValueError( | |
f"{api_url} is not in the allowed domains: {self.limit_to_domains}" | |
) | |
api_response = self.requests_wrapper.get(api_url) | |
_run_manager.on_text( | |
api_response, color="yellow", end="\n", verbose=self.verbose | |
) | |
answer = self.api_answer_chain.predict( | |
question=question, | |
api_docs=self.api_docs, | |
api_url=api_url, | |
api_response=api_response, | |
callbacks=_run_manager.get_child(), | |
) | |
return {self.output_key: answer} | |
async def _acall( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | |
) -> Dict[str, str]: | |
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() | |
question = inputs[self.question_key] | |
api_url = await self.api_request_chain.apredict( | |
question=question, | |
api_docs=self.api_docs, | |
callbacks=_run_manager.get_child(), | |
) | |
await _run_manager.on_text( | |
api_url, color="green", end="\n", verbose=self.verbose | |
) | |
api_url = api_url.strip() | |
if self.limit_to_domains and not _check_in_allowed_domain( | |
api_url, self.limit_to_domains | |
): | |
raise ValueError( | |
f"{api_url} is not in the allowed domains: {self.limit_to_domains}" | |
) | |
api_response = await self.requests_wrapper.aget(api_url) | |
await _run_manager.on_text( | |
api_response, color="yellow", end="\n", verbose=self.verbose | |
) | |
answer = await self.api_answer_chain.apredict( | |
question=question, | |
api_docs=self.api_docs, | |
api_url=api_url, | |
api_response=api_response, | |
callbacks=_run_manager.get_child(), | |
) | |
return {self.output_key: answer} | |
def from_llm_and_api_docs( | |
cls, | |
llm: BaseLanguageModel, | |
api_docs: str, | |
headers: Optional[dict] = None, | |
api_url_prompt: BasePromptTemplate = API_URL_PROMPT, | |
api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT, | |
limit_to_domains: Optional[Sequence[str]] = tuple(), | |
**kwargs: Any, | |
) -> APIChain: | |
"""Load chain from just an LLM and the api docs.""" | |
get_request_chain = LLMChain(llm=llm, prompt=api_url_prompt) | |
requests_wrapper = TextRequestsWrapper(headers=headers) | |
get_answer_chain = LLMChain(llm=llm, prompt=api_response_prompt) | |
return cls( | |
api_request_chain=get_request_chain, | |
api_answer_chain=get_answer_chain, | |
requests_wrapper=requests_wrapper, | |
api_docs=api_docs, | |
limit_to_domains=limit_to_domains, | |
**kwargs, | |
) | |
def _chain_type(self) -> str: | |
return "api_chain" | |