Surn commited on
Commit
1dda6b6
1 Parent(s): aef5578

Update Overlap Action in Melody

Browse files
app.py CHANGED
@@ -100,6 +100,8 @@ def predict(model, text, melody, duration, dimension, topk, topp, temperature, c
100
  temperature=temperature,
101
  cfg_coef=cfg_coef,
102
  duration=segment_duration,
 
 
103
  )
104
 
105
  if melody:
@@ -201,7 +203,7 @@ def ui(**kwargs):
201
  include_settings = gr.Checkbox(label="Add Settings to background", value=True, interactive=True)
202
  with gr.Row():
203
  title = gr.Textbox(label="Title", value="UnlimitedMusicGen", interactive=True)
204
- settings_font = gr.Text(label="Settings Font", value="arial.ttf", interactive=True)
205
  settings_font_color = gr.ColorPicker(label="Settings Font Color", value="#ffffff", interactive=True)
206
  with gr.Row():
207
  model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
@@ -212,8 +214,8 @@ def ui(**kwargs):
212
  with gr.Row():
213
  topk = gr.Number(label="Top-k", value=250, interactive=True)
214
  topp = gr.Number(label="Top-p", value=0, interactive=True)
215
- temperature = gr.Number(label="Randomness Temperature", value=1.0, precision=2, interactive=True)
216
- cfg_coef = gr.Number(label="Classifier Free Guidance", value=5.0, precision=2, interactive=True)
217
  with gr.Row():
218
  seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
219
  gr.Button('\U0001f3b2\ufe0f').style(full_width=False).click(fn=lambda: -1, outputs=[seed], queue=False)
 
100
  temperature=temperature,
101
  cfg_coef=cfg_coef,
102
  duration=segment_duration,
103
+ two_step_cfg=False,
104
+ rep_penalty=0.5
105
  )
106
 
107
  if melody:
 
203
  include_settings = gr.Checkbox(label="Add Settings to background", value=True, interactive=True)
204
  with gr.Row():
205
  title = gr.Textbox(label="Title", value="UnlimitedMusicGen", interactive=True)
206
+ settings_font = gr.Text(label="Settings Font", value="./assets/arial.ttf", interactive=True)
207
  settings_font_color = gr.ColorPicker(label="Settings Font Color", value="#ffffff", interactive=True)
208
  with gr.Row():
209
  model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
 
214
  with gr.Row():
215
  topk = gr.Number(label="Top-k", value=250, interactive=True)
216
  topp = gr.Number(label="Top-p", value=0, interactive=True)
217
+ temperature = gr.Number(label="Randomness Temperature", value=0.75, precision=None, interactive=True)
218
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=5.5, precision=None, interactive=True)
219
  with gr.Row():
220
  seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
221
  gr.Button('\U0001f3b2\ufe0f').style(full_width=False).click(fn=lambda: -1, outputs=[seed], queue=False)
audiocraft/models/musicgen.py CHANGED
@@ -97,7 +97,7 @@ class MusicGen:
97
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
98
  top_p: float = 0.0, temperature: float = 1.0,
99
  duration: float = 30.0, cfg_coef: float = 3.0,
100
- two_step_cfg: bool = False):
101
  """Set the generation parameters for MusicGen.
102
 
103
  Args:
@@ -110,6 +110,7 @@ class MusicGen:
110
  two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
111
  instead of batching together the two. This has some impact on how things
112
  are padded but seems to have little impact in practice.
 
113
  """
114
  assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
115
  self.generation_params = {
@@ -119,7 +120,7 @@ class MusicGen:
119
  'top_k': top_k,
120
  'top_p': top_p,
121
  'cfg_coef': cfg_coef,
122
- 'two_step_cfg': two_step_cfg,
123
  }
124
 
125
  def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
@@ -177,6 +178,58 @@ class MusicGen:
177
  assert prompt_tokens is None
178
  return self._generate_tokens(attributes, prompt_tokens, progress)
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
181
  descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
182
  progress: bool = False) -> torch.Tensor:
 
97
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
98
  top_p: float = 0.0, temperature: float = 1.0,
99
  duration: float = 30.0, cfg_coef: float = 3.0,
100
+ two_step_cfg: bool = False, rep_penalty: float = None):
101
  """Set the generation parameters for MusicGen.
102
 
103
  Args:
 
110
  two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
111
  instead of batching together the two. This has some impact on how things
112
  are padded but seems to have little impact in practice.
113
+ rep_penalty (float, optional): If set, use repetition penalty during generation. Not Implemented.
114
  """
