import streamlit as st @st.cache_resource def load_resources(): import torch from auto_gptq import AutoGPTQForCausalLM from langchain import HuggingFacePipeline, PromptTemplate from langchain.chains import RetrievalQA from langchain.document_loaders import PyPDFDirectoryLoader from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from pdf2image import convert_from_path from transformers import AutoTokenizer, TextStreamer, pipeline DEVICE = "cuda:0" if torch.cuda.is_available() else 'cpu' loader = PyPDFDirectoryLoader("pdfs") docs = loader.load() embeddings = HuggingFaceBgeEmbeddings( model_name = "BAAI/bge-base-en", model_kwargs = {"device" : DEVICE} ) text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1024, chunk_overlap = 64) texts = text_splitter.split_documents(docs) db = Chroma.from_documents(texts, embeddings, persist_directory = 'db') model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ" # model_basename = "gptq_model-4bit-128g" tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast = True) model = AutoGPTQForCausalLM.from_quantized( model_name_or_path, revision = "main", # model_basename = model_basename, use_safetensors = True, trust_remote_code = True, inject_fused_attention = False, device = DEVICE, quantize_config = None, ) streamer = TextStreamer(tokenizer, skip_prompt = True, skip_special_tokens = True) text_pipeline = pipeline("text-generation", model = model, tokenizer = tokenizer, max_new_tokens= 1024, temperature = 0, top_p = 0.95, repetition_penalty = 1.15, streamer = streamer,) llm = HuggingFacePipeline(pipeline = text_pipeline, model_kwargs = {"temperature":0}) SYSTEM_PROMPT = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer." def generate_prompt(prompt: str, system_prompt : str = SYSTEM_PROMPT) -> str: return f""" [INST] <> {system_prompt} <> {prompt} [/INST] """.strip() template = generate_prompt( """ {context} Question: {question} """, system_prompt = SYSTEM_PROMPT ) prompt = PromptTemplate(template = template, input_variables = {"context", "question"}) qa_chain = RetrievalQA.from_chain_type( llm = llm, chain_type = "stuff", retriever = db.as_retriever(search_kwargs = {"k" : 2}), return_source_documents = True, chain_type_kwargs = {"prompt" : prompt}, verbose = True) return qa_chain st.title("Please ask your question on Lithuanian rules for foreigners.") qa_chain = load_resources() context = st.text_area("Enter the context:") question = st.text_input("Enter your question:") if context and question: # Perform Question Answering answer = qa_chain(context=context, question=question) # Display the answer st.header("Answer:") st.write(answer)