|
import os |
|
|
|
from .ModelStrategy import ModelStrategy |
|
|
|
from langchain_openai import ChatOpenAI |
|
from langchain_mistralai.chat_models import ChatMistralAI |
|
from langchain_anthropic import ChatAnthropic |
|
|
|
from llamaapi import LlamaAPI |
|
from langchain_experimental.llms import ChatLlamaAPI |
|
|
|
class MistralModel(ModelStrategy): |
|
def get_model(self, model_name): |
|
return ChatMistralAI(model=model_name) |
|
|
|
|
|
class OpenAIModel(ModelStrategy): |
|
def get_model(self, model_name): |
|
return ChatOpenAI(model=model_name) |
|
|
|
|
|
class AnthropicModel(ModelStrategy): |
|
def get_model(self, model_name): |
|
return ChatAnthropic(model=model_name) |
|
|
|
|
|
class LlamaAPIModel(ModelStrategy): |
|
def get_model(self, model_name): |
|
llama = LlamaAPI(os.environ.get("LLAMA_API_KEY")) |
|
return ChatLlamaAPI(client=llama, model=model_name) |
|
|
|
class ModelManager(): |
|
def __init__(self): |
|
self.models = { |
|
"mistral": MistralModel(), |
|
"openai": OpenAIModel(), |
|
"anthropic": AnthropicModel(), |
|
"llama": LlamaAPIModel() |
|
} |
|
|
|
def get_model(self, provider, model_name): |
|
return self.models[provider].get_model(model_name) |