tskwvr / taskweaver /llm /azure_ml.py
TRaw's picture
Upload 297 files
3d3d712
from typing import Any, Generator, List, Optional
import requests
from injector import inject
from taskweaver.llm.base import CompletionService, LLMServiceConfig
from taskweaver.llm.util import ChatMessageType, format_chat_message
class AzureMLServiceConfig(LLMServiceConfig):
def _configure(self) -> None:
self._set_name("azure_ml")
shared_api_base = self.llm_module_config.api_base
self.api_base = self._get_str(
"api_base",
shared_api_base,
)
shared_api_key = self.llm_module_config.api_key
self.api_key = self._get_str(
"api_key",
shared_api_key,
)
self.chat_mode = self._get_bool(
"chat_mode",
True,
)
class AzureMLService(CompletionService):
@inject
def __init__(self, config: AzureMLServiceConfig):
self.config = config
def chat_completion(
self,
messages: List[ChatMessageType],
use_backup_engine: bool = False,
stream: bool = True,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Generator[ChatMessageType, None, None]:
endpoint = self.config.api_base
if endpoint.endswith("/"):
endpoint = endpoint[:-1]
if endpoint.endswith(".ml.azure.com"):
endpoint += "/score"
headers = {
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json",
}
params = {
# "temperature": 0.0,
"max_new_tokens": 100,
# "top_p": 0.0,
"do_sample": True,
}
if self.config.chat_mode:
prompt = messages
else:
prompt = ""
for msg in messages:
prompt += f"{msg['role']}: {msg['content']}\n\n"
prompt = [prompt]
data = {
"input_data": {
"input_string": prompt,
"parameters": params,
},
}
with requests.Session() as session:
with session.post(
endpoint,
headers=headers,
json=data,
) as response:
if response.status_code != 200:
raise Exception(
f"status code {response.status_code}: {response.text}",
)
response_json = response.json()
print(response_json)
if "output" not in response_json:
raise Exception(f"output is not in response: {response_json}")
outputs = response_json["output"]
generation = outputs[0]
# close connection before yielding
yield format_chat_message("assistant", generation)