|
import abc |
|
from typing import Any, Generator, List, Optional |
|
|
|
from injector import inject |
|
|
|
from taskweaver.config.config_mgt import AppConfigSource |
|
from taskweaver.config.module_config import ModuleConfig |
|
from taskweaver.llm.util import ChatMessageType |
|
|
|
|
|
class LLMModuleConfig(ModuleConfig): |
|
def _configure(self) -> None: |
|
self._set_name("llm") |
|
self.api_type = self._get_str( |
|
"api_type", |
|
"openai", |
|
) |
|
self.embedding_api_type = self._get_str( |
|
"embedding_api_type", |
|
"sentence_transformer", |
|
) |
|
self.api_base: Optional[str] = self._get_str("api_base", None, required=False) |
|
self.api_key: Optional[str] = self._get_str( |
|
"api_key", |
|
None, |
|
required=False, |
|
) |
|
|
|
self.model: Optional[str] = self._get_str("model", None, required=False) |
|
self.backup_model: Optional[str] = self._get_str( |
|
"backup_model", |
|
None, |
|
required=False, |
|
) |
|
self.embedding_model: Optional[str] = self._get_str( |
|
"embedding_model", |
|
None, |
|
required=False, |
|
) |
|
|
|
self.response_format: Optional[str] = self._get_enum( |
|
"response_format", |
|
options=["json_object", "text"], |
|
default="json_object", |
|
) |
|
|
|
self.use_mock: bool = self._get_bool("use_mock", False) |
|
|
|
|
|
class LLMServiceConfig(ModuleConfig): |
|
@inject |
|
def __init__( |
|
self, |
|
src: AppConfigSource, |
|
llm_module_config: LLMModuleConfig, |
|
) -> None: |
|
self.llm_module_config = llm_module_config |
|
super().__init__(src) |
|
|
|
def _set_name(self, name: str) -> None: |
|
self.name = f"llm.{name}" |
|
|
|
|
|
class CompletionService(abc.ABC): |
|
@abc.abstractmethod |
|
def chat_completion( |
|
self, |
|
messages: List[ChatMessageType], |
|
use_backup_engine: bool = False, |
|
stream: bool = True, |
|
temperature: Optional[float] = None, |
|
max_tokens: Optional[int] = None, |
|
top_p: Optional[float] = None, |
|
stop: Optional[List[str]] = None, |
|
**kwargs: Any, |
|
) -> Generator[ChatMessageType, None, None]: |
|
""" |
|
Chat completion API |
|
|
|
:param messages: list of messages |
|
|
|
:param use_backup_engine: whether to use back up engine |
|
:param stream: whether to stream the response |
|
|
|
:param temperature: temperature |
|
:param max_tokens: maximum number of tokens |
|
:param top_p: top p |
|
|
|
:param kwargs: other model specific keyword arguments |
|
|
|
:return: generator of messages |
|
""" |
|
|
|
raise NotImplementedError |
|
|
|
|
|
class EmbeddingService(abc.ABC): |
|
@abc.abstractmethod |
|
def get_embeddings(self, strings: List[str]) -> List[List[float]]: |
|
""" |
|
Embedding API |
|
|
|
:param strings: list of strings to be embedded |
|
:return: list of embeddings |
|
""" |
|
raise NotImplementedError |
|
|