ShynBui commited on
Commit
c84cd95
1 Parent(s): b84746e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -4
app.py CHANGED
@@ -1,10 +1,92 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
  if __name__ == "__main__":
7
 
8
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
9
- iface.launch()
10
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ from langchain.retrievers import EnsembleRetriever
4
+ from utils import *
5
+ import requests
6
+ from pyvi import ViTokenizer, ViPosTagger
7
+ import time
8
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
9
+ import torch
10
+
11
+ retriever = load_the_embedding_retrieve(is_ready=True, k=3)
12
+ bm25_retriever = load_the_bm25_retrieve(k=3)
13
+
14
+ ensemble_retriever = EnsembleRetriever(
15
+ retrievers=[bm25_retriever, retriever], weights=[0.5, 0.5]
16
+ )
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained("ShynBui/vie_qa", token=os.environ.get("HF_TOKEN"))
19
+ model = AutoModelForQuestionAnswering.from_pretrained("ShynBui/vie_qa", token=os.environ.get("HF_TOKEN"))
20
+
21
+ headers = {
22
+ "Accept": "application/json",
23
+ "Authorization": "Bearer "+ os.environ.get("HF_TOKEN"),
24
+ "Content-Type": "application/json"
25
+ }
26
+
27
+
28
+ def query(payload):
29
+ response = requests.post(API_URL, headers=headers, json=payload)
30
+ return response.json()
31
+
32
+
33
+ def greet(quote):
34
+ sources = []
35
+ answers = []
36
+ scores = []
37
+ ids = []
38
+
39
+ docs = ensemble_retriever.get_relevant_documents(quote)
40
+
41
+ for i in docs:
42
+ context = ViTokenizer.tokenize(i.page_content)
43
+ question = ViTokenizer.tokenize(quote)
44
+ print("source:", i.metadata['source'])
45
+ sources.append(i.metadata['source'])
46
+ output = query({
47
+ "inputs": {
48
+ "question": question,
49
+ "context": context[:256]
50
+ },
51
+ })
52
+ while "error" in output:
53
+ # print('fail')
54
+ time.sleep(1)
55
+ output = query({
56
+ "inputs": {
57
+ "question": question,
58
+ "context": context[:256]
59
+ },
60
+ })
61
+
62
+ answers.append(output['answer'])
63
+ return answers
64
+
65
+ def greet2(quote):
66
+ answers = []
67
+ docs = ensemble_retriever.get_relevant_documents(quote)
68
+
69
+ for i in docs:
70
+ context = ViTokenizer.tokenize(i.page_content)
71
+ question = ViTokenizer.tokenize(quote)
72
+
73
+ inputs = tokenizer(question, context, return_tensors="pt")
74
+
75
+ outputs = model(**inputs)
76
+
77
+ start_index = torch.argmax(outputs.start_logits)
78
+ end_index = torch.argmax(outputs.end_logits) + 1
79
+
80
+ answer = tokenizer.convert_tokens_to_string(
81
+ tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start_index:end_index]))
82
+
83
+ answers.append(answer)
84
+
85
+ return answers
86
 
 
 
87
 
88
  if __name__ == "__main__":
89
 
 
 
90
 
91
+ iface = gr.Interface(fn=greet2, inputs="text", outputs="text")
92
+ iface.launch()