115
  assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
116
  self.generation_params = {
 
120
  'top_k': top_k,
121
  'top_p': top_p,
122
  'cfg_coef': cfg_coef,
123
+ 'two_step_cfg': two_step_cfg,
124
  }
125
 
126
  def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
 
178
  assert prompt_tokens is None
179
  return self._generate_tokens(attributes, prompt_tokens, progress)
180
 
181
+ def generate_with_all(self, descriptions: tp.List[str], melody_wavs: MelodyType,
182
+ sample_rate: int, progress: bool = False, prompt: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
183
+ """Generate samples conditioned on text and melody and audio prompts.
184
+ Args:
185
+ descriptions (tp.List[str]): A list of strings used as text conditioning.
186
+ melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
187
+ melody conditioning. Should have shape [B, C, T] with B matching the description length,
188
+ C=1 or 2. It can be [C, T] if there is a single description. It can also be
189
+ a list of [C, T] tensors.
190
+ sample_rate: (int): Sample rate of the melody waveforms.
191
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
192
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
193
+ Prompt should be [B, C, T], or [C, T] if only one sample is generated.
194
+ """
195
+ if isinstance(melody_wavs, torch.Tensor):
196
+ if melody_wavs.dim() == 2:
197
+ melody_wavs = melody_wavs[None]
198
+ if melody_wavs.dim() != 3:
199
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
200
+ melody_wavs = list(melody_wavs)
201
+ else:
202
+ for melody in melody_wavs:
203
+ if melody is not None:
204
+ assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
205
+
206
+ melody_wavs = [
207
+ convert_audio(wav, sample_rate, self.sample_rate, self.audio_channels)
208
+ if wav is not None else None
209
+ for wav in melody_wavs]
210
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
211
+ melody_wavs=melody_wavs)
212
+
213
+ if prompt is not None:
214
+ if prompt.dim() == 2:
215
+ prompt = prompt[None]
216
+ if prompt.dim() != 3:
217
+ raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
218
+ prompt = convert_audio(prompt, sample_rate, self.sample_rate, self.audio_channels)
219
+ if descriptions is None:
220
+ descriptions = [None] * len(prompt)
221
+
222
+ if prompt is not None:
223
+ attributes_gen, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
224
+
225
+ #attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=prompt,
226
+ # melody_wavs=melody_wavs)
227
+ if prompt is not None:
228
+ assert prompt_tokens is not None
229
+ else:
230
+ assert prompt_tokens is None
231
+ return self._generate_tokens(attributes, prompt_tokens, progress)
232
+
233
  def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
234
  descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
235
  progress: bool = False) -> torch.Tensor:
audiocraft/utils/extend.py CHANGED
@@ -22,12 +22,15 @@ def separate_audio_segments(audio, segment_duration=30, overlap=1):
22
  start_sample = 0
23
 
24
  while total_samples >= segment_samples:
 
 
 
25
  end_sample = start_sample + segment_samples
26
  segment = audio_data[start_sample:end_sample]
27
  segments.append((sr, segment))
28
 
29
  start_sample += segment_samples - overlap_samples
30
- total_samples -= segment_samples - overlap_samples
31
 
32
  # Collect the final segment
33
  if total_samples > 0:
@@ -38,17 +41,16 @@ def separate_audio_segments(audio, segment_duration=30, overlap=1):
38
 
39
  def generate_music_segments(text, melody, MODEL, seed, duration:int=10, overlap:int=1, segment_duration:int=30):
40
  # generate audio segments
41
- melody_segments = separate_audio_segments(melody, segment_duration, overlap)
42
 
43
  # Create a list to store the melody tensors for each segment
44
  melodys = []
45
  output_segments = []
 
 
46
 
47
  # Calculate the total number of segments
48
  total_segments = max(math.ceil(duration / segment_duration),1)
49
- # account for overlap
50
- duration = duration + (max((total_segments - 1),0) * overlap)
51
- total_segments = max(math.ceil(duration / segment_duration),1)
52
  #calc excess duration
53
  excess_duration = segment_duration - (total_segments * segment_duration - duration)
54
  print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration}")
