T5-model / app.py
MominRuaf's picture
Upload 2 files
35208ed verified
#!/usr/bin/env python3
"""
T5 Detoxification API for Hugging Face Spaces
FastAPI service that can be called from external WebSocket servers
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import logging
import time
import os
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="T5 Detoxification API", version="1.0.0")
class TextRequest(BaseModel):
text: str
max_length: int = 256
class TextResponse(BaseModel):
original_text: str
detoxified_text: str
processing_time: float
device: str
class T5Service:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.loaded = False
self.load_model()
def load_model(self):
"""Load T5 detoxification model"""
try:
logger.info(f"Loading T5 model on {self.device}...")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained('s-nlp/t5-paranmt-detox')
logger.info("Tokenizer loaded")
# Load model with optimization
self.model = AutoModelForSeq2SeqLM.from_pretrained(
's-nlp/t5-paranmt-detox',
torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
low_cpu_mem_usage=True
)
# Move to device and optimize
self.model = self.model.to(self.device)
self.model.eval()
# Try torch.compile for better performance
try:
if torch.__version__.startswith("2"):
self.model = torch.compile(self.model, mode="reduce-overhead")
logger.info("Model compiled with torch.compile()")
except Exception as e:
logger.warning(f"torch.compile failed: {e}")
self.loaded = True
logger.info(f"T5 model loaded successfully on {self.device}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
self.loaded = False
def detoxify_text(self, text: str, max_length: int = 256) -> str:
"""Detoxify text using T5 model"""
if not self.loaded or not text.strip():
return text
try:
# Tokenize
inputs = self.tokenizer(
text.strip(),
return_tensors="pt",
truncation=True,
max_length=max_length
)
inputs = inputs.to(self.device)
# Generate detoxified text
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=1,
do_sample=False,
early_stopping=True
)
# Decode
detoxified = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True
).strip()
return detoxified if detoxified else text
except Exception as e:
logger.error(f"Error in detoxification: {e}")
return text
# Initialize the service
t5_service = T5Service()
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"message": "T5 Detoxification API",
"status": "running",
"model_loaded": t5_service.loaded,
"device": str(t5_service.device)
}
@app.get("/health")
async def health_check():
"""Detailed health check"""
return {
"status": "healthy" if t5_service.loaded else "unhealthy",
"model_loaded": t5_service.loaded,
"device": str(t5_service.device),
"timestamp": time.time()
}
@app.post("/detoxify", response_model=TextResponse)
async def detoxify_text(request: TextRequest):
"""Detoxify text using T5 model"""
if not request.text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
if not t5_service.loaded:
raise HTTPException(status_code=503, detail="T5 model not loaded")
start_time = time.time()
try:
detoxified_text = t5_service.detoxify_text(
request.text,
request.max_length
)
processing_time = time.time() - start_time
return TextResponse(
original_text=request.text,
detoxified_text=detoxified_text,
processing_time=round(processing_time, 3),
device=str(t5_service.device)
)
except Exception as e:
logger.error(f"Error processing request: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@app.get("/status")
async def get_status():
"""Get service status"""
return {
"model_loaded": t5_service.loaded,
"device": str(t5_service.device),
"uptime": time.time()
}
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)