Surn commited on
Commit
14af4d8
1 Parent(s): 6d70065

Process longer Audio

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. app.py +33 -16
  3. audiocraft/utils/extend.py +111 -0
  4. web-ui.bat +1 -0
.gitignore CHANGED
@@ -53,3 +53,4 @@ ENV/
53
  /notebooks
54
  /local_scripts
55
  /notes
 
 
53
  /notebooks
54
  /local_scripts
55
  /notes
56
+ /.vs
app.py CHANGED
@@ -13,6 +13,8 @@ import gradio as gr
13
  import os
14
  from audiocraft.models import MusicGen
15
  from audiocraft.data.audio import audio_write
 
 
16
 
17
  MODEL = None
18
  IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
@@ -30,32 +32,47 @@ def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
30
  MODEL = load_model(model)
31
 
32
  if duration > MODEL.lm.cfg.dataset.segment_duration:
33
- raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
 
 
34
  MODEL.set_generation_params(
35
  use_sampling=True,
36
  top_k=topk,
37
  top_p=topp,
38
  temperature=temperature,
39
  cfg_coef=cfg_coef,
40
- duration=duration,
41
  )
42
 
43
  if melody:
44
- sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
45
- print(melody.shape)
46
- if melody.dim() == 2:
47
- melody = melody[None]
48
- melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
49
- output = MODEL.generate_with_chroma(
50
- descriptions=[text],
51
- melody_wavs=melody,
52
- melody_sample_rate=sr,
53
- progress=False
54
- )
 
 
 
 
55
  else:
56
  output = MODEL.generate(descriptions=[text], progress=False)
57
 
58
- output = output.detach().cpu().float()[0]
 
 
 
 
 
 
 
 
 
