translation-app / app.py
Avanish3412's picture
Update app.py
54c896d verified
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field
import torch
import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from fastapi.middleware.cors import CORSMiddleware
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
from functools import lru_cache
import time
from typing import List, Optional
import os
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="English to Telugu Translation API",
description="Ultra-high-performance translation service",
version="2.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables
translator = None
device = None
executor = None
# Pydantic models
class TranslationRequest(BaseModel):
text: str = Field(..., max_length=5000, min_length=1)
batch_size: Optional[int] = Field(default=32, ge=1, le=64)
max_length: Optional[int] = Field(default=128, ge=1, le=256)
class TranslationResponse(BaseModel):
original_text: str
translated_text: str
processing_time: float
model_used: str
# Ultra-fast sentence splitting without external dependencies
@lru_cache(maxsize=10000)
def ultra_fast_sentence_split(text: str) -> tuple:
"""Lightning-fast sentence splitting using optimized regex"""
# Optimized regex for sentence boundaries
sentence_pattern = r'(?<=[.!?])\s+(?=[A-Z])|(?<=[.!?])\s*\n+\s*'
sentences = re.split(sentence_pattern, text.strip())
# Clean and filter sentences
clean_sentences = []
for sentence in sentences:
sentence = sentence.strip()
if sentence and len(sentence) > 1:
# Ensure sentence ends with punctuation
if not sentence.endswith(('.', '!', '?', ':', ';')):
sentence += '.'
clean_sentences.append(sentence)
return tuple(clean_sentences)
class LightningFastTranslator:
def __init__(self, model, tokenizer, device):
self.model = model
self.tokenizer = tokenizer
self.device = device
# Move model to device and optimize
self.model = self.model.to(device).eval()
# GPU optimizations
if device.type == 'cuda':
try:
# Half precision for 2x speed boost
self.model = self.model.half()
logger.info("βœ… Model converted to FP16 (half precision)")
except Exception as e:
logger.warning(f"Half precision failed: {e}")
try:
# Compile model for even faster inference (PyTorch 2.0+)
self.model = torch.compile(self.model, mode="max-autotune")
logger.info("βœ… Model compiled with torch.compile")
except Exception as e:
logger.info(f"Torch compile not available: {e}")
# Pre-allocate common tensor sizes to avoid memory allocation overhead
self._warmup_cache()
def _warmup_cache(self):
"""Pre-warm the model with common input sizes"""
warmup_texts = ["Hello world.", "This is a test sentence for warmup."]
try:
_ = self.translate_batch_lightning(warmup_texts, max_length=64)
logger.info("βœ… Model warmed up successfully")
except Exception as e:
logger.warning(f"Warmup failed: {e}")
def translate_batch_lightning(self, sentences: List[str], max_length: int = 128) -> List[str]:
"""Lightning-fast batch translation with maximum optimizations"""
if not sentences:
return []
# Pre-filter valid sentences
valid_sentences = []
sentence_map = {} # Map valid sentences back to original positions
for i, sentence in enumerate(sentences):
sentence = sentence.strip()
if sentence and len(sentence) > 1:
valid_sentences.append(sentence)
sentence_map[len(valid_sentences) - 1] = i
if not valid_sentences:
return [""] * len(sentences)
try:
# Ultra-fast tokenization with aggressive settings
inputs = self.tokenizer(
valid_sentences,
return_tensors="pt",
padding=True,
truncation=True,
max_length=min(max_length, 64), # Very aggressive truncation
return_attention_mask=True
)
# Move to device efficiently
inputs = {k: v.to(self.device, non_blocking=True) for k, v in inputs.items()}
# Lightning-fast generation with minimal settings
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=self.device.type == 'cuda'):
outputs = self.model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=min(max_length, 64), # Very short outputs
min_length=5,
num_beams=1, # Greedy search (fastest)
do_sample=False, # No sampling
early_stopping=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
use_cache=True,
num_return_sequences=1
)
# Fast batch decoding
translations = self.tokenizer.batch_decode(
outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=False # Skip cleanup for speed
)
# Map results back to original positions
result = [""] * len(sentences)
for valid_idx, translation in enumerate(translations):
original_idx = sentence_map.get(valid_idx, valid_idx)
if original_idx < len(result):
result[original_idx] = translation.strip()
return result
except Exception as e:
logger.error(f"Lightning translation error: {e}")
# Fast fallback - return original sentences
return sentences
@app.on_event("startup")
async def load_models():
global translator, device, executor
start_time = time.time()
logger.info("πŸš€ Loading ultra-fast translation models...")
# Detect best device
if torch.cuda.is_available():
device = torch.device("cuda")
logger.info(f"πŸ“± Using GPU: {torch.cuda.get_device_name()}")
else:
device = torch.device("cpu")
logger.info("πŸ“± Using CPU (consider GPU for better performance)")
# High-performance thread pool
executor = ThreadPoolExecutor(max_workers=6, thread_name_prefix="translation")
try:
model_name = "aryaumesh/english-to-telugu"
logger.info(f"πŸ“¦ Loading model: {model_name}")
# Load with all optimizations
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=True, # Use fast tokenizer
model_max_length=128 # Limit context length
)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device.type == 'cuda' else torch.float32,
low_cpu_mem_usage=True,
device_map=None # We'll move manually for better control
)
translator = LightningFastTranslator(model, tokenizer, device)
load_time = time.time() - start_time
logger.info(f"βœ… Models loaded and optimized in {load_time:.2f} seconds")
except Exception as e:
logger.error(f"❌ Error loading models: {e}")
raise
@app.on_event("shutdown")
async def shutdown_event():
global executor
logger.info("πŸ”„ Shutting down...")
if executor:
executor.shutdown(wait=True)
logger.info("βœ… Shutdown complete")
def process_translation_lightning(text: str, batch_size: int = 32, max_length: int = 128) -> tuple:
"""Lightning-fast translation processing"""
start_time = time.perf_counter() # More precise timing
# Fast text preprocessing
if not text.strip():
return "", 0.0
# Split by lines first for structure preservation
lines = text.split('\n')
translated_lines = []
for line in lines:
line = line.strip()
if not line:
translated_lines.append("")
continue
# Ultra-fast sentence splitting
sentences = list(ultra_fast_sentence_split(line))
if not sentences:
translated_lines.append("")
continue
# Process in large batches for maximum GPU utilization
translated_sentences = []
for i in range(0, len(sentences), batch_size):
batch = sentences[i:i + batch_size]
batch_translations = translator.translate_batch_lightning(batch, max_length)
translated_sentences.extend(batch_translations)
# Join sentences back
translated_line = " ".join(filter(None, translated_sentences))
translated_lines.append(translated_line)
full_translation = "\n".join(translated_lines)
processing_time = time.perf_counter() - start_time
return full_translation, processing_time
@app.post("/translate/", response_model=TranslationResponse)
async def translate_text(request: TranslationRequest):
"""Lightning-fast translation endpoint"""
if not request.text.strip():
return TranslationResponse(
original_text=request.text,
translated_text="",
processing_time=0.0,
model_used="none"
)
try:
loop = asyncio.get_event_loop()
translation, processing_time = await loop.run_in_executor(
executor,
process_translation_lightning,
request.text,
min(request.batch_size or 32, 64), # Cap batch size
min(request.max_length or 128, 128) # Cap length for speed
)
return TranslationResponse(
original_text=request.text,
translated_text=translation,
processing_time=processing_time,
model_used="aryaumesh/english-to-telugu-lightning"
)
except Exception as e:
logger.error(f"Translation error: {e}")
return TranslationResponse(
original_text=request.text,
translated_text=f"Translation error: {str(e)}",
processing_time=0.0,
model_used="error"
)
@app.get("/", response_class=HTMLResponse)
async def read_root():
"""Root endpoint with lightning-fast interface"""
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>⚑ Lightning-Fast English to Telugu Translation</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #ff9a9e 0%, #fecfef 50%, #fecfef 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 900px;
margin: 0 auto;
background: white;
border-radius: 20px;
box-shadow: 0 25px 50px rgba(0,0,0,0.15);
overflow: hidden;
}
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 40px;
text-align: center;
position: relative;
}
.header::before {
content: '⚑';
position: absolute;
top: 20px;
left: 30px;
font-size: 3em;
opacity: 0.3;
}
.header h1 { font-size: 2.8em; margin-bottom: 15px; font-weight: 700; }
.header p { font-size: 1.2em; opacity: 0.95; margin-bottom: 15px; }
.speed-badge {
display: inline-block;
background: rgba(255,255,255,0.25);
padding: 8px 20px;
border-radius: 25px;
font-weight: bold;
font-size: 1.1em;
backdrop-filter: blur(10px);
}
.content { padding: 50px; }
.form-group { margin-bottom: 30px; }
label {
display: block;
margin-bottom: 12px;
font-weight: 700;
color: #333;
font-size: 1.2em;
}
textarea {
width: 100%;
height: 150px;
padding: 20px;
border: 3px solid #e0e0e0;
border-radius: 15px;
font-size: 16px;
font-family: inherit;
resize: vertical;
transition: all 0.3s ease;
background: #fafafa;
}
textarea:focus {
outline: none;
border-color: #667eea;
box-shadow: 0 0 0 4px rgba(102, 126, 234, 0.15);
background: white;
}
.controls {
display: flex;
gap: 20px;
align-items: center;
flex-wrap: wrap;
margin-bottom: 30px;
padding: 25px;
background: #f8f9fa;
border-radius: 15px;
}
.control-group {
display: flex;
flex-direction: column;
gap: 8px;
}
.control-group label {
font-size: 1em;
margin-bottom: 0;
font-weight: 600;
}
.control-group input {
padding: 12px 16px;
border: 2px solid #e0e0e0;
border-radius: 10px;
width: 120px;
font-size: 16px;
}
button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 18px 35px;
border: none;
border-radius: 12px;
cursor: pointer;
font-size: 1.2em;
font-weight: 700;
transition: all 0.3s ease;
min-width: 200px;
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.3);
}
button:hover {
transform: translateY(-3px);
box-shadow: 0 15px 30px rgba(102, 126, 234, 0.4);
}
button:disabled {
opacity: 0.7;
cursor: not-allowed;
transform: none;
}
.result {
margin-top: 40px;
padding: 30px;
background: linear-gradient(135deg, #f8f9ff 0%, #e8f4fd 100%);
border-radius: 15px;
border: 2px solid #667eea;
}
.result h3 {
color: #333;
margin-bottom: 20px;
font-size: 1.5em;
font-weight: 700;
}
.translated-text {
background: white;
padding: 25px;
border-radius: 12px;
border: 1px solid #e0e0e0;
font-size: 1.15em;
line-height: 1.7;
color: #333;
white-space: pre-wrap;
box-shadow: 0 2px 10px rgba(0,0,0,0.05);
}
.stats {
margin-top: 20px;
display: flex;
gap: 25px;
flex-wrap: wrap;
}
.stat {
background: white;
padding: 15px 20px;
border-radius: 10px;
border: 1px solid #e0e0e0;
font-size: 1em;
box-shadow: 0 2px 8px rgba(0,0,0,0.05);
min-width: 150px;
}
.stat strong { color: #667eea; font-size: 1.1em; }
.loading {
display: none;
text-align: center;
margin: 30px 0;
}
.spinner {
display: inline-block;
width: 50px;
height: 50px;
border: 5px solid #f3f3f3;
border-top: 5px solid #667eea;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.speed-indicator {
font-weight: bold;
font-size: 1.1em;
}
.ultra-fast { color: #27ae60; }
.fast { color: #f39c12; }
.slow { color: #e74c3c; }
.performance-tip {
background: #e8f5e8;
border: 2px solid #27ae60;
border-radius: 10px;
padding: 15px;
margin-bottom: 20px;
font-size: 0.95em;
}
.performance-tip strong { color: #27ae60; }
@media (max-width: 768px) {
.container { margin: 10px; }
.header { padding: 30px 20px; }
.header h1 { font-size: 2.2em; }
.content { padding: 30px 20px; }
.controls { flex-direction: column; align-items: stretch; }
.control-group input { width: 100%; }
.stats { flex-direction: column; }
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>⚑ Lightning-Fast Translation</h1>
<p>English to Telugu β€’ Optimized for Maximum Speed</p>
<div class="speed-badge">🎯 Target: &lt;0.5 seconds</div>
</div>
<div class="content">
<div class="performance-tip">
<strong>πŸ’‘ Pro Tip:</strong> For ultra-fast results, keep sentences under 20 words and use batch size 32+
</div>
<div class="form-group">
<label for="inputText">πŸ“ Enter English Text:</label>
<textarea id="inputText" placeholder="Enter your English text here...&#10;&#10;Optimized for lightning-fast processing!&#10;Shorter sentences = faster results"></textarea>
</div>
<div class="controls">
<div class="control-group">
<label>Batch Size:</label>
<input type="number" id="batchSize" value="32" min="1" max="64">
</div>
<div class="control-group">
<label>Max Length:</label>
<input type="number" id="maxLength" value="128" min="1" max="256">
</div>
<button onclick="translateText()">⚑ Lightning Translate</button>
</div>
<div class="loading" id="loading">
<div class="spinner"></div>
<p style="margin-top: 15px; font-size: 1.1em;">Processing at lightning speed...</p>
</div>
<div id="result" class="result" style="display: none;">
<h3>πŸ“– Translation Result:</h3>
<div id="translatedText" class="translated-text"></div>
<div class="stats">
<div class="stat">
<strong>Processing Time:</strong><br>
<span id="processingTime" class="speed-indicator">-</span> seconds
</div>
<div class="stat">
<strong>Model Used:</strong><br>
<span id="modelUsed">-</span>
</div>
<div class="stat">
<strong>Input Length:</strong><br>
<span id="charCount">-</span> characters
</div>
</div>
</div>
</div>
</div>
<script>
async function translateText() {
const inputText = document.getElementById('inputText').value;
const batchSize = parseInt(document.getElementById('batchSize').value) || 32;
const maxLength = parseInt(document.getElementById('maxLength').value) || 128;
if (!inputText.trim()) {
alert('⚠️ Please enter some text to translate');
return;
}
const button = document.querySelector('button');
const loading = document.getElementById('loading');
const result = document.getElementById('result');
// Show loading state
button.textContent = '⚑ Processing...';
button.disabled = true;
loading.style.display = 'block';
result.style.display = 'none';
const startTime = performance.now();
try {
const response = await fetch('/translate/', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
text: inputText,
batch_size: batchSize,
max_length: maxLength
})
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const data = await response.json();
// Display results
document.getElementById('translatedText').textContent = data.translated_text;
const processingTimeElement = document.getElementById('processingTime');
const time = data.processing_time;
processingTimeElement.textContent = time.toFixed(3);
// Color code based on speed
processingTimeElement.className = 'speed-indicator ';
if (time < 0.5) {
processingTimeElement.className += 'ultra-fast';
processingTimeElement.textContent += ' ⚑';
} else if (time < 2) {
processingTimeElement.className += 'fast';
processingTimeElement.textContent += ' πŸš€';
} else {
processingTimeElement.className += 'slow';
processingTimeElement.textContent += ' 🐌';
}
document.getElementById('modelUsed').textContent = data.model_used;
document.getElementById('charCount').textContent = data.original_text.length;
loading.style.display = 'none';
result.style.display = 'block';
// Smooth scroll to results
result.scrollIntoView({ behavior: 'smooth' });
} catch (error) {
console.error('Translation error:', error);
document.getElementById('translatedText').textContent = 'Error: ' + error.message;
document.getElementById('processingTime').textContent = '-';
document.getElementById('modelUsed').textContent = 'Error';
document.getElementById('charCount').textContent = '-';
loading.style.display = 'none';
result.style.display = 'block';
} finally {
button.textContent = '⚑ Lightning Translate';
button.disabled = false;
}
}
// Keyboard shortcuts
document.getElementById('inputText').addEventListener('keydown', function(e) {
if (e.key === 'Enter' && e.ctrlKey) {
e.preventDefault();
translateText();
}
});
// Auto-resize textarea
document.getElementById('inputText').addEventListener('input', function() {
this.style.height = 'auto';
this.style.height = Math.max(150, this.scrollHeight) + 'px';
});
</script>
</body>
</html>
"""
return HTMLResponse(content=html_content)
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"device": str(device) if device else "not_initialized",
"model_loaded": translator is not None,
"optimization_level": "lightning_fast",
"dependencies": "minimal (no spacy required)",
"timestamp": time.time()
}
@app.get("/api/info")
async def api_info():
"""API information endpoint"""
return {
"title": "Lightning-Fast English to Telugu Translation API",
"version": "2.0.0",
"model": "aryaumesh/english-to-telugu",
"optimizations": [
"regex_sentence_splitting",
"aggressive_caching",
"gpu_half_precision",
"torch_compile",
"greedy_decoding",
"large_batch_processing",
"minimal_dependencies"
],
"target_processing_time": "< 0.5 seconds",
"no_external_deps": True,
"endpoints": {
"translate": "/translate/",
"health": "/health",
"docs": "/docs"
}
}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(
app,
host="0.0.0.0",
port=port,
workers=1,
log_level="info"
)