AIxplorer / utils.py
ffeew's picture
added duckduckgo search
275da20 verified
import streamlit as st
from langchain.llms import LlamaCpp
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.retrievers.web_research import WebResearchRetriever
from langchain.utilities import GoogleSearchAPIWrapper
from dotenv import load_dotenv
import config
from langchain.callbacks.base import BaseCallbackHandler
class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text=""):
self.container = container
self.text = initial_text
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text += token
self.container.markdown(self.text)
@st.cache_resource
def st_load_retriever(_llm, mode):
model_kwargs = {"device": config.device}
embeddings_model = HuggingFaceEmbeddings(
model_name=config.embeddings_model,
model_kwargs=model_kwargs,
)
vector_store = Chroma(
"cs_paper_store",
embeddings_model,
persist_directory=config.vector_db_path,
)
if mode == "vectordb":
# load the vector store
return vector_store.as_retriever()
elif mode == "google search":
load_dotenv()
search = GoogleSearchAPIWrapper()
web_research_retriever = WebResearchRetriever.from_llm(
vectorstore=vector_store, llm=_llm, search=search
)
return web_research_retriever
else:
raise ValueError(f"Unknown retrieval mode: {mode}")
@st.cache_resource
def st_load_llm(
temperature=config.temperature,
max_tokens=config.max_tokens,
top_p=config.top_p,
llm_path=config.llm_path,
context_length=config.context_length,
n_gpu_layers=config.n_gpu_layers,
n_batch=config.n_batch,
):
llm = LlamaCpp(
model_path=llm_path,
temperature=temperature,
max_tokens=max_tokens,
n_ctx=context_length,
n_gpu_layers=n_gpu_layers,
n_batch=n_batch,
top_p=top_p,
verbose=False,
)
return llm