|
from http import HTTPStatus |
|
from typing import Any, Generator, List, Optional |
|
|
|
from injector import inject |
|
|
|
from taskweaver.llm.base import CompletionService, EmbeddingService, LLMServiceConfig |
|
from taskweaver.llm.util import ChatMessageType |
|
|
|
|
|
class QWenServiceConfig(LLMServiceConfig): |
|
def _configure(self) -> None: |
|
self._set_name("qwen") |
|
|
|
shared_api_key = self.llm_module_config.api_key |
|
self.api_key = self._get_str( |
|
"api_key", |
|
shared_api_key, |
|
) |
|
|
|
shared_model = self.llm_module_config.model |
|
self.model = self._get_str( |
|
"model", |
|
shared_model if shared_model is not None else "qwen-max-1201", |
|
) |
|
|
|
shared_backup_model = self.llm_module_config.backup_model |
|
self.backup_model = self._get_str( |
|
"backup_model", |
|
shared_backup_model if shared_backup_model is not None else self.model, |
|
) |
|
|
|
shared_embedding_model = self.llm_module_config.embedding_model |
|
self.embedding_model = self._get_str( |
|
"embedding_model", |
|
shared_embedding_model if shared_embedding_model is not None else self.model, |
|
) |
|
|
|
|
|
class QWenService(CompletionService, EmbeddingService): |
|
dashscope = None |
|
|
|
@inject |
|
def __init__(self, config: QWenServiceConfig): |
|
self.config = config |
|
|
|
if QWenService.dashscope is None: |
|
try: |
|
import dashscope |
|
|
|
QWenService.dashscope = dashscope |
|
except Exception: |
|
raise Exception( |
|
"Package dashscope is required for using QWen API. ", |
|
) |
|
QWenService.dashscope.api_key = self.config.api_key |
|
|
|
def chat_completion( |
|
self, |
|
messages: List[ChatMessageType], |
|
use_backup_engine: bool = False, |
|
stream: bool = True, |
|
temperature: Optional[float] = None, |
|
max_tokens: Optional[int] = None, |
|
top_p: Optional[float] = None, |
|
stop: Optional[List[str]] = None, |
|
**kwargs: Any, |
|
) -> Generator[ChatMessageType, None, None]: |
|
response = QWenService.dashscope.Generation.call( |
|
model=self.config.model, |
|
messages=messages, |
|
result_format="message", |
|
max_tokens=max_tokens, |
|
top_p=top_p, |
|
temperature=temperature, |
|
stop=stop, |
|
stream=True, |
|
incremental_output=True, |
|
) |
|
|
|
for msg_chunk in response: |
|
if msg_chunk.status_code == HTTPStatus.OK: |
|
yield msg_chunk.output.choices[0]["message"] |
|
|
|
else: |
|
raise Exception( |
|
f"QWen API call failed with status code {response.status_code} and error message {response.error}", |
|
) |
|
|
|
def get_embeddings(self, strings: List[str]) -> List[List[float]]: |
|
resp = QWenService.dashscope.TextEmbedding.call( |
|
model=self.config.embedding_model, |
|
input=strings, |
|
) |
|
embeddings = [] |
|
if resp.status_code == HTTPStatus.OK: |
|
for emb in resp["output"]["embeddings"]: |
|
embeddings.append(emb["embedding"]) |
|
return embeddings |
|
else: |
|
raise Exception( |
|
f"QWen API call failed with status code {resp.status_code} and error message {resp.error}", |
|
) |
|
|