File size: 3,434 Bytes
f61780d
 
 
 
 
 
 
 
 
 
 
cee8f80
81d489e
f61780d
 
 
 
 
 
 
b9334e6
 
f61780d
 
 
 
 
 
8c7622b
b9334e6
f61780d
 
7fc9b1c
f61780d
 
 
 
 
 
 
 
 
 
 
 
 
 
2d61b6b
f61780d
 
 
b9334e6
f61780d
7fc9b1c
f61780d
2d61b6b
f61780d
 
7fc9b1c
 
c0e5a43
23b16d0
c0e5a43
 
 
 
 
cee8f80
 
c0e5a43
 
 
23b16d0
 
c0e5a43
7fc9b1c
0569929
7fc9b1c
 
2e56f36
7fc9b1c
2466cd5
2e56f36
7fc9b1c
 
f61780d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr

import pandas as pd
import numpy as np

import torch.nn.functional as F

from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity

import re


def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


df = pd.read_csv('wiki.csv')
data_embeddings = np.load("wiki-embeddings.npy")

print("loading the model...")
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large')

with gr.Blocks() as demo:
    chatbot = gr.Chatbot(label="semantic search for 230k+ wikipedia articles")
    msg = gr.Textbox(label="simple wikipedia semantic search query", placeholder="for example, \"medieval battles\"")
    clear = gr.ClearButton([msg, chatbot])

    def _search(message, chat_history):
        batch_dict = tokenizer(["query: " + message], max_length=512, padding=True, truncation=True, return_tensors='pt')
    
        outputs = model(**batch_dict)
        input_embedding = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
    
        # normalize embeddings
        input_embedding = F.normalize(input_embedding, p=2, dim=1)
        input_embedding = input_embedding[0].tolist()
    
        # Compute cosine similarities
        input_embedding = np.array(input_embedding).reshape(1, -1)
        cos_similarities = cosine_similarity(data_embeddings, input_embedding).flatten()
    
        # Get top k similar points' indices
        k = 10  # replace with your value of k
        top_k_idx = cos_similarities.argsort()[-k:][::-1]
    
        # Get corresponding 'text' for top k similar points
        top_k_text = df['title'].iloc[top_k_idx].tolist()
        
        bot_message = "\n".join(f"{i+1}. {top_k_text[i]} // {top_k_idx[i]}" for i in range(len(top_k_text)))
    
        chat_history.append((message, f"results (you can enter article number 1-{k} to see its contents):\n" + bot_message))
        return "", chat_history

    def _retrieve(message, chat_history):
        idx = int(message)
        for _, m in chat_history[::-1]:
            if m.startswith("results"):
                for n in m.split("\n")[1:]:
                    print(n)
                    if str(idx) == n.split(".")[0]:
                        df_idx = int(n.split(" // ")[-1])
                        print(df_idx)
                        article = df.iloc[df_idx]['text']
                        article = re.sub(r'(===?=?[A-Z ].+?===?=?)', r'\n\n\1\n', article)
                        chat_history.append((message, f"contents of {n}:\n{article}"))
                        return "", chat_history
        print("nothing found")
        chat_history.append((message, "🤔 article not found"))
        return "", chat_history
    
    def respond(message, chat_history):
        print(f"received input '{message}'")
        try:
            int(message)
            print(f"retrieving #{message}")
            return _retrieve(message, chat_history)
        except ValueError:
            print(f"searching for {message}")
            return _search(message, chat_history)

    msg.submit(respond, [msg, chatbot], [msg, chatbot])

demo.launch()