File size: 11,428 Bytes
8f6035e
 
 
 
385c295
 
d74ddfc
b3758b8
5fda074
 
 
163f1eb
8f6035e
5fda074
 
8f6035e
5fda074
 
 
3400476
ad0d74d
5171d49
 
 
 
 
 
 
 
 
 
 
 
 
 
5fda074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5171d49
 
5fda074
 
385c295
5fda074
 
 
 
 
385c295
 
 
5fda074
 
e71614a
8f6035e
 
 
 
 
 
 
 
ace4204
5fda074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385c295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ace4204
 
 
 
 
 
5fda074
385c295
ace4204
cf7c00e
fcbecda
 
 
ace4204
fcbecda
ace4204
 
 
385c295
ace4204
 
 
 
 
 
 
 
5fda074
ace4204
fcbecda
 
 
 
cf7c00e
fcbecda
 
 
cf7c00e
426a66e
385c295
4cd22b7
5fda074
fcbecda
ace4204
 
 
385c295
ace4204
200153d
8f6035e
 
 
 
5fda074
cf7c00e
 
 
 
 
 
 
 
 
36cbf30
5fda074
cf7c00e
ace4204
 
cf7c00e
 
e71614a
8f6035e
3400476
5fda074
 
54f1ed7
 
b2157fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
import threading
import queue
import gradio as gr
import os
import json
import numpy as np


title = """
# 👋🏻Welcome to 🙋🏻‍♂️Tonic's 📽️Nvidia 🛌🏻Embed V-1 !"""

