SSL_demo / app.py
Andrei-Iulian SĂCELEANU
added examples sections
eaee399
import re
import os
import gradio as gr
import librosa
import numpy as np
from transformers import AutoTokenizer,ViTImageProcessor
from unidecode import unidecode
from models import *
tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base")
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
def preprocess(x):
"""Preprocess input string x"""
s = unidecode(x)
s = str.lower(s)
s = re.sub(r"\[[a-z]+\]","", s)
s = re.sub(r"\*","", s)
s = re.sub(r"[^a-zA-Z0-9]+"," ",s)
s = re.sub(r" +"," ",s)
s = re.sub(r"(.)\1+",r"\1",s)
return s
label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"]
audio_label_names = ["Laughter", "Sigh", "Cough", "Throat clearing", "Sneeze", "Sniff"]
def ssl_predict(in_text, model_type):
"""main predict function"""
preprocessed = preprocess(in_text)
toks = tok(
preprocessed,
padding="max_length",
max_length=96,
truncation=True,
return_tensors="tf"
)
preds = None
if model_type == "fixmatch":
model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
model.load_weights("./checkpoints/fixmatch_tune")
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
elif model_type == "freematch":
model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
model.cls_head.load_weights("./checkpoints/freematch_tune")
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
elif model_type == "mixmatch":
model = MixMatch(bert_model="andrei-saceleanu/ro-offense-mixmatch")
model.cls_head.load_weights("./checkpoints/mixmatch")
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
elif model_type == "contrastive_reg":
model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
model.load_weights("./checkpoints/contrastive")
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
elif model_type == "label_propagation":
model = LPModel()
model.load_weights("./checkpoints/label_prop")
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
probs = list(preds[0].numpy())
d = {}
for k, v in zip(label_names, probs):
d[k] = float(v)
return d
def ssl_predict2(audio_file, model_type):
"""main predict function"""
signal, sr = librosa.load(audio_file.name, sr=16000)
length = 5 * 16000
if len(signal) < length:
signal = np.pad(signal,(0,length-len(signal)),'constant')
else:
signal = signal[:length]
spectrogram = librosa.feature.melspectrogram(y=signal, sr=sr, n_mels=128)
spectrogram = librosa.power_to_db(S=spectrogram, ref=np.max)
spectrogram_min, spectrogram_max = spectrogram.min(), spectrogram.max()
spectrogram = (spectrogram - spectrogram_min) / (spectrogram_max - spectrogram_min)
spectrogram = spectrogram.astype("float32")
inputs = processor.preprocess(
np.repeat(spectrogram[np.newaxis,:,:,np.newaxis],3,-1),
image_mean=(-3.05,-3.05,-3.05),
image_std=(2.33,2.33,2.33),
return_tensors="tf"
)
preds = None
if model_type == "fixmatch":
model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-fixmatch")
model.cls_head.load_weights("./checkpoints/audio_fixmatch")
preds, _ = model(inputs["pixel_values"], training=False)
elif model_type == "freematch":
model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-freematch")
model.cls_head.load_weights("./checkpoints/audio_freematch")
preds, _ = model(inputs["pixel_values"], training=False)
elif model_type == "mixmatch":
model = AudioMixMatch(encoder_name="andrei-saceleanu/vit-base-mixmatch")
model.cls_head.load_weights("./checkpoints/audio_mixmatch")
preds = model(inputs["pixel_values"], training=False)
probs = list(preds[0].numpy())
d = {}
for k, v in zip(audio_label_names, probs):
d[k] = float(v)
return d
text_types = ["text", "password"]
with open(file="examples.txt", mode="r", encoding="UTF-8") as fin:
lines = [elem[:-1] for elem in fin.readlines()]
DATA_DIR = os.path.abspath("./audio_data")
with open(file="audio_examples.txt", mode="r", encoding="UTF-8") as fin:
lines2 = [os.path.join(DATA_DIR, elem.strip()) for elem in fin.readlines()]
with gr.Blocks() as ssl_interface:
with gr.Tab("Text (RO-Offense)"):
with gr.Row():
with gr.Column():
in_text = gr.Textbox(label="Input text",type="password")
safe_view = gr.Checkbox(value=True,label="Safe view")
model_list = gr.Dropdown(
choices=["fixmatch", "freematch", "mixmatch", "contrastive_reg", "label_propagation"],
max_choices=1,
label="Training method",
allow_custom_value=False,
info="Select trained model according to different SSL techniques from paper",
)
with gr.Row():
clear_btn = gr.Button(value="Clear")
submit_btn = gr.Button(value="Submit")
ds = gr.Dataset(
components=[gr.Textbox(visible=False),gr.Textbox(visible=False)],
headers=["Id","Expected class"],
samples=[["1","ABUSE"],["2","INSULT"],["3","PROFANITY"],["4","OTHER"]],
type="index"
)
with gr.Column():
out_field = gr.Label(num_top_classes=4, label="Prediction")
safe_view.change(
fn= lambda checked: gr.update(type=text_types[int(checked)]),
inputs=safe_view,
outputs=in_text
)
ds.click(
fn=lambda idx: gr.update(value=lines[idx].split("##")[0]),
inputs=ds,
outputs=in_text
)
submit_btn.click(
fn=ssl_predict,
inputs=[in_text, model_list],
outputs=[out_field]
)
clear_btn.click(
fn=lambda: [None for _ in range(2)],
inputs=None,
outputs=[in_text, out_field],
queue=False
)
with gr.Tab("Audio (VocalSound)"):
with gr.Row():
with gr.Column():
audio_file = gr.File(
label="Input audio",
file_count="single",
file_types=["audio"]
)
model_list2 = gr.Dropdown(
choices=["fixmatch", "freematch", "mixmatch"],
max_choices=1,
label="Training method",
allow_custom_value=False,
info="Select trained model according to different SSL techniques from paper",
)
with gr.Row():
clear_btn2 = gr.Button(value="Clear")
submit_btn2 = gr.Button(value="Submit")
ds2 = gr.Dataset(
components=[gr.Textbox(visible=False),gr.Textbox(visible=False)],
headers=["Id","Expected class"],
samples=[["1","Laughter"],["2","Cough"],["3","Sneeze"],["4","Throatclearing"]],
type="index"
)
with gr.Column():
out_field2 = gr.Label(num_top_classes=6, label="Prediction")
submit_btn2.click(
fn=ssl_predict2,
inputs=[audio_file, model_list2],
outputs=[out_field2]
)
clear_btn2.click(
fn=lambda: [None for _ in range(2)],
inputs=None,
outputs=[audio_file, out_field2],
queue=False
)
ds2.click(
fn=lambda idx: gr.update(value=lines2[idx]),
inputs=ds2,
outputs=audio_file
)
ssl_interface.launch(server_name="0.0.0.0", server_port=7860)