awacke1 commited on
Commit
21a8717
1 Parent(s): f829bc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -133
app.py CHANGED
@@ -1,137 +1,211 @@
1
- import gradio as gr
2
- import note_seq
3
  import numpy as np
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
-
6
- tokenizer = AutoTokenizer.from_pretrained("TristanBehrens/js-fakes-4bars")
7
- model = AutoModelForCausalLM.from_pretrained("TristanBehrens/js-fakes-4bars")
8
-
9
- NOTE_LENGTH_16TH_120BPM = 0.25 * 60 / 120
10
- BAR_LENGTH_120BPM = 4.0 * 60 / 120
11
- SAMPLE_RATE=44100
12
-
13
- def token_sequence_to_note_sequence(token_sequence, use_program=True, use_drums=True, instrument_mapper=None, only_piano=False):
14
- if isinstance(token_sequence, str):
15
- token_sequence = token_sequence.split()
16
- note_sequence = empty_note_sequence()
17
-
18
- # Render all notes.
19
- current_program = 1
20
- current_is_drum = False
21
- current_instrument = 0
22
- track_count = 0
23
- for token_index, token in enumerate(token_sequence):
24
-
25
- if token == "PIECE_START":
26
- pass
27
- elif token == "PIECE_END":
28
- print("The end.")
29
- break
30
- elif token == "TRACK_START":
31
- current_bar_index = 0
32
- track_count += 1
33
- pass
34
- elif token == "TRACK_END":
35
- pass
36
- elif token == "KEYS_START":
37
- pass
38
- elif token == "KEYS_END":
39
- pass
40
- elif token.startswith("KEY="):
41
- pass
42
- elif token.startswith("INST"):
43
- instrument = token.split("=")[-1]
44
- if instrument != "DRUMS" and use_program:
45
- if instrument_mapper is not None:
46
- if instrument in instrument_mapper:
47
- instrument = instrument_mapper[instrument]
48
- current_program = int(instrument)
49
- current_instrument = track_count
50
- current_is_drum = False
51
- if instrument == "DRUMS" and use_drums:
52
- current_instrument = 0
53
- current_program = 0
54
- current_is_drum = True
55
- elif token == "BAR_START":
56
- current_time = current_bar_index * BAR_LENGTH_120BPM
57
- current_notes = {}
58
- elif token == "BAR_END":
59
- current_bar_index += 1
60
- pass
61
- elif token.startswith("NOTE_ON"):
62
- pitch = int(token.split("=")[-1])
63
- note = note_sequence.notes.add()
64
- note.start_time = current_time
65
- note.end_time = current_time + 4 * NOTE_LENGTH_16TH_120BPM
66
- note.pitch = pitch
67
- note.instrument = current_instrument
68
- note.program = current_program
69
- note.velocity = 80
70
- note.is_drum = current_is_drum
71
- current_notes[pitch] = note
72
- elif token.startswith("NOTE_OFF"):
73
- pitch = int(token.split("=")[-1])
74
- if pitch in current_notes:
75
- note = current_notes[pitch]
76
- note.end_time = current_time
77
- elif token.startswith("TIME_DELTA"):
78
- delta = float(token.split("=")[-1]) * NOTE_LENGTH_16TH_120BPM
79
- current_time += delta
80
- elif token.startswith("DENSITY="):
81
- pass
82
- elif token == "[PAD]":
83
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  else:
85
- #print(f"Ignored token {token}.")
86
- pass
87
-
88
- # Make the instruments right.
89
- instruments_drums = []
90
- for note in note_sequence.notes:
91
- pair = [note.program, note.is_drum]
92
- if pair not in instruments_drums:
93
- instruments_drums += [pair]
94
- note.instrument = instruments_drums.index(pair)
95
-
96
- if only_piano:
97
- for note in note_sequence.notes:
98
- if not note.is_drum:
99
- note.instrument = 0
100
- note.program = 0
101
-
102
- return note_sequence
103
-
104
- def empty_note_sequence(qpm=120.0, total_time=0.0):
105
- note_sequence = note_seq.protobuf.music_pb2.NoteSequence()
106
- note_sequence.tempos.add().qpm = qpm
107
- note_sequence.ticks_per_quarter = note_seq.constants.STANDARD_PPQ
108
- note_sequence.total_time = total_time
109
- return note_sequence
110
-
111
- def process(text):
112
- input_ids = tokenizer.encode(text, return_tensors="pt")
113
- generated_ids = model.generate(input_ids, max_length=500)
114
- generated_sequence = tokenizer.decode(generated_ids[0])
115
-
116
- # Convert text of notes to audio
117
- note_sequence = token_sequence_to_note_sequence(generated_sequence)
118
- synth = note_seq.midi_synth.synthesize
119
- array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
120
- note_plot = note_seq.plot_sequence(note_sequence, False)
121
- array_of_floats /=1.414
122
- array_of_floats *= 32767
123
- int16_data = array_of_floats.astype(np.int16)
124
- return SAMPLE_RATE, int16_data
125
-
126
- title = "Music generation with GPT-2"
127
-
128
- iface = gr.Interface(
129
- fn=process,
130
- inputs=[gr.inputs.Textbox(default="PIECE_START")],
131
- outputs=['audio'],
132
- title=title,
133
- examples=[["PIECE_START"], ["PIECE_START STYLE=JSFAKES GENRE=JSFAKES TRACK_START INST=48 BAR_START NOTE_ON=61"]],
134
- article="This demo is inspired in the notebook from https://huggingface.co/TristanBehrens/js-fakes-4bars"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  )
136
 
