logotherapyGPT / backend.py
balkite's picture
Upload files
2717026
raw
history blame
No virus
2.99 kB
import time
import streamlit as st
import torch
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.llms import HuggingFacePipeline
from constants import EMBEDDING_MODEL_NAME, db_all, db_frankl, db_inst
from langchain.vectorstores import Chroma
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
LlamaForCausalLM,
LlamaTokenizer,
pipeline,
)
@st.cache_resource(show_spinner=False)
def load_model(device_type, model_id):
if device_type.lower() == "cuda": # cuda
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
model.tie_weights()
else: # cpu
tokenizer = LlamaTokenizer.from_pretrained(model_id)
model = LlamaForCausalLM.from_pretrained(model_id)
generation_config = GenerationConfig.from_pretrained(model_id)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=2048,
temperature=0.1,
top_p=0.95,
repetition_penalty=1.5,
generation_config=generation_config,
)
local_llm = HuggingFacePipeline(pipeline=pipe)
return local_llm
@st.cache_data(persist=True, show_spinner=False)
def get_embeddings():
if torch.cuda.is_available():
device_type = "cuda"
else:
device_type = "cpu"
embeddings = HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type}
)
return embeddings
def get_llm():
if torch.cuda.is_available():
device_type = "cuda"
else:
device_type = "cpu"
model_id = "psmathur/orca_mini_3b"
llm = load_model(
device_type,
model_id=model_id
)
return llm
def load_qa(db_option):
db_instance = db_inst
match db_option:
case "All":
db_instance = db_all
print("loading db_all")
case "Frankl's Works":
db_instance = db_frankl
print("loading db_frankl")
case "Journal of Search for Meaning":
db_instance = db_inst
print("loading db_inst")
load_start = time.time()
embeddings = get_embeddings()
db = Chroma(
persist_directory=db_instance.get_directory(),
embedding_function=embeddings,
client_settings=db_instance.get_chroma_settings(),
)
retriever = db.as_retriever()
llm = get_llm()
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
)
load_end = time.time()
print(f"\n> Completed Initial Load (took {round(load_end - load_start, 2)} s.)")
return qa