Spaces:
Running
Running
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import re | |
| import pandas as pd | |
| import warnings | |
| import os | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import uvicorn | |
| warnings.filterwarnings('ignore') | |
| class ArabicProfanityTester: | |
| def __init__(self, model_name='Speccco/arabic_profanity_filter'): | |
| """Initialize the tester with model from Hugging Face Hub""" | |
| print(f"🔄 Loading model from Hugging Face Hub: {model_name}...") | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| self.model.eval() | |
| print("✅ Model loaded successfully from Hugging Face Hub!") | |
| print(f"📊 Model configuration:") | |
| print(f" - Model type: {type(self.model).__name__}") | |
| print(f" - Number of labels: {self.model.config.num_labels}") | |
| print(f" - Max position embeddings: {self.model.config.max_position_embeddings}") | |
| except Exception as e: | |
| print(f"❌ Failed to load model from Hub: {e}") | |
| print("🔄 Falling back to base AraBERT model...") | |
| # Fallback to base model | |
| base_model = "aubmindlab/bert-base-arabertv02" | |
| self.tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| base_model, | |
| num_labels=2 | |
| ) | |
| self.model.eval() | |
| print("⚠️ Using base AraBERT model (not fine-tuned)") | |
| def preprocess_text(self, text): | |
| """Simple text preprocessing""" | |
| if pd.isna(text): | |
| return "" | |
| text = str(text) | |
| # Remove URLs, mentions, hashtags | |
| text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE) | |
| text = re.sub(r'@\w+|#\w+', '', text) | |
| # Remove emojis and other unicode symbols | |
| emoji_pattern = re.compile("[" | |
| u"\U0001F600-\U0001F64F" # emoticons | |
| u"\U0001F300-\U0001F5FF" # symbols & pictographs | |
| u"\U0001F680-\U0001F6FF" # transport & map symbols | |
| u"\U0001F1E0-\U0001F1FF" # flags (iOS) | |
| u"\U00002702-\U000027B0" # dingbats | |
| u"\U000024C2-\U0001F251" # enclosed characters | |
| u"\U0001F900-\U0001F9FF" # supplemental symbols | |
| u"\U0001FA00-\U0001FAFF" # extended symbols | |
| u"\u2600-\u26FF" # miscellaneous symbols | |
| u"\u2700-\u27BF" # dingbats | |
| u"\uFE00-\uFE0F" # variation selectors | |
| u"\u200D" # zero width joiner | |
| "]+", flags=re.UNICODE) | |
| text = emoji_pattern.sub(r'', text) | |
| # Remove English alphabets | |
| text = re.sub(r'[a-zA-Z]', '', text) | |
| # Remove extra whitespace | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| def check_bad_words(self, text): | |
| """Check if text contains explicit bad Arabic/Egyptian words""" | |
| bad_words = [ | |
| 'شرموطة', 'خرا', 'زفت', 'أمك', 'يلعن دينك', 'متناك', | |
| 'منيك', 'نايك', 'طيز', 'عرص', 'قواد', 'وسخة', 'كسك', | |
| 'يا دين أمي', 'ابن وسخة' | |
| ] | |
| text_lower = text.lower() | |
| found_words = [] | |
| for bad_word in bad_words: | |
| if bad_word.lower() in text_lower: | |
| found_words.append(bad_word) | |
| return len(found_words) > 0, found_words | |
| def predict(self, text, show_details=True): | |
| """Predict if text is offensive or not with bad words override""" | |
| # Preprocess text | |
| processed_text = self.preprocess_text(text) | |
| # Check for explicit bad words first | |
| has_bad_words, found_bad_words = self.check_bad_words(text) | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| processed_text, | |
| return_tensors='pt', | |
| truncation=True, | |
| max_length=256, | |
| padding=True | |
| ) | |
| # Get model prediction | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=-1) | |
| model_predicted_class = torch.argmax(probabilities, dim=-1).item() | |
| model_confidence = probabilities[0][model_predicted_class].item() | |
| # Final decision: bad words override model prediction | |
| if has_bad_words: | |
| final_prediction = "Bad" | |
| final_class = 1 # Offensive | |
| override_reason = f"Contains explicit bad words: {', '.join(found_bad_words)}" | |
| else: | |
| final_prediction = "Good" if model_predicted_class == 0 else "Bad" | |
| final_class = model_predicted_class | |
| override_reason = None | |
| # Prepare result | |
| result = { | |
| 'original_text': text, | |
| 'processed_text': processed_text, | |
| 'model_prediction': 'Offensive' if model_predicted_class == 1 else 'Non-Offensive', | |
| 'model_confidence': model_confidence, | |
| 'final_prediction': final_prediction, | |
| 'final_class': final_class, | |
| 'has_bad_words': has_bad_words, | |
| 'found_bad_words': found_bad_words, | |
| 'override_reason': override_reason, | |
| 'probabilities': { | |
| 'non_offensive': probabilities[0][0].item(), | |
| 'offensive': probabilities[0][1].item() | |
| } | |
| } | |
| return result | |
| class ProfanityRequest(BaseModel): | |
| text: str | |
| class BatchProfanityRequest(BaseModel): | |
| texts: list[str] | |
| app = FastAPI( | |
| title="Arabic Profanity Filter API", | |
| description="An API to detect profanity in Arabic text using a fine-tuned AraBERT model with rule-based override.", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Initialize the tester globally | |
| tester = None | |
| async def startup_event(): | |
| """Initialize the model on startup""" | |
| global tester | |
| try: | |
| tester = ArabicProfanityTester() | |
| print("🚀 Arabic Profanity Filter API is ready!") | |
| except Exception as e: | |
| print(f"❌ Failed to load model: {e}") | |
| raise e | |
| def read_root(): | |
| return { | |
| "message": "Welcome to the Arabic Profanity Filter API", | |
| "description": "Detects profanity in Arabic text using AraBERT model with rule-based override", | |
| "endpoints": { | |
| "predict": "/predict - Single text prediction", | |
| "batch": "/batch - Batch text prediction", | |
| "health": "/health - Health check", | |
| "docs": "/docs - API documentation" | |
| } | |
| } | |
| def health_check(): | |
| """Health check endpoint""" | |
| if tester is None: | |
| return {"status": "unhealthy", "message": "Model not loaded"} | |
| return {"status": "healthy", "message": "API is running"} | |
| async def predict_profanity(request: ProfanityRequest): | |
| """ | |
| Predicts if the given Arabic text contains profanity. | |
| - **text**: The Arabic text to analyze. | |
| Returns: | |
| - original_text: The input text | |
| - processed_text: Text after preprocessing | |
| - model_prediction: Model's prediction (Offensive/Non-Offensive) | |
| - model_confidence: Model's confidence score | |
| - final_prediction: Final result (Good/Bad) after rule-based override | |
| - has_bad_words: Whether explicit bad words were found | |
| - found_bad_words: List of bad words found | |
| - probabilities: Detailed probability scores | |
| """ | |
| if tester is None: | |
| return {"error": "Model not loaded"} | |
| try: | |
| result = tester.predict(request.text, show_details=False) | |
| return result | |
| except Exception as e: | |
| return {"error": f"Prediction failed: {str(e)}"} | |
| async def predict_batch_profanity(request: BatchProfanityRequest): | |
| """ | |
| Predicts profanity for multiple Arabic texts. | |
| - **texts**: List of Arabic texts to analyze. | |
| Returns list of prediction results for each text. | |
| """ | |
| if tester is None: | |
| return {"error": "Model not loaded"} | |
| try: | |
| results = [] | |
| for text in request.texts: | |
| result = tester.predict(text, show_details=False) | |
| results.append(result) | |
| return { | |
| "predictions": results, | |
| "summary": { | |
| "total": len(results), | |
| "bad_count": sum(1 for r in results if r['final_prediction'] == 'Bad'), | |
| "good_count": sum(1 for r in results if r['final_prediction'] == 'Good'), | |
| "explicit_bad_words_count": sum(1 for r in results if r['has_bad_words']) | |
| } | |
| } | |
| except Exception as e: | |
| return {"error": f"Batch prediction failed: {str(e)}"} | |
| if __name__ == "__main__": | |
| import os | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |