chansung commited on
Commit
56d5504
1 Parent(s): 9e25881

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -7
app.py CHANGED
@@ -1,18 +1,126 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  with gr.Blocks() as demo:
4
  gr.Markdown("# E5 Large V2 Demo")
5
 
6
- q_txt = gr.Textbox(placeholder="Enter your query")
7
 
8
- p_txt1 = gr.Textbox(placeholder="Enter passage 1")
9
- p_txt2 = gr.Textbox(placeholder="Enter passage 2")
10
- p_txt3 = gr.Textbox(placeholder="Enter passage 3")
11
- p_txt4 = gr.Textbox(placeholder="Enter passage 4")
12
- p_txt5 = gr.Textbox(placeholder="Enter passage 5")
13
- p_txt6 = gr.Textbox(placeholder="Enter passage 6")
14
 
15
  submit = gr.Button("Submit")
 
 
 
 
 
16
 
17
  o_txt = gr.Textbox(placeholder="Output", lines=10, interactive=False)
18
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from torch import Tensor
6
+ from transformers import AutoTokenizer, AutoModel
7
+
8
  import gradio as gr
9
 
10
+ # instantiate tokenizer and model
11
+ def get_model(base_name='intfloat/e5-large-v2'):
12
+ tokenizer = AutoTokenizer.from_pretrained(base_name)
13
+ model = AutoModel.from_pretrained(base_name)
14
+
15
+ return tokenizer, model
16
+
17
+ # get normalized scores on input_texts, the final scores are
18
+ # reported without queries, and the number of queries should
19
+ # be denoted as in how_many_q
20
+ def get_scores(model, tokenizer, input_texts, max_length=512, how_many_q=1):
21
+ # Tokenize the input texts
22
+ batch_dict = tokenizer(
23
+ input_texts,
24
+ max_length=max_length,
25
+ padding=True,
26
+ truncation=True,
27
+ return_tensors='pt'
28
+ )
29
+
30
+ outputs = model(**batch_dict)
31
+ embeddings = average_pool(
32
+ outputs.last_hidden_state, batch_dict['attention_mask']
33
+ )
34
+
35
+ # (Optionally) normalize embeddings
36
+ embeddings = F.normalize(embeddings, p=2, dim=1)
37
+ scores = (embeddings[:how_many_q] @ embeddings[how_many_q:].T) * 100
38
+
39
+ return scores
40
+
41
+ # get top n results out of the scores. This
42
+ # function only returns the scores and indices
43
+ def get_top(scores, top_k=None):
44
+ result = torch.sort(scores, descending=True, dim=1)
45
+ top_indices = result.indices
46
+ top_values = result.values
47
+
48
+ if top_k:
49
+ top_indices = top_indices[:, :top_k]
50
+ top_values = top_values[:, :top_k]
51
+
52
+ return top_indices, top_values
53
+
54
+ # get top n results out of the scores. This function
55
+ # returns scores and indices along with the associated text
56
+ def get_human_readable_top(scores, input_texts, top_k=None):
57
+ input_texts = list(filter(lambda text: "query:" not in text, input_texts))
58
+ top_indices, top_values = get_top(scores, top_k)
59
+
60
+ result = {}
61
+ for input_idx, (indices, values) in enumerate(zip(top_indices, top_values)):
62
+ q = input_texts[input_idx]
63
+ a = []
64
+
65
+ for idx, val in zip(indices.tolist(), values.tolist()):
66
+ a.append({
67
+ "idx": idx,
68
+ "val": round(val, 3),
69
+ "text": input_texts[idx]
70
+ })
71
+
72
+ result[q] = a
73
+
74
+ return result
75
+
76
+ def average_pool(last_hidden_states: Tensor,
77
+ attention_mask: Tensor) -> Tensor:
78
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
79
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
80
+
81
+ def get_result(q_txt, p_txt1, p_txt2, p_txt3, p_txt4, p_txt5):
82
+ input_texts = [
83
+ f"query: {q_txt}"
84
+ ]
85
+
86
+ if p_txt1 != '':
87
+ input_txt.append(f"passage: {p_txt1}")
88
+
89
+ if p_txt2 != '':
90
+ input_txt.append(f"passage: {p_txt2}")
91
+
92
+ if p_txt3 != '':
93
+ input_txt.append(f"passage: {p_txt3}")
94
+
95
+ if p_txt4 != '':
96
+ input_txt.append(f"passage: {p_txt4}")
97
+
98
+ if p_txt5 != '':
99
+ input_txt.append(f"passage: {p_txt5}")
100
+
101
+ scores = get_scores(model, tokenizer, input_texts)
102
+ result = get_human_readable_top(scores, input_texts)
103
+ return json.dumps(result, indent=4)
104
+
105
+ tokenizer, model = get_model('intfloat/e5-large-v2')
106
+
107
  with gr.Blocks() as demo:
108
  gr.Markdown("# E5 Large V2 Demo")
109
 
110
+ q_txt = gr.Textbox(placeholder="Enter your query", info="Query")
111
 
112
+ p_txt1 = gr.Textbox(placeholder="Enter passage 1", info="Passage 1")
113
+ p_txt2 = gr.Textbox(placeholder="Enter passage 2", info="Passage 2")
114
+ p_txt3 = gr.Textbox(placeholder="Enter passage 3", info="Passage 3")
115
+ p_txt4 = gr.Textbox(placeholder="Enter passage 4", info="Passage 4")
116
+ p_txt5 = gr.Textbox(placeholder="Enter passage 5", info="Passage 5")
 
117
 
118
  submit = gr.Button("Submit")
119
+ submit.click(
120
+ get_result,
121
+ [q_txt, p_txt1, p_txt2, p_txt3, p_txt4, p_txt5],
122
+ o_txt
123
+ )
124
 
125
  o_txt = gr.Textbox(placeholder="Output", lines=10, interactive=False)
126