TRaw's picture
Upload 297 files
3d3d712
import abc
from typing import Any, Generator, List, Optional
from injector import inject
from taskweaver.config.config_mgt import AppConfigSource
from taskweaver.config.module_config import ModuleConfig
from taskweaver.llm.util import ChatMessageType
class LLMModuleConfig(ModuleConfig):
def _configure(self) -> None:
self._set_name("llm")
self.api_type = self._get_str(
"api_type",
"openai",
)
self.embedding_api_type = self._get_str(
"embedding_api_type",
"sentence_transformer",
)
self.api_base: Optional[str] = self._get_str("api_base", None, required=False)
self.api_key: Optional[str] = self._get_str(
"api_key",
None,
required=False,
)
self.model: Optional[str] = self._get_str("model", None, required=False)
self.backup_model: Optional[str] = self._get_str(
"backup_model",
None,
required=False,
)
self.embedding_model: Optional[str] = self._get_str(
"embedding_model",
None,
required=False,
)
self.response_format: Optional[str] = self._get_enum(
"response_format",
options=["json_object", "text"],
default="json_object",
)
self.use_mock: bool = self._get_bool("use_mock", False)
class LLMServiceConfig(ModuleConfig):
@inject
def __init__(
self,
src: AppConfigSource,
llm_module_config: LLMModuleConfig,
) -> None:
self.llm_module_config = llm_module_config
super().__init__(src)
def _set_name(self, name: str) -> None:
self.name = f"llm.{name}"
class CompletionService(abc.ABC):
@abc.abstractmethod
def chat_completion(
self,
messages: List[ChatMessageType],
use_backup_engine: bool = False,
stream: bool = True,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Generator[ChatMessageType, None, None]:
"""
Chat completion API
:param messages: list of messages
:param use_backup_engine: whether to use back up engine
:param stream: whether to stream the response
:param temperature: temperature
:param max_tokens: maximum number of tokens
:param top_p: top p
:param kwargs: other model specific keyword arguments
:return: generator of messages
"""
raise NotImplementedError
class EmbeddingService(abc.ABC):
@abc.abstractmethod
def get_embeddings(self, strings: List[str]) -> List[List[float]]:
"""
Embedding API
:param strings: list of strings to be embedded
:return: list of embeddings
"""
raise NotImplementedError