File size: 1,485 Bytes
193db9d
 
633b045
193db9d
 
 
 
 
 
 
973519b
193db9d
973519b
 
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Description: Utility functions for the model_step component.

from app_configs import AVAILABLE_MODELS, UNSELECTED_MODEL_NAME


def guess_model_provider(model_name: str):
    """Guess the provider of a model name."""
    model_name = model_name.lower()
    if model_name.startswith("gpt-"):
        return "OpenAI"
    if "sonnet" in model_name or "claude" in model_name or "haiku" in model_name:
        return "Anthropic"
    if "command" in model_name:
        return "Cohere"
    raise ValueError(f"Model `{model_name}` not yet supported")


def get_model_and_provider(model_name: str):
    """Get the model and provider from a model name."""
    if model_name == UNSELECTED_MODEL_NAME:
        return "", ""
    splits = model_name.split("/", maxsplit=1)
    if len(splits) == 1:
        full_model_name = AVAILABLE_MODELS.get(model_name, model_name)
        provider = guess_model_provider(full_model_name)
        return full_model_name, provider
    if len(splits) == 2:
        provider, model_name = splits
        full_model_name = AVAILABLE_MODELS.get(model_name, model_name)
        return full_model_name, provider
    raise ValueError(f"Model `{model_name}` not yet supported")


def get_full_model_name(model_name: str, provider: str = ""):
    """Get the full model name from a model name."""
    if model_name == "":
        return UNSELECTED_MODEL_NAME
    if not provider:
        provider = guess_model_provider(model_name)
    return f"{provider}/{model_name}"