Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModel | |
from sklearn.metrics.pairwise import cosine_similarity | |
import torch | |
import numpy as np | |
from gradio_client import Client | |
from functools import lru_cache | |
# Cache the model and tokenizer using lru_cache | |
def load_model_and_tokenizer(): | |
model_name = "./all-MiniLM-L6-v2" # Replace with your Space and model path | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
return tokenizer, model | |
# Load the model and tokenizer | |
tokenizer, model = load_model_and_tokenizer() | |
# Precompute label embeddings | |
labels = [ | |
"aerospace", "anatomy", "anthropology", "art", | |
"automotive", "blockchain", "biology", "chemistry", | |
"cryptocurrency", "data science", "design", "e-commerce", | |
"education", "engineering", "entertainment", "environment", | |
"fashion", "finance", "food commerce", "gaming", | |
"healthcare", "history", "information technology", | |
"legal", "machine learning", "marketing", "medicine", | |
"music", "philosophy", "physics", "politics", "real estate", "retail", | |
"robotics", "social media", "sports", "technical", | |
"tourism", "travel" | |
] | |
tones = [ | |
"formal", "positive", "negative", "poetic", "polite", "subtle", "casual", "neutral", | |
"informal", "pompous", "sustained", "rude", "sustained", | |
"sophisticated", "playful", "serious", "friendly" | |
] | |
styles = [ | |
"poetry", "novel", "theater", "slang", "speech", "keywords", "html", "programming" | |
] | |
gender_number = [ | |
"masculine singular", "masculine plural", "feminine singular", "feminine plural" | |
] | |
def precompute_label_embeddings(): | |
inputs = tokenizer(labels, padding=True, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
return outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embeddings | |
label_embeddings = precompute_label_embeddings() | |
# Softmax function to convert scores to probabilities | |
def softmax(x): | |
exp_x = np.exp(x - np.max(x)) # Subtract max for numerical stability | |
return exp_x / exp_x.sum() | |
# Function to detect context | |
def detect_context(input_text, threshold=0.03): | |
# Encode the input text | |
inputs = tokenizer([input_text], padding=True, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
input_embedding = outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embedding | |
# Compute cosine similarities | |
similarities = cosine_similarity(input_embedding, label_embeddings)[0] | |
# Apply softmax to convert similarities to probabilities | |
probabilities = softmax(similarities) | |
# Pair each label with its probability | |
label_probabilities = list(zip(labels, probabilities)) | |
# Filter contexts with confidence >= threshold | |
high_confidence_contexts = [(label, score) for label, score in label_probabilities if score >= threshold] | |
# If no contexts meet the threshold, default to "general" | |
if not high_confidence_contexts: | |
high_confidence_contexts = [("general", 1.0)] # Assign a default score of 1.0 for "general" | |
return high_confidence_contexts | |
# Mock translation clients for different contexts | |
def get_translation_client(context): | |
""" | |
Returns the appropriate Hugging Face Space client for the given context. | |
For now, all contexts use the same mock space. | |
""" | |
return Client("Frenchizer/space_18") # Replace with actual Space paths for each context | |
def translate_text(input_text, context): | |
""" | |
Translates the input text using the appropriate model for the given context. | |
""" | |
client = get_translation_client(context) | |
return client.predict(input_text) | |
def process_request(input_text): | |
# Step 1: Detect context | |
context_results = detect_context(input_text) | |
# Step 2: Translate the text for each context | |
translations = {} | |
for context, score in context_results: | |
translations[context] = translate_text(input_text, context) | |
# Step 3: Print the list of high-confidence contexts and translations | |
print("High-confidence contexts (score >= 0.022):", context_results) | |
print("Translations:", translations) | |
# Return the translations and contexts | |
return translations, context_results | |
def gradio_interface(input_text): | |
# Call process_request to get translations and context_results | |
translations, contexts = process_request(input_text) | |
# Extract only the translation values from the dictionary | |
translation_values = list(translations.values()) | |
# Join the translations into a single string with line breaks | |
output = "\n".join(translation_values) | |
return output.strip() | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs="text", | |
outputs="text", | |
title="Frenchizer", | |
description="Translate text from English to French with context detection and threshold." | |
) | |
interface.launch() |