tskwvr / taskweaver /llm /google_genai.py
TRaw's picture
Upload 297 files
3d3d712
from typing import Any, Generator, List, Optional
from injector import inject
from taskweaver.llm.base import CompletionService, EmbeddingService, LLMServiceConfig
from taskweaver.llm.util import ChatMessageType, format_chat_message
class GoogleGenAIServiceConfig(LLMServiceConfig):
def _configure(self) -> None:
self._set_name("google_genai")
shared_api_key = self.llm_module_config.api_key
self.api_key = self._get_str(
"api_key",
shared_api_key if shared_api_key is not None else "",
)
shared_model = self.llm_module_config.model
self.model = self._get_str(
"model",
shared_model if shared_model is not None else "gemini-pro",
)
shared_backup_model = self.llm_module_config.backup_model
self.backup_model = self._get_str(
"backup_model",
shared_backup_model if shared_backup_model is not None else self.model,
)
shared_embedding_model = self.llm_module_config.embedding_model
self.embedding_model = self._get_str(
"embedding_model",
shared_embedding_model if shared_embedding_model is not None else self.model,
)
shared_response_format = self.llm_module_config.response_format
self.response_format = self._get_enum(
"response_format",
options=["json_object", "text"],
default=shared_response_format if shared_response_format is not None else "text",
)
self.temperature = self._get_float("temperature", 0.9)
self.max_output_tokens = self._get_int("max_output_tokens", 1000)
self.top_k = self._get_int("top_k", 1)
self.top_p = self._get_float("top_p", 0)
class GoogleGenAIService(CompletionService, EmbeddingService):
@inject
def __init__(self, config: GoogleGenAIServiceConfig):
self.config = config
genai = self.import_genai_module()
genai.configure(api_key=self.config.api_key)
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
]
self.model = genai.GenerativeModel(
model_name=self.config.model,
generation_config={
"temperature": self.config.temperature,
"top_p": self.config.top_p,
"top_k": self.config.top_k,
"max_output_tokens": self.config.max_output_tokens,
},
safety_settings=safety_settings,
)
def import_genai_module(self):
try:
import google.generativeai as genai
except Exception:
raise Exception(
"Package google-generativeai is required for using Google Gemini API. "
"Please install it manually by running: `pip install google-generativeai`",
)
return genai
def chat_completion(
self,
messages: List[ChatMessageType],
use_backup_engine: bool = False,
stream: bool = True,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Generator[ChatMessageType, None, None]:
try:
return self._chat_completion(
messages=messages,
use_backup_engine=use_backup_engine,
stream=stream,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop=stop,
**kwargs,
)
except Exception:
return self._completion(
messages=messages,
use_backup_engine=use_backup_engine,
stream=stream,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop=stop,
**kwargs,
)
def _chat_completion(
self,
messages: List[ChatMessageType],
use_backup_engine: bool = False,
stream: bool = True,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Generator[ChatMessageType, None, None]:
genai_messages = []
prev_role = ""
for msg in messages:
if msg["role"] == "system":
genai_messages.append({"role": "user", "parts": [msg["content"]]})
genai_messages.append(
{
"role": "model",
"parts": ["I understand your requirements, and I will assist you in the conversations."],
},
)
prev_role = "model"
elif msg["role"] == "user":
if prev_role == "user":
# a placeholder to create alternating user and model messages
genai_messages.append({"role": "model", "parts": [" "]})
genai_messages.append({"role": "user", "parts": [msg["content"]]})
prev_role = "user"
elif msg["role"] == "assistant":
genai_messages.append({"role": "model", "parts": [msg["content"]]})
prev_role = "model"
else:
raise Exception(f"Invalid role: {msg['role']}")
if stream is False:
response = self.model.generate_content(genai_messages, stream=False)
yield format_chat_message("assistant", response.text)
response = self.model.generate_content(genai_messages, stream=True)
for chunk_obj in response:
yield format_chat_message("assistant", chunk_obj.text)
def get_embeddings(self, strings: List[str]) -> List[List[float]]:
genai = self.import_genai_module()
embedding_results = genai.embed_content(
model=self.config.embedding_model,
content=strings,
task_type="semantic_similarity",
)
return embedding_results["embedding"]