Spaces:
Runtime error
Runtime error
File size: 5,998 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import json
from typing import Any, Dict, List, Optional, cast
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.pydantic_v1 import SecretStr, validator
from langchain_core.utils import convert_to_secret_str
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase
from langchain.utils import get_from_dict_or_env
class LlamaContentFormatter(ContentFormatterBase):
"""Content formatter for `LLaMA`."""
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> Dict:
"""Converts message to a dict according to role"""
content = cast(str, message.content)
if isinstance(message, HumanMessage):
return {
"role": "user",
"content": ContentFormatterBase.escape_special_characters(content),
}
elif isinstance(message, AIMessage):
return {
"role": "assistant",
"content": ContentFormatterBase.escape_special_characters(content),
}
elif isinstance(message, SystemMessage):
return {
"role": "system",
"content": ContentFormatterBase.escape_special_characters(content),
}
elif (
isinstance(message, ChatMessage)
and message.role in LlamaContentFormatter.SUPPORTED_ROLES
):
return {
"role": message.role,
"content": ContentFormatterBase.escape_special_characters(content),
}
else:
supported = ",".join(
[role for role in LlamaContentFormatter.SUPPORTED_ROLES]
)
raise ValueError(
f"""Received unsupported role.
Supported roles for the LLaMa Foundation Model: {supported}"""
)
def _format_request_payload(
self, messages: List[BaseMessage], model_kwargs: Dict
) -> bytes:
chat_messages = [
LlamaContentFormatter._convert_message_to_dict(message)
for message in messages
]
prompt = json.dumps(
{"input_data": {"input_string": chat_messages, "parameters": model_kwargs}}
)
return self.format_request_payload(prompt=prompt, model_kwargs=model_kwargs)
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
"""Formats the request according to the chosen api"""
return str.encode(prompt)
def format_response_payload(self, output: bytes) -> str:
"""Formats response"""
return json.loads(output)["output"]
class AzureMLChatOnlineEndpoint(SimpleChatModel):
"""`AzureML` Chat models API.
Example:
.. code-block:: python
azure_chat = AzureMLChatOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_key="my-api-key",
content_formatter=content_formatter,
)
"""
endpoint_url: str = ""
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_URL`."""
endpoint_api_key: SecretStr = convert_to_secret_str("")
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_API_KEY`."""
http_client: Any = None #: :meta private:
content_formatter: Any = None
"""The content formatter that provides an input and output
transform function to handle formats between the LLM and
the endpoint"""
model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the model."""
@validator("http_client", always=True, allow_reuse=True)
@classmethod
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
"""Validate that api key and python package exist in environment."""
values["endpoint_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
)
endpoint_url = get_from_dict_or_env(
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
)
http_client = AzureMLEndpointClient(
endpoint_url, values["endpoint_api_key"].get_secret_value()
)
return http_client
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "azureml_chat_endpoint"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to an AzureML Managed Online endpoint.
Args:
messages: The messages in the conversation with the chat model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = azureml_model("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
request_payload = self.content_formatter._format_request_payload(
messages, _model_kwargs
)
response_payload = self.http_client.call(request_payload, **kwargs)
generated_text = self.content_formatter.format_response_payload(
response_payload
)
return generated_text
|