astralite-heart commited on
Commit
820ba22
·
verified ·
1 Parent(s): 9a23325

Create convert_simpletuner_lora.py

Browse files
Files changed (1) hide show
  1. lora/convert_simpletuner_lora.py +483 -0
lora/convert_simpletuner_lora.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert SimpleTuner LoRA weights to diffusers-compatible format for AuraFlow.
4
+
5
+ This script converts LoRA weights saved by SimpleTuner into a format that can be
6
+ directly loaded by diffusers' load_lora_weights() method.
7
+
8
+ Usage:
9
+ python convert_simpletuner_lora.py <input_lora.safetensors> <output_lora.safetensors>
10
+
11
+ Example:
12
+ python convert_simpletuner_lora.py input_lora.safetensors diffusers_compatible_lora.safetensors
13
+ """
14
+
15
+ import argparse
16
+ import sys
17
+ from pathlib import Path
18
+ from typing import Dict
19
+
20
+ import safetensors.torch
21
+ import torch
22
+
23
+
24
+ def detect_lora_format(state_dict: Dict[str, torch.Tensor]) -> str:
25
+ """
26
+ Detect the format of the LoRA state dict.
27
+
28
+ Returns:
29
+ "peft" if already in PEFT/diffusers format
30
+ "mixed" if mixed format (some lora_A/B, some lora.down/up)
31
+ "simpletuner_transformer" if in SimpleTuner format with transformer prefix
32
+ "simpletuner_auraflow" if in SimpleTuner AuraFlow format
33
+ "kohya" if in Kohya format
34
+ "unknown" otherwise
35
+ """
36
+ keys = list(state_dict.keys())
37
+
38
+ # Check the actual weight naming convention (lora_A/lora_B vs lora_down/lora_up)
39
+ has_lora_a_b = any((".lora_A." in k or ".lora_B." in k) for k in keys)
40
+ has_lora_down_up = any((".lora_down." in k or ".lora_up." in k) for k in keys)
41
+ has_lora_dot_down_up = any((".lora.down." in k or ".lora.up." in k) for k in keys)
42
+
43
+ # Check prefixes
44
+ has_transformer_prefix = any(k.startswith("transformer.") for k in keys)
45
+ has_lora_transformer_prefix = any(k.startswith("lora_transformer_") for k in keys)
46
+ has_lora_unet_prefix = any(k.startswith("lora_unet_") for k in keys)
47
+
48
+ # Mixed format: has both lora_A/B AND lora.down/up (SimpleTuner hybrid)
49
+ if has_transformer_prefix and has_lora_a_b and (has_lora_down_up or has_lora_dot_down_up):
50
+ return "mixed"
51
+
52
+ # Pure PEFT format: transformer.* with ONLY lora_A/lora_B
53
+ if has_transformer_prefix and has_lora_a_b and not has_lora_down_up and not has_lora_dot_down_up:
54
+ return "peft"
55
+
56
+ # SimpleTuner with transformer prefix but old naming: transformer.* with lora_down/lora_up
57
+ if has_transformer_prefix and (has_lora_down_up or has_lora_dot_down_up):
58
+ return "simpletuner_transformer"
59
+
60
+ # SimpleTuner AuraFlow format: lora_transformer_* with lora_down/lora_up
61
+ if has_lora_transformer_prefix and has_lora_down_up:
62
+ return "simpletuner_auraflow"
63
+
64
+ # Traditional Kohya format: lora_unet_* with lora_down/lora_up
65
+ if has_lora_unet_prefix and has_lora_down_up:
66
+ return "kohya"
67
+
68
+ return "unknown"
69
+
70
+
71
+ def convert_mixed_lora_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
72
+ """
73
+ Convert mixed LoRA format to pure PEFT format.
74
+
75
+ SimpleTuner sometimes saves a hybrid format where some layers use lora_A/lora_B
76
+ and others use .lora.down./.lora.up. This converts all to lora_A/lora_B.
77
+ """
78
+ new_state_dict = {}
79
+ converted_count = 0
80
+ kept_count = 0
81
+ skipped_count = 0
82
+ renames = []
83
+
84
+ # Get all keys
85
+ all_keys = sorted(state_dict.keys())
86
+
87
+ print("\nProcessing keys:")
88
+ print("-" * 80)
89
+
90
+ for key in all_keys:
91
+ # Already in correct format (lora_A or lora_B)
92
+ if ".lora_A." in key or ".lora_B." in key:
93
+ new_state_dict[key] = state_dict[key]
94
+ kept_count += 1
95
+
96
+ # Needs conversion: .lora.down. -> .lora_A.
97
+ elif ".lora.down.weight" in key:
98
+ new_key = key.replace(".lora.down.weight", ".lora_A.weight")
99
+ new_state_dict[new_key] = state_dict[key]
100
+ renames.append((key, new_key))
101
+ converted_count += 1
102
+
103
+ # Needs conversion: .lora.up. -> .lora_B.
104
+ elif ".lora.up.weight" in key:
105
+ new_key = key.replace(".lora.up.weight", ".lora_B.weight")
106
+ new_state_dict[new_key] = state_dict[key]
107
+ renames.append((key, new_key))
108
+ converted_count += 1
109
+
110
+ # Skip alpha keys (not used in PEFT format)
111
+ elif ".alpha" in key:
112
+ skipped_count += 1
113
+ continue
114
+
115
+ # Other keys (shouldn't happen, but keep them just in case)
116
+ else:
117
+ new_state_dict[key] = state_dict[key]
118
+ print(f"⚠ Warning: Unexpected key format: {key}")
119
+
120
+ print(f"\nSummary:")
121
+ print(f" ✓ Kept {kept_count} keys already in correct format (lora_A/lora_B)")
122
+ print(f" ✓ Converted {converted_count} keys from .lora.down/.lora.up to lora_A/lora_B")
123
+ print(f" ✓ Skipped {skipped_count} alpha keys")
124
+
125
+ if renames:
126
+ print(f"\nRenames applied ({len(renames)} conversions):")
127
+ print("-" * 80)
128
+ for old_key, new_key in renames:
129
+ # Show the difference more clearly
130
+ if ".lora.down.weight" in old_key:
131
+ layer = old_key.replace(".lora.down.weight", "")
132
+ print(f" {layer}")
133
+ print(f" .lora.down.weight → .lora_A.weight")
134
+ elif ".lora.up.weight" in old_key:
135
+ layer = old_key.replace(".lora.up.weight", "")
136
+ print(f" {layer}")
137
+ print(f" .lora.up.weight → .lora_B.weight")
138
+
139
+ return new_state_dict
140
+
141
+
142
+ def convert_simpletuner_transformer_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
143
+ """
144
+ Convert SimpleTuner transformer format (already has transformer. prefix but uses lora_down/lora_up)
145
+ to diffusers PEFT format (transformer. prefix with lora_A/lora_B).
146
+
147
+ This is a simpler conversion since the key structure is already correct.
148
+ """
149
+ new_state_dict = {}
150
+ renames = []
151
+
152
+ # Get all unique LoRA layer base names (without .lora_down/.lora_up/.alpha suffix)
153
+ all_keys = list(state_dict.keys())
154
+ base_keys = set()
155
+
156
+ for key in all_keys:
157
+ if ".lora_down.weight" in key:
158
+ base_key = key.replace(".lora_down.weight", "")
159
+ base_keys.add(base_key)
160
+
161
+ print(f"\nFound {len(base_keys)} LoRA layers to convert")
162
+ print("-" * 80)
163
+
164
+ # Convert each layer
165
+ for base_key in sorted(base_keys):
166
+ down_key = f"{base_key}.lora_down.weight"
167
+ up_key = f"{base_key}.lora_up.weight"
168
+ alpha_key = f"{base_key}.alpha"
169
+
170
+ if down_key not in state_dict or up_key not in state_dict:
171
+ print(f"⚠ Warning: Missing weights for {base_key}")
172
+ continue
173
+
174
+ down_weight = state_dict.pop(down_key)
175
+ up_weight = state_dict.pop(up_key)
176
+
177
+ # Handle alpha scaling
178
+ has_alpha = False
179
+ if alpha_key in state_dict:
180
+ alpha = state_dict.pop(alpha_key)
181
+ lora_rank = down_weight.shape[0]
182
+ scale = alpha / lora_rank
183
+
184
+ # Calculate scale_down and scale_up to preserve the scale value
185
+ scale_down = scale
186
+ scale_up = 1.0
187
+ while scale_down * 2 < scale_up:
188
+ scale_down *= 2
189
+ scale_up /= 2
190
+
191
+ down_weight = down_weight * scale_down
192
+ up_weight = up_weight * scale_up
193
+ has_alpha = True
194
+
195
+ # Store in PEFT format (lora_A = down, lora_B = up)
196
+ new_down_key = f"{base_key}.lora_A.weight"
197
+ new_up_key = f"{base_key}.lora_B.weight"
198
+
199
+ new_state_dict[new_down_key] = down_weight
200
+ new_state_dict[new_up_key] = up_weight
201
+
202
+ renames.append((down_key, new_down_key, has_alpha))
203
+ renames.append((up_key, new_up_key, has_alpha))
204
+
205
+ # Check for any remaining keys
206
+ remaining = [k for k in state_dict.keys() if not k.startswith("text_encoder")]
207
+ if remaining:
208
+ print(f"⚠ Warning: {len(remaining)} keys were not converted: {remaining[:5]}")
209
+
210
+ print(f"\nRenames applied ({len(renames)} conversions):")
211
+ print("-" * 80)
212
+
213
+ # Group by layer
214
+ current_layer = None
215
+ for old_key, new_key, has_alpha in renames:
216
+ layer = old_key.replace(".lora_down.weight", "").replace(".lora_up.weight", "")
217
+
218
+ if layer != current_layer:
219
+ alpha_str = " (alpha scaled)" if has_alpha else ""
220
+ print(f"\n {layer}{alpha_str}")
221
+ current_layer = layer
222
+
223
+ if ".lora_down.weight" in old_key:
224
+ print(f" .lora_down.weight → .lora_A.weight")
225
+ elif ".lora_up.weight" in old_key:
226
+ print(f" .lora_up.weight → .lora_B.weight")
227
+
228
+ return new_state_dict
229
+
230
+
231
+ def convert_simpletuner_auraflow_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
232
+ """
233
+ Convert SimpleTuner AuraFlow LoRA format to diffusers PEFT format.
234
+
235
+ SimpleTuner typically saves LoRAs in a format similar to Kohya's sd-scripts,
236
+ but for transformer-based models like AuraFlow, the keys may differ.
237
+ """
238
+ new_state_dict = {}
239
+
240
+ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
241
+ """Helper to convert a single LoRA layer."""
242
+ down_key = f"{original_key}.lora_down.weight"
243
+ if down_key not in state_dict:
244
+ return False
245
+
246
+ down_weight = state_dict.pop(down_key)
247
+ lora_rank = down_weight.shape[0]
248
+
249
+ up_weight_key = f"{original_key}.lora_up.weight"
250
+ up_weight = state_dict.pop(up_weight_key)
251
+
252
+ # Handle alpha scaling
253
+ alpha_key = f"{original_key}.alpha"
254
+ if alpha_key in state_dict:
255
+ alpha = state_dict.pop(alpha_key)
256
+ scale = alpha / lora_rank
257
+
258
+ # Calculate scale_down and scale_up to preserve the scale value
259
+ scale_down = scale
260
+ scale_up = 1.0
261
+ while scale_down * 2 < scale_up:
262
+ scale_down *= 2
263
+ scale_up /= 2
264
+
265
+ down_weight = down_weight * scale_down
266
+ up_weight = up_weight * scale_up
267
+
268
+ # Store in PEFT format (lora_A = down, lora_B = up)
269
+ diffusers_down_key = f"{diffusers_key}.lora_A.weight"
270
+ new_state_dict[diffusers_down_key] = down_weight
271
+ new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
272
+
273
+ return True
274
+
275
+ # Get all unique LoRA layer names
276
+ all_unique_keys = {
277
+ k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
278
+ for k in state_dict
279
+ if ".lora_down.weight" in k or ".lora_up.weight" in k or ".alpha" in k
280
+ }
281
+
282
+ # Process transformer blocks
283
+ for original_key in sorted(all_unique_keys):
284
+ if original_key.startswith("lora_transformer_single_transformer_blocks_"):
285
+ # Single transformer blocks
286
+ parts = original_key.split("lora_transformer_single_transformer_blocks_")[-1].split("_")
287
+ block_idx = int(parts[0])
288
+ diffusers_key = f"single_transformer_blocks.{block_idx}"
289
+
290
+ # Map the rest of the key
291
+ remaining = "_".join(parts[1:])
292
+ if "attn_to_q" in remaining:
293
+ diffusers_key += ".attn.to_q"
294
+ elif "attn_to_k" in remaining:
295
+ diffusers_key += ".attn.to_k"
296
+ elif "attn_to_v" in remaining:
297
+ diffusers_key += ".attn.to_v"
298
+ elif "proj_out" in remaining:
299
+ diffusers_key += ".proj_out"
300
+ elif "proj_mlp" in remaining:
301
+ diffusers_key += ".proj_mlp"
302
+ elif "norm_linear" in remaining:
303
+ diffusers_key += ".norm.linear"
304
+ else:
305
+ print(f"Warning: Unhandled single block key pattern: {original_key}")
306
+ continue
307
+
308
+ elif original_key.startswith("lora_transformer_transformer_blocks_"):
309
+ # Double transformer blocks
310
+ parts = original_key.split("lora_transformer_transformer_blocks_")[-1].split("_")
311
+ block_idx = int(parts[0])
312
+ diffusers_key = f"transformer_blocks.{block_idx}"
313
+
314
+ # Map the rest of the key
315
+ remaining = "_".join(parts[1:])
316
+ if "attn_to_out_0" in remaining:
317
+ diffusers_key += ".attn.to_out.0"
318
+ elif "attn_to_add_out" in remaining:
319
+ diffusers_key += ".attn.to_add_out"
320
+ elif "attn_to_q" in remaining:
321
+ diffusers_key += ".attn.to_q"
322
+ elif "attn_to_k" in remaining:
323
+ diffusers_key += ".attn.to_k"
324
+ elif "attn_to_v" in remaining:
325
+ diffusers_key += ".attn.to_v"
326
+ elif "attn_add_q_proj" in remaining:
327
+ diffusers_key += ".attn.add_q_proj"
328
+ elif "attn_add_k_proj" in remaining:
329
+ diffusers_key += ".attn.add_k_proj"
330
+ elif "attn_add_v_proj" in remaining:
331
+ diffusers_key += ".attn.add_v_proj"
332
+ elif "ff_net_0_proj" in remaining:
333
+ diffusers_key += ".ff.net.0.proj"
334
+ elif "ff_net_2" in remaining:
335
+ diffusers_key += ".ff.net.2"
336
+ elif "ff_context_net_0_proj" in remaining:
337
+ diffusers_key += ".ff_context.net.0.proj"
338
+ elif "ff_context_net_2" in remaining:
339
+ diffusers_key += ".ff_context.net.2"
340
+ elif "norm1_linear" in remaining:
341
+ diffusers_key += ".norm1.linear"
342
+ elif "norm1_context_linear" in remaining:
343
+ diffusers_key += ".norm1_context.linear"
344
+ else:
345
+ print(f"Warning: Unhandled double block key pattern: {original_key}")
346
+ continue
347
+
348
+ elif original_key.startswith("lora_te1_") or original_key.startswith("lora_te_"):
349
+ # Text encoder keys - handle separately
350
+ print(f"Found text encoder key: {original_key}")
351
+ continue
352
+
353
+ else:
354
+ print(f"Warning: Unknown key pattern: {original_key}")
355
+ continue
356
+
357
+ # Perform the conversion
358
+ _convert(original_key, diffusers_key, state_dict, new_state_dict)
359
+
360
+ # Add "transformer." prefix to all keys
361
+ transformer_state_dict = {
362
+ f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.")
363
+ }
364
+
365
+ # Check for remaining unconverted keys
366
+ if len(state_dict) > 0:
367
+ remaining_keys = [k for k in state_dict.keys() if not k.startswith("lora_te")]
368
+ if remaining_keys:
369
+ print(f"Warning: Some keys were not converted: {remaining_keys[:10]}")
370
+
371
+ return transformer_state_dict
372
+
373
+
374
+ def convert_lora(input_path: str, output_path: str) -> None:
375
+ """
376
+ Main conversion function.
377
+
378
+ Args:
379
+ input_path: Path to input LoRA safetensors file
380
+ output_path: Path to output diffusers-compatible safetensors file
381
+ """
382
+ print(f"Loading LoRA from: {input_path}")
383
+ state_dict = safetensors.torch.load_file(input_path)
384
+
385
+ print(f"Detecting LoRA format...")
386
+ format_type = detect_lora_format(state_dict)
387
+ print(f"Detected format: {format_type}")
388
+
389
+ if format_type == "peft":
390
+ print("LoRA is already in diffusers-compatible PEFT format!")
391
+ print("No conversion needed. Copying file...")
392
+ import shutil
393
+ shutil.copy(input_path, output_path)
394
+ return
395
+
396
+ elif format_type == "mixed":
397
+ print("Converting MIXED format LoRA to pure PEFT format...")
398
+ print("(Some layers use lora_A/B, others use .lora.down/.lora.up)")
399
+ converted_state_dict = convert_mixed_lora_to_diffusers(state_dict.copy())
400
+
401
+ elif format_type == "simpletuner_transformer":
402
+ print("Converting SimpleTuner transformer format to diffusers...")
403
+ print("(has transformer. prefix but uses lora_down/lora_up naming)")
404
+ converted_state_dict = convert_simpletuner_transformer_to_diffusers(state_dict.copy())
405
+
406
+ elif format_type == "simpletuner_auraflow":
407
+ print("Converting SimpleTuner AuraFlow format to diffusers...")
408
+ converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy())
409
+
410
+ elif format_type == "kohya":
411
+ print("Note: Detected Kohya format. This converter is optimized for AuraFlow.")
412
+ print("For other models, diffusers has built-in conversion.")
413
+ converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy())
414
+
415
+ else:
416
+ print("Error: Unknown LoRA format!")
417
+ print("Sample keys from the state dict:")
418
+ for i, key in enumerate(list(state_dict.keys())[:20]):
419
+ print(f" {key}")
420
+ sys.exit(1)
421
+
422
+ print(f"Saving converted LoRA to: {output_path}")
423
+ safetensors.torch.save_file(converted_state_dict, output_path)
424
+
425
+ print("\nConversion complete!")
426
+ print(f"Original keys: {len(state_dict)}")
427
+ print(f"Converted keys: {len(converted_state_dict)}")
428
+
429
+ def main():
430
+ parser = argparse.ArgumentParser(
431
+ description="Convert SimpleTuner LoRA to diffusers-compatible format",
432
+ formatter_class=argparse.RawDescriptionHelpFormatter,
433
+ epilog="""
434
+ Examples:
435
+ # Convert a SimpleTuner LoRA for AuraFlow
436
+ python convert_simpletuner_lora.py my_lora.safetensors diffusers_lora.safetensors
437
+
438
+ # Check format without converting
439
+ python convert_simpletuner_lora.py my_lora.safetensors /tmp/test.safetensors
440
+ """
441
+ )
442
+
443
+ parser.add_argument(
444
+ "input",
445
+ type=str,
446
+ help="Input LoRA file (SimpleTuner format)"
447
+ )
448
+
449
+ parser.add_argument(
450
+ "output",
451
+ type=str,
452
+ help="Output LoRA file (diffusers-compatible format)"
453
+ )
454
+
455
+ parser.add_argument(
456
+ "--dry-run",
457
+ action="store_true",
458
+ help="Only detect format, don't convert"
459
+ )
460
+
461
+ args = parser.parse_args()
462
+
463
+ # Validate input file exists
464
+ if not Path(args.input).exists():
465
+ print(f"Error: Input file not found: {args.input}")
466
+ sys.exit(1)
467
+
468
+ if args.dry_run:
469
+ print(f"Loading LoRA from: {args.input}")
470
+ state_dict = safetensors.torch.load_file(args.input)
471
+ format_type = detect_lora_format(state_dict)
472
+ print(f"Detected format: {format_type}")
473
+ print(f"\nSample keys ({min(10, len(state_dict))} of {len(state_dict)}):")
474
+ for key in list(state_dict.keys())[:10]:
475
+ print(f" {key}")
476
+ return
477
+
478
+ convert_lora(args.input, args.output)
479
+
480
+
481
+ if __name__ == "__main__":
482
+ main()
483
+