Upload 5 files
Browse files- README.md +9 -0
- gradio_app.py +15 -7
- merge_lora.py +196 -0
- train_lora.py +273 -0
- train_lora_encode_latents.py +103 -0
README.md
CHANGED
|
@@ -44,6 +44,15 @@ The first generation will be slower due to torch.compile, then speed will increa
|
|
| 44 |
|
| 45 |
The model was trained on vocals but not lyrics. Vocals will not have recognizable words.
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
## Credits
|
| 48 |
|
| 49 |
This project builds upon the following open-source projects:
|
|
|
|
| 44 |
|
| 45 |
The model was trained on vocals but not lyrics. Vocals will not have recognizable words.
|
| 46 |
|
| 47 |
+
## LoRA Training
|
| 48 |
+
|
| 49 |
+
- Prepare folder of .mp3 files
|
| 50 |
+
- Run python train_lora_encode_latents.py --audio-dir=/path/to/your/mp3s --output-dir=latents to save the latents
|
| 51 |
+
- Run python train_lora.py --latents_dir=latents to train the LoRA. You may need to adjust learning rate, steps or batch size depending on your dataset etc.
|
| 52 |
+
- Run python merge_lora.py --lora-checkpoint=lora_step1000.safetensors --output-checkpoint=merged.safetensors to merge the LoRA checkpoint into the base model for inference
|
| 53 |
+
- Run python gradio_app.py --checkpoint=merged.safetensors to run the merged checkpoint for inference
|
| 54 |
+
- Test inference with tag "soundtrack"; Lora training uses this tag. Additional tags may work.
|
| 55 |
+
|
| 56 |
## Credits
|
| 57 |
|
| 58 |
This project builds upon the following open-source projects:
|
gradio_app.py
CHANGED
|
@@ -3,6 +3,7 @@ from pathlib import Path
|
|
| 3 |
from typing import List, Tuple
|
| 4 |
import uuid
|
| 5 |
import json
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
import torch
|
| 8 |
import torchaudio
|
|
@@ -83,7 +84,7 @@ rf_sampler: RF | None = None
|
|
| 83 |
device: torch.device | None = None
|
| 84 |
_available_tags: List[str] | None = None
|
| 85 |
|
| 86 |
-
def load_resources() -> List[str]:
|
| 87 |
|
| 88 |
torch.set_float32_matmul_precision('high')
|
| 89 |
|
|
@@ -107,7 +108,6 @@ def load_resources() -> List[str]:
|
|
| 107 |
max_tags=8,
|
| 108 |
).to(device)
|
| 109 |
|
| 110 |
-
checkpoint_path = "checkpoints/checkpoint_461260.safetensors"
|
| 111 |
print(f"Loading checkpoint: {checkpoint_path}")
|
| 112 |
|
| 113 |
state_dict = load_file(checkpoint_path, device=str(device))
|
|
@@ -139,7 +139,6 @@ def generate_audio(
|
|
| 139 |
sample_steps: int,
|
| 140 |
) -> Tuple[Tuple[int, object], str]:
|
| 141 |
|
| 142 |
-
load_resources()
|
| 143 |
assert model is not None and vae is not None and rf_sampler is not None and device is not None
|
| 144 |
|
| 145 |
if not tags:
|
|
@@ -178,8 +177,8 @@ def generate_audio(
|
|
| 178 |
|
| 179 |
return (sr, audio_numpy), str(output_path)
|
| 180 |
|
| 181 |
-
def build_interface() -> gr.Blocks:
|
| 182 |
-
available_tags = load_resources()
|
| 183 |
|
| 184 |
# Define preset tag combinations
|
| 185 |
presets = [
|
|
@@ -259,7 +258,16 @@ def build_interface() -> gr.Blocks:
|
|
| 259 |
|
| 260 |
return demo
|
| 261 |
|
| 262 |
-
demo = build_interface()
|
| 263 |
-
|
| 264 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
demo.launch()
|
|
|
|
| 3 |
from typing import List, Tuple
|
| 4 |
import uuid
|
| 5 |
import json
|
| 6 |
+
import argparse
|
| 7 |
import gradio as gr
|
| 8 |
import torch
|
| 9 |
import torchaudio
|
|
|
|
| 84 |
device: torch.device | None = None
|
| 85 |
_available_tags: List[str] | None = None
|
| 86 |
|
| 87 |
+
def load_resources(checkpoint_path) -> List[str]:
|
| 88 |
|
| 89 |
torch.set_float32_matmul_precision('high')
|
| 90 |
|
|
|
|
| 108 |
max_tags=8,
|
| 109 |
).to(device)
|
| 110 |
|
|
|
|
| 111 |
print(f"Loading checkpoint: {checkpoint_path}")
|
| 112 |
|
| 113 |
state_dict = load_file(checkpoint_path, device=str(device))
|
|
|
|
| 139 |
sample_steps: int,
|
| 140 |
) -> Tuple[Tuple[int, object], str]:
|
| 141 |
|
|
|
|
| 142 |
assert model is not None and vae is not None and rf_sampler is not None and device is not None
|
| 143 |
|
| 144 |
if not tags:
|
|
|
|
| 177 |
|
| 178 |
return (sr, audio_numpy), str(output_path)
|
| 179 |
|
| 180 |
+
def build_interface(checkpoint_path) -> gr.Blocks:
|
| 181 |
+
available_tags = load_resources(checkpoint_path)
|
| 182 |
|
| 183 |
# Define preset tag combinations
|
| 184 |
presets = [
|
|
|
|
| 258 |
|
| 259 |
return demo
|
| 260 |
|
|
|
|
|
|
|
| 261 |
if __name__ == "__main__":
|
| 262 |
+
parser = argparse.ArgumentParser(description="LocalSong Gradio Interface")
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--checkpoint",
|
| 265 |
+
type=str,
|
| 266 |
+
default="checkpoints/checkpoint_461260.safetensors",
|
| 267 |
+
help="Path to the model checkpoint"
|
| 268 |
+
)
|
| 269 |
+
args = parser.parse_args()
|
| 270 |
+
|
| 271 |
+
demo = build_interface(args.checkpoint)
|
| 272 |
+
|
| 273 |
demo.launch()
|
merge_lora.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import argparse
|
| 4 |
+
from safetensors.torch import load_file, save_file
|
| 5 |
+
from model import LocalSongModel
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
class LoRALinear(nn.Module):
|
| 9 |
+
def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.original_linear = original_linear
|
| 12 |
+
self.rank = rank
|
| 13 |
+
self.alpha = alpha
|
| 14 |
+
self.scaling = alpha / rank
|
| 15 |
+
|
| 16 |
+
self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank))
|
| 17 |
+
self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features))
|
| 18 |
+
|
| 19 |
+
nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
|
| 20 |
+
nn.init.zeros_(self.lora_B)
|
| 21 |
+
|
| 22 |
+
self.original_linear.weight.requires_grad = False
|
| 23 |
+
if self.original_linear.bias is not None:
|
| 24 |
+
self.original_linear.bias.requires_grad = False
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
result = self.original_linear(x)
|
| 28 |
+
lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling
|
| 29 |
+
return result + lora_out
|
| 30 |
+
|
| 31 |
+
def inject_lora(model, rank=8, alpha=16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None):
|
| 32 |
+
if device is None:
|
| 33 |
+
device = next(model.parameters()).device
|
| 34 |
+
|
| 35 |
+
for name, module in model.named_modules():
|
| 36 |
+
if isinstance(module, nn.Linear):
|
| 37 |
+
if any(target in name for target in target_modules):
|
| 38 |
+
*parent_path, attr_name = name.split('.')
|
| 39 |
+
parent = model
|
| 40 |
+
for p in parent_path:
|
| 41 |
+
parent = getattr(parent, p)
|
| 42 |
+
|
| 43 |
+
lora_layer = LoRALinear(module, rank=rank, alpha=alpha)
|
| 44 |
+
lora_layer.lora_A.data = lora_layer.lora_A.data.to(device)
|
| 45 |
+
lora_layer.lora_B.data = lora_layer.lora_B.data.to(device)
|
| 46 |
+
setattr(parent, attr_name, lora_layer)
|
| 47 |
+
|
| 48 |
+
return model
|
| 49 |
+
|
| 50 |
+
def load_lora_weights(model, lora_path, device):
|
| 51 |
+
print(f"Loading LoRA from {lora_path}")
|
| 52 |
+
lora_state_dict = load_file(lora_path, device=str(device))
|
| 53 |
+
|
| 54 |
+
loaded_count = 0
|
| 55 |
+
for name, module in model.named_modules():
|
| 56 |
+
if isinstance(module, LoRALinear):
|
| 57 |
+
lora_a_key = f"{name}.lora_A"
|
| 58 |
+
lora_b_key = f"{name}.lora_B"
|
| 59 |
+
if lora_a_key in lora_state_dict and lora_b_key in lora_state_dict:
|
| 60 |
+
module.lora_A.data = lora_state_dict[lora_a_key].to(device)
|
| 61 |
+
module.lora_B.data = lora_state_dict[lora_b_key].to(device)
|
| 62 |
+
loaded_count += 2
|
| 63 |
+
|
| 64 |
+
print(f"Loaded {loaded_count} LoRA parameters")
|
| 65 |
+
|
| 66 |
+
def merge_lora_into_model(model):
|
| 67 |
+
"""
|
| 68 |
+
Merge LoRA weights into the base model weights.
|
| 69 |
+
For each LoRALinear layer: W_merged = W_original + (lora_A @ lora_B) * scaling
|
| 70 |
+
"""
|
| 71 |
+
print("\nMerging LoRA weights into base model...")
|
| 72 |
+
merged_count = 0
|
| 73 |
+
|
| 74 |
+
for name, module in model.named_modules():
|
| 75 |
+
if isinstance(module, LoRALinear):
|
| 76 |
+
lora_delta = (module.lora_A @ module.lora_B) * module.scaling
|
| 77 |
+
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
module.original_linear.weight.data += lora_delta.T
|
| 80 |
+
|
| 81 |
+
merged_count += 1
|
| 82 |
+
|
| 83 |
+
print(f"Merged {merged_count} LoRA layers into base weights")
|
| 84 |
+
|
| 85 |
+
def extract_base_weights(model):
|
| 86 |
+
"""
|
| 87 |
+
Extract the merged weights from LoRALinear modules back into a regular state dict.
|
| 88 |
+
"""
|
| 89 |
+
print("\nExtracting merged weights...")
|
| 90 |
+
new_state_dict = {}
|
| 91 |
+
|
| 92 |
+
for name, module in model.named_modules():
|
| 93 |
+
if isinstance(module, LoRALinear):
|
| 94 |
+
original_name_weight = f"{name}.weight"
|
| 95 |
+
original_name_bias = f"{name}.bias"
|
| 96 |
+
|
| 97 |
+
new_state_dict[original_name_weight] = module.original_linear.weight.data
|
| 98 |
+
if module.original_linear.bias is not None:
|
| 99 |
+
new_state_dict[original_name_bias] = module.original_linear.bias.data
|
| 100 |
+
|
| 101 |
+
# Copy over all non-LoRA parameters
|
| 102 |
+
for name, param in model.named_parameters():
|
| 103 |
+
if 'lora_A' not in name and 'lora_B' not in name and 'original_linear' not in name:
|
| 104 |
+
new_state_dict[name] = param.data
|
| 105 |
+
|
| 106 |
+
print(f"Extracted {len(new_state_dict)} parameters")
|
| 107 |
+
return new_state_dict
|
| 108 |
+
|
| 109 |
+
def main():
|
| 110 |
+
parser = argparse.ArgumentParser(description="Merge LoRA weights into a base model checkpoint")
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--base-checkpoint",
|
| 113 |
+
type=str,
|
| 114 |
+
default="checkpoints/checkpoint_461260.safetensors",
|
| 115 |
+
help="Path to the base model checkpoint"
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--lora-checkpoint",
|
| 119 |
+
type=str,
|
| 120 |
+
default="lora.safetensors",
|
| 121 |
+
help="Path to the LoRA checkpoint"
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--output-checkpoint",
|
| 125 |
+
type=str,
|
| 126 |
+
default="checkpoints/checkpoint_461260_merged_lora.safetensors",
|
| 127 |
+
help="Path to save the merged checkpoint"
|
| 128 |
+
)
|
| 129 |
+
args = parser.parse_args()
|
| 130 |
+
|
| 131 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 132 |
+
print(f"Using device: {device}")
|
| 133 |
+
|
| 134 |
+
# Configuration
|
| 135 |
+
base_checkpoint = args.base_checkpoint
|
| 136 |
+
lora_checkpoint = args.lora_checkpoint
|
| 137 |
+
output_checkpoint = args.output_checkpoint
|
| 138 |
+
|
| 139 |
+
lora_rank = 16
|
| 140 |
+
lora_alpha = 16.0
|
| 141 |
+
|
| 142 |
+
print(f"\nBase checkpoint: {base_checkpoint}")
|
| 143 |
+
print(f"LoRA checkpoint: {lora_checkpoint}")
|
| 144 |
+
print(f"Output checkpoint: {output_checkpoint}")
|
| 145 |
+
print(f"LoRA rank: {lora_rank}, alpha: {lora_alpha}")
|
| 146 |
+
|
| 147 |
+
# Load base model
|
| 148 |
+
print("\nLoading base model...")
|
| 149 |
+
model = LocalSongModel(
|
| 150 |
+
in_channels=8,
|
| 151 |
+
num_groups=16,
|
| 152 |
+
hidden_size=1024,
|
| 153 |
+
decoder_hidden_size=2048,
|
| 154 |
+
num_blocks=36,
|
| 155 |
+
patch_size=(16, 1),
|
| 156 |
+
num_classes=2304,
|
| 157 |
+
max_tags=8,
|
| 158 |
+
).to(device)
|
| 159 |
+
|
| 160 |
+
state_dict = load_file(base_checkpoint, device=str(device))
|
| 161 |
+
model.load_state_dict(state_dict, strict=True)
|
| 162 |
+
print("Base model loaded")
|
| 163 |
+
|
| 164 |
+
print("\nInjecting LoRA layers...")
|
| 165 |
+
model = inject_lora(model, rank=lora_rank, alpha=lora_alpha, device=device)
|
| 166 |
+
|
| 167 |
+
load_lora_weights(model, lora_checkpoint, device)
|
| 168 |
+
|
| 169 |
+
merge_lora_into_model(model)
|
| 170 |
+
|
| 171 |
+
merged_state_dict = extract_base_weights(model)
|
| 172 |
+
|
| 173 |
+
print(f"\nSaving merged checkpoint to {output_checkpoint}...")
|
| 174 |
+
save_file(merged_state_dict, output_checkpoint)
|
| 175 |
+
print("✓ Merged checkpoint saved successfully!")
|
| 176 |
+
|
| 177 |
+
print("\nVerifying merged checkpoint...")
|
| 178 |
+
test_model = LocalSongModel(
|
| 179 |
+
in_channels=8,
|
| 180 |
+
num_groups=16,
|
| 181 |
+
hidden_size=1024,
|
| 182 |
+
decoder_hidden_size=2048,
|
| 183 |
+
num_blocks=36,
|
| 184 |
+
patch_size=(16, 1),
|
| 185 |
+
num_classes=2304,
|
| 186 |
+
max_tags=8,
|
| 187 |
+
).to(device)
|
| 188 |
+
|
| 189 |
+
merged_loaded = load_file(output_checkpoint, device=str(device))
|
| 190 |
+
test_model.load_state_dict(merged_loaded, strict=True)
|
| 191 |
+
print("✓ Merged checkpoint verified successfully!")
|
| 192 |
+
|
| 193 |
+
print(f"\nDone! You can now use '{output_checkpoint}' as a standalone checkpoint without needing LoRA.")
|
| 194 |
+
|
| 195 |
+
if __name__ == '__main__':
|
| 196 |
+
main()
|
train_lora.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import argparse
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from safetensors.torch import save_file, load_file
|
| 9 |
+
from collections import deque
|
| 10 |
+
from model import LocalSongModel
|
| 11 |
+
|
| 12 |
+
HARDCODED_TAGS = [1908]
|
| 13 |
+
torch.set_float32_matmul_precision('high')
|
| 14 |
+
|
| 15 |
+
class LoRALinear(nn.Module):
|
| 16 |
+
def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.original_linear = original_linear
|
| 19 |
+
self.rank = rank
|
| 20 |
+
self.alpha = alpha
|
| 21 |
+
self.scaling = alpha / rank
|
| 22 |
+
|
| 23 |
+
self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank))
|
| 24 |
+
self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features))
|
| 25 |
+
|
| 26 |
+
nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
|
| 27 |
+
nn.init.zeros_(self.lora_B)
|
| 28 |
+
|
| 29 |
+
self.original_linear.weight.requires_grad = False
|
| 30 |
+
if self.original_linear.bias is not None:
|
| 31 |
+
self.original_linear.bias.requires_grad = False
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
result = self.original_linear(x)
|
| 35 |
+
lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling
|
| 36 |
+
return result + lora_out
|
| 37 |
+
|
| 38 |
+
def inject_lora(model: LocalSongModel, rank: int = 8, alpha: float = 16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None):
|
| 39 |
+
"""Inject LoRA layers into the model."""
|
| 40 |
+
|
| 41 |
+
lora_modules = []
|
| 42 |
+
|
| 43 |
+
if device is None:
|
| 44 |
+
device = next(model.parameters()).device
|
| 45 |
+
|
| 46 |
+
for name, module in model.named_modules():
|
| 47 |
+
|
| 48 |
+
if isinstance(module, nn.Linear):
|
| 49 |
+
|
| 50 |
+
if any(target in name for target in target_modules):
|
| 51 |
+
|
| 52 |
+
*parent_path, attr_name = name.split('.')
|
| 53 |
+
parent = model
|
| 54 |
+
for p in parent_path:
|
| 55 |
+
parent = getattr(parent, p)
|
| 56 |
+
|
| 57 |
+
lora_layer = LoRALinear(module, rank=rank, alpha=alpha)
|
| 58 |
+
|
| 59 |
+
lora_layer.lora_A.data = lora_layer.lora_A.data.to(device)
|
| 60 |
+
lora_layer.lora_B.data = lora_layer.lora_B.data.to(device)
|
| 61 |
+
setattr(parent, attr_name, lora_layer)
|
| 62 |
+
lora_modules.append(name)
|
| 63 |
+
|
| 64 |
+
print(f"Injected LoRA into {len(lora_modules)} layers:")
|
| 65 |
+
for name in lora_modules[:5]:
|
| 66 |
+
print(f" - {name}")
|
| 67 |
+
if len(lora_modules) > 5:
|
| 68 |
+
print(f" ... and {len(lora_modules) - 5} more")
|
| 69 |
+
|
| 70 |
+
return model
|
| 71 |
+
|
| 72 |
+
def get_lora_parameters(model):
|
| 73 |
+
"""Extract only LoRA parameters for optimization."""
|
| 74 |
+
lora_params = []
|
| 75 |
+
for module in model.modules():
|
| 76 |
+
if isinstance(module, LoRALinear):
|
| 77 |
+
lora_params.extend([module.lora_A, module.lora_B])
|
| 78 |
+
return lora_params
|
| 79 |
+
|
| 80 |
+
def save_lora_weights(model, output_path):
|
| 81 |
+
"""Save LoRA weights to a safetensors file."""
|
| 82 |
+
lora_state_dict = {}
|
| 83 |
+
|
| 84 |
+
for name, module in model.named_modules():
|
| 85 |
+
if isinstance(module, LoRALinear):
|
| 86 |
+
lora_state_dict[f"{name}.lora_A"] = module.lora_A
|
| 87 |
+
lora_state_dict[f"{name}.lora_B"] = module.lora_B
|
| 88 |
+
|
| 89 |
+
save_file(lora_state_dict, output_path)
|
| 90 |
+
print(f"Saved {len(lora_state_dict)} LoRA parameters to {output_path}")
|
| 91 |
+
|
| 92 |
+
class LatentDataset(Dataset):
|
| 93 |
+
"""Dataset for pre-encoded latents."""
|
| 94 |
+
|
| 95 |
+
def __init__(self, latents_dir: str):
|
| 96 |
+
self.latents_dir = Path(latents_dir)
|
| 97 |
+
|
| 98 |
+
self.latent_files = sorted(list(self.latents_dir.glob("*.pt")))
|
| 99 |
+
|
| 100 |
+
if len(self.latent_files) == 0:
|
| 101 |
+
raise ValueError(f"No .pt files found in {latents_dir}")
|
| 102 |
+
|
| 103 |
+
print(f"Found {len(self.latent_files)} latent files")
|
| 104 |
+
|
| 105 |
+
def __len__(self):
|
| 106 |
+
return len(self.latent_files)
|
| 107 |
+
|
| 108 |
+
def __getitem__(self, idx):
|
| 109 |
+
latent = torch.load(self.latent_files[idx])
|
| 110 |
+
|
| 111 |
+
if latent.ndim == 3:
|
| 112 |
+
latent = latent.unsqueeze(0)
|
| 113 |
+
|
| 114 |
+
return latent
|
| 115 |
+
|
| 116 |
+
class RectifiedFlow:
|
| 117 |
+
"""Simplified rectified flow matching."""
|
| 118 |
+
|
| 119 |
+
def __init__(self, model):
|
| 120 |
+
self.model = model
|
| 121 |
+
|
| 122 |
+
def forward(self, x, cond):
|
| 123 |
+
"""Compute flow matching loss."""
|
| 124 |
+
b = x.size(0)
|
| 125 |
+
|
| 126 |
+
nt = torch.randn((b,), device=x.device)
|
| 127 |
+
t = torch.sigmoid(nt)
|
| 128 |
+
|
| 129 |
+
texp = t.view([b, *([1] * len(x.shape[1:]))])
|
| 130 |
+
z1 = torch.randn_like(x)
|
| 131 |
+
zt = (1 - texp) * x + texp * z1
|
| 132 |
+
|
| 133 |
+
vtheta = self.model(zt, t, cond)
|
| 134 |
+
|
| 135 |
+
target = z1 - x
|
| 136 |
+
loss = ((vtheta - target) ** 2).mean()
|
| 137 |
+
|
| 138 |
+
return loss
|
| 139 |
+
|
| 140 |
+
def collate_fn(batch, subsection_length=1024):
|
| 141 |
+
"""Custom collate function to sample random subsections."""
|
| 142 |
+
sampled_latents = []
|
| 143 |
+
|
| 144 |
+
for latent in batch:
|
| 145 |
+
if latent.ndim == 3:
|
| 146 |
+
latent = latent.unsqueeze(0)
|
| 147 |
+
|
| 148 |
+
_, _, _, width = latent.shape
|
| 149 |
+
|
| 150 |
+
if width < subsection_length:
|
| 151 |
+
# Pad if too short
|
| 152 |
+
pad_amount = subsection_length - width
|
| 153 |
+
latent = torch.nn.functional.pad(latent, (0, pad_amount), mode='constant', value=0)
|
| 154 |
+
else:
|
| 155 |
+
# Randomly sample subsection
|
| 156 |
+
max_start = width - subsection_length
|
| 157 |
+
start_idx = torch.randint(0, max_start + 1, (1,)).item()
|
| 158 |
+
latent = latent[:, :, :, start_idx:start_idx + subsection_length]
|
| 159 |
+
|
| 160 |
+
sampled_latents.append(latent.squeeze(0))
|
| 161 |
+
|
| 162 |
+
batch_latents = torch.stack(sampled_latents)
|
| 163 |
+
|
| 164 |
+
batch_tags = [HARDCODED_TAGS] * len(batch)
|
| 165 |
+
|
| 166 |
+
return batch_latents, batch_tags
|
| 167 |
+
|
| 168 |
+
def main():
|
| 169 |
+
parser = argparse.ArgumentParser(description='LoRA training for LocalSong model with embedding training')
|
| 170 |
+
|
| 171 |
+
parser.add_argument('--latents_dir', type=str, required=True,
|
| 172 |
+
help='Directory containing VAE-encoded latents (.pt files)')
|
| 173 |
+
|
| 174 |
+
parser.add_argument('--checkpoint', type=str, default='checkpoints/checkpoint_461260.safetensors',
|
| 175 |
+
help='Path to base model checkpoint')
|
| 176 |
+
parser.add_argument('--lora_rank', type=int, default=16,
|
| 177 |
+
help='LoRA rank')
|
| 178 |
+
parser.add_argument('--lora_alpha', type=float, default=16,
|
| 179 |
+
help='LoRA alpha (scaling factor)')
|
| 180 |
+
parser.add_argument('--batch_size', type=int, default=16,
|
| 181 |
+
help='Batch size')
|
| 182 |
+
parser.add_argument('--lr', type=float, default=2e-4,
|
| 183 |
+
help='Learning rate')
|
| 184 |
+
parser.add_argument('--steps', type=int, default=1500,
|
| 185 |
+
help='Number of training steps')
|
| 186 |
+
parser.add_argument('--subsection_length', type=int, default=512,
|
| 187 |
+
help='Latent subsection length')
|
| 188 |
+
parser.add_argument('--output', type=str, default='lora.safetensors',
|
| 189 |
+
help='Output path for LoRA weights')
|
| 190 |
+
parser.add_argument('--save_every', type=int, default=500,
|
| 191 |
+
help='Save checkpoint every N steps')
|
| 192 |
+
|
| 193 |
+
args = parser.parse_args()
|
| 194 |
+
|
| 195 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 196 |
+
print(f"Using device: {device}")
|
| 197 |
+
|
| 198 |
+
print(f"Using hardcoded tags: {HARDCODED_TAGS}")
|
| 199 |
+
|
| 200 |
+
print(f"Loading base model from {args.checkpoint}")
|
| 201 |
+
model = LocalSongModel(
|
| 202 |
+
in_channels=8,
|
| 203 |
+
num_groups=16,
|
| 204 |
+
hidden_size=1024,
|
| 205 |
+
decoder_hidden_size=2048,
|
| 206 |
+
num_blocks=36,
|
| 207 |
+
patch_size=(16, 1),
|
| 208 |
+
num_classes=2304,
|
| 209 |
+
max_tags=8,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
print(f"Loading checkpoint from {args.checkpoint}")
|
| 213 |
+
state_dict = load_file(args.checkpoint)
|
| 214 |
+
model.load_state_dict(state_dict, strict=True)
|
| 215 |
+
print("Base model loaded")
|
| 216 |
+
|
| 217 |
+
model = model.to(device)
|
| 218 |
+
model = inject_lora(model, rank=args.lora_rank, alpha=args.lora_alpha, device=device)
|
| 219 |
+
|
| 220 |
+
model.train()
|
| 221 |
+
|
| 222 |
+
lora_params = get_lora_parameters(model)
|
| 223 |
+
optimizer = optim.Adam(lora_params, lr=args.lr)
|
| 224 |
+
print(f"Training {len(lora_params)} LoRA parameters")
|
| 225 |
+
|
| 226 |
+
dataset = LatentDataset(args.latents_dir)
|
| 227 |
+
dataloader = DataLoader(
|
| 228 |
+
dataset,
|
| 229 |
+
batch_size=args.batch_size,
|
| 230 |
+
shuffle=True,
|
| 231 |
+
num_workers=0,
|
| 232 |
+
collate_fn=lambda batch: collate_fn(batch, args.subsection_length)
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
rf = RectifiedFlow(model)
|
| 236 |
+
|
| 237 |
+
print("\nStarting training...")
|
| 238 |
+
step = 0
|
| 239 |
+
pbar = tqdm(total=args.steps, desc="Training")
|
| 240 |
+
|
| 241 |
+
loss_history = deque(maxlen=50)
|
| 242 |
+
|
| 243 |
+
while step < args.steps:
|
| 244 |
+
for batch_latents, batch_tags in dataloader:
|
| 245 |
+
batch_latents = batch_latents.to(device)
|
| 246 |
+
|
| 247 |
+
optimizer.zero_grad()
|
| 248 |
+
loss = rf.forward(batch_latents, batch_tags)
|
| 249 |
+
|
| 250 |
+
loss.backward()
|
| 251 |
+
torch.nn.utils.clip_grad_norm_(lora_params, 1.0)
|
| 252 |
+
optimizer.step()
|
| 253 |
+
|
| 254 |
+
# Track loss and compute average
|
| 255 |
+
loss_history.append(loss.item())
|
| 256 |
+
avg_loss = sum(loss_history) / len(loss_history)
|
| 257 |
+
|
| 258 |
+
pbar.set_postfix({'loss': f'{avg_loss:.4f}'})
|
| 259 |
+
pbar.update(1)
|
| 260 |
+
step += 1
|
| 261 |
+
|
| 262 |
+
if step % args.save_every == 0:
|
| 263 |
+
save_path = args.output.replace('.safetensors', f'_step{step}.safetensors')
|
| 264 |
+
save_lora_weights(model, save_path)
|
| 265 |
+
|
| 266 |
+
if step >= args.steps:
|
| 267 |
+
break
|
| 268 |
+
|
| 269 |
+
save_lora_weights(model, args.output)
|
| 270 |
+
print(f"\nTraining complete! LoRA weights saved to {args.output}")
|
| 271 |
+
|
| 272 |
+
if __name__ == '__main__':
|
| 273 |
+
main()
|
train_lora_encode_latents.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchaudio
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import argparse
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from acestep.music_dcae.music_dcae_pipeline import MusicDCAE
|
| 7 |
+
|
| 8 |
+
class AudioVAE:
|
| 9 |
+
def __init__(self, device: torch.device):
|
| 10 |
+
self.model = MusicDCAE().to(device)
|
| 11 |
+
self.model.eval()
|
| 12 |
+
self.device = device
|
| 13 |
+
self.latent_mean = torch.tensor(
|
| 14 |
+
[0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526],
|
| 15 |
+
device=device,
|
| 16 |
+
).view(1, -1, 1, 1)
|
| 17 |
+
self.latent_std = torch.tensor(
|
| 18 |
+
[0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707],
|
| 19 |
+
device=device,
|
| 20 |
+
).view(1, -1, 1, 1)
|
| 21 |
+
|
| 22 |
+
def encode(self, audio):
|
| 23 |
+
|
| 24 |
+
with torch.no_grad():
|
| 25 |
+
audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device)
|
| 26 |
+
latents, _ = self.model.encode(audio, audio_lengths, sr=48000)
|
| 27 |
+
latents = (latents - self.latent_mean) / self.latent_std
|
| 28 |
+
return latents
|
| 29 |
+
|
| 30 |
+
def decode(self, latents: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
latents = latents * self.latent_std + self.latent_mean
|
| 33 |
+
_, audio_list = self.model.decode(latents, sr=48000)
|
| 34 |
+
audio_batch = torch.stack(audio_list).to(self.device)
|
| 35 |
+
return audio_batch
|
| 36 |
+
|
| 37 |
+
def load_audio(audio_path, target_sr=48000):
|
| 38 |
+
"""Load and preprocess audio file."""
|
| 39 |
+
audio, sr = torchaudio.load(audio_path)
|
| 40 |
+
|
| 41 |
+
if audio.shape[0] == 1:
|
| 42 |
+
audio = audio.repeat(2, 1)
|
| 43 |
+
elif audio.shape[0] > 2:
|
| 44 |
+
audio = audio[:2]
|
| 45 |
+
|
| 46 |
+
if sr != target_sr:
|
| 47 |
+
resampler = torchaudio.transforms.Resample(sr, target_sr)
|
| 48 |
+
audio = resampler(audio)
|
| 49 |
+
|
| 50 |
+
return audio
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main():
|
| 54 |
+
parser = argparse.ArgumentParser(description='Encode audio files to VAE latents')
|
| 55 |
+
|
| 56 |
+
parser.add_argument('--audio-dir', type=str, required=True,
|
| 57 |
+
help='Directory containing audio files')
|
| 58 |
+
parser.add_argument('--output-dir', type=str, default="latents",
|
| 59 |
+
help='Directory to save encoded latents')
|
| 60 |
+
|
| 61 |
+
args = parser.parse_args()
|
| 62 |
+
|
| 63 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 64 |
+
print(f"Using device: {device}")
|
| 65 |
+
|
| 66 |
+
output_dir = Path(args.output_dir)
|
| 67 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 68 |
+
|
| 69 |
+
audio_dir = Path(args.audio_dir)
|
| 70 |
+
audio_extensions = ['*.mp3', '*.wav', '*.flac', '*.ogg', '*.m4a']
|
| 71 |
+
audio_files = []
|
| 72 |
+
for ext in audio_extensions:
|
| 73 |
+
audio_files.extend(list(audio_dir.glob(ext)))
|
| 74 |
+
audio_files = sorted(audio_files)
|
| 75 |
+
|
| 76 |
+
if len(audio_files) == 0:
|
| 77 |
+
raise ValueError(f"No audio files found in {args.audio_dir}")
|
| 78 |
+
|
| 79 |
+
print(f"Found {len(audio_files)} audio files")
|
| 80 |
+
|
| 81 |
+
vae = AudioVAE(device)
|
| 82 |
+
print("VAE loaded")
|
| 83 |
+
|
| 84 |
+
# Encode each audio file
|
| 85 |
+
print("\nEncoding audio files...")
|
| 86 |
+
for audio_path in tqdm(audio_files, desc="Encoding"):
|
| 87 |
+
try:
|
| 88 |
+
audio = load_audio(audio_path)
|
| 89 |
+
audio = audio.unsqueeze(0).to(device)
|
| 90 |
+
latents = vae.encode(audio)
|
| 91 |
+
latents = latents.squeeze(0)
|
| 92 |
+
|
| 93 |
+
output_path = output_dir / f"{audio_path.stem}.pt"
|
| 94 |
+
torch.save(latents.cpu(), output_path)
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"\nError encoding {audio_path.name}: {e}")
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
print(f"\nEncoding complete! Saved {len(list(output_dir.glob('*.pt')))} latent files to {output_dir}")
|
| 101 |
+
|
| 102 |
+
if __name__ == '__main__':
|
| 103 |
+
main()
|