Spaces:
Sleeping
Sleeping
import time | |
import streamlit as st | |
import torch | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
from langchain.llms import HuggingFacePipeline | |
from constants import EMBEDDING_MODEL_NAME, db_all, db_frankl, db_inst | |
from langchain.vectorstores import Chroma | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
GenerationConfig, | |
LlamaForCausalLM, | |
LlamaTokenizer, | |
pipeline, | |
) | |
def load_model(device_type, model_id): | |
if device_type.lower() == "cuda": # cuda | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
trust_remote_code=True, | |
) | |
model.tie_weights() | |
else: # cpu | |
tokenizer = LlamaTokenizer.from_pretrained(model_id) | |
model = LlamaForCausalLM.from_pretrained(model_id) | |
generation_config = GenerationConfig.from_pretrained(model_id) | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_length=2048, | |
temperature=0.1, | |
top_p=0.95, | |
repetition_penalty=1.5, | |
generation_config=generation_config, | |
) | |
local_llm = HuggingFacePipeline(pipeline=pipe) | |
return local_llm | |
def get_embeddings(): | |
if torch.cuda.is_available(): | |
device_type = "cuda" | |
else: | |
device_type = "cpu" | |
embeddings = HuggingFaceInstructEmbeddings( | |
model_name=EMBEDDING_MODEL_NAME, | |
model_kwargs={"device": device_type} | |
) | |
return embeddings | |
def get_llm(): | |
if torch.cuda.is_available(): | |
device_type = "cuda" | |
else: | |
device_type = "cpu" | |
model_id = "psmathur/orca_mini_3b" | |
llm = load_model( | |
device_type, | |
model_id=model_id | |
) | |
return llm | |
def load_qa(db_option): | |
db_instance = db_inst | |
match db_option: | |
case "All": | |
db_instance = db_all | |
print("loading db_all") | |
case "Frankl's Works": | |
db_instance = db_frankl | |
print("loading db_frankl") | |
case "Journal of Search for Meaning": | |
db_instance = db_inst | |
print("loading db_inst") | |
load_start = time.time() | |
embeddings = get_embeddings() | |
db = Chroma( | |
persist_directory=db_instance.get_directory(), | |
embedding_function=embeddings, | |
client_settings=db_instance.get_chroma_settings(), | |
) | |
retriever = db.as_retriever() | |
llm = get_llm() | |
qa = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
) | |
load_end = time.time() | |
print(f"\n> Completed Initial Load (took {round(load_end - load_start, 2)} s.)") | |
return qa | |