ShynBui commited on
Commit
922ff42
1 Parent(s): ef19859

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -63
app.py CHANGED
@@ -15,77 +15,19 @@ ensemble_retriever = EnsembleRetriever(
15
  retrievers=[bm25_retriever, retriever], weights=[0.5, 0.5]
16
  )
17
 
18
- tokenizer = AutoTokenizer.from_pretrained("ShynBui/vie_qa", token=os.environ.get("HF_TOKEN"))
19
- model = AutoModelForQuestionAnswering.from_pretrained("ShynBui/vie_qa", token=os.environ.get("HF_TOKEN"))
20
-
21
- headers = {
22
- "Accept": "application/json",
23
- "Authorization": "Bearer " + os.environ.get("HF_TOKEN"),
24
- "Content-Type": "application/json"
25
- }
26
-
27
-
28
- def query(payload):
29
- response = requests.post(os.environ.get("API_URL"), headers=headers, json=payload)
30
- return response.json()
31
-
32
-
33
- def greet(quote):
34
- sources = []
35
- answers = []
36
- scores = []
37
- ids = []
38
-
39
- docs = ensemble_retriever.get_relevant_documents(quote)
40
-
41
- for i in docs:
42
- context = ViTokenizer.tokenize(i.page_content)
43
- question = ViTokenizer.tokenize(quote)
44
- print("source:", i.metadata['source'])
45
- sources.append(i.metadata['source'])
46
- output = query({
47
- "inputs": {
48
- "question": question,
49
- "context": context[:256]
50
- },
51
- })
52
- while "error" in output:
53
- # print('fail')
54
- time.sleep(1)
55
- output = query({
56
- "inputs": {
57
- "question": question,
58
- "context": context[:256]
59
- },
60
- })
61
-
62
- answers.append(output['answer'])
63
- return answers
64
 
65
 
66
  def greet2(quote):
67
- answers = []
68
- docs = ensemble_retriever.get_relevant_documents(quote)
69
-
70
- return docs
71
-
72
- for i in docs:
73
- context = ViTokenizer.tokenize(i.page_content)
74
- question = ViTokenizer.tokenize(quote)
75
-
76
- inputs = tokenizer(question, context, return_tensors="pt")
77
 
78
- outputs = model(**inputs)
79
 
80
- start_index = torch.argmax(outputs.start_logits)
81
- end_index = torch.argmax(outputs.end_logits) + 1
82
 
83
- answer = tokenizer.convert_tokens_to_string(
84
- tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start_index:end_index]))
85
 
86
- answers.append(answer)
87
 
88
- return answers
89
 
90
 
91
  if __name__ == "__main__":
 
15
  retrievers=[bm25_retriever, retriever], weights=[0.5, 0.5]
16
  )
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def greet2(quote):
 
 
 
 
 
 
 
 
 
 
21
 
22
+ qa_chain = get_qachain(retriever=ensemble_retriever)
23
 
24
+ prompt = os.environ['PROMPT']
 
25
 
26
+ qa_chain.combine_documents_chain.llm_chain.prompt.messages[0].prompt.template = prompt
 
27
 
28
+ llm_response = qa_chain(quote)
29
 
30
+ return llm_response['result']
31
 
32
 
33
  if __name__ == "__main__":