File size: 2,062 Bytes
f5254ad
4bb745d
f5254ad
 
4bb745d
 
f5254ad
4bb745d
 
 
 
 
 
8e4a873
 
 
 
4bb745d
f5254ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170741d
f5254ad
 
 
 
 
 
 
 
 
 
170741d
f5254ad
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from llama_index.llms.huggingface import HuggingFaceLLM, HuggingFaceInferenceAPI
from llama_index.llms.openai import OpenAI
from llama_index.llms.replicate import Replicate

from dotenv import load_dotenv
import os
import streamlit as st

load_dotenv()

# download the model from the Hugging Face Hub and run it locally
# llm_mixtral_8x7b = HuggingFaceLLM(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1")

# llm_llama_2_7b_chat = HuggingFaceInferenceAPI(
#     model_name="meta-llama/Llama-2-7b-chat-hf",
#     token=os.getenv("HUGGINGFACE_API_TOKEN"),
# )

# dict = {"source": "model_name"}
integrated_llms = {
    "gpt-3.5-turbo-0125": "openai",
    "meta/llama-2-13b-chat": "replicate",
    "mistralai/Mistral-7B-Instruct-v0.2": "huggingface",
    # "mistralai/Mixtral-8x7B-v0.1": "huggingface", # 93 GB model
    # "meta-llama/Meta-Llama-3-8B": "huggingface", # too large >10G for llama index hf interference to load
}


def load_llm(model_name: str, source: str = "huggingface"):
    print("model_name: ", model_name, "source: ", source)
    if integrated_llms.get(model_name) is None:
        return None
    try:
        if source.startswith("openai"):
            llm_gpt_3_5_turbo_0125 = OpenAI(
                model=model_name,
                api_key=st.session_state.openai_api_key,
                temperature=0.0,
            )

            return llm_gpt_3_5_turbo_0125

        elif source.startswith("replicate"):
            llm_llama_13b_v2_replicate = Replicate(
                model=model_name,
                is_chat_model=True,
                additional_kwargs={"max_new_tokens": 250},
                prompt_key=st.session_state.replicate_api_token,
                temperature=0.0,
            )

            return llm_llama_13b_v2_replicate

        elif source.startswith("huggingface"):
            llm_mixtral_8x7b = HuggingFaceInferenceAPI(
                model_name=model_name,
                token=st.session_state.hf_token,
            )

            return llm_mixtral_8x7b

    except Exception as e:
        print(e)