Spaces:
Runtime error
Runtime error
Ismail Ashraq
commited on
Commit
•
5559d98
1
Parent(s):
0977dc4
add add and requirements
Browse files- app.py +108 -0
- requirements.txt +2 -0
app.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import pinecone
|
3 |
+
import streamlit as st
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
6 |
+
|
7 |
+
|
8 |
+
class BartGenerator:
|
9 |
+
def __init__(self, model_name):
|
10 |
+
self.tokenizer = BartTokenizer.from_pretrained(model_name)
|
11 |
+
self.generator = BartForConditionalGeneration.from_pretrained(model_name)
|
12 |
+
|
13 |
+
def tokenize(self, query, max_length=1024):
|
14 |
+
inputs = self.tokenizer([query], max_length=max_length, return_tensors="pt")
|
15 |
+
return inputs
|
16 |
+
|
17 |
+
def generate(self, query, min_length=20, max_length=40):
|
18 |
+
inputs = self.tokenize(query)
|
19 |
+
ids = self.generator.predict(inputs["input_ids"], num_beams=1, min_length=int(min_length), max_length=int(max_length))
|
20 |
+
answer = self.tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
21 |
+
return answer
|
22 |
+
|
23 |
+
@st.experimental_singleton
|
24 |
+
def init_models():
|
25 |
+
retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base")
|
26 |
+
generator = BartGenerator("vblagoje/bart_lfqa")
|
27 |
+
return retriever, generator
|
28 |
+
|
29 |
+
PINECONE_KEY = st.secrets["PINECONE_KEY"]
|
30 |
+
|
31 |
+
@st.experimental_singleton
|
32 |
+
def init_pinecone():
|
33 |
+
pinecone.init(api_key=PINECONE_KEY, environment="us-west1-gcp") # get a free api key from app.pinecone.io
|
34 |
+
return pinecone.Index("abstractive-question-answering")
|
35 |
+
|
36 |
+
retriever, generator = init_models()
|
37 |
+
index = init_pinecone()
|
38 |
+
|
39 |
+
def display_answer(answer):
|
40 |
+
return st.markdown(f"""
|
41 |
+
<div class="container-fluid">
|
42 |
+
<div class="row align-items-start">
|
43 |
+
<div class="col-md-12 col-sm-12">
|
44 |
+
<span style="color: #808080;">
|
45 |
+
{answer}
|
46 |
+
</span>
|
47 |
+
</div>
|
48 |
+
</div>
|
49 |
+
</div>
|
50 |
+
""", unsafe_allow_html=True)
|
51 |
+
|
52 |
+
def display_context(title, context, url):
|
53 |
+
return st.markdown(f"""
|
54 |
+
<div class="container-fluid">
|
55 |
+
<div class="row align-items-start">
|
56 |
+
<div class="col-md-12 col-sm-12">
|
57 |
+
<a href={url}>{title}</a>
|
58 |
+
<br>
|
59 |
+
<span style="color: #808080;">
|
60 |
+
<small>{context}</small>
|
61 |
+
</span>
|
62 |
+
</div>
|
63 |
+
</div>
|
64 |
+
</div>
|
65 |
+
""", unsafe_allow_html=True)
|
66 |
+
|
67 |
+
st.write("""
|
68 |
+
# Abstractive Question Answering
|
69 |
+
Ask me a question!
|
70 |
+
""")
|
71 |
+
|
72 |
+
st.markdown("""
|
73 |
+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
|
74 |
+
""", unsafe_allow_html=True)
|
75 |
+
|
76 |
+
def format_query(query, context):
|
77 |
+
context = [f"<P> {m['metadata']['passage_text']}" for m in context]
|
78 |
+
context = " ".join(context)
|
79 |
+
query = f"question: {query} context: {context}"
|
80 |
+
return query
|
81 |
+
|
82 |
+
st.sidebar.subheader("Retriever parameters:")
|
83 |
+
top_k = st.sidebar.slider("Top K", min_value=1, max_value=10, value=5)
|
84 |
+
|
85 |
+
st.sidebar.subheader("Generator parameters:")
|
86 |
+
min_length = st.sidebar.slider("Minimum Length", min_value=1, max_value=50, value=20)
|
87 |
+
max_length = st.sidebar.slider("Maximum Length", min_value=1, max_value=100, value=50)
|
88 |
+
|
89 |
+
query = st.text_input("Search!", "")
|
90 |
+
|
91 |
+
if query != "":
|
92 |
+
with st.spinner(text="Fetching context passages 🚀🚀🚀"):
|
93 |
+
xq = retriever.encode([query]).tolist()
|
94 |
+
xc = index.query(xq, top_k=int(top_k), include_metadata=True)
|
95 |
+
query = format_query(query, xc["matches"])
|
96 |
+
|
97 |
+
with st.spinner(text="Generating answer ✍️✍️✍️"):
|
98 |
+
answer = generator.generate(query, min_length=min_length, max_length=max_length)
|
99 |
+
|
100 |
+
st.write("#### Generated answer:")
|
101 |
+
display_answer(answer)
|
102 |
+
st.write("#### Answer was generated based on the following passages:")
|
103 |
+
|
104 |
+
for m in xc["matches"]:
|
105 |
+
title = m["metadata"]["article_title"]
|
106 |
+
url = "https://en.wikipedia.org/wiki/" + title.replace(" ", "_")
|
107 |
+
context = m["metadata"]["passage_text"]
|
108 |
+
display_context(title, context, url)
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
pinecone-client
|
2 |
+
sentence-transformers
|