from audiocraft.models import MusicGen from audiocraft.data.audio import audio_write import gradio as gr import numpy as np import warnings # from google.cloud import storage class MusicGenHandler(): def __init__(self, init_model_path='createsafe/grimes-stem-model', generation_duration=30.0): self.model_path = init_model_path self.generation_duration = generation_duration self._setup_model() def _setup_model(self): self.model = MusicGen.get_pretrained(self.model_path) self.model.set_generation_params(duration=self.generation_duration) def inference(self, prompts): """turns prompt or list of prompts into audio""" if not isinstance(prompts, list): prompts = list(prompts) return self.model.generate(prompts).numpy() def update_model(self, new_model_path): if not new_model_path == self.model_path: try: self.model_path = new_model_path self._setup_model() except: warnings.warn(f"could not setup model located at {new_model_path}") model = MusicGenHandler() def slider_val_to_text(val): if val == 0: return "none" elif val == 0.1: return "minimal" elif val == 0.2: return "little" elif val == 0.3: return "not much" elif val == 0.4: return "just below mean" elif val == 0.5: return "mean" elif val == 0.6: return "just above mean" elif val == 0.7: return "sufficient" elif val == 0.8: return "ample" elif val == 0.9: return "great" elif val == 1: return "maximal" def text_to_music(text, instrument, brightness, percusiveness, business, variance, temperature, bass, mids, highs, tempo, noisiness): dsp_feature_string = "" if text: dsp_feature_string += text + ". " if instrument: dsp_feature_string += instrument + ". " if brightness: dsp_feature_string += 'brightness ' + slider_val_to_text(brightness) + ', ' if percusiveness: dsp_feature_string += 'percusiveness ' + slider_val_to_text(percusiveness) + ', ' if business: dsp_feature_string += 'business ' + slider_val_to_text(business) + ', ' if variance: dsp_feature_string += 'variance ' + slider_val_to_text(variance) + ', ' if temperature: dsp_feature_string += 'temperature ' + slider_val_to_text(temperature) + ', ' if bass: dsp_feature_string += 'bass ' + slider_val_to_text(bass) + ', ' if mids: dsp_feature_string += 'mids ' + slider_val_to_text(mids) + ', ' if highs: dsp_feature_string += 'highs ' + slider_val_to_text(highs) + ', ' if tempo: dsp_feature_string += 'tempo ' + slider_val_to_text(tempo) + ', ' if noisiness: dsp_feature_string += 'noisiness ' + slider_val_to_text(noisiness) if instrument == "all-stems": model.update_model(new_model_path='./cs-pretrained/stem_model') elif instrument == "drums": model.update_model(new_model_path='./cs-pretrained/drums_model') audio = model.inference(prompts=[dsp_feature_string]) # convert to 16 bit PCM if np.max(np.abs(audio)) > 0.0: audio /= np.max(np.abs(audio)) audio *= 32767 audio = audio.astype(int) return (32000, audio) def run(): iface = gr.Interface(fn=text_to_music, inputs=[ gr.Textbox( label="Text prompt" ), gr.Dropdown( ["all-stems", "drums", "keys", "bass"], label="Instrument" ), gr.Slider( 0, 1, step=0.1, label="Brightness" ), gr.Slider( 0, 1, step=0.1, label="Percussiveness" ), gr.Slider( 0, 1, step=0.1, label="Business" ), gr.Slider( 0, 1, step=0.1, label="Variance" ), gr.Slider( 0, 1, step=0.1, label="Temperature" ), gr.Slider( 0, 1, step=0.1, label="Bass" ), gr.Slider( 0, 1, step=0.1, label="Mids" ), gr.Slider( 0, 1, step=0.1, label="Highs" ), gr.Slider( 0, 1, step=0.1, label="Tempo" ), gr.Slider( 0, 1, step=0.1, label="Noisiness" ), ], outputs="audio") iface.launch() # def __init__(self, init_model_path='./cs-pretrained/stem_model', generation_duration=30.0): if __name__ == "__main__": run()