File size: 3,298 Bytes
886bb83
15fbc32
 
 
 
 
 
 
 
 
 
 
 
 
886bb83
15fbc32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import streamlit as st
@st.cache_resource
def load_resources():
  import torch
  from auto_gptq import AutoGPTQForCausalLM
  from langchain import HuggingFacePipeline, PromptTemplate
  from langchain.chains import RetrievalQA
  from langchain.document_loaders import PyPDFDirectoryLoader
  from langchain.embeddings import HuggingFaceBgeEmbeddings
  from langchain.text_splitter import RecursiveCharacterTextSplitter
  from langchain.vectorstores import Chroma
  from pdf2image import convert_from_path
  from transformers import AutoTokenizer, TextStreamer, pipeline
  DEVICE = "cuda:0" if torch.cuda.is_available() else 'cpu'

  loader = PyPDFDirectoryLoader("pdfs")
  docs = loader.load()
  embeddings = HuggingFaceBgeEmbeddings(
    model_name = "BAAI/bge-base-en", model_kwargs = {"device" : DEVICE}
    )
  text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1024, chunk_overlap = 64)
  texts = text_splitter.split_documents(docs)
  db = Chroma.from_documents(texts, embeddings, persist_directory = 'db')
  model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"
  # model_basename = "gptq_model-4bit-128g"
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast = True)
  model = AutoGPTQForCausalLM.from_quantized(
      model_name_or_path,
      revision = "main",
      # model_basename = model_basename,
      use_safetensors = True,
      trust_remote_code = True,
      inject_fused_attention = False,
      device = DEVICE,
      quantize_config = None,
      )
  streamer = TextStreamer(tokenizer, skip_prompt = True, skip_special_tokens = True)
  text_pipeline = pipeline("text-generation",
                         model = model,
                         tokenizer = tokenizer,
                         max_new_tokens= 1024,
                         temperature = 0,
                         top_p = 0.95,
                         repetition_penalty = 1.15,
                         streamer = streamer,)
  llm = HuggingFacePipeline(pipeline = text_pipeline, model_kwargs = {"temperature":0})
  SYSTEM_PROMPT = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
   
  def generate_prompt(prompt: str, system_prompt : str = SYSTEM_PROMPT) -> str:
    return f"""
    [INST] <<SYS>>
    {system_prompt}
    <</SYS>>

    {prompt} [/INST]
    """.strip()

  
  template = generate_prompt(
      """
      {context}

      Question: {question}
      """,
      
      system_prompt = SYSTEM_PROMPT
  )
  prompt = PromptTemplate(template = template, input_variables = {"context", "question"})
  qa_chain = RetrievalQA.from_chain_type(
    llm = llm,
    chain_type = "stuff",
    retriever = db.as_retriever(search_kwargs = {"k" : 2}),
    return_source_documents = True,
    chain_type_kwargs = {"prompt" : prompt},
    verbose = True)
  return qa_chain


st.title("Please ask your question on Lithuanian rules for foreigners.")
qa_chain = load_resources()
context = st.text_area("Enter the context:")
question = st.text_input("Enter your question:")

if context and question:
    # Perform Question Answering
    answer = qa_chain(context=context, question=question)
    
    # Display the answer
    st.header("Answer:")
    st.write(answer)