Localsong commited on
Commit
12bbde9
·
verified ·
1 Parent(s): 510d544

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +9 -0
  2. gradio_app.py +15 -7
  3. merge_lora.py +196 -0
  4. train_lora.py +273 -0
  5. train_lora_encode_latents.py +103 -0
README.md CHANGED
@@ -44,6 +44,15 @@ The first generation will be slower due to torch.compile, then speed will increa
44
 
45
  The model was trained on vocals but not lyrics. Vocals will not have recognizable words.
46
 
 
 
 
 
 
 
 
 
 
47
  ## Credits
48
 
49
  This project builds upon the following open-source projects:
 
44
 
45
  The model was trained on vocals but not lyrics. Vocals will not have recognizable words.
46
 
47
+ ## LoRA Training
48
+
49
+ - Prepare folder of .mp3 files
50
+ - Run python train_lora_encode_latents.py --audio-dir=/path/to/your/mp3s --output-dir=latents to save the latents
51
+ - Run python train_lora.py --latents_dir=latents to train the LoRA. You may need to adjust learning rate, steps or batch size depending on your dataset etc.
52
+ - Run python merge_lora.py --lora-checkpoint=lora_step1000.safetensors --output-checkpoint=merged.safetensors to merge the LoRA checkpoint into the base model for inference
53
+ - Run python gradio_app.py --checkpoint=merged.safetensors to run the merged checkpoint for inference
54
+ - Test inference with tag "soundtrack"; Lora training uses this tag. Additional tags may work.
55
+
56
  ## Credits
57
 
58
  This project builds upon the following open-source projects:
gradio_app.py CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
3
  from typing import List, Tuple
4
  import uuid
5
  import json
 
6
  import gradio as gr
7
  import torch
8
  import torchaudio
@@ -83,7 +84,7 @@ rf_sampler: RF | None = None
83
  device: torch.device | None = None
84
  _available_tags: List[str] | None = None
85
 
86
- def load_resources() -> List[str]:
87
 
88
  torch.set_float32_matmul_precision('high')
89
 
@@ -107,7 +108,6 @@ def load_resources() -> List[str]:
107
  max_tags=8,
108
  ).to(device)
109
 
110
- checkpoint_path = "checkpoints/checkpoint_461260.safetensors"
111
  print(f"Loading checkpoint: {checkpoint_path}")
112
 
113
  state_dict = load_file(checkpoint_path, device=str(device))
@@ -139,7 +139,6 @@ def generate_audio(
139
  sample_steps: int,
140
  ) -> Tuple[Tuple[int, object], str]:
141
 
142
- load_resources()
143
  assert model is not None and vae is not None and rf_sampler is not None and device is not None
144
 
145
  if not tags:
@@ -178,8 +177,8 @@ def generate_audio(
178
 
179
  return (sr, audio_numpy), str(output_path)
180
 
181
- def build_interface() -> gr.Blocks:
182
- available_tags = load_resources()
183
 
184
  # Define preset tag combinations
185
  presets = [
@@ -259,7 +258,16 @@ def build_interface() -> gr.Blocks:
259
 
260
  return demo
261
 
262
- demo = build_interface()
263
-
264
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
265
  demo.launch()
 
3
  from typing import List, Tuple
4
  import uuid
5
  import json
6
+ import argparse
7
  import gradio as gr
8
  import torch
9
  import torchaudio
 
84
  device: torch.device | None = None
85
  _available_tags: List[str] | None = None
86
 
87
+ def load_resources(checkpoint_path) -> List[str]:
88
 
89
  torch.set_float32_matmul_precision('high')
90
 
 
108
  max_tags=8,
109
  ).to(device)
110
 
 
111
  print(f"Loading checkpoint: {checkpoint_path}")
112
 
113
  state_dict = load_file(checkpoint_path, device=str(device))
 
139
  sample_steps: int,
140
  ) -> Tuple[Tuple[int, object], str]:
141
 
 
142
  assert model is not None and vae is not None and rf_sampler is not None and device is not None
143
 
144
  if not tags:
 
177
 
178
  return (sr, audio_numpy), str(output_path)
179
 
180
+ def build_interface(checkpoint_path) -> gr.Blocks:
181
+ available_tags = load_resources(checkpoint_path)
182
 
183
  # Define preset tag combinations
