Spaces:
Running
Running
| import os | |
| import json | |
| import numpy as np | |
| import textwrap | |
| from tokenizers import Tokenizer | |
| import onnxruntime as ort | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| class ONNXInferencePipeline: | |
| def __init__(self, repo_id, repo_type="model"): | |
| # Read token from env. In a Space, HF_TOKEN can be set in the Secrets panel. | |
| hf_token = os.getenv("HF_TOKEN") | |
| # Load banned keywords list | |
| self.banned_keywords = self.load_banned_keywords() | |
| print(f"Loaded {len(self.banned_keywords)} banned keywords") | |
| # Download artifacts. Newer huggingface_hub uses token=, not use_auth_token= | |
| self.onnx_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="model.onnx", | |
| token=hf_token, | |
| repo_type=repo_type | |
| ) | |
| self.tokenizer_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="train_bpe_tokenizer.json", | |
| token=hf_token, | |
| repo_type=repo_type | |
| ) | |
| self.config_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="hyperparameters.json", | |
| token=hf_token, | |
| repo_type=repo_type | |
| ) | |
| # Load configuration | |
| with open(self.config_path, "r") as f: | |
| self.config = json.load(f) | |
| # Initialize tokenizer | |
| self.tokenizer = Tokenizer.from_file(self.tokenizer_path) | |
| self.max_len = int(self.config.get("max_len", 256)) | |
| # Initialize ONNX runtime session | |
| # Spaces CPU runtime typically uses CPUExecutionProvider | |
| providers = ort.get_available_providers() | |
| if "CUDAExecutionProvider" in providers: | |
| use_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| else: | |
| use_providers = ["CPUExecutionProvider"] | |
| sess_options = ort.SessionOptions() | |
| # Reduce memory and improve cold start a bit | |
| sess_options.enable_mem_pattern = False | |
| sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| self.session = ort.InferenceSession(self.onnx_path, sess_options=sess_options, providers=use_providers) | |
| # Cache model input name to avoid mismatches like input vs input_ids | |
| self.input_name = self.session.get_inputs()[0].name | |
| print(f"ONNX model input name detected: {self.input_name}") | |
| # If you want label order from config, you can read it | |
| self.class_labels = self.config.get("class_labels", ["Inappropriate Content", "Appropriate"]) | |
| def load_banned_keywords(self): | |
| """ | |
| Load banned keywords from env var named 'banned'. | |
| Supports two formats: | |
| 1) Python code snippet that returns a list (your current method) | |
| 2) JSON array of strings | |
| """ | |
| code_str = os.getenv("banned") | |
| if not code_str: | |
| print("Environment variable 'banned' is not set. Using empty list.") | |
| return [] | |
| # Try JSON first | |
| try: | |
| parsed = json.loads(code_str) | |
| if isinstance(parsed, list) and all(isinstance(x, str) for x in parsed): | |
| return parsed | |
| except Exception: | |
| pass | |
| # Fallback to executable code that returns a list | |
| local_vars = {} | |
| wrapped_code = f""" | |
| def get_banned_keywords(): | |
| {textwrap.indent(code_str, ' ')} | |
| """ | |
| try: | |
| exec(wrapped_code, {}, local_vars) | |
| result = local_vars["get_banned_keywords"]() | |
| if isinstance(result, list): | |
| return [str(x) for x in result] | |
| print("Loaded banned keywords code did not return a list. Using empty list.") | |
| return [] | |
| except Exception as e: | |
| print(f"Error loading banned keywords from code: {e}") | |
| return [] | |
| def contains_banned_keyword(self, text): | |
| """Check if the input text contains any banned keywords as whole words.""" | |
| text_lower = text.lower() | |
| words = "".join(c if c.isalnum() else " " for c in text_lower).split() | |
| word_set = set(words) | |
| for keyword in self.banned_keywords: | |
| kw = str(keyword).lower().strip() | |
| if not kw: | |
| continue | |
| if kw in word_set: | |
| print(f"Keyword detected: '{keyword}'") | |
| return True | |
| print("Keywords Passed. No inappropriate keywords found") | |
| return False | |
| def preprocess(self, text): | |
| encoding = self.tokenizer.encode(text) | |
| ids = encoding.ids[: self.max_len] | |
| padding = [0] * (self.max_len - len(ids)) | |
| return np.array(ids + padding, dtype=np.int64).reshape(1, -1) | |
| def softmax(logits): | |
| # Numerically stable softmax | |
| x = logits - np.max(logits, axis=1, keepdims=True) | |
| e = np.exp(x) | |
| return e / np.sum(e, axis=1, keepdims=True) | |
| def predict(self, text): | |
| snippet = text[:50].replace("\n", " ") | |
| print(f"\nProcessing input: '{snippet}...' ({len(text)} characters)") | |
| # First rule based filter | |
| if self.contains_banned_keyword(text): | |
| print("Input rejected by keyword filter") | |
| return { | |
| "label": self.class_labels[0], | |
| "probabilities": [1.0, 0.0] if len(self.class_labels) == 2 else [1.0] * len(self.class_labels), | |
| } | |
| # Preprocess | |
| input_array = self.preprocess(text) | |
| # Run inference. Use detected input name | |
| outputs = self.session.run(None, {self.input_name: input_array}) | |
| # Post process | |
| logits = outputs[0] | |
| probs = self.softmax(logits) | |
| pred_idx = int(np.argmax(probs)) | |
| label = self.class_labels[pred_idx] if pred_idx < len(self.class_labels) else str(pred_idx) | |
| print(f"Model Passed. Result: {label} (Confidence: {probs[0][pred_idx]:.2%})") | |
| return {"label": label, "probabilities": probs[0].tolist()} | |
| # Gradio glue | |
| def gradio_predict(text): | |
| result = PIPELINE.predict(text) | |
| return f"Prediction: {result['label']}\n" | |
| # Create pipeline at import so the Space is ready | |
| print("Initializing content filter pipeline...") | |
| PIPELINE = ONNXInferencePipeline(repo_id="iimran/abuse-detector", repo_type="model") | |
| print("Pipeline initialized successfully") | |
| if __name__ == "__main__": | |
| # Required in Spaces. PORT is injected. Bind to 0.0.0.0 | |
| iface = gr.Interface( | |
| fn=gradio_predict, | |
| inputs=gr.Textbox(lines=7, placeholder="Enter text here..."), | |
| outputs="text", | |
| title="Abuse Detector - Offensive Language Detector", | |
| description=( | |
| "Abuse detector identifies inappropriate content in text. " | |
| "It analyzes input for Australian slang and abusive language. " | |
| "It is trained on a compact dataset. It may not catch highly nuanced language, " | |
| "but it detects common day to day offensive language." | |
| ), | |
| examples=[ | |
| # Explicitly offensive examples | |
| "Congrats, you fuckbrain arsehole, you have outdone yourself in stupidity. A real cock up of a human. Should we clap for your bollocks faced greatness or just pity you?", | |
| "You are a mad bastard, but I would still grab a beer with you. Mess around all you like, you cockheaded legend. Your arsehole antics are bloody brilliant.", | |
| "Your mother should have done better raising such a useless idiot.", | |
| # Neutral or appropriate examples | |
| "Hello HR, I hope this message finds you well. I am writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team.", | |
| "Thank you for your time and consideration. Please reach out if you need anything. I would be happy to discuss further.", | |
| "The weather today is lovely, and I am looking forward to a productive day at work.", | |
| # Mixed | |
| "I appreciate your help, but honestly, you are such a clueless idiot sometimes. Still, thanks for trying." | |
| ], | |
| ) | |
| print("Launching Gradio interface...") | |
| iface.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860))) | |