Arpan Chatterjee commited on
Commit
e58d85a
1 Parent(s): 978182e

Added the streamlit app and the requirements.txt file

Browse files
Files changed (2) hide show
  1. app.py +97 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BartTokenizer, BartForConditionalGeneration
2
+ import torch
3
+ from tqdm.auto import tqdm
4
+ from sentence_transformers import SentenceTransformer
5
+ import streamlit as st
6
+ import pinecone
7
+
8
+
9
+ def connect_pinecone():
10
+ # connect to pinecone environment
11
+ pinecone.init(
12
+ api_key="eba0e7ab-e2d1-4648-bde2-13b7f8db3415",
13
+ environment="northamerica-northeast1-gcp" # find next to API key in console
14
+ )
15
+
16
+
17
+ def pinecone_create_index():
18
+ index_name = "abstractive-question-answering"
19
+
20
+ # check if the abstractive-question-answering index exists
21
+ if index_name not in pinecone.list_indexes():
22
+ # create the index if it does not exist
23
+ pinecone.create_index(
24
+ index_name,
25
+ dimension=768,
26
+ metric="cosine"
27
+ )
28
+
29
+ # connect to abstractive-question-answering index we created
30
+ index = pinecone.Index(index_name)
31
+ return index
32
+
33
+
34
+ def query_pinecone(query, retriever, index, top_k):
35
+ # generate embeddings for the query
36
+ xq = retriever.encode([query]).tolist()
37
+ # search pinecone index for context passage with the answer
38
+ xc = index.query(xq, top_k=top_k, include_metadata=True)
39
+ return xc
40
+
41
+
42
+ def format_query(query, context):
43
+ # extract passage_text from Pinecone search result and add the <P> tag
44
+ context = [f"<P> {m['metadata']['passage_text']}" for m in context]
45
+ # concatinate all context passages
46
+ context = " ".join(context)
47
+ # contcatinate the query and context passages
48
+ query = f"question: {query} context: {context}"
49
+ return query
50
+
51
+ def generate_answer(query, tokenizer, generator, device):
52
+ # tokenize the query to get input_ids
53
+ inputs = tokenizer([query], max_length=1024, return_tensors="pt").to(device)
54
+ # use generator to predict output ids
55
+ ids = generator.generate(inputs["input_ids"], num_beams=2, min_length=20, max_length=50)
56
+ # use tokenizer to decode the output ids
57
+ answer = tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ return answer
59
+
60
+
61
+
62
+ def main():
63
+ connect_pinecone()
64
+ index_name = "abstractive-question-answering" # has already been created in pinecone
65
+ index = pinecone_create_index()
66
+
67
+ user_input = st.text_input("Ask a question:")
68
+
69
+
70
+ with st.form("my_form"):
71
+ submit_button = st.form_submit_button(label='Get Answer')
72
+
73
+ #initialize retriever
74
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
75
+ # load the retriever model from huggingface model hub
76
+ retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base", device=device)
77
+
78
+ #upsertion of index has been done
79
+ #initilaize generator
80
+ # load bart tokenizer and model from huggingface
81
+ tokenizer = BartTokenizer.from_pretrained('vblagoje/bart_lfqa')
82
+ generator = BartForConditionalGeneration.from_pretrained('vblagoje/bart_lfqa').to(device)
83
+
84
+
85
+ if submit_button:
86
+ result = query_pinecone(user_input, retriever, index, top_k=1)
87
+ query = format_query(user_input, result["matches"])
88
+ print(query)
89
+ ans = generate_answer(query, tokenizer, generator, device)
90
+ st.write(ans)
91
+
92
+
93
+
94
+
95
+
96
+ if __name__ == '__main__':
97
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pinecone_client==2.2.2
2
+ sentence_transformers==2.2.2
3
+ streamlit==1.16.0
4
+ torch==2.0.0
5
+ tqdm==4.65.0
6
+ transformers==4.27.4