File size: 3,429 Bytes
3d3d712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from http import HTTPStatus
from typing import Any, Generator, List, Optional
from injector import inject
from taskweaver.llm.base import CompletionService, EmbeddingService, LLMServiceConfig
from taskweaver.llm.util import ChatMessageType
class QWenServiceConfig(LLMServiceConfig):
def _configure(self) -> None:
self._set_name("qwen")
shared_api_key = self.llm_module_config.api_key
self.api_key = self._get_str(
"api_key",
shared_api_key,
)
shared_model = self.llm_module_config.model
self.model = self._get_str(
"model",
shared_model if shared_model is not None else "qwen-max-1201",
)
shared_backup_model = self.llm_module_config.backup_model
self.backup_model = self._get_str(
"backup_model",
shared_backup_model if shared_backup_model is not None else self.model,
)
shared_embedding_model = self.llm_module_config.embedding_model
self.embedding_model = self._get_str(
"embedding_model",
shared_embedding_model if shared_embedding_model is not None else self.model,
)
class QWenService(CompletionService, EmbeddingService):
dashscope = None
@inject
def __init__(self, config: QWenServiceConfig):
self.config = config
if QWenService.dashscope is None:
try:
import dashscope
QWenService.dashscope = dashscope
except Exception:
raise Exception(
"Package dashscope is required for using QWen API. ",
)
QWenService.dashscope.api_key = self.config.api_key
def chat_completion(
self,
messages: List[ChatMessageType],
use_backup_engine: bool = False,
stream: bool = True,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Generator[ChatMessageType, None, None]:
response = QWenService.dashscope.Generation.call(
model=self.config.model,
messages=messages,
result_format="message", # set the result to be "message" format.
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
stop=stop,
stream=True,
incremental_output=True,
)
for msg_chunk in response:
if msg_chunk.status_code == HTTPStatus.OK:
yield msg_chunk.output.choices[0]["message"]
else:
raise Exception(
f"QWen API call failed with status code {response.status_code} and error message {response.error}",
)
def get_embeddings(self, strings: List[str]) -> List[List[float]]:
resp = QWenService.dashscope.TextEmbedding.call(
model=self.config.embedding_model,
input=strings,
)
embeddings = []
if resp.status_code == HTTPStatus.OK:
for emb in resp["output"]["embeddings"]:
embeddings.append(emb["embedding"])
return embeddings
else:
raise Exception(
f"QWen API call failed with status code {resp.status_code} and error message {resp.error}",
)
|