Mattral commited on
Commit
4c6bffd
1 Parent(s): 973fa76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -25
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- from gradio_pdf import PDF
3
  from qdrant_client import models, QdrantClient
4
  from sentence_transformers import SentenceTransformer
5
  from PyPDF2 import PdfReader
@@ -24,6 +23,10 @@ llm = AutoModelForCausalLM.from_pretrained(
24
  )
25
  print("LLM loaded...")
26
 
 
 
 
 
27
  def get_chunks(text):
28
  text_splitter = RecursiveCharacterTextSplitter(
29
  chunk_size=250,
@@ -34,13 +37,16 @@ def get_chunks(text):
34
 
35
  def setup_database(files):
36
  all_chunks = []
 
37
  for file in files:
38
  reader = PdfReader(file)
39
  text = "".join(page.extract_text() for page in reader.pages)
40
  chunks = get_chunks(text)
41
  all_chunks.extend(chunks)
 
 
 
42
 
43
- client = QdrantClient(path="./db")
44
  client.recreate_collection(
45
  collection_name="my_facts",
46
  vectors_config=models.VectorParams(
@@ -48,12 +54,13 @@ def setup_database(files):
48
  distance=models.Distance.COSINE,
49
  ),
50
  )
 
51
 
52
  records = [
53
  models.Record(
54
  id=idx,
55
  vector=encoder.encode(chunk).tolist(),
56
- payload={f"chunk_{idx}": chunk}
57
  ) for idx, chunk in enumerate(all_chunks)
58
  ]
59
 
@@ -61,16 +68,16 @@ def setup_database(files):
61
  collection_name="my_facts",
62
  records=records,
63
  )
 
64
 
65
- def answer_question(question):
66
- client = QdrantClient(path="./db")
67
  hits = client.search(
68
  collection_name="my_facts",
69
  query_vector=encoder.encode(question).tolist(),
70
  limit=3
71
  )
72
 
73
- context = " ".join(hit.payload[f"chunk_{hit.id}"] for hit in hits)
74
 
75
  system_prompt = """You are a helpful co-worker, you will use the provided context to answer user questions.
76
  Read the given context before answering questions and think step by step. If you cannot answer a user question based on
@@ -82,29 +89,36 @@ def answer_question(question):
82
  instruction = f"Context: {context}\nUser: {question}"
83
  prompt_template = f"{B_INST}{B_SYS}{system_prompt}{E_SYS}{instruction}{E_INST}"
84
 
85
- response = llm(prompt_template)
86
- return response
 
87
 
88
  def chat(messages, files):
89
  if files:
90
  setup_database(files)
91
- if messages:
92
- question = messages[-1]["text"]
93
- answer = answer_question(question)
94
- messages.append({"text": answer, "is_user": False})
 
 
 
95
  return messages
96
 
97
- interface = gr.Interface(
98
- fn=chat,
99
- inputs=[
100
- gr.Chatbot(label="Chat"),
101
- gr.File(label="Upload PDFs", file_count="multiple")
102
- ],
103
- outputs=gr.Chatbot(label="Chat"),
104
- title="Q&A with PDFs 👩🏻‍💻📓✍🏻💡",
105
- description="This app facilitates a conversation with PDFs uploaded💡",
106
- theme="soft",
107
- live=True,
108
- )
 
 
 
109
 
110
- interface.launch()
 
1
  import gradio as gr
 
2
  from qdrant_client import models, QdrantClient
3
  from sentence_transformers import SentenceTransformer
4
  from PyPDF2 import PdfReader
 
23
  )
24
  print("LLM loaded...")
25
 
26
+ # Initialize QdrantClient
27
+ client = QdrantClient(path="./db")
28
+ print("DB created...")
29
+
30
  def get_chunks(text):
31
  text_splitter = RecursiveCharacterTextSplitter(
32
  chunk_size=250,
 
37
 
38
  def setup_database(files):
39
  all_chunks = []
40
+
41
  for file in files:
42
  reader = PdfReader(file)
43
  text = "".join(page.extract_text() for page in reader.pages)
44
  chunks = get_chunks(text)
45
  all_chunks.extend(chunks)
46
+
47
+ print(f"Total chunks: {len(all_chunks)}")
48
+ print("Chunks are ready...")
49
 
 
50
  client.recreate_collection(
51
  collection_name="my_facts",
52
  vectors_config=models.VectorParams(
 
54
  distance=models.Distance.COSINE,
55
  ),
56
  )
57
+ print("Collection created...")
58
 
59
  records = [
60
  models.Record(
61
  id=idx,
62
  vector=encoder.encode(chunk).tolist(),
63
+ payload={"text": chunk}
64
  ) for idx, chunk in enumerate(all_chunks)
65
  ]
66
 
 
68
  collection_name="my_facts",
69
  records=records,
70
  )
71
+ print("Records uploaded...")
72
 
73
+ def answer(question):
 
74
  hits = client.search(
75
  collection_name="my_facts",
76
  query_vector=encoder.encode(question).tolist(),
77
  limit=3
78
  )
79
 
80
+ context = " ".join(hit.payload["text"] for hit in hits)
81
 
82
  system_prompt = """You are a helpful co-worker, you will use the provided context to answer user questions.
83
  Read the given context before answering questions and think step by step. If you cannot answer a user question based on
 
89
  instruction = f"Context: {context}\nUser: {question}"
90
  prompt_template = f"{B_INST}{B_SYS}{system_prompt}{E_SYS}{instruction}{E_INST}"
91
 
92
+ print(prompt_template)
93
+ result = llm(prompt_template)
94
+ return result
95
 
96
  def chat(messages, files):
97
  if files:
98
  setup_database(files)
99
+
100
+ if not messages:
101
+ return "Please upload PDF documents to initialize the database."
102
+
103
+ last_message = messages[-1]["content"]
104
+ response = answer(last_message)
105
+ messages.append({"role": "assistant", "content": response})
106
  return messages
107
 
108
+ with gr.Blocks() as demo:
109
+ chatbot = gr.Chatbot()
110
+ file_input = gr.File(label="Upload PDFs", file_count="multiple")
111
+ with gr.Row():
112
+ with gr.Column(scale=0.85):
113
+ txt = gr.Textbox(show_label=False, placeholder="Enter your question here...").style(container=False)
114
+ with gr.Column(scale=0.15, min_width=0):
115
+ send_btn = gr.Button("Send")
116
+
117
+ def respond(messages, files, txt):
118
+ messages = chat(messages, files)
119
+ return messages, None, ""
120
+
121
+ send_btn.click(respond, [chatbot, file_input, txt], [chatbot, file_input, txt])
122
+ txt.submit(respond, [chatbot, file_input, txt], [chatbot, file_input, txt])
123
 
124
+ demo.launch()