Ismail Ashraq commited on
Commit
5559d98
1 Parent(s): 0977dc4

add add and requirements

Browse files
Files changed (2) hide show
  1. app.py +108 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pinecone
3
+ import streamlit as st
4
+ from sentence_transformers import SentenceTransformer
5
+ from transformers import BartTokenizer, BartForConditionalGeneration
6
+
7
+
8
+ class BartGenerator:
9
+ def __init__(self, model_name):
10
+ self.tokenizer = BartTokenizer.from_pretrained(model_name)
11
+ self.generator = BartForConditionalGeneration.from_pretrained(model_name)
12
+
13
+ def tokenize(self, query, max_length=1024):
14
+ inputs = self.tokenizer([query], max_length=max_length, return_tensors="pt")
15
+ return inputs
16
+
17
+ def generate(self, query, min_length=20, max_length=40):
18
+ inputs = self.tokenize(query)
19
+ ids = self.generator.predict(inputs["input_ids"], num_beams=1, min_length=int(min_length), max_length=int(max_length))
20
+ answer = self.tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
21
+ return answer
22
+
23
+ @st.experimental_singleton
24
+ def init_models():
25
+ retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base")
26
+ generator = BartGenerator("vblagoje/bart_lfqa")
27
+ return retriever, generator
28
+
29
+ PINECONE_KEY = st.secrets["PINECONE_KEY"]
30
+
31
+ @st.experimental_singleton
32
+ def init_pinecone():
33
+ pinecone.init(api_key=PINECONE_KEY, environment="us-west1-gcp") # get a free api key from app.pinecone.io
34
+ return pinecone.Index("abstractive-question-answering")
35
+
36
+ retriever, generator = init_models()
37
+ index = init_pinecone()
38
+
39
+ def display_answer(answer):
40
+ return st.markdown(f"""
41
+ <div class="container-fluid">
42
+ <div class="row align-items-start">
43
+ <div class="col-md-12 col-sm-12">
44
+ <span style="color: #808080;">
45
+ {answer}
46
+ </span>
47
+ </div>
48
+ </div>
49
+ </div>
50
+ """, unsafe_allow_html=True)
51
+
52
+ def display_context(title, context, url):
53
+ return st.markdown(f"""
54
+ <div class="container-fluid">
55
+ <div class="row align-items-start">
56
+ <div class="col-md-12 col-sm-12">
57
+ <a href={url}>{title}</a>
58
+ <br>
59
+ <span style="color: #808080;">
60
+ <small>{context}</small>
61
+ </span>
62
+ </div>
63
+ </div>
64
+ </div>
65
+ """, unsafe_allow_html=True)
66
+
67
+ st.write("""
68
+ # Abstractive Question Answering
69
+ Ask me a question!
70
+ """)
71
+
72
+ st.markdown("""
73
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
74
+ """, unsafe_allow_html=True)
75
+
76
+ def format_query(query, context):
77
+ context = [f"<P> {m['metadata']['passage_text']}" for m in context]
78
+ context = " ".join(context)
79
+ query = f"question: {query} context: {context}"
80
+ return query
81
+
82
+ st.sidebar.subheader("Retriever parameters:")
83
+ top_k = st.sidebar.slider("Top K", min_value=1, max_value=10, value=5)
84
+
85
+ st.sidebar.subheader("Generator parameters:")
86
+ min_length = st.sidebar.slider("Minimum Length", min_value=1, max_value=50, value=20)
87
+ max_length = st.sidebar.slider("Maximum Length", min_value=1, max_value=100, value=50)
88
+
89
+ query = st.text_input("Search!", "")
90
+
91
+ if query != "":
92
+ with st.spinner(text="Fetching context passages 🚀🚀🚀"):
93
+ xq = retriever.encode([query]).tolist()
94
+ xc = index.query(xq, top_k=int(top_k), include_metadata=True)
95
+ query = format_query(query, xc["matches"])
96
+
97
+ with st.spinner(text="Generating answer ✍️✍️✍️"):
98
+ answer = generator.generate(query, min_length=min_length, max_length=max_length)
99
+
100
+ st.write("#### Generated answer:")
101
+ display_answer(answer)
102
+ st.write("#### Answer was generated based on the following passages:")
103
+
104
+ for m in xc["matches"]:
105
+ title = m["metadata"]["article_title"]
106
+ url = "https://en.wikipedia.org/wiki/" + title.replace(" ", "_")
107
+ context = m["metadata"]["passage_text"]
108
+ display_context(title, context, url)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pinecone-client
2
+ sentence-transformers