Spaces:
Build error
Build error
from flask import Flask, render_template | |
from flask_socketio import SocketIO, emit | |
# from flask_cors import CORS | |
from audiocraft.models import musicgen | |
import torchaudio | |
import soundfile as sf | |
# app = Flask(__name__) | |
app = Flask(__name__, static_folder="../build", static_url_path="/") | |
app.debug = True | |
app.secret_key = "random secret key!" | |
# CORS(app) | |
# cors = CORS(app, resource={r"/*": {"origins": "*"}}) | |
socketio = SocketIO(app, cors_allowed_origins="*") | |
print("Loading model...") | |
model = musicgen.MusicGen.get_pretrained("melody") | |
model.set_generation_params(duration=8) | |
def index(): | |
print("HI") | |
return render_template("index.html") | |
def connect(): | |
print("Client connected") | |
stream_audio() | |
def disconnect(): | |
print("Client disconnected") | |
def stream_audio(data): | |
descriptions = ["Film score epic moment"] | |
melody, sr = torchaudio.load("./asitwas_vocals.wav") | |
print("Running inference...") | |
wav = model.generate_with_chroma(descriptions, melody[None].expand(1, -1, -1), sr) | |
model_sampling_rate = 32000 | |
sf.write("output.wav", wav[0].numpy().T, model_sampling_rate) | |
chunk_size = 1024 | |
for i in range(0, len(wav[0]), chunk_size): | |
chunk = wav[0][i : i + chunk_size]# * 500 | |
emit("audio_chunk", chunk.tolist()) | |
if __name__ == "__main__": | |
socketio.run(app) | |