File size: 1,181 Bytes
63e33f8 ea077e1 0c111b7 ea077e1 63e33f8 ea077e1 63e33f8 ea077e1 63e33f8 ea077e1 63e33f8 ea077e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import os
from .ModelStrategy import ModelStrategy
from langchain_openai import ChatOpenAI
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_anthropic import ChatAnthropic
from llamaapi import LlamaAPI
from langchain_experimental.llms import ChatLlamaAPI
class MistralModel(ModelStrategy):
def get_model(self, model_name):
return ChatMistralAI(model=model_name)
class OpenAIModel(ModelStrategy):
def get_model(self, model_name):
return ChatOpenAI(model=model_name)
class AnthropicModel(ModelStrategy):
def get_model(self, model_name):
return ChatAnthropic(model=model_name)
class LlamaAPIModel(ModelStrategy):
def get_model(self, model_name):
llama = LlamaAPI(os.environ.get("LLAMA_API_KEY"))
return ChatLlamaAPI(client=llama, model=model_name)
class ModelManager():
def __init__(self):
self.models = {
"mistral": MistralModel(),
"openai": OpenAIModel(),
"anthropic": AnthropicModel(),
"llama": LlamaAPIModel()
}
def get_model(self, provider, model_name):
return self.models[provider].get_model(model_name) |