Spaces:
Runtime error
Runtime error
import re | |
import torch | |
import gradio as gr | |
from transformers import pipeline, AutoTokenizer | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
class AbuseHateProfanityDetector: | |
def __init__(self): | |
# Device configuration (CPU or GPU) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Initialize detection models | |
self.Abuse_detector = pipeline("text-classification", model="Hate-speech-CNERG/english-abusive-MuRIL", device=self.device) | |
self.Hate_speech_detector = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-hate-latest", device=self.device) | |
self.Profanity_detector = pipeline("text-classification", model="tarekziade/pardonmyai", device=self.device) | |
# Load tokenizers | |
self.abuse_tokenizer = AutoTokenizer.from_pretrained('Hate-speech-CNERG/english-abusive-MuRIL') | |
self.hate_speech_tokenizer = AutoTokenizer.from_pretrained('cardiffnlp/twitter-roberta-base-hate-latest') | |
self.profanity_tokenizer = AutoTokenizer.from_pretrained('tarekziade/pardonmyai') | |
# Define max token sizes for each model | |
self.Abuse_max_context_size = 512 | |
self.HateSpeech_max_context_size = 512 | |
self.Profanity_max_context_size = 512 | |
def preprocess_and_clean_text(self, text: str) -> str: | |
""" | |
Preprocesses and cleans the text. | |
""" | |
stammering_pattern = r'\b(\w+)\s*[,;]+\s*(\1\b\s*[,;]*)+' | |
passage_without_stammering = re.sub(stammering_pattern, r'\1', text) | |
passage_without_um = re.sub(r'\bum\b', ' ', passage_without_stammering) | |
modified_text = re.sub(r'\s*,+\s*', ', ', passage_without_um) | |
processed_text = re.sub(r'\s+([^\w\s])', r'\1', modified_text) | |
processed_text = re.sub(r'\s+', ' ', processed_text) | |
pattern = r'(\.\s*)+' | |
cleaned_text = re.sub(pattern, '.', processed_text) | |
return cleaned_text.strip() | |
def token_length(self, text, tokenizer): | |
""" | |
Computes the token length of a text. | |
""" | |
tokens = tokenizer.encode(text, add_special_tokens=False) | |
return len(tokens) | |
def create_token_length_wrapper(self, tokenizer): | |
""" | |
Creates a closure to calculate token length using the tokenizer. | |
""" | |
def token_length_wrapper(text): | |
return self.token_length(text, tokenizer) | |
return token_length_wrapper | |
def chunk_text(self, text, tokenizer, max_length): | |
""" | |
Chunks the input text based on the max token length and cleans the text. | |
""" | |
text = self.preprocess_and_clean_text(text) | |
token_length_wrapper = self.create_token_length_wrapper(tokenizer) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=max_length - 2, length_function=token_length_wrapper) | |
chunks = text_splitter.split_text(text) | |
return chunks | |
def classify_text(self, text: str): | |
""" | |
Classifies text for abuse, hate speech, and profanity using the respective models. | |
""" | |
# Split text into chunks for each classification model | |
abuse_chunks = self.chunk_text(text, self.abuse_tokenizer, self.Abuse_max_context_size) | |
hate_speech_chunks = self.chunk_text(text, self.hate_speech_tokenizer, self.HateSpeech_max_context_size) | |
profanity_chunks = self.chunk_text(text, self.profanity_tokenizer, self.Profanity_max_context_size) | |
# Initialize flags | |
abusive_flag = False | |
hatespeech_flag = False | |
profanity_flag = False | |
# Detect Abuse | |
for chunk in abuse_chunks: | |
result = self.Abuse_detector(chunk) | |
if result[0]['label'] == 'LABEL_1': # Assuming LABEL_1 is abusive content | |
abusive_flag = True | |
# Detect Hate Speech | |
for chunk in hate_speech_chunks: | |
result = self.Hate_speech_detector(chunk) | |
if result[0]['label'] == 'HATE': # Assuming HATE label indicates hate speech | |
hatespeech_flag = True | |
# Detect Profanity | |
for chunk in profanity_chunks: | |
result = self.Profanity_detector(chunk) | |
if result[0]['label'] == 'OFFENSIVE': # Assuming OFFENSIVE label indicates profanity | |
profanity_flag = True | |
# Return classification results | |
return { | |
"abusive_flag": abusive_flag, | |
"hatespeech_flag": hatespeech_flag, | |
"profanity_flag": profanity_flag | |
} | |
def extract_speaker_text(self, transcript, client_label="Client", care_provider_label="Care Provider"): | |
""" | |
Extracts text spoken by the client and the care provider from the transcript. | |
""" | |
client_text = [] | |
care_provider_text = [] | |
lines = transcript.split("\n") | |
for line in lines: | |
if line.startswith(client_label + ":"): | |
client_text.append(line[len(client_label) + 1:].strip()) | |
elif line.startswith(care_provider_label + ":"): | |
care_provider_text.append(line[len(care_provider_label) + 1:].strip()) | |
return " ".join(client_text), " ".join(care_provider_text) | |
# Gradio interface for the web app | |
detector = AbuseHateProfanityDetector() | |
interface = gr.Interface( | |
fn=detector.classify_text, | |
inputs=[gr.Textbox(label="Enter text")], | |
outputs="json", | |
title="Abuse, Hate Speech, and Profanity Detection", | |
description="Enter text to detect whether it contains abusive, hateful, or offensive content." | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
interface.launch(share=True) |