import json import torch import torch.nn.functional as F from torch import Tensor from transformers import AutoTokenizer, AutoModel import gradio as gr # instantiate tokenizer and model def get_model(base_name='intfloat/e5-large-v2'): tokenizer = AutoTokenizer.from_pretrained(base_name) model = AutoModel.from_pretrained(base_name) return tokenizer, model # get normalized scores on input_texts, the final scores are # reported without queries, and the number of queries should # be denoted as in how_many_q def get_scores(model, tokenizer, input_texts, max_length=512, how_many_q=1): # Tokenize the input texts batch_dict = tokenizer( input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt' ) outputs = model(**batch_dict) embeddings = average_pool( outputs.last_hidden_state, batch_dict['attention_mask'] ) # (Optionally) normalize embeddings embeddings = F.normalize(embeddings, p=2, dim=1) scores = (embeddings[:how_many_q] @ embeddings[how_many_q:].T) * 100 return scores # get top n results out of the scores. This # function only returns the scores and indices def get_top(scores, top_k=None): result = torch.sort(scores, descending=True, dim=1) top_indices = result.indices top_values = result.values if top_k: top_indices = top_indices[:, :top_k] top_values = top_values[:, :top_k] return top_indices, top_values # get top n results out of the scores. This function # returns scores and indices along with the associated text def get_human_readable_top(scores, input_texts, top_k=None): input_texts = list(filter(lambda text: "query:" not in text, input_texts)) top_indices, top_values = get_top(scores, top_k) result = {} for input_idx, (indices, values) in enumerate(zip(top_indices, top_values)): q = input_texts[input_idx] a = [] for idx, val in zip(indices.tolist(), values.tolist()): a.append({ "idx": idx, "val": round(val, 3), "text": input_texts[idx] }) result[q] = a return result def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] def get_result(q_txt, p_txt1, p_txt2, p_txt3, p_txt4, p_txt5): input_texts = [ f"query: {q_txt}" ] if p_txt1 != '': input_txt.append(f"passage: {p_txt1}") if p_txt2 != '': input_txt.append(f"passage: {p_txt2}") if p_txt3 != '': input_txt.append(f"passage: {p_txt3}") if p_txt4 != '': input_txt.append(f"passage: {p_txt4}") if p_txt5 != '': input_txt.append(f"passage: {p_txt5}") scores = get_scores(model, tokenizer, input_texts) result = get_human_readable_top(scores, input_texts) return json.dumps(result, indent=4) tokenizer, model = get_model('intfloat/e5-large-v2') with gr.Blocks() as demo: gr.Markdown("# E5 Large V2 Demo") q_txt = gr.Textbox(placeholder="Enter your query", info="Query") p_txt1 = gr.Textbox(placeholder="Enter passage 1", info="Passage 1") p_txt2 = gr.Textbox(placeholder="Enter passage 2", info="Passage 2") p_txt3 = gr.Textbox(placeholder="Enter passage 3", info="Passage 3") p_txt4 = gr.Textbox(placeholder="Enter passage 4", info="Passage 4") p_txt5 = gr.Textbox(placeholder="Enter passage 5", info="Passage 5") submit = gr.Button("Submit") submit.click( get_result, [q_txt, p_txt1, p_txt2, p_txt3, p_txt4, p_txt5], o_txt ) o_txt = gr.Textbox(placeholder="Output", lines=10, interactive=False) demo.launch()