arena / model.py
Kang Suhyun
[#104] Display error message for the context window exceeded error (#105)
9e789e7 unverified
raw
history blame
3.43 kB
"""
This module contains functions to interact with the models.
"""
import json
import os
from typing import List
from google.cloud import secretmanager
from google.oauth2 import service_account
import litellm
from credentials import get_credentials_json
GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT")
MODELS_SECRET = os.environ.get("MODELS_SECRET")
secretmanager_client = secretmanager.SecretManagerServiceClient(
credentials=service_account.Credentials.from_service_account_info(
get_credentials_json()))
models_secret = secretmanager_client.access_secret_version(
name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT,
MODELS_SECRET, "latest"))
decoded_secret = models_secret.payload.data.decode("UTF-8")
supported_models_json = json.loads(decoded_secret)
DEFAULT_SUMMARIZE_INSTRUCTION = "Summarize the following text, maintaining the language of the text." # pylint: disable=line-too-long
DEFAULT_TRANSLATE_INSTRUCTION = "Translate the following text from {source_lang} to {target_lang}." # pylint: disable=line-too-long
class ContextWindowExceededError(Exception):
pass
class Model:
def __init__(
self,
name: str,
provider: str = None,
# The JSON keys are in camelCase. To unpack these keys into
# Model attributes, we need to use the same camelCase names.
apiKey: str = None, # pylint: disable=invalid-name
apiBase: str = None, # pylint: disable=invalid-name
summarizeInstruction: str = None, # pylint: disable=invalid-name
translateInstruction: str = None): # pylint: disable=invalid-name
self.name = name
self.provider = provider
self.api_key = apiKey
self.api_base = apiBase
self.summarize_instruction = summarizeInstruction or DEFAULT_SUMMARIZE_INSTRUCTION # pylint: disable=line-too-long
self.translate_instruction = translateInstruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
def completion(self, messages: List, max_tokens: float = None) -> str:
try:
response = litellm.completion(model=self.provider + "/" +
self.name if self.provider else self.name,
api_key=self.api_key,
api_base=self.api_base,
messages=messages,
max_tokens=max_tokens)
return response.choices[0].message.content
except litellm.ContextWindowExceededError as e:
raise ContextWindowExceededError() from e
supported_models: List[Model] = [
Model(name=model_name, **model_config)
for model_name, model_config in supported_models_json.items()
]
def check_models(models: List[Model]):
for model in models:
print(f"Checking model {model.name}...")
try:
model.completion(messages=[{
"role": "system",
"content": "You are a kind person."
}, {
"role": "user",
"content": "Hello."
}],
max_tokens=5)
print(f"Model {model.name} is available.")
# This check is designed to verify the availability of the models
# without any issues. Therefore, we need to catch all exceptions.
except Exception as e: # pylint: disable=broad-except
raise RuntimeError(f"Model {model.name} is not available: {e}") from e