BillBojangeles2000 commited on
Commit
abb2086
1 Parent(s): 147a264

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -63
app.py CHANGED
@@ -2,67 +2,70 @@ import pinecone
2
  import streamlit as st
3
 
4
  API = st.text_area('Enter API key:')
5
-
6
- # connect to pinecone environment
7
- pinecone.init(
8
- api_key="API",
9
- environment="us-central1-gcp" # find next to API key in console
10
- )
11
-
12
- index_name = "abstractive-question-answering"
13
-
14
- # check if the abstractive-question-answering index exists
15
- if index_name not in pinecone.list_indexes():
16
- # create the index if it does not exist
17
- pinecone.create_index(
18
- index_name,
19
- dimension=768,
20
- metric="cosine"
21
  )
22
-
23
- # connect to abstractive-question-answering index we created
24
- index = pinecone.Index(index_name)
25
-
26
- import torch
27
- from sentence_transformers import SentenceTransformer
28
-
29
- # set device to GPU if available
30
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
- # load the retriever model from huggingface model hub
32
- retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base", device=device)
33
-
34
- from transformers import BartTokenizer, BartForConditionalGeneration
35
-
36
- # load bart tokenizer and model from huggingface
37
- tokenizer = BartTokenizer.from_pretrained('vblagoje/bart_lfqa')
38
- generator = BartForConditionalGeneration.from_pretrained('vblagoje/bart_lfqa').to('cpu')
39
-
40
- def query_pinecone(query, top_k):
41
- # generate embeddings for the query
42
- xq = retriever.encode([query]).tolist()
43
- # search pinecone index for context passage with the answer
44
- xc = index.query(xq, top_k=top_k, include_metadata=True)
45
- return xc
46
-
47
- def format_query(query, context):
48
- # extract passage_text from Pinecone search result and add the <P> tag
49
- context = [f"<P> {m['metadata']['text']}" for m in context]
50
- # concatinate all context passages
51
- context = " ".join(context)
52
- # contcatinate the query and context passages
53
- query = f"question: {query} context: {context}"
54
- return query
55
-
56
- def generate_answer(query):
57
- # tokenize the query to get input_ids
58
- inputs = tokenizer([query], trunication=True, max_length=1024, return_tensors="pt")
59
- # use generator to predict output ids
60
- ids = generator.generate(inputs["input_ids"], num_beams=2, min_length=20, max_length=64)
61
- # use tokenizer to decode the output ids
62
- answer = tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
63
- return pprint(answer)
64
-
65
- query = st.text_area('Enter your question:')
66
- context = query_pinecone(query, top_k=5)
67
- query = format_query(query, context["matches"])
68
- generate_answer(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import streamlit as st
3
 
4
  API = st.text_area('Enter API key:')
5
+ res = st.button('Submit')
6
+ if res = True:
7
+ # connect to pinecone environment
8
+ pinecone.init(
9
+ api_key="API",
10
+ environment="us-central1-gcp" # find next to API key in console
 
 
 
 
 
 
 
 
 
 
11
  )
12
+
13
+ index_name = "abstractive-question-answering"
14
+
15
+ # check if the abstractive-question-answering index exists
16
+ if index_name not in pinecone.list_indexes():
17
+ # create the index if it does not exist
18
+ pinecone.create_index(
19
+ index_name,
20
+ dimension=768,
21
+ metric="cosine"
22
+ )
23
+
24
+ # connect to abstractive-question-answering index we created
25
+ index = pinecone.Index(index_name)
26
+
27
+ import torch
28
+ from sentence_transformers import SentenceTransformer
29
+
30
+ # set device to GPU if available
31
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
32
+ # load the retriever model from huggingface model hub
33
+ retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base", device=device)
34
+
35
+ from transformers import BartTokenizer, BartForConditionalGeneration
36
+
37
+ # load bart tokenizer and model from huggingface
38
+ tokenizer = BartTokenizer.from_pretrained('vblagoje/bart_lfqa')
39
+ generator = BartForConditionalGeneration.from_pretrained('vblagoje/bart_lfqa').to('cpu')
40
+
41
+ def query_pinecone(query, top_k):
42
+ # generate embeddings for the query
43
+ xq = retriever.encode([query]).tolist()
44
+ # search pinecone index for context passage with the answer
45
+ xc = index.query(xq, top_k=top_k, include_metadata=True)
46
+ return xc
47
+
48
+ def format_query(query, context):
49
+ # extract passage_text from Pinecone search result and add the <P> tag
50
+ context = [f"<P> {m['metadata']['text']}" for m in context]
51
+ # concatinate all context passages
52
+ context = " ".join(context)
53
+ # contcatinate the query and context passages
54
+ query = f"question: {query} context: {context}"
55
+ return query
56
+
57
+ def generate_answer(query):
58
+ # tokenize the query to get input_ids
59
+ inputs = tokenizer([query], trunication=True, max_length=1024, return_tensors="pt")
60
+ # use generator to predict output ids
61
+ ids = generator.generate(inputs["input_ids"], num_beams=2, min_length=20, max_length=64)
62
+ # use tokenizer to decode the output ids
63
+ answer = tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64
+ return pprint(answer)
65
+
66
+ query = st.text_area('Enter your question:')
67
+ s = st.button('Submit')
68
+ if s = True:
69
+ context = query_pinecone(query, top_k=5)
70
+ query = format_query(query, context["matches"])
71
+ generate_answer(query)