Spaces:
Running
Running
File size: 7,832 Bytes
0f5ffe4 698b30e 0f5ffe4 1294687 dd2abf1 1294687 0f5ffe4 43b9c07 0f5ffe4 7e3a83e 0f5ffe4 1294687 dd2abf1 98ea36e 698b30e cf0cedb 698b30e 1294687 32cfc9b 1294687 32cfc9b dd2abf1 1294687 dd2abf1 1294687 0f5ffe4 dd2abf1 1294687 dd2abf1 1294687 caf28b3 1294687 dd2abf1 0f5ffe4 565a2a0 dd2abf1 0f5ffe4 caf28b3 0f5ffe4 dd2abf1 43b9c07 dd2abf1 0f5ffe4 dd2abf1 0f5ffe4 dd2abf1 caf28b3 0f5ffe4 caf28b3 0f5ffe4 efdf187 052069b efdf187 1ebbd09 052069b aac216b 0f5ffe4 aa14789 1294687 aa14789 1294687 aa14789 1294687 aa14789 0f5ffe4 dd2abf1 0f5ffe4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
import json
import numpy as np
import textwrap
from tokenizers import Tokenizer
import onnxruntime as ort
from huggingface_hub import hf_hub_download
import gradio as gr
class ONNXInferencePipeline:
def __init__(self, repo_id):
# Retrieve the Hugging Face token from the environment variable
hf_token = os.getenv("HF_TOKEN")
if hf_token is None:
raise ValueError("HF_TOKEN environment variable is not set.")
# Load banned keywords list
self.banned_keywords = self.load_banned_keywords()
print(f"Loaded {len(self.banned_keywords)} banned keywords")
# Download files from Hugging Face Hub using the token
self.onnx_path = hf_hub_download(repo_id=repo_id, filename="model.onnx", use_auth_token=hf_token)
self.tokenizer_path = hf_hub_download(repo_id=repo_id, filename="train_bpe_tokenizer.json", use_auth_token=hf_token)
self.config_path = hf_hub_download(repo_id=repo_id, filename="hyperparameters.json", use_auth_token=hf_token)
# Load configuration
with open(self.config_path) as f:
self.config = json.load(f)
# Initialize tokenizer
self.tokenizer = Tokenizer.from_file(self.tokenizer_path)
self.max_len = 256
# Initialize ONNX runtime session
self.session = ort.InferenceSession(self.onnx_path)
self.providers = ['CPUExecutionProvider'] # Use CUDA if available
if 'CUDAExecutionProvider' in ort.get_available_providers():
self.providers = ['CUDAExecutionProvider']
self.session.set_providers(self.providers)
def load_banned_keywords(self):
# For testing purposes, using a small list
# In production, load your full list
code_str = os.getenv("banned")
if not code_str:
raise Exception("Environment variable 'banned' is not set. Please set it with your banned keywords list.")
# Create a local namespace to execute the code
local_vars = {}
# Wrap the code in a function to allow return statements
wrapped_code = f"""
def get_banned_keywords():
{textwrap.indent(code_str, ' ')}
"""
try:
# Execute the wrapped code
exec(wrapped_code, globals(), local_vars)
# Call the function to get the banned keywords
return local_vars['get_banned_keywords']()
except Exception as e:
print(f"Error loading banned keywords: {e}")
# Return a default empty list if there's an error
return []
def contains_banned_keyword(self, text):
"""Check if the input text contains any banned keywords."""
text_lower = text.lower()
# Split the text into words
words = ''.join(c if c.isalnum() else ' ' for c in text_lower).split()
for keyword in self.banned_keywords:
keyword_lower = keyword.lower()
# Check if keyword is a whole word in the text
if keyword_lower in words:
print(f"Keyword detected: '{keyword}'")
return True
print("Keywords Passed - No inappropriate keywords found")
return False
def preprocess(self, text):
encoding = self.tokenizer.encode(text)
ids = encoding.ids[:self.max_len]
padding = [0] * (self.max_len - len(ids))
return np.array(ids + padding, dtype=np.int64).reshape(1, -1)
def predict(self, text):
print(f"\nProcessing input: '{text[:50]}...' ({len(text)} characters)")
# First check if the text contains any banned keywords
if self.contains_banned_keyword(text):
print("Input rejected by keyword filter")
return {
'label': 'Inappropriate Content',
'probabilities': [1.0, 0.0] # Assuming [inappropriate, appropriate]
}
# If no banned keywords found, proceed with model prediction
print("Running ML model for classification...")
# Preprocess
input_array = self.preprocess(text)
# Run inference
results = self.session.run(
None,
{'input': input_array}
)
# Post-process
logits = results[0]
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
predicted_class = int(np.argmax(probabilities))
# Map to labels
class_labels = ['Inappropriate Content', 'Appropriate']
# Log model result
print(f"Model Passed - Result: {class_labels[predicted_class]} (Confidence: {probabilities[0][predicted_class]:.2%})")
return {
'label': class_labels[predicted_class],
'probabilities': probabilities[0].tolist()
}
# Example usage
if __name__ == "__main__":
# Initialize the pipeline with the Hugging Face repository ID
print("Initializing content filter pipeline...")
pipeline = ONNXInferencePipeline(repo_id="iimran/abuse-detector")
print("Pipeline initialized successfully")
# Example texts for testing
example_texts = [
"You're a worthless piece of garbage who should die",
"Hello HR, I hope this message finds you well. I'm writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team."
]
for text in example_texts:
result = pipeline.predict(text)
print(f"Input: {text[:50]}...")
print(f"Prediction: {result['label']} ")
print("-" * 80)
# Define a function for Gradio to use
def gradio_predict(text):
result = pipeline.predict(text)
return (
f"Prediction: {result['label']} \n"
)
# Create a Gradio interface
iface = gr.Interface(
fn=gradio_predict,
inputs=gr.Textbox(lines=7, placeholder="Enter text here..."),
outputs="text",
title="Abuse Detector - Offensive Language Detector",
description=(
"Abuse detector is designed to identify inappropriate content in text. "
"It analyzes input for Australian Slang language and abuses. "
"While it's trained on a compact dataset and may not catch highly nuanced or sophisticated language, "
"it effectively detects day-to-day offensive language commonly used in conversations."
),
examples=[
# Explicitly offensive examples
"Congrats, you fuckbrain arsehole, you've outdone yourself in stupidity. A real cock-up of a human—should we clap for your bollocks-faced greatness or just pity you?",
"You're a mad bastard, but I'd still grab a beer with you! Fuck around all you like, you cockheaded legend—your arsehole antics are bloody brilliant.",
"Your mother should have done better raising such a useless idiot.",
# Neutral or appropriate examples
"Hello HR, I hope this message finds you well. I'm writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team.",
"Thank you for your time and consideration. Please don't hesitate to reach out if you need additional information—I'd be happy to discuss further. Looking forward to hearing from you soon!",
"The weather today is lovely, and I'm looking forward to a productive day at work.",
# Mixed examples (some offensive, some neutral)
"I appreciate your help, but honestly, you're such a clueless idiot sometimes. Still, thanks for trying."
]
)
# Launch the Gradio app
print("Launching Gradio interface...")
iface.launch() |