rag-lt-docs / app.py
helloworld53
making rag file
15fbc32
import streamlit as st
@st.cache_resource
def load_resources():
import torch
from auto_gptq import AutoGPTQForCausalLM
from langchain import HuggingFacePipeline, PromptTemplate
from langchain.chains import RetrievalQA
from langchain.document_loaders import PyPDFDirectoryLoader
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from pdf2image import convert_from_path
from transformers import AutoTokenizer, TextStreamer, pipeline
DEVICE = "cuda:0" if torch.cuda.is_available() else 'cpu'
loader = PyPDFDirectoryLoader("pdfs")
docs = loader.load()
embeddings = HuggingFaceBgeEmbeddings(
model_name = "BAAI/bge-base-en", model_kwargs = {"device" : DEVICE}
)
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1024, chunk_overlap = 64)
texts = text_splitter.split_documents(docs)
db = Chroma.from_documents(texts, embeddings, persist_directory = 'db')
model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"
# model_basename = "gptq_model-4bit-128g"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast = True)
model = AutoGPTQForCausalLM.from_quantized(
model_name_or_path,
revision = "main",
# model_basename = model_basename,
use_safetensors = True,
trust_remote_code = True,
inject_fused_attention = False,
device = DEVICE,
quantize_config = None,
)
streamer = TextStreamer(tokenizer, skip_prompt = True, skip_special_tokens = True)
text_pipeline = pipeline("text-generation",
model = model,
tokenizer = tokenizer,
max_new_tokens= 1024,
temperature = 0,
top_p = 0.95,
repetition_penalty = 1.15,
streamer = streamer,)
llm = HuggingFacePipeline(pipeline = text_pipeline, model_kwargs = {"temperature":0})
SYSTEM_PROMPT = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
def generate_prompt(prompt: str, system_prompt : str = SYSTEM_PROMPT) -> str:
return f"""
[INST] <<SYS>>
{system_prompt}
<</SYS>>
{prompt} [/INST]
""".strip()
template = generate_prompt(
"""
{context}
Question: {question}
""",
system_prompt = SYSTEM_PROMPT
)
prompt = PromptTemplate(template = template, input_variables = {"context", "question"})
qa_chain = RetrievalQA.from_chain_type(
llm = llm,
chain_type = "stuff",
retriever = db.as_retriever(search_kwargs = {"k" : 2}),
return_source_documents = True,
chain_type_kwargs = {"prompt" : prompt},
verbose = True)
return qa_chain
st.title("Please ask your question on Lithuanian rules for foreigners.")
qa_chain = load_resources()
context = st.text_area("Enter the context:")
question = st.text_input("Enter your question:")
if context and question:
# Perform Question Answering
answer = qa_chain(context=context, question=question)
# Display the answer
st.header("Answer:")
st.write(answer)