137
- iface.launch(debug=True)
 
 
 
1
  import numpy as np
2
+ import torch
3
+ import gradio as gr
4
+ import spaces
5
+ from queue import Queue
6
+ from threading import Thread
7
+ from typing import Optional
8
+ from transformers import MusicgenForConditionalGeneration, MusicgenProcessor, set_seed
9
+ from transformers.generation.streamers import BaseStreamer
10
+
11
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
12
+ processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")
13
+
14
+ title = "9🌍MusicHub - Text to Music Stream Generator"
15
+ description = """ Facebook MusicGen-Small Model - Generate and stream music with model https://huggingface.co/facebook/musicgen-small """
16
+ article = """
17
+ ## How It Works:
18
+ MusicGen is an auto-regressive transformer-based model, meaning generates audio codes (tokens) in a causal fashion.
19
+ At each decoding step, the model generates a new set of audio codes, conditional on the text input and all previous audio codes. From the
20
+ frame rate of the [EnCodec model](https://huggingface.co/facebook/encodec_32khz) used to decode the generated codes to audio waveform.
21
+ """
22
+
23
+
24
+ class MusicgenStreamer(BaseStreamer):
25
+ def __init__(
26
+ self,
27
+ model: MusicgenForConditionalGeneration,
28
+ device: Optional[str] = None,
29
+ play_steps: Optional[int] = 10,
30
+ stride: Optional[int] = None,
31
+ timeout: Optional[float] = None,
32
+ ):
33
+ """
34
+ Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
35
+ useful for applications that benefit from acessing the generated audio in a non-blocking way (e.g. in an interactive
36
+ Gradio demo).
37
+ Parameters:
38
+ model (`MusicgenForConditionalGeneration`):
39
+ The MusicGen model used to generate the audio waveform.
40
+ device (`str`, *optional*):
41
+ The torch device on which to run the computation. If `None`, will default to the device of the model.
42
+ play_steps (`int`, *optional*, defaults to 10):
43
+ The number of generation steps with which to return the generated audio array. Using fewer steps will
44
+ mean the first chunk is ready faster, but will require more codec decoding steps overall. This value
45
+ should be tuned to your device and latency requirements.
46
+ stride (`int`, *optional*):
47
+ The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
48
+ the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to
49
+ play_steps // 6 in the audio space.
50
+ timeout (`int`, *optional*):
51
+ The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
52
+ in `.generate()`, when it is called in a separate thread.
53
+ """
54
+ self.decoder = model.decoder
55
+ self.audio_encoder = model.audio_encoder
56
+ self.generation_config = model.generation_config
57
+ self.device = device if device is not None else model.device
58
+
59
+ # variables used in the streaming process
60
+ self.play_steps = play_steps
61
+ if stride is not None:
62
+ self.stride = stride
63
+ else:
64
+ hop_length = np.prod(self.audio_encoder.config.upsampling_ratios)
65
+ self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
66
+ self.token_cache = None
67
+ self.to_yield = 0
68
+
69
+ # varibles used in the thread process
70
+ self.audio_queue = Queue()
71
+ self.stop_signal = None
72
+ self.timeout = timeout
73
+
74
+ def apply_delay_pattern_mask(self, input_ids):
75
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
76
+ _, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
77
+ input_ids[:, :1],
78
+ pad_token_id=self.generation_config.decoder_start_token_id,
79
+ max_length=input_ids.shape[-1],
80
+ )
81
+ # apply the pattern mask to the input ids
82
+ input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask)
83
+
84
+ # revert the pattern delay mask by filtering the pad token id
85
+ input_ids = input_ids[input_ids != self.generation_config.pad_token_id].reshape(
86
+ 1, self.decoder.num_codebooks, -1
87
+ )
88
+
89
+ # append the frame dimension back to the audio codes
90
+ input_ids = input_ids[None, ...]
91
+
92
+ # send the input_ids to the correct device
93
+ input_ids = input_ids.to(self.audio_encoder.device)
94
+
95
+ output_values = self.audio_encoder.decode(
96
+ input_ids,
97
+ audio_scales=[None],
98
+ )
99
+ audio_values = output_values.audio_values[0, 0]
100
+ return audio_values.cpu().float().numpy()
101
+
102
+ def put(self, value):
103
+ batch_size = value.shape[0] // self.decoder.num_codebooks
104
+ if batch_size > 1:
105
+ raise ValueError("MusicgenStreamer only supports batch size 1")
106
+
107
+ if self.token_cache is None:
108
+ self.token_cache = value
109
+ else:
110
+ self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
111
+
112
+ if self.token_cache.shape[-1] % self.play_steps == 0:
113
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
114
+ self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
115
+ self.to_yield += len(audio_values) - self.to_yield - self.stride
116
+
117
+ def end(self):
118
+ """Flushes any remaining cache and appends the stop symbol."""
119
+ if self.token_cache is not None:
120
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
121
  else:
