Abuse-Detection / app.py
iimran's picture
Update app.py
0c698cb verified
import os
import json
import numpy as np
import textwrap
from tokenizers import Tokenizer
import onnxruntime as ort
from huggingface_hub import hf_hub_download
import gradio as gr
class ONNXInferencePipeline:
def __init__(self, repo_id, repo_type="model"):
# Read token from env. In a Space, HF_TOKEN can be set in the Secrets panel.
hf_token = os.getenv("HF_TOKEN")
# Load banned keywords list
self.banned_keywords = self.load_banned_keywords()
print(f"Loaded {len(self.banned_keywords)} banned keywords")
# Download artifacts. Newer huggingface_hub uses token=, not use_auth_token=
self.onnx_path = hf_hub_download(
repo_id=repo_id,
filename="model.onnx",
token=hf_token,
repo_type=repo_type
)
self.tokenizer_path = hf_hub_download(
repo_id=repo_id,
filename="train_bpe_tokenizer.json",
token=hf_token,
repo_type=repo_type
)
self.config_path = hf_hub_download(
repo_id=repo_id,
filename="hyperparameters.json",
token=hf_token,
repo_type=repo_type
)
# Load configuration
with open(self.config_path, "r") as f:
self.config = json.load(f)
# Initialize tokenizer
self.tokenizer = Tokenizer.from_file(self.tokenizer_path)
self.max_len = int(self.config.get("max_len", 256))
# Initialize ONNX runtime session
# Spaces CPU runtime typically uses CPUExecutionProvider
providers = ort.get_available_providers()
if "CUDAExecutionProvider" in providers:
use_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
use_providers = ["CPUExecutionProvider"]
sess_options = ort.SessionOptions()
# Reduce memory and improve cold start a bit
sess_options.enable_mem_pattern = False
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.session = ort.InferenceSession(self.onnx_path, sess_options=sess_options, providers=use_providers)
# Cache model input name to avoid mismatches like input vs input_ids
self.input_name = self.session.get_inputs()[0].name
print(f"ONNX model input name detected: {self.input_name}")
# If you want label order from config, you can read it
self.class_labels = self.config.get("class_labels", ["Inappropriate Content", "Appropriate"])
def load_banned_keywords(self):
"""
Load banned keywords from env var named 'banned'.
Supports two formats:
1) Python code snippet that returns a list (your current method)
2) JSON array of strings
"""
code_str = os.getenv("banned")
if not code_str:
print("Environment variable 'banned' is not set. Using empty list.")
return []
# Try JSON first
try:
parsed = json.loads(code_str)
if isinstance(parsed, list) and all(isinstance(x, str) for x in parsed):
return parsed
except Exception:
pass
# Fallback to executable code that returns a list
local_vars = {}
wrapped_code = f"""
def get_banned_keywords():
{textwrap.indent(code_str, ' ')}
"""
try:
exec(wrapped_code, {}, local_vars)
result = local_vars["get_banned_keywords"]()
if isinstance(result, list):
return [str(x) for x in result]
print("Loaded banned keywords code did not return a list. Using empty list.")
return []
except Exception as e:
print(f"Error loading banned keywords from code: {e}")
return []
def contains_banned_keyword(self, text):
"""Check if the input text contains any banned keywords as whole words."""
text_lower = text.lower()
words = "".join(c if c.isalnum() else " " for c in text_lower).split()
word_set = set(words)
for keyword in self.banned_keywords:
kw = str(keyword).lower().strip()
if not kw:
continue
if kw in word_set:
print(f"Keyword detected: '{keyword}'")
return True
print("Keywords Passed. No inappropriate keywords found")
return False
def preprocess(self, text):
encoding = self.tokenizer.encode(text)
ids = encoding.ids[: self.max_len]
padding = [0] * (self.max_len - len(ids))
return np.array(ids + padding, dtype=np.int64).reshape(1, -1)
@staticmethod
def softmax(logits):
# Numerically stable softmax
x = logits - np.max(logits, axis=1, keepdims=True)
e = np.exp(x)
return e / np.sum(e, axis=1, keepdims=True)
def predict(self, text):
snippet = text[:50].replace("\n", " ")
print(f"\nProcessing input: '{snippet}...' ({len(text)} characters)")
# First rule based filter
if self.contains_banned_keyword(text):
print("Input rejected by keyword filter")
return {
"label": self.class_labels[0],
"probabilities": [1.0, 0.0] if len(self.class_labels) == 2 else [1.0] * len(self.class_labels),
}
# Preprocess
input_array = self.preprocess(text)
# Run inference. Use detected input name
outputs = self.session.run(None, {self.input_name: input_array})
# Post process
logits = outputs[0]
probs = self.softmax(logits)
pred_idx = int(np.argmax(probs))
label = self.class_labels[pred_idx] if pred_idx < len(self.class_labels) else str(pred_idx)
print(f"Model Passed. Result: {label} (Confidence: {probs[0][pred_idx]:.2%})")
return {"label": label, "probabilities": probs[0].tolist()}
# Gradio glue
def gradio_predict(text):
result = PIPELINE.predict(text)
return f"Prediction: {result['label']}\n"
# Create pipeline at import so the Space is ready
print("Initializing content filter pipeline...")
PIPELINE = ONNXInferencePipeline(repo_id="iimran/abuse-detector", repo_type="model")
print("Pipeline initialized successfully")
if __name__ == "__main__":
# Required in Spaces. PORT is injected. Bind to 0.0.0.0
iface = gr.Interface(
fn=gradio_predict,
inputs=gr.Textbox(lines=7, placeholder="Enter text here..."),
outputs="text",
title="Abuse Detector - Offensive Language Detector",
description=(
"Abuse detector identifies inappropriate content in text. "
"It analyzes input for Australian slang and abusive language. "
"It is trained on a compact dataset. It may not catch highly nuanced language, "
"but it detects common day to day offensive language."
),
examples=[
# Explicitly offensive examples
"Congrats, you fuckbrain arsehole, you have outdone yourself in stupidity. A real cock up of a human. Should we clap for your bollocks faced greatness or just pity you?",
"You are a mad bastard, but I would still grab a beer with you. Mess around all you like, you cockheaded legend. Your arsehole antics are bloody brilliant.",
"Your mother should have done better raising such a useless idiot.",
# Neutral or appropriate examples
"Hello HR, I hope this message finds you well. I am writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team.",
"Thank you for your time and consideration. Please reach out if you need anything. I would be happy to discuss further.",
"The weather today is lovely, and I am looking forward to a productive day at work.",
# Mixed
"I appreciate your help, but honestly, you are such a clueless idiot sometimes. Still, thanks for trying."
],
)
print("Launching Gradio interface...")
iface.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))