Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
# Load swissBERT for sentence embeddings model | |
model_name = "jgrosjean-mathesis/sentence-swissbert" | |
model = AutoModel.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def generate_sentence_embedding(sentence, language): | |
if "de" in language: | |
model.set_default_language("de_CH") | |
if "fr" in language: | |
model.set_default_language("fr_CH") | |
if "it" in language: | |
model.set_default_language("it_CH") | |
if "rm" in language: | |
model.set_default_language("rm_CH") | |
inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt", max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
token_embeddings = outputs.last_hidden_state | |
attention_mask = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float() | |
sum_embeddings = torch.sum(token_embeddings * attention_mask, 1) | |
sum_mask = torch.clamp(attention_mask.sum(1), min=1e-9) | |
embedding = sum_embeddings / sum_mask | |
return embedding | |
def calculate_cosine_similarities(source_sentence, source_language, target_sentence_1, target_language_1, target_sentence_2, target_language_2, target_sentence_3, target_language_3): | |
source_embedding = generate_sentence_embedding(source_sentence, source_language) | |
target_embedding_1 = generate_sentence_embedding(target_sentence_1, target_language_1) | |
target_embedding_2 = generate_sentence_embedding(target_sentence_2, target_language_2) | |
target_embedding_3 = generate_sentence_embedding(target_sentence_3, target_language_3) | |
cosine_score_1 = cosine_similarity(source_embedding, target_embedding_1) | |
cosine_score_2 = cosine_similarity(source_embedding, target_embedding_2) | |
cosine_score_3 = cosine_similarity(source_embedding, target_embedding_3) | |
cosine_scores = { | |
target_sentence_1: cosine_score_1[0][0], | |
target_sentence_2: cosine_score_2[0][0], | |
target_sentence_3: cosine_score_3[0][0] | |
} | |
cosine_scores_dict = dict(sorted(cosine_scores.items(), key=lambda item: item[1], reverse=True)) | |
cosine_scores_output = "" | |
for key, value in cosine_scores_dict.items(): | |
cosine_scores_output += key + ": " + str(value) + "\n" | |
cosine_scores_output = "**" + cosine_scores_output.replace("\n", "**\n", 1) | |
return cosine_scores_output | |
def main(): | |
demo = gr.Interface( | |
fn=calculate_cosine_similarities, | |
inputs=[ | |
gr.Textbox(lines=1, placeholder="Enter source sentence", label="Source Sentence"), | |
gr.Dropdown(["de", "fr", "it", "rm"], label="Source Language"), | |
gr.Textbox(lines=1, placeholder="Enter target sentence 1", label="Target Sentence 1"), | |
gr.Dropdown(["de", "fr", "it", "rm"], label="Target Language 1"), | |
gr.Textbox(lines=1, placeholder="Enter target sentence 2", label="Target Sentence 2"), | |
gr.Dropdown(["de", "fr", "it", "rm"], label="Target Language 2"), | |
gr.Textbox(lines=1, placeholder="Enter target sentence 3", label="Target Sentence 3"), | |
gr.Dropdown(["de", "fr", "it", "rm"], label="Target Language 3") | |
], | |
outputs= gr.Textbox(label="Cosine Similarity Scores", type="text", lines=3), | |
title="Sentence Similarity Calculator", | |
description="Enter a source sentence and up to three target sentences to calculate their cosine similarity.", | |
examples=[ | |
["Der Zug fährt um 9 Uhr in Zürich ab.", "de", "Le train arrive à Lausanne à 11 heures.", "fr", "Alla stazione di Lugano ci sono diversi binari.", "it", "A Cuera van biars trens ellas muntognas.", "rm"] | |
] | |
) | |
demo.launch(share=True) | |
if __name__ == "__main__": | |
main() | |