122
+ audio_values = np.zeros(self.to_yield)
123
+
124
+ self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
125
+
126
+ def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
127
+ """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
128
+ self.audio_queue.put(audio, timeout=self.timeout)
129
+ if stream_end:
130
+ self.audio_queue.put(self.stop_signal, timeout=self.timeout)
131
+
132
+ def __iter__(self):
133
+ return self
134
+
135
+ def __next__(self):
136
+ value = self.audio_queue.get(timeout=self.timeout)
137
+ if not isinstance(value, np.ndarray) and value == self.stop_signal:
138
+ raise StopIteration()
139
+ else:
140
+ return value
141
+
142
+
143
+ sampling_rate = model.audio_encoder.config.sampling_rate
144
+ frame_rate = model.audio_encoder.config.frame_rate
145
+
146
+ target_dtype = np.int16
147
+ max_range = np.iinfo(target_dtype).max
148
+
149
+
150
+ @spaces.GPU
151
+ def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0, seed=0):
152
+ max_new_tokens = int(frame_rate * audio_length_in_s)
153
+ play_steps = int(frame_rate * play_steps_in_s)
154
+
155
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
156
+ if device != model.device:
157
+ model.to(device)
158
+ if device == "cuda:0":
159
+ model.half()
160
+
161
+ inputs = processor(
162
+ text=text_prompt,
163
+ padding=True,
164
+ return_tensors="pt",
165
+ )
166
+
167
+ streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
168
+
169
+ generation_kwargs = dict(
170
+ **inputs.to(device),
171
+ streamer=streamer,
172
+ max_new_tokens=max_new_tokens,
173
+ )
174
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
175
+ thread.start()
176
+
177
+ set_seed(seed)
178
+ for new_audio in streamer:
179
+ print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
180
+ new_audio = (new_audio * max_range).astype(np.int16)
181
+ yield (sampling_rate, new_audio)
182
+
183
+
184
+ demo = gr.Interface(
185
+ fn=generate_audio,
186
+ inputs=[
187
+ gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
188
+ gr.Slider(10, 30, value=15, step=5, label="Audio length in seconds"),
189
+ gr.Slider(0.5, 2.5, value=0.5, step=0.5, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps"),
190
+ gr.Slider(0, 10, value=5, step=1, label="Seed for random generations"),
191
+ ],
192
+ outputs=[
193
+ gr.Audio(label="Generated Music", streaming=True, autoplay=True)
194
+ ],
195
+ examples = [
196
+ ["Country acoustic guitar fast line dance singer like Kenny Chesney and Garth brooks and Luke Combs and Chris Stapleton. bpm: 100", 30, 0.5, 5],
197
+ ["Electronic Dance track with pulsating bass and high energy synths. bpm: 126", 30, 0.5, 5],
198
+ ["Rap Beats with deep bass and snappy snares. bpm: 80", 30, 0.5, 5],
199
+ ["Lo-Fi track with smooth beats and chill vibes. bpm: 100", 30, 0.5, 5],
200
+ ["Global Groove track with international instruments and dance rhythms. bpm: 128", 30, 0.5, 5],
201
+ ["Relaxing Meditation music with ambient pads and soothing melodies. bpm: 80", 30, 0.5, 5],
202
+ ["Rave Dance track with hard-hitting beats and euphoric synths. bpm: 128", 30, 0.5, 5]
203
+ ],
204
+
205
+ title=title,
206
+ description=description,
207
+ article=article,
208
+ cache_examples=False,
209
  )
210
 
211
+ demo.queue().launch()