sanchit-gandhi HF staff commited on
Commit
d4bad2d
1 Parent(s): 57d878d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from queue import Queue
2
+ from threading import Thread
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from transformers import MusicgenForConditionalGeneration, MusicgenProcessor, set_seed
9
+ from transformers.generation.streamers import BaseStreamer
10
+
11
+ import gradio as gr
12
+
13
+
14
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
+
16
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
17
+ processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")
18
+
19
+ if device == "cuda:0":
20
+ model.to(device).half();
21
+
22
+ class MusicgenStreamer(BaseStreamer):
23
+ def __init__(
24
+ self,
25
+ model: MusicgenForConditionalGeneration,
26
+ device: Optional[str] = None,
27
+ play_steps: Optional[int] = 10,
28
+ stride: Optional[int] = None,
29
+ timeout: Optional[float] = None,
30
+ ):
31
+ """
32
+ Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
33
+ useful for applications that benefit from acessing the generated audio in a non-blocking way (e.g. in an interactive
34
+ Gradio demo).
35
+
36
+ Parameters:
37
+ model (`MusicgenForConditionalGeneration`):
38
+ The MusicGen model used to generate the audio waveform.
39
+ device (`str`, *optional*):
40
+ The torch device on which to run the computation. If `None`, will default to the device of the model.
41
+ play_steps (`int`, *optional*, defaults to 10):
42
+ The number of generation steps with which to return the generated audio array. Using fewer steps will
43
+ mean the first chunk is ready faster, but will require more codec decoding steps overall. This value
44
+ should be tuned to your device and latency requirements.
45
+ stride (`int`, *optional*):
46
+ The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
47
+ the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to
48
+ play_steps // 6 in the audio space.
49
+ timeout (`int`, *optional*):
50
+ The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
51
+ in `.generate()`, when it is called in a separate thread.
52
+ """
53
+ self.decoder = model.decoder
54
+ self.audio_encoder = model.audio_encoder
55
+ self.generation_config = model.generation_config
56
+ self.device = device if device is not None else model.device
57
+
58
+ # variables used in the streaming process
59
+ self.play_steps = play_steps
60
+ if stride is not None:
61
+ self.stride = stride
62
+ else:
63
+ hop_length = np.prod(self.audio_encoder.config.upsampling_ratios)
64
+ self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
65
+ self.token_cache = None
66
+ self.to_yield = 0
67
+
68
+ # varibles used in the thread process
69
+ self.audio_queue = Queue()
70
+ self.stop_signal = None
71
+ self.timeout = timeout
72
+
73
+ def apply_delay_pattern_mask(self, input_ids):
74
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
75
+ _, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
76
+ input_ids[:, :1],
77
+ pad_token_id=self.generation_config.decoder_start_token_id,
78
+ max_length=input_ids.shape[-1],
79
+ )
80
+ # apply the pattern mask to the input ids
81
+ input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask)
82
+
83
+ # revert the pattern delay mask by filtering the pad token id
84
+ input_ids = input_ids[input_ids != self.generation_config.pad_token_id].reshape(
85
+ 1, self.decoder.num_codebooks, -1
86
+ )
87
+
88
+ # append the frame dimension back to the audio codes
89
+ input_ids = input_ids[None, ...]
90
+
91
+ # send the input_ids to the correct device
92
+ input_ids = input_ids.to(self.audio_encoder.device)
93
+
94
+ output_values = self.audio_encoder.decode(
95
+ input_ids,
96
+ audio_scales=[None],
97
+ )
98
+ audio_values = output_values.audio_values[0, 0]
99
+ return audio_values.cpu().float().numpy()
100
+
101
+ def put(self, value):
102
+ batch_size = value.shape[0] // self.decoder.num_codebooks
103
+ if batch_size > 1:
104
+ raise ValueError("MusicgenStreamer only supports batch size 1")
105
+
106
+ if self.token_cache is None:
107
+ self.token_cache = value
108
+ else:
109
+ self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
110
+
111
+ if self.token_cache.shape[-1] % self.play_steps == 0:
112
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
113
+ self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
114
+ self.to_yield += len(audio_values) - self.to_yield - self.stride
115
+
116
+ def end(self):
117
+ """Flushes any remaining cache and appends the stop symbol."""
118
+ if self.token_cache is not None:
119
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
120
+ else:
121
+ audio_values = np.zeros(self.to_yield)
122
+
123
+ self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
124
+
125
+ def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
126
+ """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
127
+ self.audio_queue.put(audio, timeout=self.timeout)
128
+ if stream_end:
129
+ self.audio_queue.put(self.stop_signal, timeout=self.timeout)
130
+
131
+ def __iter__(self):
132
+ return self
133
+
134
+ def __next__(self):
135
+ value = self.audio_queue.get(timeout=self.timeout)
136
+ if not isinstance(value, np.ndarray) and value == self.stop_signal:
137
+ raise StopIteration()
138
+ else:
139
+ return value
140
+
141
+
142
+ sampling_rate = model.audio_encoder.config.sampling_rate
143
+ frame_rate = model.audio_encoder.config.frame_rate
144
+
145
+
146
+ def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0):
147
+ inputs = processor(
148
+ text=text_prompt,
149
+ padding=True,
150
+ return_tensors="pt",
151
+ )
152
+
153
+ max_new_tokens = int(frame_rate * audio_length_in_s)
154
+ play_steps = int(frame_rate * play_steps_in_s)
155
+
156
+ streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
157
+
158
+ generation_kwargs = dict(
159
+ **inputs.to(device),
160
+ streamer=streamer,
161
+ max_new_tokens=max_new_tokens,
162
+ )
163
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
164
+ thread.start()
165
+
166
+ set_seed(0)
167
+ for new_audio in streamer:
168
+ yield gr.make_waveform((sampling_rate, new_audio))
169
+
170
+
171
+ demo = gr.Interface(
172
+ fn=generate_audio,
173
+ inputs=[
174
+ gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
175
+ gr.Slider(10, 30, value=15, step=5, label="Audio length in s"),
176
+ gr.Slider(2, 10, value=2, step=2, label="Streaming interval in s"),
177
+ ],
178
+ outputs=[
179
+ gr.Audio(label="Generated Music", format="numpy", streaming=True, autoplay=True)
180
+ ],
181
+ )
182
+
183
+ demo.queue().launch()