from typing import Dict import gradio as gr import whisper from whisper.tokenizer import get_tokenizer import classify model_cache = {} def zero_shot_classify(audio_path: str, class_names: str, model_name: str) -> Dict[str, float]: class_names = class_names.split(",") tokenizer = get_tokenizer(multilingual=".en" not in model_name) if model_name not in model_cache: model = whisper.load_model(model_name) model_cache[model_name] = model else: model = model_cache[model_name] internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs( model=model, class_names=class_names, tokenizer=tokenizer, ) audio_features = classify.calculate_audio_features(audio_path, model) average_logprobs = classify.calculate_average_logprobs( model=model, audio_features=audio_features, class_names=class_names, tokenizer=tokenizer, ) average_logprobs -= internal_lm_average_logprobs scores = average_logprobs.softmax(-1).tolist() return {class_name: score for class_name, score in zip(class_names, scores)} def main(): CLASS_NAMES = "[dog barking],[helicopter whirring],[laughing],[birds chirping],[clock ticking]" AUDIO_PATHS = [ "./data/(dog)1-100032-A-0.wav", "./data/(helicopter)1-181071-A-40.wav", "./data/(laughing)1-1791-A-26.wav", "./data/(chirping_birds)1-34495-A-14.wav", "./data/(clock_tick)1-21934-A-38.wav", ] EXAMPLES = [] for audio_path in AUDIO_PATHS: EXAMPLES.append([audio_path, CLASS_NAMES, "small"]) DESCRIPTION = ( '
' "

This demo allows you to try out zero-shot audio classification using " "Whisper.

" "

Github: https://github.com/jumon/zac

" "

Example audio files are from the ESC-50" " dataset (CC BY-NC 3.0).

" ) demo = gr.Interface( fn=zero_shot_classify, inputs=[ gr.Audio(source="upload", type="filepath", label="Audio File"), gr.Textbox(lines=1, label="Candidate class names (comma-separated)"), gr.Radio( choices=["tiny", "base", "small", "medium", "large"], value="small", label="Model Name", ), ], outputs="label", examples=EXAMPLES, title="Zero-shot Audio Classification using Whisper", description=DESCRIPTION, ) demo.launch() if __name__ == "__main__": main()