TRaw's picture
Upload 297 files
3d3d712
from http import HTTPStatus
from typing import Any, Generator, List, Optional
from injector import inject
from taskweaver.llm.base import CompletionService, EmbeddingService, LLMServiceConfig
from taskweaver.llm.util import ChatMessageType
class QWenServiceConfig(LLMServiceConfig):
def _configure(self) -> None:
self._set_name("qwen")
shared_api_key = self.llm_module_config.api_key
self.api_key = self._get_str(
"api_key",
shared_api_key,
)
shared_model = self.llm_module_config.model
self.model = self._get_str(
"model",
shared_model if shared_model is not None else "qwen-max-1201",
)
shared_backup_model = self.llm_module_config.backup_model
self.backup_model = self._get_str(
"backup_model",
shared_backup_model if shared_backup_model is not None else self.model,
)
shared_embedding_model = self.llm_module_config.embedding_model
self.embedding_model = self._get_str(
"embedding_model",
shared_embedding_model if shared_embedding_model is not None else self.model,
)
class QWenService(CompletionService, EmbeddingService):
dashscope = None
@inject
def __init__(self, config: QWenServiceConfig):
self.config = config
if QWenService.dashscope is None:
try:
import dashscope
QWenService.dashscope = dashscope
except Exception:
raise Exception(
"Package dashscope is required for using QWen API. ",
)
QWenService.dashscope.api_key = self.config.api_key
def chat_completion(
self,
messages: List[ChatMessageType],
use_backup_engine: bool = False,
stream: bool = True,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Generator[ChatMessageType, None, None]:
response = QWenService.dashscope.Generation.call(
model=self.config.model,
messages=messages,
result_format="message", # set the result to be "message" format.
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
stop=stop,
stream=True,
incremental_output=True,
)
for msg_chunk in response:
if msg_chunk.status_code == HTTPStatus.OK:
yield msg_chunk.output.choices[0]["message"]
else:
raise Exception(
f"QWen API call failed with status code {response.status_code} and error message {response.error}",
)
def get_embeddings(self, strings: List[str]) -> List[List[float]]:
resp = QWenService.dashscope.TextEmbedding.call(
model=self.config.embedding_model,
input=strings,
)
embeddings = []
if resp.status_code == HTTPStatus.OK:
for emb in resp["output"]["embeddings"]:
embeddings.append(emb["embedding"])
return embeddings
else:
raise Exception(
f"QWen API call failed with status code {resp.status_code} and error message {resp.error}",
)