yachtiocraft / app.py
alba.saco
update
450547e
raw
history blame contribute delete
No virus
4.54 kB
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
import gradio as gr
import numpy as np
import warnings
# from google.cloud import storage
class MusicGenHandler():
def __init__(self, init_model_path='createsafe/grimes-stem-model', generation_duration=30.0):
self.model_path = init_model_path
self.generation_duration = generation_duration
self._setup_model()
def _setup_model(self):
self.model = MusicGen.get_pretrained(self.model_path)
self.model.set_generation_params(duration=self.generation_duration)
def inference(self, prompts):
"""turns prompt or list of prompts into audio"""
if not isinstance(prompts, list):
prompts = list(prompts)
return self.model.generate(prompts).numpy()
def update_model(self, new_model_path):
if not new_model_path == self.model_path:
try:
self.model_path = new_model_path
self._setup_model()
except:
warnings.warn(f"could not setup model located at {new_model_path}")
model = MusicGenHandler()
def slider_val_to_text(val):
if val == 0:
return "none"
elif val == 0.1:
return "minimal"
elif val == 0.2:
return "little"
elif val == 0.3:
return "not much"
elif val == 0.4:
return "just below mean"
elif val == 0.5:
return "mean"
elif val == 0.6:
return "just above mean"
elif val == 0.7:
return "sufficient"
elif val == 0.8:
return "ample"
elif val == 0.9:
return "great"
elif val == 1:
return "maximal"
def text_to_music(text, instrument, brightness, percusiveness, business, variance, temperature, bass, mids, highs, tempo, noisiness):
dsp_feature_string = ""
if text:
dsp_feature_string += text + ". "
if instrument:
dsp_feature_string += instrument + ". "
if brightness:
dsp_feature_string += 'brightness ' + slider_val_to_text(brightness) + ', '
if percusiveness:
dsp_feature_string += 'percusiveness ' + slider_val_to_text(percusiveness) + ', '
if business:
dsp_feature_string += 'business ' + slider_val_to_text(business) + ', '
if variance:
dsp_feature_string += 'variance ' + slider_val_to_text(variance) + ', '
if temperature:
dsp_feature_string += 'temperature ' + slider_val_to_text(temperature) + ', '
if bass:
dsp_feature_string += 'bass ' + slider_val_to_text(bass) + ', '
if mids:
dsp_feature_string += 'mids ' + slider_val_to_text(mids) + ', '
if highs:
dsp_feature_string += 'highs ' + slider_val_to_text(highs) + ', '
if tempo:
dsp_feature_string += 'tempo ' + slider_val_to_text(tempo) + ', '
if noisiness:
dsp_feature_string += 'noisiness ' + slider_val_to_text(noisiness)
if instrument == "all-stems":
model.update_model(new_model_path='./cs-pretrained/stem_model')
elif instrument == "drums":
model.update_model(new_model_path='./cs-pretrained/drums_model')
audio = model.inference(prompts=[dsp_feature_string])
# convert to 16 bit PCM
if np.max(np.abs(audio)) > 0.0:
audio /= np.max(np.abs(audio))
audio *= 32767
audio = audio.astype(int)
return (32000, audio)
def run():
iface = gr.Interface(fn=text_to_music, inputs=[
gr.Textbox(
label="Text prompt"
),
gr.Dropdown(
["all-stems", "drums", "keys", "bass"], label="Instrument"
),
gr.Slider(
0, 1, step=0.1, label="Brightness"
),
gr.Slider(
0, 1, step=0.1, label="Percussiveness"
),
gr.Slider(
0, 1, step=0.1, label="Business"
),
gr.Slider(
0, 1, step=0.1, label="Variance"
),
gr.Slider(
0, 1, step=0.1, label="Temperature"
),
gr.Slider(
0, 1, step=0.1, label="Bass"
),
gr.Slider(
0, 1, step=0.1, label="Mids"
),
gr.Slider(
0, 1, step=0.1, label="Highs"
),
gr.Slider(
0, 1, step=0.1, label="Tempo"
),
gr.Slider(
0, 1, step=0.1, label="Noisiness"
),
], outputs="audio")
iface.launch()
# def __init__(self, init_model_path='./cs-pretrained/stem_model', generation_duration=30.0):
if __name__ == "__main__":
run()