File size: 3,883 Bytes
56d5504
 
 
 
 
 
 
8c50326
 
56d5504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ce0c52
9e25881
 
56d5504
9e25881
56d5504
 
 
 
 
9e25881
 
56d5504
 
 
 
 
9e25881
 
8c50326
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import json
import torch
import torch.nn.functional as F

from torch import Tensor
from transformers import AutoTokenizer, AutoModel

import gradio as gr

# instantiate tokenizer and model
def get_model(base_name='intfloat/e5-large-v2'):
    tokenizer = AutoTokenizer.from_pretrained(base_name)
    model = AutoModel.from_pretrained(base_name)
    
    return tokenizer, model

# get normalized scores on input_texts, the final scores are
# reported without queries, and the number of queries should
# be denoted as in how_many_q
def get_scores(model, tokenizer, input_texts, max_length=512, how_many_q=1):
    # Tokenize the input texts
    batch_dict = tokenizer(
        input_texts,
        max_length=max_length,
        padding=True,
        truncation=True,
        return_tensors='pt'
    )

    outputs = model(**batch_dict)
    embeddings = average_pool(
        outputs.last_hidden_state, batch_dict['attention_mask']
    )

    # (Optionally) normalize embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)
    scores = (embeddings[:how_many_q] @ embeddings[how_many_q:].T) * 100

    return scores

# get top n results out of the scores. This
# function only returns the scores and indices
def get_top(scores, top_k=None):
    result = torch.sort(scores, descending=True, dim=1)
    top_indices = result.indices
    top_values = result.values

    if top_k:
        top_indices = top_indices[:, :top_k]
        top_values = top_values[:, :top_k]

    return top_indices, top_values

# get top n results out of the scores. This function
# returns scores and indices along with the associated text
def get_human_readable_top(scores, input_texts, top_k=None):
    input_texts = list(filter(lambda text: "query:" not in text, input_texts))
    top_indices, top_values = get_top(scores, top_k)

    result = {}
    for input_idx, (indices, values) in enumerate(zip(top_indices, top_values)):
        q = input_texts[input_idx]
        a = []

        for idx, val in zip(indices.tolist(), values.tolist()):
            a.append({
            "idx": idx,
            "val": round(val, 3),
            "text": input_texts[idx]
            })

        result[q] = a

    return result

def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def get_result(q_txt, p_txt1, p_txt2, p_txt3, p_txt4, p_txt5):
    input_texts = [
        f"query: {q_txt}"
    ]

    if p_txt1 != '':
        input_txt.append(f"passage: {p_txt1}")

    if p_txt2 != '':
        input_txt.append(f"passage: {p_txt2}")

    if p_txt3 != '':
        input_txt.append(f"passage: {p_txt3}")        

    if p_txt4 != '':
        input_txt.append(f"passage: {p_txt4}")

    if p_txt5 != '':
        input_txt.append(f"passage: {p_txt5}")

    scores = get_scores(model, tokenizer, input_texts)
    result = get_human_readable_top(scores, input_texts)
    return json.dumps(result, indent=4)

tokenizer, model = get_model('intfloat/e5-large-v2')

with gr.Blocks() as demo:
    gr.Markdown("# E5 Large V2 Demo")
    
    q_txt = gr.Textbox(placeholder="Enter your query", info="Query")

    p_txt1 = gr.Textbox(placeholder="Enter passage 1", info="Passage 1")
    p_txt2 = gr.Textbox(placeholder="Enter passage 2", info="Passage 2")
    p_txt3 = gr.Textbox(placeholder="Enter passage 3", info="Passage 3")
    p_txt4 = gr.Textbox(placeholder="Enter passage 4", info="Passage 4")
    p_txt5 = gr.Textbox(placeholder="Enter passage 5", info="Passage 5")

    submit = gr.Button("Submit")
    submit.click(
        get_result,
        [q_txt, p_txt1, p_txt2, p_txt3, p_txt4, p_txt5],
        o_txt
    )
    
    o_txt = gr.Textbox(placeholder="Output", lines=10, interactive=False)

demo.launch()