Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py | |
# also released under the MIT license. | |
import argparse | |
from concurrent.futures import ProcessPoolExecutor | |
import os | |
from pathlib import Path | |
import subprocess as sp | |
from tempfile import NamedTemporaryFile | |
import time | |
import typing as tp | |
import warnings | |
import glob | |
import csv | |
import torch | |
import gradio as gr | |
import numpy as np | |
import shutil | |
from audiocraft.data.audio_utils import convert_audio | |
from audiocraft.data.audio import audio_write, audio_read | |
from audiocraft.models import MusicGen | |
from demucs import pretrained | |
from demucs.apply import apply_model | |
from demucs.audio import convert_audio | |
from gradio_client import Client | |
import pretty_midi | |
import huggingface_hub | |
from huggingface_hub import Repository | |
from datetime import datetime | |
LOCAL = False | |
USE_MIDI = True | |
# LOGS | |
DATASET_REPO_URL = "https://huggingface.co/datasets/soundsauce/soundsauce-logs" | |
DATA_FILENAME = "ratings.csv" | |
DATA_FILE = os.path.join("data", DATA_FILENAME) | |
AUDIO_DIR = os.path.join("data", "audio") | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
print("is none?", HF_TOKEN is None) | |
repo = Repository( | |
local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN | |
) | |
print("hfh", huggingface_hub.__version__) | |
MODEL = None # Last used model | |
DEMUCS_MODEL = None | |
MAX_BATCH_SIZE = 12 | |
INTERRUPTING = False | |
client = None | |
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform | |
_old_call = sp.call | |
stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3} | |
stem_idx = torch.LongTensor([stem2idx['vocal'], stem2idx['other'], stem2idx['bass']]) | |
melody_files = list(glob.glob('clips/**/*.wav', recursive=True)) | |
midi_files = list(glob.glob('clips/**/*.mid', recursive=True)) | |
crops = [(0, 5), (0, 10), (0, 15)] | |
selected_melody = "" | |
selected_crop = None | |
selected_text = "" | |
output_file = "" | |
def store_message(message: dict): | |
if message and output_file: | |
if not os.path.exists(AUDIO_DIR): | |
os.makedirs(AUDIO_DIR) | |
repo.git_pull() | |
with open(DATA_FILE, "a") as csvfile: | |
writer = csv.DictWriter(csvfile, fieldnames=message.keys()) | |
writer.writerow(message) | |
filepath = os.path.join(AUDIO_DIR, message["TIME"]) + ".mp3" | |
shutil.copy(output_file, filepath) | |
commit_url = repo.push_to_hub() | |
print("Commited to", commit_url) | |
def _call_nostderr(*args, **kwargs): | |
# Avoid ffmpeg vomitting on the logs. | |
kwargs['stderr'] = sp.DEVNULL | |
kwargs['stdout'] = sp.DEVNULL | |
_old_call(*args, **kwargs) | |
sp.call = _call_nostderr | |
# Preallocating the pool of processes. | |
pool = ProcessPoolExecutor(4) | |
pool.__enter__() | |
def interrupt(): | |
global INTERRUPTING | |
INTERRUPTING = True | |
class FileCleaner: | |
def __init__(self, file_lifetime: float = 3600): | |
self.file_lifetime = file_lifetime | |
self.files = [] | |
def add(self, path: tp.Union[str, Path]): | |
self._cleanup() | |
self.files.append((time.time(), Path(path))) | |
def _cleanup(self): | |
now = time.time() | |
for time_added, path in list(self.files): | |
if now - time_added > self.file_lifetime: | |
if path.exists(): | |
path.unlink() | |
self.files.pop(0) | |
else: | |
break | |
# 10 minutes | |
file_cleaner = FileCleaner(600) | |
def make_waveform(*args, **kwargs): | |
# Further remove some warnings. | |
be = time.time() | |
with warnings.catch_warnings(): | |
warnings.simplefilter('ignore') | |
out = gr.make_waveform(*args, **kwargs) | |
print("Make a video took", time.time() - be) | |
return out | |
def load_model(version='melody'): | |
global MODEL, DEMUCS_MODEL | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
if LOCAL: | |
if MODEL is None or MODEL.name != version: | |
print("Loading model", version) | |
# If gpu is not available, we'll use cpu. | |
MODEL = MusicGen.get_pretrained(version, device=device) | |
if DEMUCS_MODEL is None: | |
DEMUCS_MODEL = pretrained.get_model('htdemucs').to(device) | |
def connect_to_endpoint(): | |
global client | |
client = Client("https://facebook-musicgen--44zzp.hf.space/") | |
def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs): | |
global output_file | |
MODEL.set_generation_params(duration=duration, cfg_coef=5, **gen_kwargs) | |
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) | |
be = time.time() | |
processed_melodies = [] | |
target_sr = 32000 | |
target_ac = 1 | |
for melody in melodies: | |
if melody is None: | |
processed_melodies.append(None) | |
else: | |
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t() | |
if melody.dim() == 1: | |
melody = melody[None] | |
melody = melody[..., :int(sr * duration)] | |
melody = convert_audio(melody, sr, target_sr, target_ac) | |
processed_melodies.append(melody) | |
outputs = MODEL.generate_with_chroma( | |
descriptions=texts, | |
melody_wavs=processed_melodies, | |
melody_sample_rate=target_sr, | |
progress=progress, | |
) | |
outputs = outputs.detach().float() | |
out_files = [] | |
for output in outputs: | |
# Demucs | |
print("Running demucs") | |
wav = convert_audio(output, MODEL.sample_rate, DEMUCS_MODEL.samplerate, DEMUCS_MODEL.audio_channels) | |
wav = wav.unsqueeze(0) | |
stems = apply_model(DEMUCS_MODEL, wav) | |
stems = stems[:, stem_idx] # extract stem | |
stems = stems.sum(1) # merge extracted stems | |
stems = convert_audio(stems, DEMUCS_MODEL.samplerate, MODEL.sample_rate, 1) | |
demucs_output = stems[0] | |
output = output.cpu() | |
demucs_output = demucs_output.cpu() | |
# Naming | |
d_filename = f"temp/{texts[0][:10]}.wav" | |
# If path exists, add number. If number exists, update number. | |
i = 1 | |
while Path(d_filename).exists(): | |
d_filename = f"temp/{texts[0][:10]}_{i}.wav" | |
i += 1 | |
audio_write( | |
d_filename, demucs_output, MODEL.sample_rate, strategy="loudness", | |
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False, format="mp3") | |
out_files.append(d_filename) | |
file_cleaner.add(d_filename) | |
output_file = d_filename | |
res = [out_file for out_file in out_files] | |
for file in res: | |
file_cleaner.add(file) | |
print("batch finished", len(texts), time.time() - be) | |
print("Tempfiles currently stored: ", len(file_cleaner.files)) | |
return res | |
def predict_full(text, melody, progress=gr.Progress()): | |
global selected_text | |
global INTERRUPTING | |
INTERRUPTING = False | |
print("Running local model") | |
def _progress(generated, to_generate): | |
progress((generated, to_generate)) | |
if INTERRUPTING: | |
raise gr.Error("Interrupted.") | |
MODEL.set_custom_progress_callback(_progress) | |
outs = _do_predictions( | |
[text], [melody], duration=10, progress=True) | |
selected_text = text | |
return outs[0]#, gr.File.update(value=outs[0], visible=True) | |
def select_new_melody(): | |
global selected_melody | |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: | |
if not USE_MIDI: | |
new_melody_file = np.random.choice(melody_files) | |
selected_melody = new_melody_file | |
else: | |
new_melody_file = np.random.choice(midi_files) | |
selected_melody = new_melody_file | |
new_melody_file = render_midi(new_melody_file, fname=file.name) | |
crop_melody(new_melody_file, fname=file.name) | |
file_cleaner.add(file.name) | |
return file.name | |
def render_midi(midi_file, fname): | |
# sonify midi as sine wave | |
pm = pretty_midi.PrettyMIDI(midi_file) | |
sine_waves = pm.synthesize(fs=32000) | |
audio_write(fname, torch.from_numpy(sine_waves), 32000, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) | |
return fname | |
def crop_melody(melody_file, fname): | |
global selected_crop | |
crop = np.random.choice(len(crops)) | |
crop = crops[crop] | |
selected_crop = crop | |
melody, sr = audio_read(melody_file) | |
melody = melody[:, crop[0]*sr:crop[1]*sr] | |
audio_write(fname, melody, sr, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) | |
def run_remote_model(text, melody, num_retries=3): | |
global selected_text, output_file | |
print("Running Audiocraft API model with text", text, "and melody", melody.split("/")[-1]) | |
result = client.predict( | |
text, # str in 'Describe your music' Textbox component | |
melody, # str (filepath or URL to file) in 'File' Audio component | |
fn_index=0 | |
) | |
# Naming | |
d_filename = os.path.join("temp", f"{text[:10]}.wav") | |
# If path exists, add number. If number exists, update number. | |
i = 1 | |
while Path(d_filename).exists(): | |
d_filename = os.path.join("temp", f"{text[:10]}_{i}.wav") | |
i += 1 | |
# Convert mp4 to wav, using ffmpeg | |
# ffmpeg -i input.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 output.wav | |
sp.run(["ffmpeg", "-i", result, "-vn", "-acodec", "pcm_s16le", "-ar", "32000", "-ac", "1", d_filename]) | |
# Load wav file, if there is an issue with audiocraft, file will not exist | |
try: | |
output, sr = audio_read(d_filename) | |
except RuntimeError: | |
print("Audiocraft API failed, trying again...") | |
if num_retries == 0: | |
print("Audiocraft API failed, returning empty file...") | |
return torch.zeros(1, 1), 32000 | |
return run_remote_model(text, melody, num_retries=num_retries-1) | |
# Crop to 10 seconds | |
output = output[:, :10*sr] | |
# Demucs | |
print("Running demucs") | |
wav = convert_audio(output, sr, DEMUCS_MODEL.samplerate, DEMUCS_MODEL.audio_channels) | |
wav = wav.unsqueeze(0) | |
stems = apply_model(DEMUCS_MODEL, wav) | |
stems = stems[:, stem_idx] # extract stem | |
stems = stems.sum(1) # merge extracted stems | |
stems = convert_audio(stems, DEMUCS_MODEL.samplerate, 32000, 1) | |
demucs_output = stems[0] | |
output = output.cpu() | |
demucs_output = demucs_output.cpu() | |
file_cleaner.add(d_filename) | |
d_filename = d_filename.replace(".wav", ".mp3") | |
audio_write( | |
d_filename, demucs_output, 32000, strategy="loudness", | |
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False, format="mp3") | |
file_cleaner.add(d_filename) | |
selected_text = text | |
print("Finished", text) | |
print("Tempfiles currently stored: ", len(file_cleaner.files)) | |
output_file = d_filename | |
return d_filename#, gr.File.update(value=d_filename, visible=True) | |
def rating_callback(rating: int): | |
timestamp = str(datetime.now()) | |
rating_data = { | |
"TEXT": selected_text, | |
"MELODY": selected_melody, | |
"CROP": selected_crop, | |
"RATING": rating, | |
"VERSION": "local" if LOCAL else "api", | |
"TIME": timestamp | |
} | |
print(rating_data) | |
store_message(rating_data) | |
def ui_full(launch_kwargs): | |
with gr.Blocks() as interface: | |
gr.Markdown( | |
""" | |
# Soundsauce Melody Playground | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
text = gr.Text(label="Input Text", interactive=True) | |
with gr.Column(): | |
# previously, type="numpy" | |
if LOCAL: | |
audio_type="numpy" | |
else: | |
audio_type="filepath" | |
melody = gr.Audio(type=audio_type, label="File", source="upload", | |
interactive=True, elem_id="melody-input", value=select_new_melody(), visible=False) | |
new_melody = gr.Button("Change input melody", interactive=True) | |
# with gr.Row(): | |
submit = gr.Button("Submit") | |
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license. | |
# _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) | |
with gr.Column(): | |
output_without_drum = gr.Audio(label="Output") | |
with gr.Row(): | |
slider = gr.Slider(label="Rating", minimum=0, maximum=10, step=1, value=0, scale=2) | |
submit_button = gr.Button("Submit Rating", scale=1) | |
with gr.Accordion("Show Example Ratings", open=False): | |
gr.Markdown(""" | |
## Example Ratings | |
""") | |
gr.Audio(label="Rating = 0", value="examples/0-rating.mp3") | |
gr.Audio(label="Rating = 1", value="examples/1-rating.mp3") | |
gr.Audio(label="Rating = 2", value="examples/2-rating.mp3") | |
gr.Audio(label="Rating = 3", value="examples/3-rating.mp3") | |
gr.Audio(label="Rating = 4", value="examples/4-rating.mp3") | |
gr.Audio(label="Rating = 5", value="examples/5-rating.mp3") | |
# file_download_no_drum = gr.File(label="Download", visible=False) | |
# gr.Markdown( | |
# """ | |
# Note that the files will be deleted after 10 minutes, so make sure to download! | |
# """ | |
# ) | |
if LOCAL: | |
submit.click(predict_full, | |
inputs=[text, melody], | |
outputs=[output_without_drum])#, file_download_no_drum]) | |
else: | |
submit.click(run_remote_model, inputs=[text, melody], outputs=[output_without_drum])#, file_download_no_drum]) | |
new_melody.click(select_new_melody, outputs=[melody]) | |
# Button callbacks | |
submit_button.click(rating_callback, inputs=[slider]) | |
gr.Examples( | |
fn=predict_full, | |
examples=[ | |
["Enchanting Flute Trills amidst Misty String Section"], | |
["Gliding Mellotron Strings over Vibrant Phrases"], | |
["Synth Brass Melody Floating over Airy Wind Chimes"], | |
["Rhythmic Acoustic Guitar Licks with Echoing Layers"], | |
["Whimsical Flute Flourishes in a Mystical Forest Glade"], | |
["Airy Piccolo Trills accompanied by Floating Harp Arpeggios"], | |
["Dreamy Harp Glissandos accompanied by Distant Celesta"], | |
["Hypnotic Synth Pads layered with Enigmatic Guitar Progressions"], | |
["Enchanting Kalimba Melodies atop Mystical Atmosphere"], | |
], | |
inputs=[text], | |
label="Example Inputs", | |
outputs=[output_without_drum]#, file_download_no_drum] | |
) | |
interface.queue().launch(**launch_kwargs) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--listen', | |
type=str, | |
default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1', | |
help='IP to listen on for connections to Gradio', | |
) | |
args = parser.parse_args() | |
launch_kwargs = {} | |
launch_kwargs['server_name'] = args.listen | |
print("Using midi:", USE_MIDI) | |
# Load melody model | |
load_model() | |
if not LOCAL: | |
connect_to_endpoint() | |
if not os.path.exists("temp"): | |
os.mkdir("temp") | |
# Show the interface | |
ui_full(launch_kwargs) | |