File size: 2,400 Bytes
55fc0a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d7c096
 
 
 
 
 
55fc0a1
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Type

from neollm.llm.llm.abstract_llm import AbstractLLM
from neollm.llm.model_name._abstract_model_name import AbstractModelName
from neollm.types import ClientSettings

from .platform import Platform


def get_llm(model_name: str, platform: str, client_settings: ClientSettings) -> AbstractLLM:
    try:
        platform_enum = Platform(platform)
    except ValueError as e:
        raise ValueError(
            f"{str(e)}\n"
            f"{platform} is not supported. Supported platforms are {', '.join([member.value for member in Platform])}."
        ) from e

    model_name_class: Type[AbstractModelName]
    if platform_enum == Platform.AZURE:
        from neollm.llm.model_name.azure_model_name import AzureModelName

        model_name_class = AzureModelName

    elif platform_enum == Platform.OPENAI:
        from neollm.llm.model_name.openai_model_name import OpenAIModelName

        model_name_class = OpenAIModelName

    elif platform_enum == Platform.ANTHROPIC:
        from neollm.llm.model_name.anthropic_model_name import AnthropicModelName

        model_name_class = AnthropicModelName

    elif platform_enum == Platform.GCP:
        from neollm.llm.model_name.gcp_model_name import GCPModelName

        model_name_class = GCPModelName

    elif platform_enum == Platform.AWS:
        from neollm.llm.model_name.aws_model_name import AWSModelName

        model_name_class = AWSModelName

    elif platform_enum == Platform.LOCAL_VLLM:
        from neollm.llm.model_name.local_vllm_model_name import LocalvLLMModelName

        model_name_class = LocalvLLMModelName
    elif platform_enum == Platform.GOOGLE_GENERATIVEAI:
        from neollm.llm.model_name.google_generativeai_model_name import (
            GoogleGenerativeAIModelName,
        )

        model_name_class = GoogleGenerativeAIModelName
    else:
        raise ValueError(f"{platform} is not supported.")

    try:
        # TODO: Platformのmethodで`model_name`を吐き出すようにしたら簡素化できそう
        model_name_enum = model_name_class(model_name)  # type: ignore[abstract]
    except ValueError as e:
        raise ValueError(
            f"{str(e)}\n"
            f"{platform} is not supported. Supported platforms are {', '.join([member.value for member in model_name_class])}."
        ) from e
    return model_name_enum.to_llm(client_settings, model_name)