File size: 3,434 Bytes
f61780d cee8f80 81d489e f61780d b9334e6 f61780d 8c7622b b9334e6 f61780d 7fc9b1c f61780d 2d61b6b f61780d b9334e6 f61780d 7fc9b1c f61780d 2d61b6b f61780d 7fc9b1c c0e5a43 23b16d0 c0e5a43 cee8f80 c0e5a43 23b16d0 c0e5a43 7fc9b1c 0569929 7fc9b1c 2e56f36 7fc9b1c 2466cd5 2e56f36 7fc9b1c f61780d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import gradio as gr
import pandas as pd
import numpy as np
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import re
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
df = pd.read_csv('wiki.csv')
data_embeddings = np.load("wiki-embeddings.npy")
print("loading the model...")
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large')
with gr.Blocks() as demo:
chatbot = gr.Chatbot(label="semantic search for 230k+ wikipedia articles")
msg = gr.Textbox(label="simple wikipedia semantic search query", placeholder="for example, \"medieval battles\"")
clear = gr.ClearButton([msg, chatbot])
def _search(message, chat_history):
batch_dict = tokenizer(["query: " + message], max_length=512, padding=True, truncation=True, return_tensors='pt')
outputs = model(**batch_dict)
input_embedding = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
# normalize embeddings
input_embedding = F.normalize(input_embedding, p=2, dim=1)
input_embedding = input_embedding[0].tolist()
# Compute cosine similarities
input_embedding = np.array(input_embedding).reshape(1, -1)
cos_similarities = cosine_similarity(data_embeddings, input_embedding).flatten()
# Get top k similar points' indices
k = 10 # replace with your value of k
top_k_idx = cos_similarities.argsort()[-k:][::-1]
# Get corresponding 'text' for top k similar points
top_k_text = df['title'].iloc[top_k_idx].tolist()
bot_message = "\n".join(f"{i+1}. {top_k_text[i]} // {top_k_idx[i]}" for i in range(len(top_k_text)))
chat_history.append((message, f"results (you can enter article number 1-{k} to see its contents):\n" + bot_message))
return "", chat_history
def _retrieve(message, chat_history):
idx = int(message)
for _, m in chat_history[::-1]:
if m.startswith("results"):
for n in m.split("\n")[1:]:
if str(idx) == n.split(".")[0]:
df_idx = int(n.split(" // ")[-1])
article = df.iloc[df_idx]['text']
article = re.sub(r'(===?=?[A-Z ].+?===?=?)', r'\n\n\1\n', article)
chat_history.append((message, f"contents of {n}:\n{article}"))
return "", chat_history
print("nothing found")
chat_history.append((message, "🤔 article not found"))
return "", chat_history
def respond(message, chat_history):
print(f"received input '{message}'")
print(f"retrieving #{message}")
return _retrieve(message, chat_history)
except ValueError:
print(f"searching for {message}")
return _search(message, chat_history)
msg.submit(respond, [msg, chatbot], [msg, chatbot])
demo.launch() |