helloworld53 commited on
Commit
15fbc32
1 Parent(s): 7e48800

making rag file

Browse files
Files changed (1) hide show
  1. app.py +89 -2
app.py CHANGED
@@ -1,4 +1,91 @@
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ @st.cache_resource
3
+ def load_resources():
4
+ import torch
5
+ from auto_gptq import AutoGPTQForCausalLM
6
+ from langchain import HuggingFacePipeline, PromptTemplate
7
+ from langchain.chains import RetrievalQA
8
+ from langchain.document_loaders import PyPDFDirectoryLoader
9
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain.vectorstores import Chroma
12
+ from pdf2image import convert_from_path
13
+ from transformers import AutoTokenizer, TextStreamer, pipeline
14
+ DEVICE = "cuda:0" if torch.cuda.is_available() else 'cpu'
15
 
16
+ loader = PyPDFDirectoryLoader("pdfs")
17
+ docs = loader.load()
18
+ embeddings = HuggingFaceBgeEmbeddings(
19
+ model_name = "BAAI/bge-base-en", model_kwargs = {"device" : DEVICE}
20
+ )
21
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1024, chunk_overlap = 64)
22
+ texts = text_splitter.split_documents(docs)
23
+ db = Chroma.from_documents(texts, embeddings, persist_directory = 'db')
24
+ model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"
25
+ # model_basename = "gptq_model-4bit-128g"
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast = True)
27
+ model = AutoGPTQForCausalLM.from_quantized(
28
+ model_name_or_path,
29
+ revision = "main",
30
+ # model_basename = model_basename,
31
+ use_safetensors = True,
32
+ trust_remote_code = True,
33
+ inject_fused_attention = False,
34
+ device = DEVICE,
35
+ quantize_config = None,
36
+ )
37
+ streamer = TextStreamer(tokenizer, skip_prompt = True, skip_special_tokens = True)
38
+ text_pipeline = pipeline("text-generation",
39
+ model = model,
40
+ tokenizer = tokenizer,
41
+ max_new_tokens= 1024,
42
+ temperature = 0,
43
+ top_p = 0.95,
44
+ repetition_penalty = 1.15,
45
+ streamer = streamer,)
46
+ llm = HuggingFacePipeline(pipeline = text_pipeline, model_kwargs = {"temperature":0})
47
+ SYSTEM_PROMPT = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
48
+
49
+ def generate_prompt(prompt: str, system_prompt : str = SYSTEM_PROMPT) -> str:
50
+ return f"""
51
+ [INST] <<SYS>>
52
+ {system_prompt}
53
+ <</SYS>>
54
+
55
+ {prompt} [/INST]
56
+ """.strip()
57
+
58
+
59
+ template = generate_prompt(
60
+ """
61
+ {context}
62
+
63
+ Question: {question}
64
+ """,
65
+
66
+ system_prompt = SYSTEM_PROMPT
67
+ )
68
+ prompt = PromptTemplate(template = template, input_variables = {"context", "question"})
69
+ qa_chain = RetrievalQA.from_chain_type(
70
+ llm = llm,
71
+ chain_type = "stuff",
72
+ retriever = db.as_retriever(search_kwargs = {"k" : 2}),
73
+ return_source_documents = True,
74
+ chain_type_kwargs = {"prompt" : prompt},
75
+ verbose = True)
76
+ return qa_chain
77
+
78
+
79
+ st.title("Please ask your question on Lithuanian rules for foreigners.")
80
+ qa_chain = load_resources()
81
+ context = st.text_area("Enter the context:")
82
+ question = st.text_input("Enter your question:")
83
+
84
+ if context and question:
85
+ # Perform Question Answering
86
+ answer = qa_chain(context=context, question=question)
87
+
88
+ # Display the answer
89
+ st.header("Answer:")
90
+ st.write(answer)
91
+