description = """
You can use this Space to test out the current model [nvidia/NV-Embed-v1](https://huggingface.co/nvidia/NV-Embed-v1). 🐣a generalist embedding model that ranks No. 1 on the Massive Text Embedding Benchmark (MTEB benchmark)(as of May 24, 2024), with 56 tasks, encompassing retrieval, reranking, classification, clustering, and semantic textual similarity tasks.
You can also use 📽️Nvidia 🛌🏻Embed V-1 by cloning this space. 🧬🔬🔍 Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/NV-Embed?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3> 
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻  [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to 🌟 [MultiTonic](https://github.com/MultiTonic) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""

tasks = {
        'ArguAna': 'Given a claim, find documents that refute the claim',
        'ClimateFEVER': 'Given a claim about climate change, retrieve documents that support or refute the claim',
        'DBPedia': 'Given a query, retrieve relevant entity descriptions from DBPedia',
        'FEVER': 'Given a claim, retrieve documents that support or refute the claim',
        'FiQA2018': 'Given a financial question, retrieve user replies that best answer the question',
        'HotpotQA': 'Given a multi-hop question, retrieve documents that can help answer the question',
        'MSMARCO': 'Given a web search query, retrieve relevant passages that answer the query',
        'NFCorpus': 'Given a question, retrieve relevant documents that best answer the question',
        'NQ': 'Given a question, retrieve Wikipedia passages that answer the question',
        'QuoraRetrieval': 'Given a question, retrieve questions that are semantically equivalent to the given question',
        'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper',
        'SciFact': 'Given a scientific claim, retrieve documents that support or refute the claim',
        'Touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question',
        'Natural Language Inference' : 'Retrieve semantically similar text',
        'Natural Language Inference' : 'Given a premise, retrieve a hypothesis that is entailed by the premise 20k',
        'PAQ, MSMARCO' : 'Given a web search query, retrieve relevant passages that answer the query',
        'PAQ, MSMARCO' : 'Given a question, retrieve passages that answer the question',
        'SQUAD' : 'Given a question, retrieve Wikipedia passages that answer the question' ,
        'StackExchange' : 'Given a question paragraph at StackExchange, retrieve a question duplicated paragraph',
        'Natural Question' : 'Given a question, retrieve Wikipedia passages that answer the question',
        'BioASQ' : 'Given a question, retrieve detailed question descriptions that are duplicates to the given question',
        'STS12, STS22, STSBenchmark' : 'Retrieve semantically similar text.',
        'AmazonCounterfactualClassification' : 'Classify a given Amazon customer review text as either counterfactual or not-counterfactual' , 
        'AmazonReviewsClassification' : 'Classify the given Amazon review into its appropriate rating category' , 
        'Banking77Classification' : 'Given a online banking query, find the corresponding intents',
        'EmotionClassification' : 'Classify the emotion expressed in the given Twitter message into one of the six emotions:anger, fear, joy, love, sadness, and surprise',
        'ImdbClassification': 'Classify the sentiment expressed in the given movie review text from the IMDB dataset',
        'MTOPIntentClassification' : 'Classify the intent of the given utterance in task-oriented conversation',
        'ToxicConversationsClassification' : 'Classify the given comments as either toxic or not toxic',
        'TweetSentimentExtractionClassification' : 'Classify the sentiment of a given tweet as either positive, negative, or neutral',
        'ArxivClusteringP2P' : 'Identify the main and secondary category of Arxiv papers based on the titles and abstracts',
        'ArxivClusteringS2S' : 'Identify the main and secondary category of Arxiv papers based on the titles',
        'BiorxivClusteringP2P' : 'Identify the main category of Biorxiv papers based on the titles and abstracts' ,
        'BiorxivClusteringS2S' : 'Identify the main category of Biorxiv papers based on the titles',
        'MedrxivClusteringP2P' : 'Identify the main category of Medrxiv papers based on the titles and abstracts',
        'MedrxivClusteringS2S' : 'Identify the main category of Medrxiv papers based on the titles',
        'TwentyNewsgroupsClustering' : 'Identify the topic or theme of the given news articles'
}

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model and tokenizer globally
tokenizer = AutoTokenizer.from_pretrained('nvidia/NV-Embed-v1', trust_remote_code=True)
model = AutoModel.from_pretrained('nvidia/NV-Embed-v1', trust_remote_code=True).to(device)

# Embedding requests and response queues
embedding_request_queue = queue.Queue()
embedding_response_queue = queue.Queue()

def clear_cuda_cache():
    torch.cuda.empty_cache()

def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def format_response(embeddings):
    return {
        "data": [
            {
                "embedding": embeddings,
                "index": 0,
                "object": "embedding"
            }
        ],
        "model": "e5-mistral",
        "object": "list",
        "usage": {
            "prompt_tokens": 17,
            "total_tokens": 17
        }
    }

def embedding_worker():
    while True:
        # Wait for an item in the queue
        item = embedding_request_queue.get()
        if item is None:
            break
        selected_task, input_text = item
        embeddings = compute_embeddings(selected_task, input_text)
        formatted_response = format_response(embeddings)

        embedding_response_queue.put(formatted_response)
        embedding_request_queue.task_done()
        clear_cuda_cache()

def compute_embeddings(selected_task, input_text):
    try:
        task_description = tasks[selected_task]
    except KeyError:
        print(f"Selected task not found: {selected_task}")
        return f"Error: Task '{selected_task}' not found. Please select a valid task."
    
    max_length = 2048
    processed_texts = [f'Instruct: {task_description}\nQuery: {input_text}']
    
    batch_dict = tokenizer(processed_texts, max_length=max_length - 1, return_attention_mask=False, padding=False, truncation=True)
    batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
    batch_dict = tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt')
    batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
    outputs = model(**batch_dict)
    embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
    embeddings = F.normalize(embeddings, p=2, dim=1)
    embeddings_list = embeddings.detach().cpu().numpy().tolist()
    clear_cuda_cache()
    return embeddings_list

def compute_similarity(selected_task, sentence1, sentence2, extra_sentence1, extra_sentence2):
    try:
        task_description = tasks[selected_task]
    except KeyError:
        print(f"Selected task not found: {selected_task}")
        return f"Error: Task '{selected_task}' not found. Please select a valid task."
    
    # Compute embeddings for each sentence
    embeddings1 = compute_embeddings(selected_task, sentence1)
    embeddings2 = compute_embeddings(selected_task, sentence2)
    embeddings3 = compute_embeddings(selected_task, extra_sentence1)
    embeddings4 = compute_embeddings(selected_task, extra_sentence2)
    
    similarity1 = compute_cosine_similarity(embeddings1, embeddings2)
    similarity2 = compute_cosine_similarity(embeddings1, embeddings3)
    similarity3 = compute_cosine_similarity(embeddings1, embeddings4)

    similarity_scores = {"Similarity 1-2": similarity1, "Similarity 1-3": similarity2, "Similarity 1-4": similarity3}
    clear_cuda_cache()
    return similarity_scores

def compute_cosine_similarity(emb1, emb2):
    tensor1 = torch.tensor(emb1).to(device).half()
    tensor2 = torch.tensor(emb2).to(device).half()
    similarity = F.cosine_similarity(tensor1, tensor2).item()
    clear_cuda_cache()
    return similarity

def app_interface():
    with gr.Blocks() as demo:
        gr.Markdown(title)
        gr.Markdown(description)

        with gr.Row():
            task_dropdown = gr.Dropdown(list(tasks.keys()), label="Select a Task", value=list(tasks.keys())[0])

        with gr.Tab("Sentence Similarity"):
            sentence1_box = gr.Textbox(label="'Focus Sentence' - The 'Subject'")
            sentence2_box = gr.Textbox(label="'Input Sentence' - 1")
            extra_sentence1_box = gr.Textbox(label="'Input Sentence' - 2")
            extra_sentence2_box = gr.Textbox(label="'Input Sentence' - 3")
            similarity_button = gr.Button("Compute Similarity")
            similarity_output = gr.Textbox(label="🐣e5-mistral🛌🏻 Similarity Scores")

            similarity_button.click(
                fn=compute_similarity,
                inputs=[task_dropdown, sentence1_box, sentence2_box, extra_sentence1_box, extra_sentence2_box],
                outputs=similarity_output
            )

    return demo

embedding_worker_thread = threading.Thread(target=embedding_worker, daemon=True)
embedding_worker_thread.start()

app_interface().queue()
app_interface().launch(share=True)