api / app.py
soiz's picture
Update app.py
eabf505 verified
from queue import Queue
from threading import Thread
from typing import Optional
import numpy as np
import torch
from flask import Flask, request, jsonify, send_file
from transformers import MusicgenForConditionalGeneration, MusicgenProcessor, set_seed
from transformers.generation.streamers import BaseStreamer
import io
import soundfile as sf
# Load the model and processor
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")
class MusicgenStreamer(BaseStreamer):
def __init__(
self,
model: MusicgenForConditionalGeneration,
device: Optional[str] = None,
play_steps: Optional[int] = 10,
stride: Optional[int] = None,
timeout: Optional[float] = None,
):
self.decoder = model.decoder
self.audio_encoder = model.audio_encoder
self.generation_config = model.generation_config
self.device = device if device is not None else model.device
self.play_steps = play_steps
if stride is not None:
self.stride = stride
else:
hop_length = np.prod(self.audio_encoder.config.upsampling_ratios)
self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
self.token_cache = None
self.to_yield = 0
self.audio_queue = Queue()
self.stop_signal = None
self.timeout = timeout
def apply_delay_pattern_mask(self, input_ids):
_, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids[:, :1],
pad_token_id=self.generation_config.decoder_start_token_id,
max_length=input_ids.shape[-1],
)
input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask)
input_ids = input_ids[input_ids != self.generation_config.pad_token_id].reshape(
1, self.decoder.num_codebooks, -1
)
input_ids = input_ids[None, ...]
input_ids = input_ids.to(self.audio_encoder.device)
output_values = self.audio_encoder.decode(
input_ids,
audio_scales=[None],
)
audio_values = output_values.audio_values[0, 0]
return audio_values.cpu().float().numpy()
def put(self, value):
batch_size = value.shape[0] // self.decoder.num_codebooks
if batch_size > 1:
raise ValueError("MusicgenStreamer only supports batch size 1")
if self.token_cache is None:
self.token_cache = value
else:
self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
if self.token_cache.shape[-1] % self.play_steps == 0:
audio_values = self.apply_delay_pattern_mask(self.token_cache)
self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
self.to_yield += len(audio_values) - self.to_yield - self.stride
def end(self):
if self.token_cache is not None:
audio_values = self.apply_delay_pattern_mask(self.token_cache)
else:
audio_values = np.zeros(self.to_yield)
self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
self.audio_queue.put(audio, timeout=self.timeout)
if stream_end:
self.audio_queue.put(self.stop_signal, timeout=self.timeout)
def __iter__(self):
return self
def __next__(self):
value = self.audio_queue.get(timeout=self.timeout)
if not isinstance(value, np.ndarray) and value == self.stop_signal:
raise StopIteration()
else:
return value
sampling_rate = model.audio_encoder.config.sampling_rate
frame_rate = model.audio_encoder.config.frame_rate
app = Flask(__name__)
@app.route('/generate_audio', methods=['POST'])
def generate_audio():
data = request.json
text_prompt = data.get('text_prompt', '80s pop track with synth and instrumentals')
audio_length_in_s = float(data.get('audio_length_in_s', 10.0))
play_steps_in_s = float(data.get('play_steps_in_s', 2.0))
seed = int(data.get('seed', 0))
max_new_tokens = int(frame_rate * audio_length_in_s)
play_steps = int(frame_rate * play_steps_in_s)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
if device == "cuda:0":
model.half()
inputs = processor(
text=text_prompt,
padding=True,
return_tensors="pt",
)
streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
generation_kwargs = dict(
**inputs.to(device),
streamer=streamer,
max_new_tokens=max_new_tokens,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
set_seed(seed)
generated_audio = []
for new_audio in streamer:
generated_audio.append(new_audio)
# Concatenate the audio chunks
final_audio = np.concatenate(generated_audio)
# Save the audio to a buffer and send it as a response
buffer = io.BytesIO()
sf.write(buffer, final_audio, sampling_rate, format="wav")
buffer.seek(0)
return send_file(buffer, mimetype="audio/wav", as_attachment=True, download_name="generated_music.wav")
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)