@@ -76,11 +78,15 @@ def generate_music_segments(text, melody, MODEL, seed, duration:int=10, overlap:
76
  torch.manual_seed(seed)
77
  for idx, verse in enumerate(melodys):
78
  print(f"Generating New Melody Segment {idx + 1}: {text}\r")
79
- output = MODEL.generate_with_chroma(
 
 
 
80
  descriptions=[text],
81
  melody_wavs=verse,
82
- melody_sample_rate=sr,
83
- progress=True
 
84
  )
85
 
86
  # Append the generated output to the list of segments
@@ -151,24 +157,31 @@ def load_font(font_name, font_size=16):
151
  Example:
152
  font = load_font("Arial.ttf", font_size=20)
153
  """
154
-
155
- try:
156
- font = ImageFont.truetype(font_name, font_size)
157
- except (FileNotFoundError, OSError):
158
  try:
159
  font = ImageFont.truetype(font_name, font_size)
160
- print("Font not found. Downloading from Hugging Face model hub...\n")
161
- except:
 
162
  try:
163
- req = requests.get(font_name)
164
- font = ImageFont.truetype(BytesIO(req.content), font_size)
165
- print("Font not found. Downloading from URL...\n")
166
- except:
167
- try:
168
- font = ImageFont.truetype(hf_hub_download("/assets", font_name), encoding="UTF-8")
169
- print(f"Font not found: {font_name} Using default font\n")
170
- except:
171
- font = ImageFont.load_default()
 
 
 
 
 
 
 
 
172
  return font
173
 
174
 
 
22
  start_sample = 0
23
 
24
  while total_samples >= segment_samples:
25
+ # Collect the segment
26
+ # the end sample is the start sample plus the segment samples,
27
+ # the start sample, after 0, is minus the overlap samples to account for the overlap
28
  end_sample = start_sample + segment_samples
29
  segment = audio_data[start_sample:end_sample]
30
  segments.append((sr, segment))
31
 
32
  start_sample += segment_samples - overlap_samples
33
+ total_samples -= segment_samples
34
 
35
  # Collect the final segment
36
  if total_samples > 0:
 
41
 
42
  def generate_music_segments(text, melody, MODEL, seed, duration:int=10, overlap:int=1, segment_duration:int=30):
43
  # generate audio segments
44
+ melody_segments = separate_audio_segments(melody, segment_duration, 0)
45
 
46
  # Create a list to store the melody tensors for each segment
47
  melodys = []
48
  output_segments = []
49
+ last_chunk = []
50
+ text += ", seed=" + str(seed)
51
 
52
  # Calculate the total number of segments
53
  total_segments = max(math.ceil(duration / segment_duration),1)
 
 
 
54
  #calc excess duration
55
  excess_duration = segment_duration - (total_segments * segment_duration - duration)
56
  print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration}")
 
78
  torch.manual_seed(seed)
79
  for idx, verse in enumerate(melodys):
80
  print(f"Generating New Melody Segment {idx + 1}: {text}\r")
81
+ if output_segments:
82
+ # If this isn't the first segment, use the last chunk of the previous segment as the input
83
+ last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:]
84
+ output = MODEL.generate_with_all(
85
  descriptions=[text],
86
  melody_wavs=verse,
87
+ sample_rate=sr,
88
+ progress=True,
89
+ prompt=last_chunk if len(last_chunk) > 0 else None,
90
  )
91
 
92
  # Append the generated output to the list of segments
 
157
  Example:
158
  font = load_font("Arial.ttf", font_size=20)
159
  """
160
+ font = None
161
+ if not "http" in font_name:
 
 
162
  try:
163
  font = ImageFont.truetype(font_name, font_size)
164
+ except (FileNotFoundError, OSError):
165
+ print("Font not found. Trying to download from local assets folder...\n")
166
+ if font is None:
167
  try:
168
+ font = ImageFont.truetype("assets/" + font_name, font_size)
169
+ except (FileNotFoundError, OSError):
170
+ print("Font not found. Trying to download from URL...\n")
171
+
172
+ if font is None:
173
+ try:
174
+ req = requests.get(font_name)
175
+ font = ImageFont.truetype(BytesIO(req.content), font_size)
176
+ except (FileNotFoundError, OSError):
177
+ print(f"Font found: {font_name} Using Hugging Face download font\n")
178
+
179
+ if font is None:
180
+ try:
181
+ font = ImageFont.truetype(hf_hub_download("assets", font_name), encoding="UTF-8")
182
+ except (FileNotFoundError, OSError):
183
+ font = ImageFont.load_default()
184
+ print(f"Font not found: {font_name} Using default font\n")
185
  return font
186
 
187