Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import allin1 | |
import time | |
import json | |
import torch | |
import librosa | |
import numpy as np | |
from pathlib import Path | |
HEADER = """ | |
<header style="text-align: center;"> | |
<h1> | |
All-In-One Music Structure Analyzer 🔮 | |
</h1> | |
<p> | |
<a href="https://github.com/mir-aidj/all-in-one">[Python Package]</a> | |
<a href="https://arxiv.org/abs/2307.16425">[Paper]</a> | |
<a href="https://taejun.kim/music-dissector/">[Visual Demo]</a> | |
</p> | |
</header> | |
<main | |
style="display: flex; justify-content: center;" | |
> | |
<div | |
style="display: inline-block;" | |
> | |
<p> | |
This Space demonstrates the music structure analyzer predicts: | |
<ul | |
style="padding-left: 1rem;" | |
> | |
<li>BPM</li> | |
<li>Beats</li> | |
<li>Downbeats</li> | |
<li>Functional segment boundaries</li> | |
<li>Functional segment labels (e.g. intro, verse, chorus, bridge, outro)</li> | |
</ul> | |
</p> | |
<p> | |
For more information, please visit the links above ✨🧸 | |
</p> | |
</div> | |
</main> | |
""" | |
CACHE_EXAMPLES = os.getenv('CACHE_EXAMPLES', '1') == '1' | |
base_dir = "/tmp/gradio/" | |
# Defining sample rate for voice activity detection (must use multiple of 8k) | |
SAMPLING_RATE = 32000 | |
torch.set_num_threads(1) | |
# Import of models to do voice detection | |
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad:v4.0', | |
model='silero_vad', | |
force_reload=True) | |
(get_speech_timestamps, | |
save_audio, | |
read_audio, | |
VADIterator, | |
collect_chunks) = utils | |
def analyze(path): | |
#Measure time for inference | |
start = time.time() | |
string_path = path | |
path = Path(path) | |
result= allin1.analyze( | |
path, | |
out_dir='./struct', | |
multiprocess=False, | |
keep_byproducts=True, # TODO: remove this | |
) | |
json_structure_output = None | |
for root, dirs, files in os.walk(f"./struct"): | |
for file_path in files: | |
json_structure_output = os.path.join(root, file_path) | |
print(json_structure_output) | |
add_voice_label(json_structure_output, string_path) | |
fig = allin1.visualize( | |
result, | |
multiprocess=False, | |
) | |
fig.set_dpi(300) | |
#allin1.sonify( | |
# result, | |
# out_dir='./sonif', | |
# multiprocess=False, | |
#) | |
#sonif_path = Path(f'./sonif/{path.stem}.sonif{path.suffix}').resolve().as_posix() | |
#Measure time for inference | |
end = time.time() | |
elapsed_time = end-start | |
# Get the base name of the file | |
file_name = os.path.basename(path) | |
# Remove the extension from the file name | |
file_name_without_extension = os.path.splitext(file_name)[0] | |
print(file_name_without_extension) | |
bass_path, drums_path, other_path, vocals_path = None, None, None, None | |
for root, dirs, files in os.walk(f"./demix/htdemucs/{file_name_without_extension}"): | |
for file_path in files: | |
file_path = os.path.join(root, file_path) | |
print(file_path) | |
if "bass.wav" in file_path: | |
bass_path = file_path | |
if "vocals.wav" in file_path: | |
vocals_path = file_path | |
if "other.wav" in file_path: | |
other_path = file_path | |
if "drums.wav" in file_path: | |
drums_path = file_path | |
#return result.bpm, fig, sonif_path, elapsed_time | |
return result.bpm, fig, elapsed_time, json_structure_output, bass_path, drums_path, other_path, vocals_path | |
def aggregate_vocal_times(vocal_time): | |
""" | |
Aggregates multiple vocal segments into one single segment. This is done because | |
usually segments are very short (<3 seconds) sections of audio. | |
""" | |
# This is an hyperparameter for the aggregation of the segments. This means we aggregate | |
# until we don't find a segment which has a start_time NEXT_SEGMENT_SECONDS after the end_time | |
# of the previous segment | |
NEXT_SEGMENT_SECONDS = 5 | |
try: | |
start_time = 0.0 | |
end_time = 0.0 | |
begin_seq = True | |
compressed_vocal_times = [] | |
for vocal_time in vocal_times: | |
if begin_seq: | |
start_time = vocal_time['start_time'] | |
end_time = vocal_time['end_time'] | |
begin_seq = False | |
continue | |
if float(vocal_time['start_time']) < float(end_time) + NEXT_SEGMENT_SECONDS: | |
end_time = vocal_time['end_time'] | |
else: | |
print(start_time, end_time) | |
compressed_vocal_times.append( { | |
"start_time": f"{start_time}", | |
"end_time": f"{end_time}" | |
} | |
) | |
start_time = vocal_time['start_time'] | |
end_time = vocal_time['end_time'] | |
compressed_vocal_times.append( { | |
"start_time": f"{start_time}", | |
"end_time": f"{end_time}" | |
} | |
) | |
except Exception as e: | |
print(f"An exception occurred: {e}") | |
return compressed_vocal_times | |
def add_voice_label(json_file, audio_path): | |
# This is an hyperparameter of the model which determines wheter to consider | |
# the segment voice of non voice | |
THRESHOLD_PROBABILITY = 0.75 | |
# Load the JSON file | |
with open(json_file, 'r') as f: | |
data = json.load(f) | |
# Create VAD object | |
vad_iterator = VADIterator(model) | |
# Read input audio file | |
wav, _ = librosa.load(audio_path, sr=SAMPLING_RATE, mono=True) | |
speech_probs = [] | |
# Size of the window we compute the probability on. | |
# This is an hyperparameter for the detection and can be changed to obtain different | |
# result. I found this to be optimal. | |
window_size_samples = int(SAMPLING_RATE/4) | |
for i in range(0, len(wav), window_size_samples): | |
chunk = torch.from_numpy(wav[i: i+ window_size_samples]) | |
if len(chunk) < window_size_samples: | |
break | |
speech_prob = model(chunk, SAMPLING_RATE).item() | |
speech_probs.append(speech_prob) | |
vad_iterator.reset_states() # reset model states after each audio | |
voice_idxs = np.where(np.array(speech_probs) >= THRESHOLD_PROBABILITY)[0] | |
print(len(voice_idxs)) | |
if len(voice_idxs) == 0: | |
print("NO VOICE SEGMENTS DETECTED!") | |
try: | |
begin_seq = True | |
start_idx = 0 | |
vocal_times=[] | |
for i in range(len(voice_idxs)-1): | |
if begin_seq: | |
start_idx = voice_idxs[i] | |
begin_seq = False | |
if voice_idxs[i+1] == voice_idxs[i]+1: | |
continue | |
start_time = float((start_idx*window_size_samples)/SAMPLING_RATE) | |
end_time = float((voice_idxs[i]*window_size_samples)/SAMPLING_RATE) | |
vocal_times.append( { | |
"start_time": f"{start_time:.2f}", | |
"end_time": f"{end_time:.2f}" | |
} | |
) | |
begin_seq = True | |
vocal_times = aggregate_vocal_times(vocal_times) | |
data['vocal_times'] = vocal_times | |
except Exception as e: | |
print(f"An exception occurred: {e}") | |
with open(json_file, 'w') as f: | |
print("writing_to_json...") | |
json.dump(data, f, indent=4) | |
with gr.Blocks() as demo: | |
gr.HTML(HEADER) | |
input_audio_path = gr.Audio( | |
label='Input', | |
type='filepath', | |
format='mp3', | |
show_download_button=False, | |
) | |
button = gr.Button('Analyze', variant='primary') | |
output_viz = gr.Plot(label='Visualization') | |
with gr.Row(): | |
output_bpm = gr.Textbox(label='BPM', scale=1) | |
#output_sonif = gr.Audio( | |
# label='Sonification', | |
# type='filepath', | |
# format='mp3', | |
# show_download_button=False, | |
# scale=9, | |
#) | |
elapsed_time = gr.Textbox(label='Overall inference time', scale=1) | |
json_structure_output = gr.File(label="Json structure") | |
with gr.Column(): | |
bass = gr.Audio(label='bass', show_share_button=False) | |
vocals =gr.Audio(label='vocals', show_share_button=False) | |
other = gr.Audio(label='other', show_share_button=False) | |
drums =gr.Audio(label='drums', show_share_button=False) | |
#bass_path = gr.Textbox(label='bass_path', scale=1) | |
#drums_path = gr.Textbox(label='drums_path', scale=1) | |
#other_path = gr.Textbox(label='other_path', scale=1) | |
#vocals_path = gr.Textbox(label='vocals_path', scale=1) | |
#gr.Examples( | |
# examples=[ | |
# './assets/NewJeans - Super Shy.mp3', | |
# './assets/Bruno Mars - 24k Magic.mp3' | |
# ], | |
# inputs=input_audio_path, | |
# outputs=[output_bpm, output_viz, output_sonif], | |
# fn=analyze, | |
# cache_examples=CACHE_EXAMPLES, | |
#) | |
button.click( | |
fn=analyze, | |
inputs=input_audio_path, | |
#outputs=[output_bpm, output_viz, output_sonif, elapsed_time], | |
outputs=[output_bpm, output_viz, elapsed_time, json_structure_output, bass, drums, other, vocals], | |
api_name='analyze', | |
) | |
if __name__ == '__main__': | |
demo.launch() | |