File size: 1,877 Bytes
609199b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd4dded
609199b
cd4dded
 
609199b
 
 
 
 
 
 
 
 
 
 
cd4dded
609199b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel

def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def get_similarity_scores(queries: list, passages: list, model, tokenizer):
    tokenizer.add_eos_token = True

    max_length = 4096
    input_texts = queries + passages
    batch_dict = tokenizer(input_texts, max_length=max_length - 1, padding=True, truncation=True, return_tensors="pt")
    outputs = model(**batch_dict)
    embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

    embeddings = F.normalize(embeddings, p=2, dim=1)
    scores = (embeddings[:len(queries)] @ embeddings[len(queries):].T) * 100
    return scores.tolist()

def similarity_ui(keyNames:list, fields:list):
    task = 'Given a keyName, find similarity score against provided fields'
    queries = keyNames
    passages = fields

    scores = get_similarity_scores(queries, passages, model, tokenizer)
    return {'Similarity Scores': scores}

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('Salesforce/SFR-Embedding-Mistral')
model = AutoModel.from_pretrained('Salesforce/SFR-Embedding-Mistral')

# Create Gradio Interface
gr.Interface(
    fn=similarity_ui, 
    inputs="text", "text", 
    outputs="text",
    title="Similarity Score Calculator",
    description="Enter a Key Name and 3 Fields to find similarity scores"
).launch()