Spaces:
Running
Running
import os | |
import json | |
import numpy as np | |
import textwrap | |
from tokenizers import Tokenizer | |
import onnxruntime as ort | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
class ONNXInferencePipeline: | |
def __init__(self, repo_id): | |
# Retrieve the Hugging Face token from the environment variable | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token is None: | |
raise ValueError("HF_TOKEN environment variable is not set.") | |
# Load banned keywords list | |
self.banned_keywords = self.load_banned_keywords() | |
print(f"Loaded {len(self.banned_keywords)} banned keywords") | |
# Download files from Hugging Face Hub using the token | |
self.onnx_path = hf_hub_download(repo_id=repo_id, filename="model.onnx", use_auth_token=hf_token) | |
self.tokenizer_path = hf_hub_download(repo_id=repo_id, filename="train_bpe_tokenizer.json", use_auth_token=hf_token) | |
self.config_path = hf_hub_download(repo_id=repo_id, filename="hyperparameters.json", use_auth_token=hf_token) | |
# Load configuration | |
with open(self.config_path) as f: | |
self.config = json.load(f) | |
# Initialize tokenizer | |
self.tokenizer = Tokenizer.from_file(self.tokenizer_path) | |
self.max_len = 256 | |
# Initialize ONNX runtime session | |
self.session = ort.InferenceSession(self.onnx_path) | |
self.providers = ['CPUExecutionProvider'] # Use CUDA if available | |
if 'CUDAExecutionProvider' in ort.get_available_providers(): | |
self.providers = ['CUDAExecutionProvider'] | |
self.session.set_providers(self.providers) | |
def load_banned_keywords(self): | |
# For testing purposes, using a small list | |
# In production, load your full list | |
code_str = os.getenv("banned") | |
if not code_str: | |
raise Exception("Environment variable 'banned' is not set. Please set it with your banned keywords list.") | |
# Create a local namespace to execute the code | |
local_vars = {} | |
# Wrap the code in a function to allow return statements | |
wrapped_code = f""" | |
def get_banned_keywords(): | |
{textwrap.indent(code_str, ' ')} | |
""" | |
try: | |
# Execute the wrapped code | |
exec(wrapped_code, globals(), local_vars) | |
# Call the function to get the banned keywords | |
return local_vars['get_banned_keywords']() | |
except Exception as e: | |
print(f"Error loading banned keywords: {e}") | |
# Return a default empty list if there's an error | |
return [] | |
def contains_banned_keyword(self, text): | |
"""Check if the input text contains any banned keywords.""" | |
text_lower = text.lower() | |
# Split the text into words | |
words = ''.join(c if c.isalnum() else ' ' for c in text_lower).split() | |
for keyword in self.banned_keywords: | |
keyword_lower = keyword.lower() | |
# Check if keyword is a whole word in the text | |
if keyword_lower in words: | |
print(f"Keyword detected: '{keyword}'") | |
return True | |
print("Keywords Passed - No inappropriate keywords found") | |
return False | |
def preprocess(self, text): | |
encoding = self.tokenizer.encode(text) | |
ids = encoding.ids[:self.max_len] | |
padding = [0] * (self.max_len - len(ids)) | |
return np.array(ids + padding, dtype=np.int64).reshape(1, -1) | |
def predict(self, text): | |
print(f"\nProcessing input: '{text[:50]}...' ({len(text)} characters)") | |
# First check if the text contains any banned keywords | |
if self.contains_banned_keyword(text): | |
print("Input rejected by keyword filter") | |
return { | |
'label': 'Inappropriate Content', | |
'probabilities': [1.0, 0.0] # Assuming [inappropriate, appropriate] | |
} | |
# If no banned keywords found, proceed with model prediction | |
print("Running ML model for classification...") | |
# Preprocess | |
input_array = self.preprocess(text) | |
# Run inference | |
results = self.session.run( | |
None, | |
{'input': input_array} | |
) | |
# Post-process | |
logits = results[0] | |
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) | |
predicted_class = int(np.argmax(probabilities)) | |
# Map to labels | |
class_labels = ['Inappropriate Content', 'Appropriate'] | |
# Log model result | |
print(f"Model Passed - Result: {class_labels[predicted_class]} (Confidence: {probabilities[0][predicted_class]:.2%})") | |
return { | |
'label': class_labels[predicted_class], | |
'probabilities': probabilities[0].tolist() | |
} | |
# Example usage | |
if __name__ == "__main__": | |
# Initialize the pipeline with the Hugging Face repository ID | |
print("Initializing content filter pipeline...") | |
pipeline = ONNXInferencePipeline(repo_id="iimran/abuse-detector") | |
print("Pipeline initialized successfully") | |
# Example texts for testing | |
example_texts = [ | |
"You're a worthless piece of garbage who should die", | |
"Hello HR, I hope this message finds you well. I'm writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team." | |
] | |
for text in example_texts: | |
result = pipeline.predict(text) | |
print(f"Input: {text[:50]}...") | |
print(f"Prediction: {result['label']} ") | |
print("-" * 80) | |
# Define a function for Gradio to use | |
def gradio_predict(text): | |
result = pipeline.predict(text) | |
return ( | |
f"Prediction: {result['label']} \n" | |
) | |
# Create a Gradio interface | |
iface = gr.Interface( | |
fn=gradio_predict, | |
inputs=gr.Textbox(lines=7, placeholder="Enter text here..."), | |
outputs="text", | |
title="Abuse Detector - Offensive Language Detector", | |
description=( | |
"Abuse detector is designed to identify inappropriate content in text. " | |
"It analyzes input for Australian Slang language and abuses. " | |
"While it's trained on a compact dataset and may not catch highly nuanced or sophisticated language, " | |
"it effectively detects day-to-day offensive language commonly used in conversations." | |
), | |
examples=[ | |
# Explicitly offensive examples | |
"Congrats, you fuckbrain arsehole, you've outdone yourself in stupidity. A real cock-up of a human—should we clap for your bollocks-faced greatness or just pity you?", | |
"You're a mad bastard, but I'd still grab a beer with you! Fuck around all you like, you cockheaded legend—your arsehole antics are bloody brilliant.", | |
"Your mother should have done better raising such a useless idiot.", | |
# Neutral or appropriate examples | |
"Hello HR, I hope this message finds you well. I'm writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team.", | |
"Thank you for your time and consideration. Please don't hesitate to reach out if you need additional information—I'd be happy to discuss further. Looking forward to hearing from you soon!", | |
"The weather today is lovely, and I'm looking forward to a productive day at work.", | |
# Mixed examples (some offensive, some neutral) | |
"I appreciate your help, but honestly, you're such a clueless idiot sometimes. Still, thanks for trying." | |
] | |
) | |
# Launch the Gradio app | |
print("Launching Gradio interface...") | |
iface.launch() |