File size: 2,693 Bytes
b38ebdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ff1651
 
b38ebdd
3ff1651
 
b38ebdd
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F


# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


class Matcher:

    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
        self.model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

    def _encoder(self, text: list[str]):
        encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')
        with torch.no_grad():
            model_output = self.model(**encoded_input)
        sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        return sentence_embeddings

    def __call__(self, textA: list[str], textB: list[str]):
        embeddings_a = self._encoder(textA)
        embeddings_b = self._encoder(textB)
        sim = embeddings_a @ embeddings_b.T
        match_inds = torch.argmax(sim, dim=1)
        match_conf = torch.max(sim, dim=1).values
        return match_inds.tolist(), match_conf.tolist()


def run_match(source_text, destination_text):
    matcher = Matcher()
    sources = source_text.split("\n")
    destinations = destination_text.split("\n")
    match_inds, match_conf = matcher(sources, destinations)
    matches = [f"{sources[i]} -> {destinations[match_inds[i]]} ({match_conf[i]:.2f})" for i in
               range(len(sources))]
    return "\n".join(matches)


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            source_text = gr.Textbox(lines=10, label="Query Text", name="source_text",
                                     default="diavola with extra chillies\nseafood\nmargherita")
        with gr.Column():
            dest_text = gr.Textbox(lines=10, label="Target Text", name="destination_text",
                                   default="cheese pizza\nhot and spicy pizza\ntuna, prawn and onion pizza")
        with gr.Column():
            matches = gr.Textbox(lines=10, label="Matches", name="matches")
    with gr.Row():
        match_btn = gr.Button(label="Match", name="run")
        match_btn.click(fn=run_match, inputs=[source_text, dest_text], outputs=matches)

    demo.launch()