File size: 2,033 Bytes
f68c440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275da20
f68c440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import streamlit as st
from langchain.llms import LlamaCpp
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.retrievers.web_research import WebResearchRetriever
from langchain.utilities import GoogleSearchAPIWrapper
from dotenv import load_dotenv
import config
from langchain.callbacks.base import BaseCallbackHandler


class StreamHandler(BaseCallbackHandler):
    def __init__(self, container, initial_text=""):
        self.container = container
        self.text = initial_text

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.text += token
        self.container.markdown(self.text)


@st.cache_resource
def st_load_retriever(_llm, mode):
    model_kwargs = {"device": config.device}
    embeddings_model = HuggingFaceEmbeddings(
        model_name=config.embeddings_model,
        model_kwargs=model_kwargs,
    )

    vector_store = Chroma(
        "cs_paper_store",
        embeddings_model,
        persist_directory=config.vector_db_path,
    )

    if mode == "vectordb":
        # load the vector store
        return vector_store.as_retriever()

    elif mode == "google search":
        load_dotenv()
        search = GoogleSearchAPIWrapper()
        web_research_retriever = WebResearchRetriever.from_llm(
            vectorstore=vector_store, llm=_llm, search=search
        )
        return web_research_retriever

    else:
        raise ValueError(f"Unknown retrieval mode: {mode}")


@st.cache_resource
def st_load_llm(
    temperature=config.temperature,
    max_tokens=config.max_tokens,
    top_p=config.top_p,
    llm_path=config.llm_path,
    context_length=config.context_length,
    n_gpu_layers=config.n_gpu_layers,
    n_batch=config.n_batch,
):
    llm = LlamaCpp(
        model_path=llm_path,
        temperature=temperature,
        max_tokens=max_tokens,
        n_ctx=context_length,
        n_gpu_layers=n_gpu_layers,
        n_batch=n_batch,
        top_p=top_p,
        verbose=False,
    )

    return llm