|
import json |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from torch import Tensor |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
import gradio as gr |
|
|
|
|
|
def get_model(base_name='intfloat/e5-large-v2'): |
|
tokenizer = AutoTokenizer.from_pretrained(base_name) |
|
model = AutoModel.from_pretrained(base_name) |
|
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
|
def get_scores(model, tokenizer, input_texts, max_length=512, how_many_q=1): |
|
|
|
batch_dict = tokenizer( |
|
input_texts, |
|
max_length=max_length, |
|
padding=True, |
|
truncation=True, |
|
return_tensors='pt' |
|
) |
|
|
|
outputs = model(**batch_dict) |
|
embeddings = average_pool( |
|
outputs.last_hidden_state, batch_dict['attention_mask'] |
|
) |
|
|
|
|
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
scores = (embeddings[:how_many_q] @ embeddings[how_many_q:].T) * 100 |
|
|
|
return scores |
|
|
|
|
|
|
|
def get_top(scores, top_k=None): |
|
result = torch.sort(scores, descending=True, dim=1) |
|
top_indices = result.indices |
|
top_values = result.values |
|
|
|
if top_k: |
|
top_indices = top_indices[:, :top_k] |
|
top_values = top_values[:, :top_k] |
|
|
|
return top_indices, top_values |
|
|
|
|
|
|
|
def get_human_readable_top(scores, input_texts, top_k=None): |
|
input_texts = list(filter(lambda text: "query:" not in text, input_texts)) |
|
top_indices, top_values = get_top(scores, top_k) |
|
|
|
result = {} |
|
for input_idx, (indices, values) in enumerate(zip(top_indices, top_values)): |
|
q = input_texts[input_idx] |
|
a = [] |
|
|
|
for idx, val in zip(indices.tolist(), values.tolist()): |
|
a.append({ |
|
"idx": idx, |
|
"val": round(val, 3), |
|
"text": input_texts[idx] |
|
}) |
|
|
|
result[q] = a |
|
|
|
return result |
|
|
|
def average_pool(last_hidden_states: Tensor, |
|
attention_mask: Tensor) -> Tensor: |
|
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) |
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
|
|
def get_result(q_txt, p_txt1, p_txt2, p_txt3, p_txt4, p_txt5): |
|
input_texts = [ |
|
f"query: {q_txt}" |
|
] |
|
|
|
if p_txt1 != '': |
|
input_txt.append(f"passage: {p_txt1}") |
|
|
|
if p_txt2 != '': |
|
input_txt.append(f"passage: {p_txt2}") |
|
|
|
if p_txt3 != '': |
|
input_txt.append(f"passage: {p_txt3}") |
|
|
|
if p_txt4 != '': |
|
input_txt.append(f"passage: {p_txt4}") |
|
|
|
if p_txt5 != '': |
|
input_txt.append(f"passage: {p_txt5}") |
|
|
|
scores = get_scores(model, tokenizer, input_texts) |
|
result = get_human_readable_top(scores, input_texts) |
|
return json.dumps(result, indent=4) |
|
|
|
tokenizer, model = get_model('intfloat/e5-large-v2') |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# E5 Large V2 Demo") |
|
|
|
q_txt = gr.Textbox(placeholder="Enter your query", info="Query") |
|
|
|
p_txt1 = gr.Textbox(placeholder="Enter passage 1", info="Passage 1") |
|
p_txt2 = gr.Textbox(placeholder="Enter passage 2", info="Passage 2") |
|
p_txt3 = gr.Textbox(placeholder="Enter passage 3", info="Passage 3") |
|
p_txt4 = gr.Textbox(placeholder="Enter passage 4", info="Passage 4") |
|
p_txt5 = gr.Textbox(placeholder="Enter passage 5", info="Passage 5") |
|
submit = gr.Button("Submit") |
|
o_txt = gr.Textbox(placeholder="Output", lines=10, interactive=False) |
|
|
|
submit.click( |
|
get_result, |
|
[q_txt, p_txt1, p_txt2, p_txt3, p_txt4, p_txt5], |
|
o_txt |
|
) |
|
|
|
demo.launch() |