File size: 3,429 Bytes
3d3d712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from http import HTTPStatus
from typing import Any, Generator, List, Optional

from injector import inject

from taskweaver.llm.base import CompletionService, EmbeddingService, LLMServiceConfig
from taskweaver.llm.util import ChatMessageType


class QWenServiceConfig(LLMServiceConfig):
    def _configure(self) -> None:
        self._set_name("qwen")

        shared_api_key = self.llm_module_config.api_key
        self.api_key = self._get_str(
            "api_key",
            shared_api_key,
        )

        shared_model = self.llm_module_config.model
        self.model = self._get_str(
            "model",
            shared_model if shared_model is not None else "qwen-max-1201",
        )

        shared_backup_model = self.llm_module_config.backup_model
        self.backup_model = self._get_str(
            "backup_model",
            shared_backup_model if shared_backup_model is not None else self.model,
        )

        shared_embedding_model = self.llm_module_config.embedding_model
        self.embedding_model = self._get_str(
            "embedding_model",
            shared_embedding_model if shared_embedding_model is not None else self.model,
        )


class QWenService(CompletionService, EmbeddingService):
    dashscope = None

    @inject
    def __init__(self, config: QWenServiceConfig):
        self.config = config

        if QWenService.dashscope is None:
            try:
                import dashscope

                QWenService.dashscope = dashscope
            except Exception:
                raise Exception(
                    "Package dashscope is required for using QWen API. ",
                )
        QWenService.dashscope.api_key = self.config.api_key

    def chat_completion(
        self,
        messages: List[ChatMessageType],
        use_backup_engine: bool = False,
        stream: bool = True,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
        top_p: Optional[float] = None,
        stop: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> Generator[ChatMessageType, None, None]:
        response = QWenService.dashscope.Generation.call(
            model=self.config.model,
            messages=messages,
            result_format="message",  # set the result to be "message" format.
            max_tokens=max_tokens,
            top_p=top_p,
            temperature=temperature,
            stop=stop,
            stream=True,
            incremental_output=True,
        )

        for msg_chunk in response:
            if msg_chunk.status_code == HTTPStatus.OK:
                yield msg_chunk.output.choices[0]["message"]

            else:
                raise Exception(
                    f"QWen API call failed with status code {response.status_code} and error message {response.error}",
                )

    def get_embeddings(self, strings: List[str]) -> List[List[float]]:
        resp = QWenService.dashscope.TextEmbedding.call(
            model=self.config.embedding_model,
            input=strings,
        )
        embeddings = []
        if resp.status_code == HTTPStatus.OK:
            for emb in resp["output"]["embeddings"]:
                embeddings.append(emb["embedding"])
            return embeddings
        else:
            raise Exception(
                f"QWen API call failed with status code {resp.status_code} and error message {resp.error}",
            )