mskov's picture
Update app.py
fd26334
raw
history blame
7.67 kB
import os
os.system("pip install git+https://github.com/openai/whisper.git")
import whisper
import evaluate
from evaluate.utils import launch_gradio_widget
import gradio as gr
import torch
import pandas as pd
import classify
import replace_explitives
from whisper.model import Whisper
from whisper.tokenizer import get_tokenizer
from speechbrain.pretrained.interfaces import foreign_class
from transformers import AutoModelForSequenceClassification, pipeline, WhisperTokenizer, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer
# pull in emotion detection
# --- Add element for specification
# pull in text classification
# --- Add custom labels
# --- Associate labels with radio elements
# add logic to initiate mock notificaiton when detected
# pull in misophonia-specific model
model_cache = {}
# Building prediction function for gradio
emo_dict = {
'sad': 'Sad',
'hap': 'Happy',
'ang': 'Anger',
'neu': 'Neutral'
}
# static classes for now, but it would be best ot have the user select from multiple, and to enter their own
class_options = {
"racism": ["racism", "hate speech", "bigotry", "racially targeted", "racial slur", "ethnic slur", "ethnic hate", "pro-white nationalism"],
"LGBTQ+ hate": ["gay slur", "trans slur", "homophobic slur", "transphobia", "anti-LBGTQ+", "hate speech"],
"sexually explicit": ["sexually explicit", "sexually coercive", "sexual exploitation", "vulgar", "raunchy", "sexist", "sexually demeaning", "sexual violence", "victim blaming"],
"misophonia": ["chewing", "breathing", "mouthsounds", "popping", "sneezing", "yawning", "smacking", "sniffling", "panting"]
}
pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large")
def classify_emotion(audio):
#### Emotion classification ####
emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
out_prob, score, index, text_lab = emotion_classifier.classify_file(audio)
return emo_dict[text_lab[0]]
def slider_logic(slider):
if slider == 1:
theshold = .98
elif slider == 2:
threshold = .88
elif slider == 3:
threshold = .77
elif slider == 4:
threshold = .66
elif slider == 5:
threshold = .55
else:
threshold = []
return threshold
# Create a Gradio interface with audio file and text inputs
def classify_toxicity(audio_file, text_input, classify_anxiety, emo_class, explitive_selection, slider):
# Transcribe the audio file using Whisper ASR
if audio_file != None:
transcribed_text = pipe(audio_file)["text"]
else:
transcribed_text = text_input
if classify_anxiety != "misophonia":
print("emo_class ", emo_class, "explitive select", explitive_selection)
## SLIDER ##
threshold = slider_logic(slider)
#------- explitive call ---------------
if replace_explitives != None and emo_class == None:
transcribed_text = replace_explitives.sub_explitives(transcribed_text, explitive_selection)
#### Toxicity Classifier ####
toxicity_module = evaluate.load("toxicity", "facebook/roberta-hate-speech-dynabench-r4-target")
#toxicity_module = evaluate.load("toxicity", 'DaNLP/da-electra-hatespeech-detection', module_type="measurement")
toxicity_results = toxicity_module.compute(predictions=[transcribed_text])
toxicity_score = toxicity_results["toxicity"][0]
print(toxicity_score)
# emo call
if emo_class != None:
classify_emotion(audio_file)
#### Text classification #####
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
text_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
sequence_to_classify = transcribed_text
print(classify_anxiety, class_options)
candidate_labels = class_options.get(classify_anxiety, [])
# classification_output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
classification_output = text_classifier(sequence_to_classify, candidate_labels, multi_label=True)
print("class output ", type(classification_output))
# classification_df = pd.DataFrame.from_dict(classification_output)
print("keys ", classification_output.keys())
# plot.update(x=classification_df["labels"], y=classification_df["scores"])
if toxicity_score > threshold:
print("threshold exceeded!!")
return toxicity_score, classification_output, transcribed_text
# return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}"
else:
threshold = slider_logic(slider)
model = whisper.load_model("large")
# model = model_cache[model_name]
# class_names = classify_anxiety.split(",")
class_names_list = class_options.get(classify_anxiety, [])
class_str = ""
for elm in class_names_list:
class_str += elm + ","
#class_names = class_names_temp.split(",")
class_names = class_str.split(",")
print("class names ", class_names, "classify_anxiety ", classify_anxiety)
tokenizer = get_tokenizer("large")
# tokenizer= WhisperTokenizer.from_pretrained("openai/whisper-large")
internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
model=model,
class_names=class_names,
# class_names=classify_anxiety,
tokenizer=tokenizer,
)
audio_features = classify.calculate_audio_features(audio_file, model)
average_logprobs = classify.calculate_average_logprobs(
model=model,
audio_features=audio_features,
class_names=class_names,
tokenizer=tokenizer,
)
average_logprobs -= internal_lm_average_logprobs
scores = average_logprobs.softmax(-1).tolist()
return {class_name: score for class_name, score in zip(class_names, scores)}
return classify_anxiety
with gr.Blocks() as iface:
with gr.Column():
anxiety_class = gr.Radio(["racism", "LGBTQ+ hate", "sexually explicit", "misophonia"])
explit_preference = gr.Radio(choices=["N-Word", "B-Word", "All Explitives"], label="Words to omit from general anxiety classes", info="certain words may be acceptible within certain contects for given groups of people, and some people may be unbothered by explitives broadly speaking.")
emo_class = gr.Radio(choices=["negaitve emotionality"], label="label", info="Select if you would like explitives to be considered anxiety-indiucing in the case of anger/ negative emotionality.")
sense_slider = gr.Slider(minimum=1, maximum=5, label="How readily do you want the tool to intervene? 1 = in extreme cases and 5 = at every opportunity")
with gr.Column():
aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
text = gr.Textbox(label="Enter Text", placeholder="Enter text here...")
submit_btn = gr.Button(label="Run")
with gr.Column():
out_val = gr.Textbox()
out_class = gr.Textbox()
out_text = gr.Textbox()
submit_btn.click(fn=classify_toxicity, inputs=[aud_input, text, anxiety_class, emo_class, explit_preference, sense_slider], outputs=[out_val, out_class, out_text])
iface.launch()