|
from typing import Any, Generator, List, Optional |
|
|
|
import requests |
|
from injector import inject |
|
|
|
from taskweaver.llm.base import CompletionService, LLMServiceConfig |
|
from taskweaver.llm.util import ChatMessageType, format_chat_message |
|
|
|
|
|
class AzureMLServiceConfig(LLMServiceConfig): |
|
def _configure(self) -> None: |
|
self._set_name("azure_ml") |
|
|
|
shared_api_base = self.llm_module_config.api_base |
|
self.api_base = self._get_str( |
|
"api_base", |
|
shared_api_base, |
|
) |
|
|
|
shared_api_key = self.llm_module_config.api_key |
|
self.api_key = self._get_str( |
|
"api_key", |
|
shared_api_key, |
|
) |
|
|
|
self.chat_mode = self._get_bool( |
|
"chat_mode", |
|
True, |
|
) |
|
|
|
|
|
class AzureMLService(CompletionService): |
|
@inject |
|
def __init__(self, config: AzureMLServiceConfig): |
|
self.config = config |
|
|
|
def chat_completion( |
|
self, |
|
messages: List[ChatMessageType], |
|
use_backup_engine: bool = False, |
|
stream: bool = True, |
|
temperature: Optional[float] = None, |
|
max_tokens: Optional[int] = None, |
|
top_p: Optional[float] = None, |
|
stop: Optional[List[str]] = None, |
|
**kwargs: Any, |
|
) -> Generator[ChatMessageType, None, None]: |
|
endpoint = self.config.api_base |
|
if endpoint.endswith("/"): |
|
endpoint = endpoint[:-1] |
|
|
|
if endpoint.endswith(".ml.azure.com"): |
|
endpoint += "/score" |
|
|
|
headers = { |
|
"Authorization": f"Bearer {self.config.api_key}", |
|
"Content-Type": "application/json", |
|
} |
|
params = { |
|
|
|
"max_new_tokens": 100, |
|
|
|
"do_sample": True, |
|
} |
|
if self.config.chat_mode: |
|
prompt = messages |
|
else: |
|
prompt = "" |
|
for msg in messages: |
|
prompt += f"{msg['role']}: {msg['content']}\n\n" |
|
prompt = [prompt] |
|
|
|
data = { |
|
"input_data": { |
|
"input_string": prompt, |
|
"parameters": params, |
|
}, |
|
} |
|
with requests.Session() as session: |
|
with session.post( |
|
endpoint, |
|
headers=headers, |
|
json=data, |
|
) as response: |
|
if response.status_code != 200: |
|
raise Exception( |
|
f"status code {response.status_code}: {response.text}", |
|
) |
|
response_json = response.json() |
|
print(response_json) |
|
if "output" not in response_json: |
|
raise Exception(f"output is not in response: {response_json}") |
|
outputs = response_json["output"] |
|
generation = outputs[0] |
|
|
|
|
|
yield format_chat_message("assistant", generation) |
|
|