# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py # also released under the MIT license. import argparse from concurrent.futures import ProcessPoolExecutor import os from pathlib import Path import subprocess as sp from tempfile import NamedTemporaryFile import time import typing as tp import warnings import glob import csv import torch import gradio as gr import numpy as np import shutil from audiocraft.data.audio_utils import convert_audio from audiocraft.data.audio import audio_write, audio_read from audiocraft.models import MusicGen from demucs import pretrained from demucs.apply import apply_model from demucs.audio import convert_audio from gradio_client import Client import pretty_midi import huggingface_hub from huggingface_hub import Repository from datetime import datetime LOCAL = False USE_MIDI = True # LOGS DATASET_REPO_URL = "https://huggingface.co/datasets/soundsauce/soundsauce-logs" DATA_FILENAME = "ratings.csv" DATA_FILE = os.path.join("data", DATA_FILENAME) AUDIO_DIR = os.path.join("data", "audio") HF_TOKEN = os.environ.get("HF_TOKEN") print("is none?", HF_TOKEN is None) repo = Repository( local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN ) print("hfh", huggingface_hub.__version__) MODEL = None # Last used model DEMUCS_MODEL = None MAX_BATCH_SIZE = 12 INTERRUPTING = False client = None # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform _old_call = sp.call stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3} stem_idx = torch.LongTensor([stem2idx['vocal'], stem2idx['other'], stem2idx['bass']]) melody_files = list(glob.glob('clips/**/*.wav', recursive=True)) midi_files = list(glob.glob('clips/**/*.mid', recursive=True)) crops = [(0, 5), (0, 10), (0, 15)] selected_melody = "" selected_crop = None selected_text = "" output_file = "" def store_message(message: dict): if message and output_file: if not os.path.exists(AUDIO_DIR): os.makedirs(AUDIO_DIR) repo.git_pull() with open(DATA_FILE, "a") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=message.keys()) writer.writerow(message) filepath = os.path.join(AUDIO_DIR, message["TIME"]) + ".mp3" shutil.copy(output_file, filepath) commit_url = repo.push_to_hub() print("Commited to", commit_url) def _call_nostderr(*args, **kwargs): # Avoid ffmpeg vomitting on the logs. kwargs['stderr'] = sp.DEVNULL kwargs['stdout'] = sp.DEVNULL _old_call(*args, **kwargs) sp.call = _call_nostderr # Preallocating the pool of processes. pool = ProcessPoolExecutor(4) pool.__enter__() def interrupt(): global INTERRUPTING INTERRUPTING = True class FileCleaner: def __init__(self, file_lifetime: float = 3600): self.file_lifetime = file_lifetime self.files = [] def add(self, path: tp.Union[str, Path]): self._cleanup() self.files.append((time.time(), Path(path))) def _cleanup(self): now = time.time() for time_added, path in list(self.files): if now - time_added > self.file_lifetime: if path.exists(): path.unlink() self.files.pop(0) else: break # 10 minutes file_cleaner = FileCleaner(600) def make_waveform(*args, **kwargs): # Further remove some warnings. be = time.time() with warnings.catch_warnings(): warnings.simplefilter('ignore') out = gr.make_waveform(*args, **kwargs) print("Make a video took", time.time() - be) return out def load_model(version='melody'): global MODEL, DEMUCS_MODEL device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if LOCAL: if MODEL is None or MODEL.name != version: print("Loading model", version) # If gpu is not available, we'll use cpu. MODEL = MusicGen.get_pretrained(version, device=device) if DEMUCS_MODEL is None: DEMUCS_MODEL = pretrained.get_model('htdemucs').to(device) def connect_to_endpoint(): global client client = Client("https://facebook-musicgen--44zzp.hf.space/") def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs): global output_file MODEL.set_generation_params(duration=duration, cfg_coef=5, **gen_kwargs) print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) be = time.time() processed_melodies = [] target_sr = 32000 target_ac = 1 for melody in melodies: if melody is None: processed_melodies.append(None) else: sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t() if melody.dim() == 1: melody = melody[None] melody = melody[..., :int(sr * duration)] melody = convert_audio(melody, sr, target_sr, target_ac) processed_melodies.append(melody) outputs = MODEL.generate_with_chroma( descriptions=texts, melody_wavs=processed_melodies, melody_sample_rate=target_sr, progress=progress, ) outputs = outputs.detach().float() out_files = [] for output in outputs: # Demucs print("Running demucs") wav = convert_audio(output, MODEL.sample_rate, DEMUCS_MODEL.samplerate, DEMUCS_MODEL.audio_channels) wav = wav.unsqueeze(0) stems = apply_model(DEMUCS_MODEL, wav) stems = stems[:, stem_idx] # extract stem stems = stems.sum(1) # merge extracted stems stems = convert_audio(stems, DEMUCS_MODEL.samplerate, MODEL.sample_rate, 1) demucs_output = stems[0] output = output.cpu() demucs_output = demucs_output.cpu() # Naming d_filename = f"temp/{texts[0][:10]}.wav" # If path exists, add number. If number exists, update number. i = 1 while Path(d_filename).exists(): d_filename = f"temp/{texts[0][:10]}_{i}.wav" i += 1 audio_write( d_filename, demucs_output, MODEL.sample_rate, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False, format="mp3") out_files.append(d_filename) file_cleaner.add(d_filename) output_file = d_filename res = [out_file for out_file in out_files] for file in res: file_cleaner.add(file) print("batch finished", len(texts), time.time() - be) print("Tempfiles currently stored: ", len(file_cleaner.files)) return res def predict_full(text, melody, progress=gr.Progress()): global selected_text global INTERRUPTING INTERRUPTING = False print("Running local model") def _progress(generated, to_generate): progress((generated, to_generate)) if INTERRUPTING: raise gr.Error("Interrupted.") MODEL.set_custom_progress_callback(_progress) outs = _do_predictions( [text], [melody], duration=10, progress=True) selected_text = text return outs[0]#, gr.File.update(value=outs[0], visible=True) def select_new_melody(): global selected_melody with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: if not USE_MIDI: new_melody_file = np.random.choice(melody_files) selected_melody = new_melody_file else: new_melody_file = np.random.choice(midi_files) selected_melody = new_melody_file new_melody_file = render_midi(new_melody_file, fname=file.name) crop_melody(new_melody_file, fname=file.name) file_cleaner.add(file.name) return file.name def render_midi(midi_file, fname): # sonify midi as sine wave pm = pretty_midi.PrettyMIDI(midi_file) sine_waves = pm.synthesize(fs=32000) audio_write(fname, torch.from_numpy(sine_waves), 32000, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) return fname def crop_melody(melody_file, fname): global selected_crop crop = np.random.choice(len(crops)) crop = crops[crop] selected_crop = crop melody, sr = audio_read(melody_file) melody = melody[:, crop[0]*sr:crop[1]*sr] audio_write(fname, melody, sr, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) def run_remote_model(text, melody, num_retries=3): global selected_text, output_file print("Running Audiocraft API model with text", text, "and melody", melody.split("/")[-1]) result = client.predict( text, # str in 'Describe your music' Textbox component melody, # str (filepath or URL to file) in 'File' Audio component fn_index=0 ) # Naming d_filename = os.path.join("temp", f"{text[:10]}.wav") # If path exists, add number. If number exists, update number. i = 1 while Path(d_filename).exists(): d_filename = os.path.join("temp", f"{text[:10]}_{i}.wav") i += 1 # Convert mp4 to wav, using ffmpeg # ffmpeg -i input.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 output.wav sp.run(["ffmpeg", "-i", result, "-vn", "-acodec", "pcm_s16le", "-ar", "32000", "-ac", "1", d_filename]) # Load wav file, if there is an issue with audiocraft, file will not exist try: output, sr = audio_read(d_filename) except RuntimeError: print("Audiocraft API failed, trying again...") if num_retries == 0: print("Audiocraft API failed, returning empty file...") return torch.zeros(1, 1), 32000 return run_remote_model(text, melody, num_retries=num_retries-1) # Crop to 10 seconds output = output[:, :10*sr] # Demucs print("Running demucs") wav = convert_audio(output, sr, DEMUCS_MODEL.samplerate, DEMUCS_MODEL.audio_channels) wav = wav.unsqueeze(0) stems = apply_model(DEMUCS_MODEL, wav) stems = stems[:, stem_idx] # extract stem stems = stems.sum(1) # merge extracted stems stems = convert_audio(stems, DEMUCS_MODEL.samplerate, 32000, 1) demucs_output = stems[0] output = output.cpu() demucs_output = demucs_output.cpu() file_cleaner.add(d_filename) d_filename = d_filename.replace(".wav", ".mp3") audio_write( d_filename, demucs_output, 32000, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False, format="mp3") file_cleaner.add(d_filename) selected_text = text print("Finished", text) print("Tempfiles currently stored: ", len(file_cleaner.files)) output_file = d_filename return d_filename#, gr.File.update(value=d_filename, visible=True) def rating_callback(rating: int): timestamp = str(datetime.now()) rating_data = { "TEXT": selected_text, "MELODY": selected_melody, "CROP": selected_crop, "RATING": rating, "VERSION": "local" if LOCAL else "api", "TIME": timestamp } print(rating_data) store_message(rating_data) def ui_full(launch_kwargs): with gr.Blocks() as interface: gr.Markdown( """ # Soundsauce Melody Playground """ ) with gr.Row(): with gr.Column(): with gr.Row(): text = gr.Text(label="Input Text", interactive=True) with gr.Column(): # previously, type="numpy" if LOCAL: audio_type="numpy" else: audio_type="filepath" melody = gr.Audio(type=audio_type, label="File", source="upload", interactive=True, elem_id="melody-input", value=select_new_melody(), visible=False) new_melody = gr.Button("Change input melody", interactive=True) # with gr.Row(): submit = gr.Button("Submit") # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license. # _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) with gr.Column(): output_without_drum = gr.Audio(label="Output") with gr.Row(): slider = gr.Slider(label="Rating", minimum=0, maximum=10, step=1, value=0, scale=2) submit_button = gr.Button("Submit Rating", scale=1) with gr.Accordion("Show Example Ratings", open=False): gr.Markdown(""" ## Example Ratings """) gr.Audio(label="Rating = 0", value="examples/0-rating.mp3") gr.Audio(label="Rating = 1", value="examples/1-rating.mp3") gr.Audio(label="Rating = 2", value="examples/2-rating.mp3") gr.Audio(label="Rating = 3", value="examples/3-rating.mp3") gr.Audio(label="Rating = 4", value="examples/4-rating.mp3") gr.Audio(label="Rating = 5", value="examples/5-rating.mp3") # file_download_no_drum = gr.File(label="Download", visible=False) # gr.Markdown( # """ # Note that the files will be deleted after 10 minutes, so make sure to download! # """ # ) if LOCAL: submit.click(predict_full, inputs=[text, melody], outputs=[output_without_drum])#, file_download_no_drum]) else: submit.click(run_remote_model, inputs=[text, melody], outputs=[output_without_drum])#, file_download_no_drum]) new_melody.click(select_new_melody, outputs=[melody]) # Button callbacks submit_button.click(rating_callback, inputs=[slider]) gr.Examples( fn=predict_full, examples=[ ["Enchanting Flute Trills amidst Misty String Section"], ["Gliding Mellotron Strings over Vibrant Phrases"], ["Synth Brass Melody Floating over Airy Wind Chimes"], ["Rhythmic Acoustic Guitar Licks with Echoing Layers"], ["Whimsical Flute Flourishes in a Mystical Forest Glade"], ["Airy Piccolo Trills accompanied by Floating Harp Arpeggios"], ["Dreamy Harp Glissandos accompanied by Distant Celesta"], ["Hypnotic Synth Pads layered with Enigmatic Guitar Progressions"], ["Enchanting Kalimba Melodies atop Mystical Atmosphere"], ], inputs=[text], label="Example Inputs", outputs=[output_without_drum]#, file_download_no_drum] ) interface.queue().launch(**launch_kwargs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( '--listen', type=str, default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1', help='IP to listen on for connections to Gradio', ) args = parser.parse_args() launch_kwargs = {} launch_kwargs['server_name'] = args.listen print("Using midi:", USE_MIDI) # Load melody model load_model() if not LOCAL: connect_to_endpoint() if not os.path.exists("temp"): os.mkdir("temp") # Show the interface ui_full(launch_kwargs)