Spaces:
Runtime error
Runtime error
import logging | |
from typing import Any, Dict, List, Mapping, Optional, cast | |
from langchain_core.messages import ( | |
AIMessage, | |
BaseMessage, | |
ChatMessage, | |
FunctionMessage, | |
HumanMessage, | |
SystemMessage, | |
) | |
from langchain_core.outputs import ( | |
ChatGeneration, | |
ChatResult, | |
) | |
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr | |
from langchain.callbacks.manager import ( | |
AsyncCallbackManagerForLLMRun, | |
CallbackManagerForLLMRun, | |
) | |
from langchain.chat_models.base import BaseChatModel | |
logger = logging.getLogger(__name__) | |
# Ignoring type because below is valid pydantic code | |
# Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg] | |
class ChatParams(BaseModel, extra=Extra.allow): # type: ignore[call-arg] | |
"""Parameters for the `Javelin AI Gateway` LLM.""" | |
temperature: float = 0.0 | |
stop: Optional[List[str]] = None | |
max_tokens: Optional[int] = None | |
class ChatJavelinAIGateway(BaseChatModel): | |
"""`Javelin AI Gateway` chat models API. | |
To use, you should have the ``javelin_sdk`` python package installed. | |
For more information, see https://docs.getjavelin.io | |
Example: | |
.. code-block:: python | |
from langchain.chat_models import ChatJavelinAIGateway | |
chat = ChatJavelinAIGateway( | |
gateway_uri="<javelin-ai-gateway-uri>", | |
route="<javelin-ai-gateway-chat-route>", | |
params={ | |
"temperature": 0.1 | |
} | |
) | |
""" | |
route: str | |
"""The route to use for the Javelin AI Gateway API.""" | |
gateway_uri: Optional[str] = None | |
"""The URI for the Javelin AI Gateway API.""" | |
params: Optional[ChatParams] = None | |
"""Parameters for the Javelin AI Gateway LLM.""" | |
client: Any | |
"""javelin client.""" | |
javelin_api_key: Optional[SecretStr] = None | |
"""The API key for the Javelin AI Gateway.""" | |
def __init__(self, **kwargs: Any): | |
try: | |
from javelin_sdk import ( | |
JavelinClient, | |
UnauthorizedError, | |
) | |
except ImportError: | |
raise ImportError( | |
"Could not import javelin_sdk python package. " | |
"Please install it with `pip install javelin_sdk`." | |
) | |
super().__init__(**kwargs) | |
if self.gateway_uri: | |
try: | |
self.client = JavelinClient( | |
base_url=self.gateway_uri, | |
api_key=cast(SecretStr, self.javelin_api_key).get_secret_value(), | |
) | |
except UnauthorizedError as e: | |
raise ValueError("Javelin: Incorrect API Key.") from e | |
def _default_params(self) -> Dict[str, Any]: | |
params: Dict[str, Any] = { | |
"gateway_uri": self.gateway_uri, | |
"javelin_api_key": cast(SecretStr, self.javelin_api_key).get_secret_value(), | |
"route": self.route, | |
**(self.params.dict() if self.params else {}), | |
} | |
return params | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
message_dicts = [ | |
ChatJavelinAIGateway._convert_message_to_dict(message) | |
for message in messages | |
] | |
data: Dict[str, Any] = { | |
"messages": message_dicts, | |
**(self.params.dict() if self.params else {}), | |
} | |
resp = self.client.query_route(self.route, query_body=data) | |
return ChatJavelinAIGateway._create_chat_result(resp.dict()) | |
async def _agenerate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
message_dicts = [ | |
ChatJavelinAIGateway._convert_message_to_dict(message) | |
for message in messages | |
] | |
data: Dict[str, Any] = { | |
"messages": message_dicts, | |
**(self.params.dict() if self.params else {}), | |
} | |
resp = await self.client.aquery_route(self.route, query_body=data) | |
return ChatJavelinAIGateway._create_chat_result(resp.dict()) | |
def _identifying_params(self) -> Dict[str, Any]: | |
return self._default_params | |
def _get_invocation_params( | |
self, stop: Optional[List[str]] = None, **kwargs: Any | |
) -> Dict[str, Any]: | |
"""Get the parameters used to invoke the model FOR THE CALLBACKS.""" | |
return { | |
**self._default_params, | |
**super()._get_invocation_params(stop=stop, **kwargs), | |
} | |
def _llm_type(self) -> str: | |
"""Return type of chat model.""" | |
return "javelin-ai-gateway-chat" | |
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: | |
role = _dict["role"] | |
content = _dict["content"] | |
if role == "user": | |
return HumanMessage(content=content) | |
elif role == "assistant": | |
return AIMessage(content=content) | |
elif role == "system": | |
return SystemMessage(content=content) | |
else: | |
return ChatMessage(content=content, role=role) | |
def _raise_functions_not_supported() -> None: | |
raise ValueError( | |
"Function messages are not supported by the Javelin AI Gateway. Please" | |
" create a feature request at https://docs.getjavelin.io" | |
) | |
def _convert_message_to_dict(message: BaseMessage) -> dict: | |
if isinstance(message, ChatMessage): | |
message_dict = {"role": message.role, "content": message.content} | |
elif isinstance(message, HumanMessage): | |
message_dict = {"role": "user", "content": message.content} | |
elif isinstance(message, AIMessage): | |
message_dict = {"role": "assistant", "content": message.content} | |
elif isinstance(message, SystemMessage): | |
message_dict = {"role": "system", "content": message.content} | |
elif isinstance(message, FunctionMessage): | |
raise ValueError( | |
"Function messages are not supported by the Javelin AI Gateway. Please" | |
" create a feature request at https://docs.getjavelin.io" | |
) | |
else: | |
raise ValueError(f"Got unknown message type: {message}") | |
if "function_call" in message.additional_kwargs: | |
ChatJavelinAIGateway._raise_functions_not_supported() | |
if message.additional_kwargs: | |
logger.warning( | |
"Additional message arguments are unsupported by Javelin AI Gateway " | |
" and will be ignored: %s", | |
message.additional_kwargs, | |
) | |
return message_dict | |
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult: | |
generations = [] | |
for candidate in response["llm_response"]["choices"]: | |
message = ChatJavelinAIGateway._convert_dict_to_message( | |
candidate["message"] | |
) | |
message_metadata = candidate.get("metadata", {}) | |
gen = ChatGeneration( | |
message=message, | |
generation_info=dict(message_metadata), | |
) | |
generations.append(gen) | |
response_metadata = response.get("metadata", {}) | |
return ChatResult(generations=generations, llm_output=response_metadata) | |