184
  presets = [
 
258
 
259
  return demo
260
 
 
 
261
  if __name__ == "__main__":
262
+ parser = argparse.ArgumentParser(description="LocalSong Gradio Interface")
263
+ parser.add_argument(
264
+ "--checkpoint",
265
+ type=str,
266
+ default="checkpoints/checkpoint_461260.safetensors",
267
+ help="Path to the model checkpoint"
268
+ )
269
+ args = parser.parse_args()
270
+
271
+ demo = build_interface(args.checkpoint)
272
+
273
  demo.launch()
merge_lora.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import argparse
4
+ from safetensors.torch import load_file, save_file
5
+ from model import LocalSongModel
6
+ from pathlib import Path
7
+
8
+ class LoRALinear(nn.Module):
9
+ def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0):
10
+ super().__init__()
11
+ self.original_linear = original_linear
12
+ self.rank = rank
13
+ self.alpha = alpha
14
+ self.scaling = alpha / rank
15
+
16
+ self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank))
17
+ self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features))
18
+
19
+ nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
20
+ nn.init.zeros_(self.lora_B)
21
+
22
+ self.original_linear.weight.requires_grad = False
23
+ if self.original_linear.bias is not None:
24
+ self.original_linear.bias.requires_grad = False
25
+
26
+ def forward(self, x):
27
+ result = self.original_linear(x)
28
+ lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling
29
+ return result + lora_out
30
+
31
+ def inject_lora(model, rank=8, alpha=16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None):
32
+ if device is None:
33
+ device = next(model.parameters()).device
34
+
35
+ for name, module in model.named_modules():
36
+ if isinstance(module, nn.Linear):
37
+ if any(target in name for target in target_modules):
38
+ *parent_path, attr_name = name.split('.')
39
+ parent = model
40
+ for p in parent_path:
41
+ parent = getattr(parent, p)
42
+
43
+ lora_layer = LoRALinear(module, rank=rank, alpha=alpha)
44
+ lora_layer.lora_A.data = lora_layer.lora_A.data.to(device)
45
+ lora_layer.lora_B.data = lora_layer.lora_B.data.to(device)
46
+ setattr(parent, attr_name, lora_layer)
47
+
48
+ return model
49
+
50
+ def load_lora_weights(model, lora_path, device):
51
+ print(f"Loading LoRA from {lora_path}")
52
+ lora_state_dict = load_file(lora_path, device=str(device))
53
+
54
+ loaded_count = 0
55
+ for name, module in model.named_modules():
56
+ if isinstance(module, LoRALinear):
57
+ lora_a_key = f"{name}.lora_A"
58
+ lora_b_key = f"{name}.lora_B"
59
+ if lora_a_key in lora_state_dict and lora_b_key in lora_state_dict:
60
+ module.lora_A.data = lora_state_dict[lora_a_key].to(device)
61
+ module.lora_B.data = lora_state_dict[lora_b_key].to(device)
62
+ loaded_count += 2
63
+
64
+ print(f"Loaded {loaded_count} LoRA parameters")
65
+
66
+ def merge_lora_into_model(model):
67
+ """
68
+ Merge LoRA weights into the base model weights.
69
+ For each LoRALinear layer: W_merged = W_original + (lora_A @ lora_B) * scaling
70
+ """
71
+ print("\nMerging LoRA weights into base model...")
72
+ merged_count = 0
73
+
74
+ for name, module in model.named_modules():
75
+ if isinstance(module, LoRALinear):
76
+ lora_delta = (module.lora_A @ module.lora_B) * module.scaling
77
+
78
+ with torch.no_grad():
79
+ module.original_linear.weight.data += lora_delta.T
80
+
81
+ merged_count += 1
82
+
83
+ print(f"Merged {merged_count} LoRA layers into base weights")
84
+
85
+ def extract_base_weights(model):
86
+ """
87
+ Extract the merged weights from LoRALinear modules back into a regular state dict.
88
+ """
89
+ print("\nExtracting merged weights...")
90
+ new_state_dict = {}
91
+
92
+ for name, module in model.named_modules():
93
+ if isinstance(module, LoRALinear):
94
+ original_name_weight = f"{name}.weight"
95
+ original_name_bias = f"{name}.bias"
96
+
97
+ new_state_dict[original_name_weight] = module.original_linear.weight.data
98
+ if module.original_linear.bias is not None:
99
+ new_state_dict[original_name_bias] = module.original_linear.bias.data
100
+
101
+ # Copy over all non-LoRA parameters
102
+ for name, param in model.named_parameters():
103
+ if 'lora_A' not in name and 'lora_B' not in name and 'original_linear' not in name:
104
+ new_state_dict[name] = param.data
105
+
106
+ print(f"Extracted {len(new_state_dict)} parameters")
107
+ return new_state_dict
108
+
109
+ def main():
110
+ parser = argparse.ArgumentParser(description="Merge LoRA weights into a base model checkpoint")
111
+ parser.add_argument(
112
+ "--base-checkpoint",
113
+ type=str,
114
+ default="checkpoints/checkpoint_461260.safetensors",
115
+ help="Path to the base model checkpoint"
116
+ )
117
+ parser.add_argument(
118
+ "--lora-checkpoint",
119
+ type=str,
120
+ default="lora.safetensors",
121
+ help="Path to the LoRA checkpoint"
122
+ )
123
+ parser.add_argument(
124
+ "--output-checkpoint",
125
+ type=str,
126
+ default="checkpoints/checkpoint_461260_merged_lora.safetensors",
127
+ help="Path to save the merged checkpoint"
128
+ )
129
+ args = parser.parse_args()
130
+
131
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
132
+ print(f"Using device: {device}")
133
+
134
+ # Configuration
135
+ base_checkpoint = args.base_checkpoint
136
+ lora_checkpoint = args.lora_checkpoint
137
+ output_checkpoint = args.output_checkpoint
138
+
139
+ lora_rank = 16
140
+ lora_alpha = 16.0
141
+
142
+ print(f"\nBase checkpoint: {base_checkpoint}")
143
+ print(f"LoRA checkpoint: {lora_checkpoint}")
144
+ print(f"Output checkpoint: {output_checkpoint}")
145
+ print(f"LoRA rank: {lora_rank}, alpha: {lora_alpha}")
146
+
147
+ # Load base model
148
+ print("\nLoading base model...")
149
+ model = LocalSongModel(
150
+ in_channels=8,
151
+ num_groups=16,
152
+ hidden_size=1024,
153
+ decoder_hidden_size=2048,
154
+ num_blocks=36,
155
+ patch_size=(16, 1),
156
+ num_classes=2304,
157
+ max_tags=8,
158
+ ).to(device)
159
+
160
+ state_dict = load_file(base_checkpoint, device=str(device))
161
+ model.load_state_dict(state_dict, strict=True)
162
+ print("Base model loaded")
163
+
164
+ print("\nInjecting LoRA layers...")
165
+ model = inject_lora(model, rank=lora_rank, alpha=lora_alpha, device=device)
166
+
167
+ load_lora_weights(model, lora_checkpoint, device)
168
+
169
+ merge_lora_into_model(model)
170
+
171
+ merged_state_dict = extract_base_weights(model)
172
+
173
+ print(f"\nSaving merged checkpoint to {output_checkpoint}...")
174
+ save_file(merged_state_dict, output_checkpoint)
175
+ print("✓ Merged checkpoint saved successfully!")
176
+
177
+ print("\nVerifying merged checkpoint...")
178
+ test_model = LocalSongModel(
179
+ in_channels=8,
180
+ num_groups=16,
181
+ hidden_size=1024,
182
+ decoder_hidden_size=2048,
183
+ num_blocks=36,
184
+ patch_size=(16, 1),
185
+ num_classes=2304,
186
+ max_tags=8,
187
+ ).to(device)
188
+
189
+ merged_loaded = load_file(output_checkpoint, device=str(device))
190
+ test_model.load_state_dict(merged_loaded, strict=True)
191
+ print("✓ Merged checkpoint verified successfully!")
192
+
193
+ print(f"\nDone! You can now use '{output_checkpoint}' as a standalone checkpoint without needing LoRA.")
194
+
195
+ if __name__ == '__main__':
196
+ main()
train_lora.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from pathlib import Path
6
+ import argparse
7
+ from tqdm import tqdm
8
+ from safetensors.torch import save_file, load_file
9
+ from collections import deque
10
+ from model import LocalSongModel
11
+
12
+ HARDCODED_TAGS = [1908]
13
+ torch.set_float32_matmul_precision('high')
14
+
15
+ class LoRALinear(nn.Module):
16
+ def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0):
17
+ super().__init__()
18
+ self.original_linear = original_linear
19
+ self.rank = rank
20
+ self.alpha = alpha
21
+ self.scaling = alpha / rank
22
+
23
+ self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank))
24
+ self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features))
25
+
26
+ nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
27
+ nn.init.zeros_(self.lora_B)
28
+
29
+ self.original_linear.weight.requires_grad = False
30
+ if self.original_linear.bias is not None:
31
+ self.original_linear.bias.requires_grad = False
32
+
33
+ def forward(self, x):
34
+ result = self.original_linear(x)
35
+ lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling
36
+ return result + lora_out
37
+
38
+ def inject_lora(model: LocalSongModel, rank: int = 8, alpha: float = 16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None):
39
+ """Inject LoRA layers into the model."""
40
+
41
+ lora_modules = []
42
+
43
+ if device is None:
44
+ device = next(model.parameters()).device
45
+
46
+ for name, module in model.named_modules():
47
+
48
+ if isinstance(module, nn.Linear):
49
+
50
+ if any(target in name for target in target_modules):
51
+
52
+ *parent_path, attr_name = name.split('.')
53
+ parent = model
54
+ for p in parent_path:
55
+ parent = getattr(parent, p)
56
+
57
+ lora_layer = LoRALinear(module, rank=rank, alpha=alpha)
58
+
59
+ lora_layer.lora_A.data = lora_layer.lora_A.data.to(device)
60
+ lora_layer.lora_B.data = lora_layer.lora_B.data.to(device)
61
+ setattr(parent, attr_name, lora_layer)
62
+ lora_modules.append(name)
63
+
64
+ print(f"Injected LoRA into {len(lora_modules)} layers:")
65
+ for name in lora_modules[:5]:
66
+ print(f" - {name}")
67
+ if len(lora_modules) > 5:
68
+ print(f" ... and {len(lora_modules) - 5} more")
69
+
70
+ return model
71
+
72
+ def get_lora_parameters(model):
73
+ """Extract only LoRA parameters for optimization."""
74
+ lora_params = []
75
+ for module in model.modules():
76
+ if isinstance(module, LoRALinear):
77
+ lora_params.extend([module.lora_A, module.lora_B])
78
+ return lora_params
79
+
80
+ def save_lora_weights(model, output_path):
81
+ """Save LoRA weights to a safetensors file."""
82
+ lora_state_dict = {}
83
+
84
+ for name, module in model.named_modules():
85
+ if isinstance(module, LoRALinear):
86
+ lora_state_dict[f"{name}.lora_A"] = module.lora_A
87
+ lora_state_dict[f"{name}.lora_B"] = module.lora_B
88
+
89
+ save_file(lora_state_dict, output_path)
90
+ print(f"Saved {len(lora_state_dict)} LoRA parameters to {output_path}")
91
+
92
+ class LatentDataset(Dataset):
93
+ """Dataset for pre-encoded latents."""
94
+
95
+ def __init__(self, latents_dir: str):
96
+ self.latents_dir = Path(latents_dir)
97
+
98
+ self.latent_files = sorted(list(self.latents_dir.glob("*.pt")))
99
+
100
+ if len(self.latent_files) == 0:
101
+ raise ValueError(f"No .pt files found in {latents_dir}")
102
+
103
+ print(f"Found {len(self.latent_files)} latent files")
104
+
105
+ def __len__(self):
106
+ return len(self.latent_files)
107
+
108
+ def __getitem__(self, idx):
109
+ latent = torch.load(self.latent_files[idx])
110
+
111
+ if latent.ndim == 3:
112
+ latent = latent.unsqueeze(0)
113
+
114
+ return latent
115
+
116
+ class RectifiedFlow:
117
+ """Simplified rectified flow matching."""
118
+
119
+ def __init__(self, model):
120
+ self.model = model
121
+
122
+ def forward(self, x, cond):
123
+ """Compute flow matching loss."""
124
+ b = x.size(0)
125
+
126
+ nt = torch.randn((b,), device=x.device)
127
+ t = torch.sigmoid(nt)
128
+
129
+ texp = t.view([b, *([1] * len(x.shape[1:]))])
130
+ z1 = torch.randn_like(x)
131
+ zt = (1 - texp) * x + texp * z1
132
+
133
+ vtheta = self.model(zt, t, cond)
134
+
135
+ target = z1 - x
136
+ loss = ((vtheta - target) ** 2).mean()
137
+
138
+ return loss
139
+
140
+ def collate_fn(batch, subsection_length=1024):
141
+ """Custom collate function to sample random subsections."""
142
+ sampled_latents = []
143
+
144
+ for latent in batch:
145
+ if latent.ndim == 3:
146
+ latent = latent.unsqueeze(0)
147
+
148
+ _, _, _, width = latent.shape
149
+
150
+ if width < subsection_length:
151
+ # Pad if too short
152
+ pad_amount = subsection_length - width
153
+ latent = torch.nn.functional.pad(latent, (0, pad_amount), mode='constant', value=0)
154
+ else:
155
+ # Randomly sample subsection
156
+ max_start = width - subsection_length
157
+ start_idx = torch.randint(0, max_start + 1, (1,)).item()
158
+ latent = latent[:, :, :, start_idx:start_idx + subsection_length]
159
+
160
+ sampled_latents.append(latent.squeeze(0))
161
+
162
+ batch_latents = torch.stack(sampled_latents)
163
+
164
+ batch_tags = [HARDCODED_TAGS] * len(batch)
165
+
166
+ return batch_latents, batch_tags
167
+
168
+ def main():
169
+ parser = argparse.ArgumentParser(description='LoRA training for LocalSong model with embedding training')
170
+
171
+ parser.add_argument('--latents_dir', type=str, required=True,
172
+ help='Directory containing VAE-encoded latents (.pt files)')
173
+
174
+ parser.add_argument('--checkpoint', type=str, default='checkpoints/checkpoint_461260.safetensors',
175
+ help='Path to base model checkpoint')
176
+ parser.add_argument('--lora_rank', type=int, default=16,
177
+ help='LoRA rank')
178
+ parser.add_argument('--lora_alpha', type=float, default=16,
179
+ help='LoRA alpha (scaling factor)')
180
+ parser.add_argument('--batch_size', type=int, default=16,
181
+ help='Batch size')
182
+ parser.add_argument('--lr', type=float, default=2e-4,
183
+ help='Learning rate')
184
+ parser.add_argument('--steps', type=int, default=1500,
185
+ help='Number of training steps')
186
+ parser.add_argument('--subsection_length', type=int, default=512,
187
+ help='Latent subsection length')
188
+ parser.add_argument('--output', type=str, default='lora.safetensors',
189
+ help='Output path for LoRA weights')
190
+ parser.add_argument('--save_every', type=int, default=500,
191
+ help='Save checkpoint every N steps')
192
+
193
+ args = parser.parse_args()
194
+
195
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
196
+ print(f"Using device: {device}")
197
+
198
+ print(f"Using hardcoded tags: {HARDCODED_TAGS}")
199
+
200
+ print(f"Loading base model from {args.checkpoint}")
201
+ model = LocalSongModel(
202
+ in_channels=8,
203
+ num_groups=16,
204
+ hidden_size=1024,
205
+ decoder_hidden_size=2048,
206
+ num_blocks=36,
207
+ patch_size=(16, 1),
208
+ num_classes=2304,
209
+ max_tags=8,
210
+ )
211
+
212
+ print(f"Loading checkpoint from {args.checkpoint}")
213
+ state_dict = load_file(args.checkpoint)
214
+ model.load_state_dict(state_dict, strict=True)
215
+ print("Base model loaded")
216
+
217
+ model = model.to(device)
218
+ model = inject_lora(model, rank=args.lora_rank, alpha=args.lora_alpha, device=device)
219
+
220
+ model.train()
221
+
222
+ lora_params = get_lora_parameters(model)
223
+ optimizer = optim.Adam(lora_params, lr=args.lr)
224
+ print(f"Training {len(lora_params)} LoRA parameters")
225
+
226
+ dataset = LatentDataset(args.latents_dir)
227
+ dataloader = DataLoader(
228
+ dataset,
229
+ batch_size=args.batch_size,
230
+ shuffle=True,
231
+ num_workers=0,
232
+ collate_fn=lambda batch: collate_fn(batch, args.subsection_length)
233
+ )
234
+
235
+ rf = RectifiedFlow(model)
236
+
237
+ print("\nStarting training...")
238
+ step = 0
239
+ pbar = tqdm(total=args.steps, desc="Training")
240
+
241
+ loss_history = deque(maxlen=50)
242
+
243
+ while step < args.steps:
244
+ for batch_latents, batch_tags in dataloader:
245
+ batch_latents = batch_latents.to(device)
246
+
247
+ optimizer.zero_grad()
248
+ loss = rf.forward(batch_latents, batch_tags)
249
+
250
+ loss.backward()
251
+ torch.nn.utils.clip_grad_norm_(lora_params, 1.0)
252
+ optimizer.step()
253
+
254
+ # Track loss and compute average
255
+ loss_history.append(loss.item())
256
+ avg_loss = sum(loss_history) / len(loss_history)
257
+
258
+ pbar.set_postfix({'loss': f'{avg_loss:.4f}'})
259
+ pbar.update(1)
260
+ step += 1
261
+
262
+ if step % args.save_every == 0:
263
+ save_path = args.output.replace('.safetensors', f'_step{step}.safetensors')
264
+ save_lora_weights(model, save_path)
265
+
266
+ if step >= args.steps:
267
+ break
268
+
269
+ save_lora_weights(model, args.output)
270
+ print(f"\nTraining complete! LoRA weights saved to {args.output}")
271
+
272
+ if __name__ == '__main__':
273
+ main()
train_lora_encode_latents.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from pathlib import Path
4
+ import argparse
5
+ from tqdm import tqdm
6
+ from acestep.music_dcae.music_dcae_pipeline import MusicDCAE
7
+
8
+ class AudioVAE:
9
+ def __init__(self, device: torch.device):
10
+ self.model = MusicDCAE().to(device)
11
+ self.model.eval()
12
+ self.device = device
13
+ self.latent_mean = torch.tensor(
14
+ [0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526],
15
+ device=device,
16
+ ).view(1, -1, 1, 1)
17
+ self.latent_std = torch.tensor(
18
+ [0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707],
19
+ device=device,
20
+ ).view(1, -1, 1, 1)
21
+
22
+ def encode(self, audio):
23
+
24
+ with torch.no_grad():
25
+ audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device)
26
+ latents, _ = self.model.encode(audio, audio_lengths, sr=48000)
27
+ latents = (latents - self.latent_mean) / self.latent_std
28
+ return latents
29
+
30
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
31
+ with torch.no_grad():
32
+ latents = latents * self.latent_std + self.latent_mean
33
+ _, audio_list = self.model.decode(latents, sr=48000)
34
+ audio_batch = torch.stack(audio_list).to(self.device)
35
+ return audio_batch
36
+
37
+ def load_audio(audio_path, target_sr=48000):
38
+ """Load and preprocess audio file."""
39
+ audio, sr = torchaudio.load(audio_path)
40
+
41
+ if audio.shape[0] == 1:
42
+ audio = audio.repeat(2, 1)
43
+ elif audio.shape[0] > 2:
44
+ audio = audio[:2]
45
+
46
+ if sr != target_sr:
47
+ resampler = torchaudio.transforms.Resample(sr, target_sr)
48
+ audio = resampler(audio)
49
+
50
+ return audio
51
+
52
+
53
+ def main():
54
+ parser = argparse.ArgumentParser(description='Encode audio files to VAE latents')
55
+
56
+ parser.add_argument('--audio-dir', type=str, required=True,
57
+ help='Directory containing audio files')
58
+ parser.add_argument('--output-dir', type=str, default="latents",
59
+ help='Directory to save encoded latents')
60
+
61
+ args = parser.parse_args()
62
+
63
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
+ print(f"Using device: {device}")
65
+
66
+ output_dir = Path(args.output_dir)
67
+ output_dir.mkdir(parents=True, exist_ok=True)
68
+
69
+ audio_dir = Path(args.audio_dir)
70
+ audio_extensions = ['*.mp3', '*.wav', '*.flac', '*.ogg', '*.m4a']
71
+ audio_files = []
72
+ for ext in audio_extensions:
73
+ audio_files.extend(list(audio_dir.glob(ext)))
74
+ audio_files = sorted(audio_files)
75
+
76
+ if len(audio_files) == 0:
77
+ raise ValueError(f"No audio files found in {args.audio_dir}")
78
+
79
+ print(f"Found {len(audio_files)} audio files")
80
+
81
+ vae = AudioVAE(device)
82
+ print("VAE loaded")
83
+
84
+ # Encode each audio file
85
+ print("\nEncoding audio files...")
86
+ for audio_path in tqdm(audio_files, desc="Encoding"):
87
+ try:
88
+ audio = load_audio(audio_path)
89
+ audio = audio.unsqueeze(0).to(device)
90
+ latents = vae.encode(audio)
91
+ latents = latents.squeeze(0)
92
+
93
+ output_path = output_dir / f"{audio_path.stem}.pt"
94
+ torch.save(latents.cpu(), output_path)
95
+
96
+ except Exception as e:
97
+ print(f"\nError encoding {audio_path.name}: {e}")
98
+ continue
99
+
100
+ print(f"\nEncoding complete! Saved {len(list(output_dir.glob('*.pt')))} latent files to {output_dir}")
101
+
102
+ if __name__ == '__main__':
103
+ main()