Spaces:
Runtime error
Runtime error
from audiocraft.models import MusicGen | |
from audiocraft.data.audio import audio_write | |
import gradio as gr | |
import numpy as np | |
import warnings | |
class MusicGenHandler(): | |
def __init__(self, init_model_path='./cs-pretrained/stem_model', generation_duration=30.0): | |
self.model_path = init_model_path | |
self.generation_duration = generation_duration | |
self._setup_model() | |
def _setup_model(self): | |
self.model = MusicGen.get_pretrained(self.model_path) | |
self.model.set_generation_params(duration=self.generation_duration) | |
def inference(self, prompts): | |
"""turns prompt or list of prompts into audio""" | |
if not isinstance(prompts, list): | |
prompts = list(prompts) | |
return self.model.generate(prompts).numpy() | |
def update_model(self, new_model_path): | |
if not new_model_path == self.model_path: | |
try: | |
self.model_path = new_model_path | |
self._setup_model() | |
except: | |
warnings.warn(f"could not setup model located at {new_model_path}") | |
model = MusicGenHandler() | |
def slider_val_to_text(val): | |
if val == 0: | |
return "none" | |
elif val == 0.1: | |
return "minimal" | |
elif val == 0.2: | |
return "little" | |
elif val == 0.3: | |
return "not much" | |
elif val == 0.4: | |
return "just below mean" | |
elif val == 0.5: | |
return "mean" | |
elif val == 0.6: | |
return "just above mean" | |
elif val == 0.7: | |
return "sufficient" | |
elif val == 0.8: | |
return "ample" | |
elif val == 0.9: | |
return "great" | |
elif val == 1: | |
return "maximal" | |
def text_to_music(text, instrument, brightness, percusiveness, business, variance, temperature, bass, mids, highs, tempo, noisiness): | |
dsp_feature_string = "" | |
if text: | |
dsp_feature_string += text + ". " | |
if instrument: | |
dsp_feature_string += instrument + ". " | |
if brightness: | |
dsp_feature_string += 'brightness ' + slider_val_to_text(brightness) + ', ' | |
if percusiveness: | |
dsp_feature_string += 'percusiveness ' + slider_val_to_text(percusiveness) + ', ' | |
if business: | |
dsp_feature_string += 'business ' + slider_val_to_text(business) + ', ' | |
if variance: | |
dsp_feature_string += 'variance ' + slider_val_to_text(variance) + ', ' | |
if temperature: | |
dsp_feature_string += 'temperature ' + slider_val_to_text(temperature) + ', ' | |
if bass: | |
dsp_feature_string += 'bass ' + slider_val_to_text(bass) + ', ' | |
if mids: | |
dsp_feature_string += 'mids ' + slider_val_to_text(mids) + ', ' | |
if highs: | |
dsp_feature_string += 'highs ' + slider_val_to_text(highs) + ', ' | |
if tempo: | |
dsp_feature_string += 'tempo ' + slider_val_to_text(tempo) + ', ' | |
if noisiness: | |
dsp_feature_string += 'noisiness ' + slider_val_to_text(noisiness) | |
if instrument == "all-stems": | |
model.update_model(new_model_path='./cs-pretrained/stem_model') | |
elif instrument == "drums": | |
model.update_model(new_model_path='./cs-pretrained/drums_model') | |
audio = model.inference(prompts=[dsp_feature_string]) | |
# convert to 16 bit PCM | |
if np.max(np.abs(audio)) > 0.0: | |
audio /= np.max(np.abs(audio)) | |
audio *= 32767 | |
audio = audio.astype(int) | |
return (32000, audio) | |
def run(): | |
iface = gr.Interface(fn=text_to_music, inputs=[ | |
gr.Textbox( | |
label="Text prompt" | |
), | |
gr.Dropdown( | |
["all-stems", "drums", "keys", "bass"], label="Instrument" | |
), | |
gr.Slider( | |
0, 1, step=0.1, label="Brightness" | |
), | |
gr.Slider( | |
0, 1, step=0.1, label="Percussiveness" | |
), | |
gr.Slider( | |
0, 1, step=0.1, label="Business" | |
), | |
gr.Slider( | |
0, 1, step=0.1, label="Variance" | |
), | |
gr.Slider( | |
0, 1, step=0.1, label="Temperature" | |
), | |
gr.Slider( | |
0, 1, step=0.1, label="Bass" | |
), | |
gr.Slider( | |
0, 1, step=0.1, label="Mids" | |
), | |
gr.Slider( | |
0, 1, step=0.1, label="Highs" | |
), | |
gr.Slider( | |
0, 1, step=0.1, label="Tempo" | |
), | |
gr.Slider( | |
0, 1, step=0.1, label="Noisiness" | |
), | |
], outputs="audio") | |
iface.launch() | |
if __name__ == "__main__": | |
run() | |