File size: 5,495 Bytes
3d3d712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from typing import Any, Generator, List, Optional, Type
from injector import Injector, inject
from taskweaver.llm.azure_ml import AzureMLService
from taskweaver.llm.base import CompletionService, EmbeddingService, LLMModuleConfig
from taskweaver.llm.google_genai import GoogleGenAIService
from taskweaver.llm.mock import MockApiService
from taskweaver.llm.ollama import OllamaService
from taskweaver.llm.openai import OpenAIService
from taskweaver.llm.placeholder import PlaceholderEmbeddingService
from taskweaver.llm.sentence_transformer import SentenceTransformerService
from .qwen import QWenService
from .util import ChatMessageType, format_chat_message
class LLMApi(object):
@inject
def __init__(self, config: LLMModuleConfig, injector: Injector) -> None:
self.config = config
self.injector = injector
if self.config.api_type in ["openai", "azure", "azure_ad"]:
self._set_completion_service(OpenAIService)
elif self.config.api_type == "ollama":
self._set_completion_service(OllamaService)
elif self.config.api_type == "azure_ml":
self._set_completion_service(AzureMLService)
elif self.config.api_type == "google_genai":
self._set_completion_service(GoogleGenAIService)
elif self.config.api_type == "qwen":
self._set_completion_service(QWenService)
else:
raise ValueError(f"API type {self.config.api_type} is not supported")
if self.config.embedding_api_type in ["openai", "azure", "azure_ad"]:
self._set_embedding_service(OpenAIService)
elif self.config.embedding_api_type == "ollama":
self._set_embedding_service(OllamaService)
elif self.config.embedding_api_type == "google_genai":
self._set_embedding_service(GoogleGenAIService)
elif self.config.embedding_api_type == "sentence_transformer":
self._set_embedding_service(SentenceTransformerService)
elif self.config.embedding_api_type == "qwen":
self._set_embedding_service(QWenService)
elif self.config.embedding_api_type == "azure_ml":
self.embedding_service = PlaceholderEmbeddingService(
"Azure ML does not support embeddings yet. Please configure a different embedding API.",
)
elif self.config.embedding_api_type == "qwen":
self.embedding_service = PlaceholderEmbeddingService(
"QWen does not support embeddings yet. Please configure a different embedding API.",
)
else:
raise ValueError(
f"Embedding API type {self.config.embedding_api_type} is not supported",
)
if self.config.use_mock:
# add mock proxy layer to the completion and embedding services
base_completion_service = self.completion_service
base_embedding_service = self.embedding_service
mock = self.injector.get(MockApiService)
mock.set_base_completion_service(base_completion_service)
mock.set_base_embedding_service(base_embedding_service)
self._set_completion_service(MockApiService)
self._set_embedding_service(MockApiService)
def _set_completion_service(self, svc: Type[CompletionService]) -> None:
self.completion_service: CompletionService = self.injector.get(svc)
self.injector.binder.bind(svc, to=self.completion_service)
def _set_embedding_service(self, svc: Type[EmbeddingService]) -> None:
self.embedding_service: EmbeddingService = self.injector.get(svc)
self.injector.binder.bind(svc, to=self.embedding_service)
def chat_completion(
self,
messages: List[ChatMessageType],
use_backup_engine: bool = False,
stream: bool = True,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> ChatMessageType:
msg: ChatMessageType = format_chat_message("assistant", "")
for msg_chunk in self.completion_service.chat_completion(
messages,
use_backup_engine,
stream,
temperature,
max_tokens,
top_p,
stop,
**kwargs,
):
msg["role"] = msg_chunk["role"]
msg["content"] += msg_chunk["content"]
if "name" in msg_chunk:
msg["name"] = msg_chunk["name"]
return msg
def chat_completion_stream(
self,
messages: List[ChatMessageType],
use_backup_engine: bool = False,
stream: bool = True,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Generator[ChatMessageType, None, None]:
return self.completion_service.chat_completion(
messages,
use_backup_engine,
stream,
temperature,
max_tokens,
top_p,
stop,
**kwargs,
)
def get_embedding(self, string: str) -> List[float]:
return self.embedding_service.get_embeddings([string])[0]
def get_embedding_list(self, strings: List[str]) -> List[List[float]]:
return self.embedding_service.get_embeddings(strings)
|