File size: 2,947 Bytes
3d3d712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from typing import Any, Generator, List, Optional

import requests
from injector import inject

from taskweaver.llm.base import CompletionService, LLMServiceConfig
from taskweaver.llm.util import ChatMessageType, format_chat_message


class AzureMLServiceConfig(LLMServiceConfig):
    def _configure(self) -> None:
        self._set_name("azure_ml")

        shared_api_base = self.llm_module_config.api_base
        self.api_base = self._get_str(
            "api_base",
            shared_api_base,
        )

        shared_api_key = self.llm_module_config.api_key
        self.api_key = self._get_str(
            "api_key",
            shared_api_key,
        )

        self.chat_mode = self._get_bool(
            "chat_mode",
            True,
        )


class AzureMLService(CompletionService):
    @inject
    def __init__(self, config: AzureMLServiceConfig):
        self.config = config

    def chat_completion(
        self,
        messages: List[ChatMessageType],
        use_backup_engine: bool = False,
        stream: bool = True,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
        top_p: Optional[float] = None,
        stop: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> Generator[ChatMessageType, None, None]:
        endpoint = self.config.api_base
        if endpoint.endswith("/"):
            endpoint = endpoint[:-1]

        if endpoint.endswith(".ml.azure.com"):
            endpoint += "/score"

        headers = {
            "Authorization": f"Bearer {self.config.api_key}",
            "Content-Type": "application/json",
        }
        params = {
            # "temperature": 0.0,
            "max_new_tokens": 100,
            # "top_p": 0.0,
            "do_sample": True,
        }
        if self.config.chat_mode:
            prompt = messages
        else:
            prompt = ""
            for msg in messages:
                prompt += f"{msg['role']}: {msg['content']}\n\n"
            prompt = [prompt]

        data = {
            "input_data": {
                "input_string": prompt,
                "parameters": params,
            },
        }
        with requests.Session() as session:
            with session.post(
                endpoint,
                headers=headers,
                json=data,
            ) as response:
                if response.status_code != 200:
                    raise Exception(
                        f"status code {response.status_code}: {response.text}",
                    )
                response_json = response.json()
                print(response_json)
                if "output" not in response_json:
                    raise Exception(f"output is not in response: {response_json}")
                outputs = response_json["output"]
                generation = outputs[0]

        # close connection before yielding
        yield format_chat_message("assistant", generation)