|
|
|
|
|
from app_configs import AVAILABLE_MODELS, UNSELECTED_MODEL_NAME |
|
|
|
|
|
def guess_model_provider(model_name: str): |
|
"""Guess the provider of a model name.""" |
|
model_name = model_name.lower() |
|
if model_name.startswith("gpt-"): |
|
return "OpenAI" |
|
if "sonnet" in model_name or "claude" in model_name or "haiku" in model_name: |
|
return "Anthropic" |
|
if "command" in model_name: |
|
return "Cohere" |
|
raise ValueError(f"Model `{model_name}` not yet supported") |
|
|
|
|
|
def get_model_and_provider(model_name: str): |
|
"""Get the model and provider from a model name.""" |
|
if model_name == UNSELECTED_MODEL_NAME: |
|
return "", "" |
|
splits = model_name.split("/", maxsplit=1) |
|
if len(splits) == 1: |
|
full_model_name = AVAILABLE_MODELS.get(model_name, model_name) |
|
provider = guess_model_provider(full_model_name) |
|
return full_model_name, provider |
|
if len(splits) == 2: |
|
provider, model_name = splits |
|
full_model_name = AVAILABLE_MODELS.get(model_name, model_name) |
|
return full_model_name, provider |
|
raise ValueError(f"Model `{model_name}` not yet supported") |
|
|
|
|
|
def get_full_model_name(model_name: str, provider: str = ""): |
|
"""Get the full model name from a model name.""" |
|
if model_name == "": |
|
return UNSELECTED_MODEL_NAME |
|
if not provider: |
|
provider = guess_model_provider(model_name) |
|
return f"{provider}/{model_name}" |
|
|