Spaces:
Runtime error
Runtime error
import asyncio | |
import json | |
import logging | |
from functools import partial | |
from typing import Any, AsyncIterator, Dict, List, Optional, cast | |
import requests | |
from langchain_core.messages import ( | |
AIMessage, | |
AIMessageChunk, | |
BaseMessage, | |
ChatMessage, | |
HumanMessage, | |
SystemMessage, | |
) | |
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | |
from langchain_core.pydantic_v1 import root_validator | |
from langchain.callbacks.manager import ( | |
AsyncCallbackManagerForLLMRun, | |
CallbackManagerForLLMRun, | |
) | |
from langchain.chat_models.base import BaseChatModel | |
from langchain.llms.utils import enforce_stop_tokens | |
from langchain.utils import get_from_dict_or_env | |
logger = logging.getLogger(__name__) | |
class PaiEasChatEndpoint(BaseChatModel): | |
"""Eas LLM Service chat model API. | |
To use, must have a deployed eas chat llm service on AliCloud. One can set the | |
environment variable ``eas_service_url`` and ``eas_service_token`` set with your eas | |
service url and service token. | |
Example: | |
.. code-block:: python | |
from langchain.chat_models import PaiEasChatEndpoint | |
eas_chat_endpoint = PaiEasChatEndpoint( | |
eas_service_url="your_service_url", | |
eas_service_token="your_service_token" | |
) | |
""" | |
"""PAI-EAS Service URL""" | |
eas_service_url: str | |
"""PAI-EAS Service TOKEN""" | |
eas_service_token: str | |
"""PAI-EAS Service Infer Params""" | |
max_new_tokens: Optional[int] = 512 | |
temperature: Optional[float] = 0.8 | |
top_p: Optional[float] = 0.1 | |
top_k: Optional[int] = 10 | |
do_sample: Optional[bool] = False | |
use_cache: Optional[bool] = True | |
stop_sequences: Optional[List[str]] = None | |
"""Enable stream chat mode.""" | |
streaming: bool = False | |
"""Key/value arguments to pass to the model. Reserved for future use""" | |
model_kwargs: Optional[dict] = None | |
version: Optional[str] = "2.0" | |
timeout: Optional[int] = 5000 | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that api key and python package exists in environment.""" | |
values["eas_service_url"] = get_from_dict_or_env( | |
values, "eas_service_url", "EAS_SERVICE_URL" | |
) | |
values["eas_service_token"] = get_from_dict_or_env( | |
values, "eas_service_token", "EAS_SERVICE_TOKEN" | |
) | |
return values | |
def _identifying_params(self) -> Dict[str, Any]: | |
"""Get the identifying parameters.""" | |
_model_kwargs = self.model_kwargs or {} | |
return { | |
"eas_service_url": self.eas_service_url, | |
"eas_service_token": self.eas_service_token, | |
**{"model_kwargs": _model_kwargs}, | |
} | |
def _llm_type(self) -> str: | |
"""Return type of llm.""" | |
return "pai_eas_chat_endpoint" | |
def _default_params(self) -> Dict[str, Any]: | |
"""Get the default parameters for calling Cohere API.""" | |
return { | |
"max_new_tokens": self.max_new_tokens, | |
"temperature": self.temperature, | |
"top_k": self.top_k, | |
"top_p": self.top_p, | |
"stop_sequences": [], | |
"do_sample": self.do_sample, | |
"use_cache": self.use_cache, | |
} | |
def _invocation_params( | |
self, stop_sequences: Optional[List[str]], **kwargs: Any | |
) -> dict: | |
params = self._default_params | |
if self.model_kwargs: | |
params.update(self.model_kwargs) | |
if self.stop_sequences is not None and stop_sequences is not None: | |
raise ValueError("`stop` found in both the input and default params.") | |
elif self.stop_sequences is not None: | |
params["stop"] = self.stop_sequences | |
else: | |
params["stop"] = stop_sequences | |
return {**params, **kwargs} | |
def format_request_payload( | |
self, messages: List[BaseMessage], **model_kwargs: Any | |
) -> dict: | |
prompt: Dict[str, Any] = {} | |
user_content: List[str] = [] | |
assistant_content: List[str] = [] | |
for message in messages: | |
"""Converts message to a dict according to role""" | |
content = cast(str, message.content) | |
if isinstance(message, HumanMessage): | |
user_content = user_content + [content] | |
elif isinstance(message, AIMessage): | |
assistant_content = assistant_content + [content] | |
elif isinstance(message, SystemMessage): | |
prompt["system_prompt"] = content | |
elif isinstance(message, ChatMessage) and message.role in [ | |
"user", | |
"assistant", | |
"system", | |
]: | |
if message.role == "system": | |
prompt["system_prompt"] = content | |
elif message.role == "user": | |
user_content = user_content + [content] | |
elif message.role == "assistant": | |
assistant_content = assistant_content + [content] | |
else: | |
supported = ",".join([role for role in ["user", "assistant", "system"]]) | |
raise ValueError( | |
f"""Received unsupported role. | |
Supported roles for the LLaMa Foundation Model: {supported}""" | |
) | |
prompt["prompt"] = user_content[len(user_content) - 1] | |
history = [ | |
history_item | |
for _, history_item in enumerate(zip(user_content[:-1], assistant_content)) | |
] | |
prompt["history"] = history | |
return {**prompt, **model_kwargs} | |
def _format_response_payload( | |
self, output: bytes, stop_sequences: Optional[List[str]] | |
) -> str: | |
"""Formats response""" | |
try: | |
text = json.loads(output)["response"] | |
if stop_sequences: | |
text = enforce_stop_tokens(text, stop_sequences) | |
return text | |
except Exception as e: | |
if isinstance(e, json.decoder.JSONDecodeError): | |
return output.decode("utf-8") | |
raise e | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs) | |
message = AIMessage(content=output_str) | |
generation = ChatGeneration(message=message) | |
return ChatResult(generations=[generation]) | |
def _call( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
params = self._invocation_params(stop, **kwargs) | |
request_payload = self.format_request_payload(messages, **params) | |
response_payload = self._call_eas(request_payload) | |
generated_text = self._format_response_payload(response_payload, params["stop"]) | |
if run_manager: | |
run_manager.on_llm_new_token(generated_text) | |
return generated_text | |
def _call_eas(self, query_body: dict) -> Any: | |
"""Generate text from the eas service.""" | |
headers = { | |
"Content-Type": "application/json", | |
"Accept": "application/json", | |
"Authorization": f"{self.eas_service_token}", | |
} | |
# make request | |
response = requests.post( | |
self.eas_service_url, headers=headers, json=query_body, timeout=self.timeout | |
) | |
if response.status_code != 200: | |
raise Exception( | |
f"Request failed with status code {response.status_code}" | |
f" and message {response.text}" | |
) | |
return response.text | |
def _call_eas_stream(self, query_body: dict) -> Any: | |
"""Generate text from the eas service.""" | |
headers = { | |
"Content-Type": "application/json", | |
"Accept": "application/json", | |
"Authorization": f"{self.eas_service_token}", | |
} | |
# make request | |
response = requests.post( | |
self.eas_service_url, headers=headers, json=query_body, timeout=self.timeout | |
) | |
if response.status_code != 200: | |
raise Exception( | |
f"Request failed with status code {response.status_code}" | |
f" and message {response.text}" | |
) | |
return response | |
def _convert_chunk_to_message_message( | |
self, | |
chunk: str, | |
) -> AIMessageChunk: | |
data = json.loads(chunk.encode("utf-8")) | |
return AIMessageChunk(content=data.get("response", "")) | |
async def _astream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> AsyncIterator[ChatGenerationChunk]: | |
params = self._invocation_params(stop, **kwargs) | |
request_payload = self.format_request_payload(messages, **params) | |
request_payload["use_stream_chat"] = True | |
response = self._call_eas_stream(request_payload) | |
for chunk in response.iter_lines( | |
chunk_size=8192, decode_unicode=False, delimiter=b"\0" | |
): | |
if chunk: | |
content = self._convert_chunk_to_message_message(chunk) | |
# identify stop sequence in generated text, if any | |
stop_seq_found: Optional[str] = None | |
for stop_seq in params["stop"]: | |
if stop_seq in content.content: | |
stop_seq_found = stop_seq | |
# identify text to yield | |
text: Optional[str] = None | |
if stop_seq_found: | |
content.content = content.content[ | |
: content.content.index(stop_seq_found) | |
] | |
# yield text, if any | |
if text: | |
if run_manager: | |
await run_manager.on_llm_new_token(cast(str, content.content)) | |
yield ChatGenerationChunk(message=content) | |
# break if stop sequence found | |
if stop_seq_found: | |
break | |
async def _agenerate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
stream: Optional[bool] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
if stream if stream is not None else self.streaming: | |
generation: Optional[ChatGenerationChunk] = None | |
async for chunk in self._astream( | |
messages=messages, stop=stop, run_manager=run_manager, **kwargs | |
): | |
generation = chunk | |
assert generation is not None | |
return ChatResult(generations=[generation]) | |
func = partial( | |
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs | |
) | |
return await asyncio.get_event_loop().run_in_executor(None, func) | |