Surn commited on
Commit
aef7fad
·
1 Parent(s): 5d66b58

Testing Seed Values

Browse files

Allow loading from file

app.py CHANGED
@@ -15,6 +15,7 @@ from audiocraft.models import MusicGen
15
  from audiocraft.data.audio import audio_write
16
  from audiocraft.utils.extend import generate_music_segments, add_settings_to_image, sanitize_file_name
17
  import numpy as np
 
18
 
19
  MODEL = None
20
  IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
@@ -25,7 +26,7 @@ def load_model(version):
25
  return MusicGen.get_pretrained(version)
26
 
27
 
28
- def predict(model, text, melody, duration, dimension, topk, topp, temperature, cfg_coef, background, title, include_settings, settings_font, settings_font_color):
29
  global MODEL
30
  output_segments = None
31
  topk = int(topk)
@@ -36,6 +37,10 @@ def predict(model, text, melody, duration, dimension, topk, topp, temperature, c
36
  segment_duration = MODEL.lm.cfg.dataset.segment_duration
37
  else:
38
  segment_duration = duration
 
 
 
 
39
  MODEL.set_generation_params(
40
  use_sampling=True,
41
  top_k=topk,
@@ -47,7 +52,7 @@ def predict(model, text, melody, duration, dimension, topk, topp, temperature, c
47
 
48
  if melody:
49
  if duration > MODEL.lm.cfg.dataset.segment_duration:
50
- output_segments = generate_music_segments(text, melody, MODEL, duration, MODEL.lm.cfg.dataset.segment_duration)
51
  else:
52
  # pure original code
53
  sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
@@ -76,14 +81,13 @@ def predict(model, text, melody, duration, dimension, topk, topp, temperature, c
76
  output = output.detach().cpu().float()[0]
77
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
78
  if include_settings:
79
- video_description = f"{text}\n Duration: {str(duration)} Dimension: {dimension}\n Top-k:{topk} Top-p:{topp}\n Randomness:{temperature}\n cfg:{cfg_coef}"
80
  background = add_settings_to_image(title, video_description, background_path=background, font=settings_font, font_color=settings_font_color)
81
- #filename = sanitize_file_name(title) if title != "" else file.name
82
  audio_write(
83
  file.name, output, MODEL.sample_rate, strategy="loudness",
84
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
85
  waveform_video = gr.make_waveform(file.name,bg_image=background, bar_count=40)
86
- return waveform_video
87
 
88
 
89
  def ui(**kwargs):
@@ -121,15 +125,23 @@ def ui(**kwargs):
121
  model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
122
  with gr.Row():
123
  duration = gr.Slider(minimum=1, maximum=1000, value=10, label="Duration", interactive=True)
 
124
  dimension = gr.Slider(minimum=-2, maximum=1, value=1, step=1, label="Dimension", info="determines which direction to add new segements of audio. (0 = stack tracks, 1 = lengthen, -1 = ?)", interactive=True)
125
  with gr.Row():
126
  topk = gr.Number(label="Top-k", value=250, interactive=True)
127
  topp = gr.Number(label="Top-p", value=0, interactive=True)
128
  temperature = gr.Number(label="Randomness Temperature", value=1.0, precision=2, interactive=True)
129
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, precision=2, interactive=True)
130
- with gr.Column():
 
 
 
 
131
  output = gr.Video(label="Generated Music")
132
- submit.click(predict, inputs=[model, text, melody, duration, dimension, topk, topp, temperature, cfg_coef, background, title, include_settings, settings_font, settings_font_color], outputs=[output])
 
 
 
133
  gr.Examples(
134
  fn=predict,
135
  examples=[
 
15
  from audiocraft.data.audio import audio_write
16
  from audiocraft.utils.extend import generate_music_segments, add_settings_to_image, sanitize_file_name
17
  import numpy as np
18
+ import random
19
 
20
  MODEL = None
21
  IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
 
26
  return MusicGen.get_pretrained(version)
27
 
28
 
29
+ def predict(model, text, melody, duration, dimension, topk, topp, temperature, cfg_coef, background, title, include_settings, settings_font, settings_font_color, seed, overlap=1):
30
  global MODEL
31
  output_segments = None
32
  topk = int(topk)
 
37
  segment_duration = MODEL.lm.cfg.dataset.segment_duration
38
  else:
39
  segment_duration = duration
40
+ # implement seed
41
+ if seed < 0:
42
+ seed = random.randint(0, 0xffff_ffff_ffff)
43
+ torch.manual_seed(seed)
44
  MODEL.set_generation_params(
45
  use_sampling=True,
46
  top_k=topk,
 
52
 
53
  if melody:
54
  if duration > MODEL.lm.cfg.dataset.segment_duration:
55
+ output_segments = generate_music_segments(text, melody, MODEL, seed, duration, overlap, MODEL.lm.cfg.dataset.segment_duration)
56
  else:
57
  # pure original code
58
  sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
 
81
  output = output.detach().cpu().float()[0]
82
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
83
  if include_settings:
84
+ video_description = f"{text}\n Duration: {str(duration)} Dimension: {dimension}\n Top-k:{topk} Top-p:{topp}\n Randomness:{temperature}\n cfg:{cfg_coef} overlap: {overlap}\n Seed: {seed}"
85
  background = add_settings_to_image(title, video_description, background_path=background, font=settings_font, font_color=settings_font_color)
 
86
  audio_write(
87
  file.name, output, MODEL.sample_rate, strategy="loudness",
88
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
89
  waveform_video = gr.make_waveform(file.name,bg_image=background, bar_count=40)
90
+ return waveform_video, seed
91
 
92
 
93
  def ui(**kwargs):
 
125
  model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
126
  with gr.Row():
127
  duration = gr.Slider(minimum=1, maximum=1000, value=10, label="Duration", interactive=True)
128
+ overlap = gr.Slider(minimum=1, maximum=29, value=5, step=1, label="Overlap", interactive=True)
129
  dimension = gr.Slider(minimum=-2, maximum=1, value=1, step=1, label="Dimension", info="determines which direction to add new segements of audio. (0 = stack tracks, 1 = lengthen, -1 = ?)", interactive=True)
130
  with gr.Row():
131
  topk = gr.Number(label="Top-k", value=250, interactive=True)
132
  topp = gr.Number(label="Top-p", value=0, interactive=True)
133
  temperature = gr.Number(label="Randomness Temperature", value=1.0, precision=2, interactive=True)
134
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, precision=2, interactive=True)
135
+ with gr.Row():
136
+ seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
137
+ gr.Button('\U0001f3b2\ufe0f').style(full_width=False).click(fn=lambda: -1, outputs=[seed], queue=False)
138
+ reuse_seed = gr.Button('\u267b\ufe0f').style(full_width=False)
139
+ with gr.Column() as c:
140
  output = gr.Video(label="Generated Music")
141
+ seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
142
+
143
+ reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False)
144
+ submit.click(predict, inputs=[model, text, melody, duration, dimension, topk, topp, temperature, cfg_coef, background, title, include_settings, settings_font, settings_font_color, seed, overlap], outputs=[output, seed_used])
145
  gr.Examples(
146
  fn=predict,
147
  examples=[
audiocraft/models/loaders.py CHANGED
@@ -50,6 +50,10 @@ def _get_state_dict(
50
 
51
  if os.path.isfile(file_or_url_or_id):
52
  return torch.load(file_or_url_or_id, map_location=device)
 
 
 
 
53
 
54
  elif file_or_url_or_id.startswith('https://'):
55
  return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
 
50
 
51
  if os.path.isfile(file_or_url_or_id):
52
  return torch.load(file_or_url_or_id, map_location=device)
53
+
54
+ if os.path.isdir(file_or_url_or_id):
55
+ file = f"{file_or_url_or_id}/{filename}"
56
+ return torch.load(file, map_location=device)
57
 
58
  elif file_or_url_or_id.startswith('https://'):
59
  return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
audiocraft/models/musicgen.py CHANGED
@@ -80,10 +80,11 @@ class MusicGen:
80
  return MusicGen(name, compression_model, lm)
81
 
82
  if name not in HF_MODEL_CHECKPOINTS_MAP:
83
- raise ValueError(
84
- f"{name} is not a valid checkpoint name. "
85
- f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
86
- )
 
87
 
88
  cache_dir = os.environ.get('MUSICGEN_ROOT', None)
89
  compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
 
80
  return MusicGen(name, compression_model, lm)
81
 
82
  if name not in HF_MODEL_CHECKPOINTS_MAP:
83
+ if not os.path.isfile(name) and not os.path.isdir(name):
84
+ raise ValueError(
85
+ f"{name} is not a valid checkpoint name. "
86
+ f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
87
+ )
88
 
89
  cache_dir = os.environ.get('MUSICGEN_ROOT', None)
90
  compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
audiocraft/utils/extend.py CHANGED
@@ -8,29 +8,34 @@ import tempfile
8
  import os
9
  import textwrap
10
 
11
- def separate_audio_segments(audio, segment_duration=30):
12
  sr, audio_data = audio[0], audio[1]
13
-
14
  total_samples = len(audio_data)
15
  segment_samples = sr * segment_duration
16
-
17
- total_segments = math.ceil(total_samples / segment_samples)
18
-
19
  segments = []
20
-
21
- for segment_idx in range(total_segments):
22
- print(f"Audio Input segment {segment_idx + 1} / {total_segments + 1} \r")
23
- start_sample = segment_idx * segment_samples
24
- end_sample = (segment_idx + 1) * segment_samples
25
-
26
  segment = audio_data[start_sample:end_sample]
27
  segments.append((sr, segment))
28
-
 
 
 
 
 
 
 
 
29
  return segments
30
 
31
- def generate_music_segments(text, melody, MODEL, duration:int=10, segment_duration:int=30):
32
  # generate audio segments
33
- melody_segments = separate_audio_segments(melody, segment_duration)
34
 
35
  # Create a list to store the melody tensors for each segment
36
  melodys = []
@@ -40,7 +45,7 @@ def generate_music_segments(text, melody, MODEL, duration:int=10, segment_durati
40
  total_segments = max(math.ceil(duration / segment_duration),1)
41
  print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds")
42
 
43
- # If melody_segments is shorter than total_segments, repeat the segments until the total number of segments is reached
44
  if len(melody_segments) < total_segments:
45
  for i in range(total_segments - len(melody_segments)):
46
  segment = melody_segments[i]
@@ -59,6 +64,7 @@ def generate_music_segments(text, melody, MODEL, duration:int=10, segment_durati
59
  # Append the segment to the melodys list
60
  melodys.append(verse)
61
 
 
62
  for idx, verse in enumerate(melodys):
63
  print(f"Generating New Melody Segment {idx + 1}: {text}\r")
64
  output = MODEL.generate_with_chroma(
@@ -74,42 +80,6 @@ def generate_music_segments(text, melody, MODEL, duration:int=10, segment_durati
74
  print(f"output_segments: {len(output_segments)}: shape: {output.shape} dim {output.dim()}")
75
  return output_segments
76
 
77
- #def generate_music_segments(text, melody, duration, MODEL, segment_duration=30):
78
- # sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
79
-
80
- # # Create a list to store the melody tensors for each segment
81
- # melodys = []
82
-
83
- # # Calculate the total number of segments
84
- # total_segments = math.ceil(melody.shape[1] / (sr * segment_duration))
85
-
86
- # # Iterate over the segments
87
- # for segment_idx in range(total_segments):
88
- # print(f"segment {segment_idx + 1} / {total_segments + 1} \r")
89
- # start_frame = segment_idx * sr * segment_duration
90
- # end_frame = (segment_idx + 1) * sr * segment_duration
91
-
92
- # # Extract the segment from the melody tensor
93
- # segment = melody[:, start_frame:end_frame]
94
-
95
- # # Append the segment to the melodys list
96
- # melodys.append(segment)
97
-
98
- # output_segments = []
99
-
100
- # for segment in melodys:
101
- # output = MODEL.generate_with_chroma(
102
- # descriptions=[text],
103
- # melody_wavs=segment,
104
- # melody_sample_rate=sr,
105
- # progress=False
106
- # )
107
-
108
- # # Append the generated output to the list of segments
109
- # output_segments.append(output[:, :segment_duration])
110
-
111
- # return output_segments
112
-
113
  def save_image(image):
114
  """
115
  Saves a PIL image to a temporary file and returns the file path.
@@ -184,13 +154,4 @@ def add_settings_to_image(title: str = "title", description: str = "", width: in
184
  background.paste(image, offset, mask=image)
185
 
186
  # Save the image and return the file path
187
- return save_image(background)
188
-
189
-
190
- def sanitize_file_name(filename):
191
- valid_chars = "-_.() " + string.ascii_letters + string.digits
192
- sanitized_filename = ''.join(c for c in filename if c in valid_chars)
193
- return sanitized_filename
194
-
195
-
196
-
 
8
  import os
9
  import textwrap
10
 
11
+ def separate_audio_segments(audio, segment_duration=30, overlap=1):
12
  sr, audio_data = audio[0], audio[1]
13
+
14
  total_samples = len(audio_data)
15
  segment_samples = sr * segment_duration
16
+ overlap_samples = sr * overlap
17
+
 
18
  segments = []
19
+ start_sample = 0
20
+
21
+ while total_samples >= segment_samples:
22
+ end_sample = start_sample + segment_samples
 
 
23
  segment = audio_data[start_sample:end_sample]
24
  segments.append((sr, segment))
25
+
26
+ start_sample += segment_samples - overlap_samples
27
+ total_samples -= segment_samples - overlap_samples
28
+
29
+ # Collect the final segment
30
+ if total_samples > 0:
31
+ segment = audio_data[-segment_samples:]
32
+ segments.append((sr, segment))
33
+
34
  return segments
35
 
36
+ def generate_music_segments(text, melody, MODEL, seed, duration:int=10, overlap:int=1, segment_duration:int=30):
37
  # generate audio segments
38
+ melody_segments = separate_audio_segments(melody, segment_duration, overlap)
39
 
40
  # Create a list to store the melody tensors for each segment
41
  melodys = []
 
45
  total_segments = max(math.ceil(duration / segment_duration),1)
46
  print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds")
47
 
48
+ # If melody_segments is shorter than total_segments, repeat the segments until the total_segments is reached
49
  if len(melody_segments) < total_segments:
50
  for i in range(total_segments - len(melody_segments)):
51
  segment = melody_segments[i]
 
64
  # Append the segment to the melodys list
65
  melodys.append(verse)
66
 
67
+ torch.manual_seed(seed)
68
  for idx, verse in enumerate(melodys):
69
  print(f"Generating New Melody Segment {idx + 1}: {text}\r")
70
  output = MODEL.generate_with_chroma(
 
80
  print(f"output_segments: {len(output_segments)}: shape: {output.shape} dim {output.dim()}")
81
  return output_segments
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def save_image(image):
84
  """
85
  Saves a PIL image to a temporary file and returns the file path.
 
154
  background.paste(image, offset, mask=image)
155
 
156
  # Save the image and return the file path
157
+ return save_image(background)