ksang's picture
Upload 3 files
448fd25
raw
history blame contribute delete
No virus
2.21 kB
# %%
import gradio as gr
import torchaudio
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
import librosa
import torch
# %%
def dump_pickle(file_path: str, file, mode: str = "wb"):
import pickle
with open(file_path, mode=mode) as f:
pickle.dump(file, f)
def load_pickle(file_path: str, mode: str = "rb", encoding=""):
import pickle
with open(file_path, mode=mode) as f:
return pickle.load(f, encoding=encoding)
# %%
label2id = load_pickle('label2id.pkl')
id2label = load_pickle('id2label.pkl')
# %%
model = AutoModelForAudioClassification.from_pretrained(
"facebook/wav2vec2-base", num_labels=len(label2id), label2id=label2id, id2label=id2label
)
# %%
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
# %%
checkpoint = torch.load('pytorch_model.bin', map_location=torch.device('cpu'))
# %%
model.load_state_dict(checkpoint)
# %%
def predict(input):
if input == None:
return "Please input a valid file or record yourself by clicking the microphone"
elif input:
waveform, sr = librosa.load(input)
waveform = torch.from_numpy(waveform).unsqueeze(0)
waveform = torchaudio.transforms.Resample(sr, 16_000)(waveform)
inputs = feature_extractor(waveform, sampling_rate=feature_extractor.sampling_rate,
max_length=16000, truncation=True)
tensor = torch.tensor(inputs['input_values'][0])
with torch.no_grad():
output = model(tensor)
logits = output['logits'][0]
label_id = torch.argmax(logits).item()
label_name = id2label[str(label_id)]
return label_name
else:
return "File is not valid"
# %%
demo = gr.Interface(
fn=predict,
title="Audio Gender Classification",
description="Record your voice or upload an audio file to see what gender our model classifies it as",
inputs=gr.Audio(source="microphone", type="filepath", optional=False, label="Speak to classify your voice!"), # record audio, save in temp file to feed to inference func
outputs="text",
examples= [["male.mp3"], ["female.mp3"]]
)
# %%
demo.launch()
# %%