RegBot4.0 / models /llms.py
hbui's picture
llama-index-update (#1)
170741d verified
raw
history blame contribute delete
No virus
2.06 kB
from llama_index.llms.huggingface import HuggingFaceLLM, HuggingFaceInferenceAPI
from llama_index.llms.openai import OpenAI
from llama_index.llms.replicate import Replicate
from dotenv import load_dotenv
import os
import streamlit as st
load_dotenv()
# download the model from the Hugging Face Hub and run it locally
# llm_mixtral_8x7b = HuggingFaceLLM(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1")
# llm_llama_2_7b_chat = HuggingFaceInferenceAPI(
# model_name="meta-llama/Llama-2-7b-chat-hf",
# token=os.getenv("HUGGINGFACE_API_TOKEN"),
# )
# dict = {"source": "model_name"}
integrated_llms = {
"gpt-3.5-turbo-0125": "openai",
"meta/llama-2-13b-chat": "replicate",
"mistralai/Mistral-7B-Instruct-v0.2": "huggingface",
# "mistralai/Mixtral-8x7B-v0.1": "huggingface", # 93 GB model
# "meta-llama/Meta-Llama-3-8B": "huggingface", # too large >10G for llama index hf interference to load
}
def load_llm(model_name: str, source: str = "huggingface"):
print("model_name: ", model_name, "source: ", source)
if integrated_llms.get(model_name) is None:
return None
try:
if source.startswith("openai"):
llm_gpt_3_5_turbo_0125 = OpenAI(
model=model_name,
api_key=st.session_state.openai_api_key,
temperature=0.0,
)
return llm_gpt_3_5_turbo_0125
elif source.startswith("replicate"):
llm_llama_13b_v2_replicate = Replicate(
model=model_name,
is_chat_model=True,
additional_kwargs={"max_new_tokens": 250},
prompt_key=st.session_state.replicate_api_token,
temperature=0.0,
)
return llm_llama_13b_v2_replicate
elif source.startswith("huggingface"):
llm_mixtral_8x7b = HuggingFaceInferenceAPI(
model_name=model_name,
token=st.session_state.hf_token,
)
return llm_mixtral_8x7b
except Exception as e:
print(e)