Spaces:
Runtime error
Runtime error
from typing import Any, Dict, List, Mapping, Optional | |
import requests | |
from langchain_core.pydantic_v1 import Extra | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from langchain.llms.base import LLM | |
from langchain.llms.utils import enforce_stop_tokens | |
class ContentHandlerAmazonAPIGateway: | |
"""Adapter to prepare the inputs from Langchain to a format | |
that LLM model expects. | |
It also provides helper function to extract | |
the generated text from the model response.""" | |
def transform_input( | |
cls, prompt: str, model_kwargs: Dict[str, Any] | |
) -> Dict[str, Any]: | |
return {"inputs": prompt, "parameters": model_kwargs} | |
def transform_output(cls, response: Any) -> str: | |
return response.json()[0]["generated_text"] | |
class AmazonAPIGateway(LLM): | |
"""Amazon API Gateway to access LLM models hosted on AWS.""" | |
api_url: str | |
"""API Gateway URL""" | |
headers: Optional[Dict] = None | |
"""API Gateway HTTP Headers to send, e.g. for authentication""" | |
model_kwargs: Optional[Dict] = None | |
"""Keyword arguments to pass to the model.""" | |
content_handler: ContentHandlerAmazonAPIGateway = ContentHandlerAmazonAPIGateway() | |
"""The content handler class that provides an input and | |
output transform functions to handle formats between LLM | |
and the endpoint. | |
""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
def _identifying_params(self) -> Mapping[str, Any]: | |
"""Get the identifying parameters.""" | |
_model_kwargs = self.model_kwargs or {} | |
return { | |
**{"api_url": self.api_url, "headers": self.headers}, | |
**{"model_kwargs": _model_kwargs}, | |
} | |
def _llm_type(self) -> str: | |
"""Return type of llm.""" | |
return "amazon_api_gateway" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
"""Call out to Amazon API Gateway model. | |
Args: | |
prompt: The prompt to pass into the model. | |
stop: Optional list of stop words to use when generating. | |
Returns: | |
The string generated by the model. | |
Example: | |
.. code-block:: python | |
response = se("Tell me a joke.") | |
""" | |
_model_kwargs = self.model_kwargs or {} | |
payload = self.content_handler.transform_input(prompt, _model_kwargs) | |
try: | |
response = requests.post( | |
self.api_url, | |
headers=self.headers, | |
json=payload, | |
) | |
text = self.content_handler.transform_output(response) | |
except Exception as error: | |
raise ValueError(f"Error raised by the service: {error}") | |
if stop is not None: | |
text = enforce_stop_tokens(text, stop) | |
return text | |