Spaces:
Runtime error
Runtime error
import os | |
os.system("pip install git+https://github.com/openai/whisper.git") | |
import evaluate | |
from evaluate.utils import launch_gradio_widget | |
import gradio as gr | |
import torch | |
from speechbrain.pretrained.interfaces import foreign_class | |
from transformers import AutoModelForSequenceClassification, pipeline, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer | |
# pull in emotion detection | |
# --- Add element for specification | |
# pull in text classification | |
# --- Add custom labels | |
# --- Associate labels with radio elements | |
# add logic to initiate mock notificaiton when detected | |
# pull in misophonia-specific model | |
# Building prediction function for gradio | |
emotion_dict = { | |
'sad': 'Sad', | |
'hap': 'Happy', | |
'ang': 'Anger', | |
'neu': 'Neutral' | |
} | |
pipe = pipeline("automatic-speech-recognition") | |
# Create a Gradio interface with audio file and text inputs | |
def classify_toxicity(audio_file, text_input, classify_anxiety): | |
# Transcribe the audio file using Whisper ASR | |
if audio_file != None: | |
'''whisper_model = WhisperModel.from_pretrained("openai/whisper-base") | |
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") | |
transcription_results = whisper_model.compute(uploaded=audio_file) | |
audio = whisper.load_audio(audio_file) | |
mel = whisper.log_mel_spectrogram(audio).to(model.device) | |
_, probs = model.detect_language(mel) | |
options = whisper.DecodingOptions(fp16 = False) | |
result = whisper.decode(model, mel, options) | |
# Extract the transcribed text | |
# transcribed_text = transcription_results["transcription"] | |
''' | |
# model = whisper.load_model("base") | |
# transcribed_text = model.transcribe(audio_file) | |
transcribed_text = pipe(audio_file)["text"] | |
#### Emotion classification #### | |
emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier") | |
out_prob, score, index, text_lab = emotion_classifier.classify_file(audio_file.name) | |
else: | |
transcribed_text = text_input | |
#### Toxicity Classifier #### | |
toxicity_module = evaluate.load("toxicity", "facebook/roberta-hate-speech-dynabench-r4-target") | |
#toxicity_module = evaluate.load("toxicity", 'DaNLP/da-electra-hatespeech-detection', module_type="measurement") | |
toxicity_results = toxicity_module.compute(predictions=[transcribed_text]) | |
toxicity_score = toxicity_results["toxicity"][0] | |
print(toxicity_score) | |
#### Text classification ##### | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
text_classifier = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli") | |
sequence_to_classify = transcribed_text | |
candidate_labels = classify_anxiety | |
# classification_output = classifier(sequence_to_classify, candidate_labels, multi_label=False) | |
classification_output = text_classifier(sequence_to_classify, candidate_labels, multi_label=False) | |
print(classification_output) | |
#### Emotion classification #### | |
emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier") | |
out_prob, score, index, text_lab = emotion_classifier.classify_file(audio_file.name) | |
return toxicity_score, classification_output, emo_dict[text_lab[0]], transcribed_text | |
# return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}" | |
with gr.Blocks() as iface: | |
with gr.Column(): | |
classify = gr.Radio(["racial identity hate", "LGBTQ+ hate", "sexually explicit", "misophonia"]) | |
with gr.Column(): | |
aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File") | |
text = gr.Textbox(label="Enter Text", placeholder="Enter text here...") | |
submit_btn = gr.Button(label="Run") | |
with gr.Column(): | |
out_text = gr.Textbox() | |
submit_btn.click(fn=classify_toxicity, inputs=[aud_input, text, classify], outputs=out_text) | |
iface.launch() |