59
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
60
  audio_write(
61
  file.name, output, MODEL.sample_rate, strategy="loudness",
@@ -91,7 +108,7 @@ def ui(**kwargs):
91
  with gr.Row():
92
  model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
93
  with gr.Row():
94
- duration = gr.Slider(minimum=1, maximum=30, value=10, label="Duration", interactive=True)
95
  with gr.Row():
96
  topk = gr.Number(label="Top-k", value=250, interactive=True)
97
  topp = gr.Number(label="Top-p", value=0, interactive=True)
@@ -194,7 +211,7 @@ if __name__ == "__main__":
194
  parser.add_argument(
195
  '--server_port',
196
  type=int,
197
- default=0,
198
  help='Port to run the server listener on',
199
  )
200
  parser.add_argument(
 
13
  import os
14
  from audiocraft.models import MusicGen
15
  from audiocraft.data.audio import audio_write
16
+ from audiocraft.utils.extend import generate_music_segments
17
+ import numpy as np
18
 
19
  MODEL = None
20
  IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
 
32
  MODEL = load_model(model)
33
 
34
  if duration > MODEL.lm.cfg.dataset.segment_duration:
35
+ segment_duration = MODEL.lm.cfg.dataset.segment_duration
36
+ else:
37
+ segment_duration = duration
38
  MODEL.set_generation_params(
39
  use_sampling=True,
40
  top_k=topk,
41
  top_p=topp,
42
  temperature=temperature,
43
  cfg_coef=cfg_coef,
44
+ duration=segment_duration,
45
  )
46
 
47
  if melody:
48
+ if duration > MODEL.lm.cfg.dataset.segment_duration:
49
+ output_segments = generate_music_segments(text, melody, MODEL, duration, MODEL.lm.cfg.dataset.segment_duration)
50
+ else:
51
+ # pure original code
52
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
53
+ print(melody.shape)
54
+ if melody.dim() == 2:
55
+ melody = melody[None]
56
+ melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
57
+ output = MODEL.generate_with_chroma(
58
+ descriptions=[text],
59
+ melody_wavs=melody,
60
+ melody_sample_rate=sr,
61
+ progress=True
62
+ )
63
  else:
64
  output = MODEL.generate(descriptions=[text], progress=False)
65
 
66
+ if output_segments:
67
+ try:
68
+ # Combine the output segments into one long audio file
69
+ output_segments = [segment.detach().cpu().float()[0] for segment in output_segments]
70
+ output = torch.cat(output_segments, dim=2)
71
+ except Exception as e:
72
+ print(f"error combining segments: {e}. Using first segment only")
73
+ output = output_segments[0].detach().cpu().float()[0]
74
+ else:
75
+ output = output.detach().cpu().float()[0]
76
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
77
  audio_write(
78
  file.name, output, MODEL.sample_rate, strategy="loudness",
 
108
  with gr.Row():
109
  model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
110
  with gr.Row():
111
+ duration = gr.Slider(minimum=1, maximum=1000, value=10, label="Duration", interactive=True)
112
  with gr.Row():
113
  topk = gr.Number(label="Top-k", value=250, interactive=True)
114
  topp = gr.Number(label="Top-p", value=0, interactive=True)
 
211
  parser.add_argument(
212
  '--server_port',
213
  type=int,
214
+ default=7859,
215
  help='Port to run the server listener on',
216
  )
217
  parser.add_argument(
audiocraft/utils/extend.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from audiocraft.models import MusicGen
4
+ import numpy as np
5
+
6
+
7
+ def separate_audio_segments(audio, segment_duration=30):
8
+ sr, audio_data = audio[0], audio[1]
9
+
10
+ total_samples = len(audio_data)
11
+ segment_samples = sr * segment_duration
12
+
13
+ total_segments = math.ceil(total_samples / segment_samples)
14
+
15
+ segments = []
16
+
17
+ for segment_idx in range(total_segments):
18
+ print(f"Audio Input segment {segment_idx + 1} / {total_segments + 1} \r")
19
+ start_sample = segment_idx * segment_samples
20
+ end_sample = (segment_idx + 1) * segment_samples
21
+
22
+ segment = audio_data[start_sample:end_sample]
23
+ segments.append((sr, segment))
24
+
25
+ return segments
26
+
27
+ def generate_music_segments(text, melody, MODEL, duration:int=10, segment_duration:int=30):
28
+ # generate audio segments
29
+ melody_segments = separate_audio_segments(melody, segment_duration)
30
+
31
+ # Create a list to store the melody tensors for each segment
32
+ melodys = []
33
+ output_segments = []
34
+
35
+ # Calculate the total number of segments
36
+ total_segments = max(math.ceil(duration / segment_duration),1)
37
+ print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds")
38
+
39
+ # If melody_segments is shorter than total_segments, repeat the segments until the total number of segments is reached
40
+ if len(melody_segments) < total_segments:
41
+ for i in range(total_segments - len(melody_segments)):
42
+ segment = melody_segments[i]
43
+ melody_segments.append(segment)
44
+ print(f"melody_segments: {len(melody_segments)} fixed")
45
+
46
+ # Iterate over the segments to create list of Meldoy tensors
47
+ for segment_idx in range(total_segments):
48
+ print(f"segment {segment_idx} of {total_segments} \r")
49
+ sr, verse = melody_segments[segment_idx][0], torch.from_numpy(melody_segments[segment_idx][1]).to(MODEL.device).float().t().unsqueeze(0)
50
+
51
+ print(f"shape:{verse.shape} dim:{verse.dim()}")
52
+ if verse.dim() == 2:
53
+ verse = verse[None]
54
+ verse = verse[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
55
+ # Append the segment to the melodys list
56
+ melodys.append(verse)
57
+
58
+ for idx, verse in enumerate(melodys):
59
+ print(f"Generating New Melody Segment {idx + 1}: {text}\r")
60
+ output = MODEL.generate_with_chroma(
61
+ descriptions=[text],
62
+ melody_wavs=verse,
63
+ melody_sample_rate=sr,
64
+ progress=True
65
+ )
66
+
67
+ # Append the generated output to the list of segments
68
+ #output_segments.append(output[:, :segment_duration])
69
+ output_segments.append(output)
70
+ print(f"output_segments: {len(output_segments)}: shape[0]: {output.shape} dim {output.dim()}")
71
+ return output_segments
72
+
73
+ #def generate_music_segments(text, melody, duration, MODEL, segment_duration=30):
74
+ # sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
75
+
76
+ # # Create a list to store the melody tensors for each segment
77
+ # melodys = []
78
+
79
+ # # Calculate the total number of segments
80
+ # total_segments = math.ceil(melody.shape[1] / (sr * segment_duration))
81
+
82
+ # # Iterate over the segments
83
+ # for segment_idx in range(total_segments):
84
+ # print(f"segment {segment_idx + 1} / {total_segments + 1} \r")
85
+ # start_frame = segment_idx * sr * segment_duration
86
+ # end_frame = (segment_idx + 1) * sr * segment_duration
87
+
88
+ # # Extract the segment from the melody tensor
89
+ # segment = melody[:, start_frame:end_frame]
90
+
91
+ # # Append the segment to the melodys list
92
+ # melodys.append(segment)
93
+
94
+ # output_segments = []
95
+
96
+ # for segment in melodys:
97
+ # output = MODEL.generate_with_chroma(
98
+ # descriptions=[text],
99
+ # melody_wavs=segment,
100
+ # melody_sample_rate=sr,
101
+ # progress=False
102
+ # )
103
+
104
+ # # Append the generated output to the list of segments
105
+ # output_segments.append(output[:, :segment_duration])
106
+
107
+ # return output_segments
108
+
109
+
110
+
111
+
web-ui.bat ADDED
@@ -0,0 +1 @@
 
 
1
+ py -m app