space_17 / app.py
Frenchizer's picture
Update app.py
dcd32ae verified
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import torch
import numpy as np
from gradio_client import Client
from functools import lru_cache
# Cache the model and tokenizer using lru_cache
@lru_cache(maxsize=1)
def load_model_and_tokenizer():
model_name = "./all-MiniLM-L6-v2" # Replace with your Space and model path
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
return tokenizer, model
# Load the model and tokenizer
tokenizer, model = load_model_and_tokenizer()
# Precompute label embeddings
labels = [
"aerospace", "anatomy", "anthropology", "art",
"automotive", "blockchain", "biology", "chemistry",
"cryptocurrency", "data science", "design", "e-commerce",
"education", "engineering", "entertainment", "environment",
"fashion", "finance", "food commerce", "gaming",
"healthcare", "history", "information technology",
"legal", "machine learning", "marketing", "medicine",
"music", "philosophy", "physics", "politics", "real estate", "retail",
"robotics", "social media", "sports", "technical",
"tourism", "travel"
]
tones = [
"formal", "positive", "negative", "poetic", "polite", "subtle", "casual", "neutral",
"informal", "pompous", "sustained", "rude", "sustained",
"sophisticated", "playful", "serious", "friendly"
]
styles = [
"poetry", "novel", "theater", "slang", "speech", "keywords", "html", "programming"
]
gender_number = [
"masculine singular", "masculine plural", "feminine singular", "feminine plural"
]
@lru_cache(maxsize=1)
def precompute_label_embeddings():
inputs = tokenizer(labels, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embeddings
label_embeddings = precompute_label_embeddings()
# Softmax function to convert scores to probabilities
def softmax(x):
exp_x = np.exp(x - np.max(x)) # Subtract max for numerical stability
return exp_x / exp_x.sum()
# Function to detect context
def detect_context(input_text, threshold=0.03):
# Encode the input text
inputs = tokenizer([input_text], padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
input_embedding = outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embedding
# Compute cosine similarities
similarities = cosine_similarity(input_embedding, label_embeddings)[0]
# Apply softmax to convert similarities to probabilities
probabilities = softmax(similarities)
# Pair each label with its probability
label_probabilities = list(zip(labels, probabilities))
# Filter contexts with confidence >= threshold
high_confidence_contexts = [(label, score) for label, score in label_probabilities if score >= threshold]
# If no contexts meet the threshold, default to "general"
if not high_confidence_contexts:
high_confidence_contexts = [("general", 1.0)] # Assign a default score of 1.0 for "general"
return high_confidence_contexts
# Mock translation clients for different contexts
def get_translation_client(context):
"""
Returns the appropriate Hugging Face Space client for the given context.
For now, all contexts use the same mock space.
"""
return Client("Frenchizer/space_18") # Replace with actual Space paths for each context
def translate_text(input_text, context):
"""
Translates the input text using the appropriate model for the given context.
"""
client = get_translation_client(context)
return client.predict(input_text)
def process_request(input_text):
# Step 1: Detect context
context_results = detect_context(input_text)
# Step 2: Translate the text for each context
translations = {}
for context, score in context_results:
translations[context] = translate_text(input_text, context)
# Step 3: Print the list of high-confidence contexts and translations
print("High-confidence contexts (score >= 0.022):", context_results)
print("Translations:", translations)
# Return the translations and contexts
return translations, context_results
def gradio_interface(input_text):
# Call process_request to get translations and context_results
translations, contexts = process_request(input_text)
# Extract only the translation values from the dictionary
translation_values = list(translations.values())
# Join the translations into a single string with line breaks
output = "\n".join(translation_values)
return output.strip()
# Create the Gradio interface
interface = gr.Interface(
fn=gradio_interface,
inputs="text",
outputs="text",
title="Frenchizer",
description="Translate text from English to French with context detection and threshold."
)
interface.launch()