Spaces:
Running
Running
Kang Suhyun
[#104] Display error message for the context window exceeded error (#105)
9e789e7
unverified
""" | |
This module contains functions to interact with the models. | |
""" | |
import json | |
import os | |
from typing import List | |
from google.cloud import secretmanager | |
from google.oauth2 import service_account | |
import litellm | |
from credentials import get_credentials_json | |
GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT") | |
MODELS_SECRET = os.environ.get("MODELS_SECRET") | |
secretmanager_client = secretmanager.SecretManagerServiceClient( | |
credentials=service_account.Credentials.from_service_account_info( | |
get_credentials_json())) | |
models_secret = secretmanager_client.access_secret_version( | |
name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT, | |
MODELS_SECRET, "latest")) | |
decoded_secret = models_secret.payload.data.decode("UTF-8") | |
supported_models_json = json.loads(decoded_secret) | |
DEFAULT_SUMMARIZE_INSTRUCTION = "Summarize the following text, maintaining the language of the text." # pylint: disable=line-too-long | |
DEFAULT_TRANSLATE_INSTRUCTION = "Translate the following text from {source_lang} to {target_lang}." # pylint: disable=line-too-long | |
class ContextWindowExceededError(Exception): | |
pass | |
class Model: | |
def __init__( | |
self, | |
name: str, | |
provider: str = None, | |
# The JSON keys are in camelCase. To unpack these keys into | |
# Model attributes, we need to use the same camelCase names. | |
apiKey: str = None, # pylint: disable=invalid-name | |
apiBase: str = None, # pylint: disable=invalid-name | |
summarizeInstruction: str = None, # pylint: disable=invalid-name | |
translateInstruction: str = None): # pylint: disable=invalid-name | |
self.name = name | |
self.provider = provider | |
self.api_key = apiKey | |
self.api_base = apiBase | |
self.summarize_instruction = summarizeInstruction or DEFAULT_SUMMARIZE_INSTRUCTION # pylint: disable=line-too-long | |
self.translate_instruction = translateInstruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long | |
def completion(self, messages: List, max_tokens: float = None) -> str: | |
try: | |
response = litellm.completion(model=self.provider + "/" + | |
self.name if self.provider else self.name, | |
api_key=self.api_key, | |
api_base=self.api_base, | |
messages=messages, | |
max_tokens=max_tokens) | |
return response.choices[0].message.content | |
except litellm.ContextWindowExceededError as e: | |
raise ContextWindowExceededError() from e | |
supported_models: List[Model] = [ | |
Model(name=model_name, **model_config) | |
for model_name, model_config in supported_models_json.items() | |
] | |
def check_models(models: List[Model]): | |
for model in models: | |
print(f"Checking model {model.name}...") | |
try: | |
model.completion(messages=[{ | |
"role": "system", | |
"content": "You are a kind person." | |
}, { | |
"role": "user", | |
"content": "Hello." | |
}], | |
max_tokens=5) | |
print(f"Model {model.name} is available.") | |
# This check is designed to verify the availability of the models | |
# without any issues. Therefore, we need to catch all exceptions. | |
except Exception as e: # pylint: disable=broad-except | |
raise RuntimeError(f"Model {model.name} is not available: {e}") from e | |