File size: 2,400 Bytes
55fc0a1 3d7c096 55fc0a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
from typing import Type
from neollm.llm.llm.abstract_llm import AbstractLLM
from neollm.llm.model_name._abstract_model_name import AbstractModelName
from neollm.types import ClientSettings
from .platform import Platform
def get_llm(model_name: str, platform: str, client_settings: ClientSettings) -> AbstractLLM:
try:
platform_enum = Platform(platform)
except ValueError as e:
raise ValueError(
f"{str(e)}\n"
f"{platform} is not supported. Supported platforms are {', '.join([member.value for member in Platform])}."
) from e
model_name_class: Type[AbstractModelName]
if platform_enum == Platform.AZURE:
from neollm.llm.model_name.azure_model_name import AzureModelName
model_name_class = AzureModelName
elif platform_enum == Platform.OPENAI:
from neollm.llm.model_name.openai_model_name import OpenAIModelName
model_name_class = OpenAIModelName
elif platform_enum == Platform.ANTHROPIC:
from neollm.llm.model_name.anthropic_model_name import AnthropicModelName
model_name_class = AnthropicModelName
elif platform_enum == Platform.GCP:
from neollm.llm.model_name.gcp_model_name import GCPModelName
model_name_class = GCPModelName
elif platform_enum == Platform.AWS:
from neollm.llm.model_name.aws_model_name import AWSModelName
model_name_class = AWSModelName
elif platform_enum == Platform.LOCAL_VLLM:
from neollm.llm.model_name.local_vllm_model_name import LocalvLLMModelName
model_name_class = LocalvLLMModelName
elif platform_enum == Platform.GOOGLE_GENERATIVEAI:
from neollm.llm.model_name.google_generativeai_model_name import (
GoogleGenerativeAIModelName,
)
model_name_class = GoogleGenerativeAIModelName
else:
raise ValueError(f"{platform} is not supported.")
try:
# TODO: Platformのmethodで`model_name`を吐き出すようにしたら簡素化できそう
model_name_enum = model_name_class(model_name) # type: ignore[abstract]
except ValueError as e:
raise ValueError(
f"{str(e)}\n"
f"{platform} is not supported. Supported platforms are {', '.join([member.value for member in model_name_class])}."
) from e
return model_name_enum.to_llm(client_settings, model_name)
|