codemichaeld commited on
Commit
2a57dcf
Β·
verified Β·
1 Parent(s): 1fd8c55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +390 -132
app.py CHANGED
@@ -10,6 +10,8 @@ from huggingface_hub import HfApi, hf_hub_download
10
  from safetensors.torch import load_file, save_file
11
  import torch
12
  import torch.nn.functional as F
 
 
13
  try:
14
  from modelscope.hub.file_download import model_file_download as ms_file_download
15
  from modelscope.hub.api import HubApi as ModelScopeApi
@@ -17,75 +19,182 @@ try:
17
  except ImportError:
18
  MODELScope_AVAILABLE = False
19
 
20
- def low_rank_decomposition(weight, rank=64):
21
- """
22
- Correct LoRA decomposition supporting 2D and 4D tensors.
23
- Returns (lora_A, lora_B) such that weight β‰ˆ lora_B @ lora_A for 2D,
24
- or appropriate conv form for 4D.
25
- """
26
  original_shape = weight.shape
27
  original_dtype = weight.dtype
28
-
29
  try:
 
30
  if weight.ndim == 2:
31
- actual_rank = min(rank, min(weight.shape) // 2)
32
- if actual_rank < 4:
33
- return None, None
34
-
35
  U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
36
- S_sqrt = torch.sqrt(S[:actual_rank])
37
-
38
- # Standard LoRA factorization: W β‰ˆ W_B @ W_A
39
- W_A = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous() # [rank, in_features]
40
- W_B = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous() # [out_features, rank]
41
-
42
- return W_A.to(original_dtype), W_B.to(original_dtype)
43
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  elif weight.ndim == 4:
45
- out_ch, in_ch, k_h, k_w = weight.shape
46
- if k_h * k_w <= 9: # small conv kernels (e.g., 3x3)
47
- # Reshape to 2D: [out_ch, in_ch * k_h * k_w]
48
- weight_2d = weight.view(out_ch, -1)
49
- actual_rank = min(rank, min(weight_2d.shape) // 2)
50
- if actual_rank < 4:
51
- return None, None
52
-
53
- U, S, Vh = torch.linalg.svd(weight_2d.float(), full_matrices=False)
54
- S_sqrt = torch.sqrt(S[:actual_rank])
55
-
56
- W_A_2d = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous()
57
- W_B_2d = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous()
58
-
59
- # Reshape back to conv format
60
- W_A = W_A_2d.view(actual_rank, in_ch, k_h, k_w).contiguous()
61
- W_B = W_B_2d.view(out_ch, actual_rank, 1, 1).contiguous()
62
-
63
- return W_A.to(original_dtype), W_B.to(original_dtype)
64
-
65
- return None, None
66
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  except Exception as e:
68
- print(f"Decomposition error for {original_shape}: {e}")
69
- return None, None
70
-
71
- def should_apply_lora(key, weight, architecture="auto"):
72
- """Architecture-aware LoRA eligibility."""
73
- lower_key = key.lower()
74
-
75
- # Skip bias, norm, and tiny tensors
76
- if 'bias' in lower_key or 'norm' in lower_key or weight.numel() < 256:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  return False
78
-
 
 
 
 
 
 
 
 
 
 
 
79
  if architecture == "text_encoder":
80
- return any(t in lower_key for t in ['emb', 'embed', 'attn', 'mlp'])
 
 
 
81
  elif architecture == "unet_transformer":
82
- return any(t in lower_key for t in ['attn', 'transformer', 'to_q', 'to_k', 'to_v', 'to_out'])
 
 
 
83
  elif architecture == "unet_conv":
84
- return any(t in lower_key for t in ['conv', 'resnet', 'down', 'up', 'skip'])
 
 
 
85
  elif architecture == "vae":
86
- return any(t in lower_key for t in ['encoder', 'decoder', 'quant', 'post_quant', 'pre_quant'])
87
- else: # "auto" or "all"
88
- return weight.ndim in [2, 4]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_format, lora_rank=64, architecture="auto", progress=gr.Progress()):
91
  progress(0.1, desc="Starting FP8 conversion with LoRA extraction...")
@@ -96,84 +205,196 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
96
  header_json = f.read(header_size).decode('utf-8')
97
  header = json.loads(header_json)
98
  return header.get('__metadata__', {})
99
-
100
  metadata = read_safetensors_metadata(safetensors_path)
101
  progress(0.2, desc="Loaded metadata.")
102
-
103
  state_dict = load_file(safetensors_path)
104
  progress(0.4, desc="Loaded weights.")
105
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if fp8_format == "e5m2":
107
  fp8_dtype = torch.float8_e5m2
108
  else:
109
  fp8_dtype = torch.float8_e4m3fn
110
-
111
  sd_fp8 = {}
112
  lora_weights = {}
113
- total = len(state_dict)
114
- lora_keys = []
115
-
116
  lora_stats = {
117
- 'total_layers': total,
 
118
  'layers_eligible': 0,
119
  'layers_processed': 0,
120
  'layers_skipped': [],
 
 
121
  }
122
-
 
 
 
123
  for i, key in enumerate(state_dict):
124
- progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
125
  weight = state_dict[key]
126
-
 
127
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
128
  fp8_weight = weight.to(fp8_dtype)
129
  sd_fp8[key] = fp8_weight
130
-
131
- if should_apply_lora(key, weight, architecture):
 
 
 
132
  lora_stats['layers_eligible'] += 1
133
-
134
  try:
135
- A, B = low_rank_decomposition(weight, rank=lora_rank)
136
- if A is not None and B is not None:
137
- lora_weights[f"lora_A.{key}"] = A.to(torch.float16)
138
- lora_weights[f"lora_B.{key}"] = B.to(torch.float16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  lora_keys.append(key)
140
  lora_stats['layers_processed'] += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  else:
142
- lora_stats['layers_skipped'].append(f"{key}: decomposition failed")
 
143
  except Exception as e:
144
- lora_stats['layers_skipped'].append(f"{key}: exception: {e}")
 
 
145
  else:
146
- reason = "filtered by architecture" if architecture != "auto" else "not 2D/4D or too small"
147
- lora_stats['layers_skipped'].append(f"{key}: skipped ({reason})")
148
  else:
149
  sd_fp8[key] = weight
150
- lora_stats['layers_skipped'].append(f"{key}: non-float dtype")
151
-
 
 
 
 
 
 
 
152
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
153
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
154
  lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors")
155
-
156
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
157
- save_file(lora_weights, lora_path, metadata={
158
- "format": "pt",
 
 
159
  "lora_rank": str(lora_rank),
160
  "architecture": architecture,
 
 
161
  "stats": json.dumps(lora_stats)
162
- })
163
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  progress(0.9, desc="Saved FP8 and LoRA files.")
165
  progress(1.0, desc="βœ… FP8 + LoRA extraction complete!")
166
-
167
- stats_msg = f"FP8 ({fp8_format}) and rank-{lora_rank} LoRA ({architecture}) saved.\n"
168
- stats_msg += f"Processed {lora_stats['layers_processed']}/{lora_stats['layers_eligible']} eligible layers."
169
  if lora_stats['layers_processed'] == 0:
170
- stats_msg += " ⚠️ No valid LoRA weights generated."
171
-
172
- return True, stats_msg, lora_stats
 
 
173
 
174
  except Exception as e:
175
- import traceback
176
- return False, f"Error: {str(e)}\n{traceback.format_exc()}", None
 
177
 
178
  def parse_hf_url(url):
179
  url = url.strip().rstrip("/")
@@ -251,9 +472,11 @@ def process_and_upload_fp8(
251
  return None, "❌ Hugging Face token required for source.", ""
252
  if target_type == "huggingface" and not hf_token:
253
  return None, "❌ Hugging Face token required for target.", ""
 
 
254
  if lora_rank < 4:
255
  return None, "❌ LoRA rank must be at least 4.", ""
256
-
257
  temp_dir = None
258
  output_dir = tempfile.mkdtemp()
259
  try:
@@ -261,24 +484,24 @@ def process_and_upload_fp8(
261
  safetensors_path, temp_dir = download_safetensors_file(
262
  source_type, repo_url, safetensors_filename, hf_token, progress
263
  )
264
-
265
  progress(0.25, desc=f"Converting to FP8 with LoRA ({architecture})...")
266
  success, msg, stats = convert_safetensors_to_fp8_with_lora(
267
  safetensors_path, output_dir, fp8_format, lora_rank, architecture, progress
268
  )
269
-
270
  if not success:
271
  return None, f"❌ Conversion failed: {msg}", ""
272
-
273
  progress(0.9, desc="Uploading...")
274
  repo_url_final = upload_to_target(
275
  target_type, new_repo_id, output_dir, fp8_format, architecture, hf_token, modelscope_token, private_repo
276
  )
277
-
278
  base_name = os.path.splitext(safetensors_filename)[0]
279
  lora_filename = f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors"
280
  fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors"
281
-
282
  readme = f"""---
283
  library_name: diffusers
284
  tags:
@@ -288,7 +511,7 @@ tags:
288
  - low-rank
289
  - diffusion
290
  - architecture-{architecture}
291
- - converted-by-gradio
292
  ---
293
  # FP8 Model with Low-Rank LoRA
294
  - **Source**: `{repo_url}`
@@ -299,6 +522,16 @@ tags:
299
  - **LoRA File**: `{lora_filename}`
300
  - **FP8 File**: `{fp8_filename}`
301
 
 
 
 
 
 
 
 
 
 
 
302
  ## Usage (Inference)
303
  ```python
304
  from safetensors.torch import load_file
@@ -311,26 +544,38 @@ lora_state = load_file("{lora_filename}")
311
  # Reconstruct approximate original weights
312
  reconstructed = {{}}
313
  for key in fp8_state:
314
- if f"lora_A.{{key}}" in lora_state and f"lora_B.{{key}}" in lora_state:
315
- A = lora_state[f"lora_A.{{key}}"].to(torch.float32)
316
- B = lora_state[f"lora_B.{{key}}"].to(torch.float32)
 
 
 
 
 
317
  if A.ndim == 2 and B.ndim == 2:
318
  lora_weight = B @ A
 
 
 
 
 
319
  else:
320
- # Conv LoRA: simplified reconstruction
321
- lora_weight = F.conv2d(fp8_state[key].unsqueeze(0).to(torch.float32), A, groups=1)[:, :B.shape[0]]
322
- lora_weight = lora_weight.squeeze(0) + F.conv2d(fp8_state[key].unsqueeze(0).to(torch.float32), B, groups=1).squeeze(0)
 
 
323
  reconstructed[key] = fp8_state[key].to(torch.float32) + lora_weight
324
  else:
325
  reconstructed[key] = fp8_state[key].to(torch.float32)
326
  ```
327
 
328
- > Requires PyTorch β‰₯ 2.1 for FP8 support. Use matching architecture during inference.
329
  """
330
-
331
  with open(os.path.join(output_dir, "README.md"), "w") as f:
332
  f.write(readme)
333
-
334
  if target_type == "huggingface":
335
  HfApi(token=hf_token).upload_file(
336
  path_or_fileobj=os.path.join(output_dir, "README.md"),
@@ -339,18 +584,27 @@ for key in fp8_state:
339
  repo_type="model",
340
  token=hf_token
341
  )
342
-
343
  progress(1.0, desc="βœ… Done!")
344
  result_html = f"""
345
  βœ… Success!
346
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
347
- Includes: FP8 + rank-{lora_rank} LoRA ({architecture}).
 
 
 
 
348
  """
 
 
 
349
  return gr.HTML(result_html), "βœ… FP8 + LoRA upload successful!", msg
350
-
351
  except Exception as e:
352
- import traceback
353
- return None, f"❌ Error: {str(e)}\n{traceback.format_exc()}", ""
 
 
354
  finally:
355
  if temp_dir:
356
  shutil.rmtree(temp_dir, ignore_errors=True)
@@ -358,17 +612,18 @@ Includes: FP8 + rank-{lora_rank} LoRA ({architecture}).
358
 
359
  with gr.Blocks(title="FP8 + LoRA Extractor (HF ↔ ModelScope)") as demo:
360
  gr.Markdown("# πŸ”„ Advanced FP8 Pruner with Architecture-Specific LoRA Extraction")
361
- gr.Markdown("Convert `.safetensors` β†’ **FP8** + **targeted LoRA** for precision recovery. Supports Hugging Face ↔ ModelScope.")
362
-
363
  with gr.Row():
364
  with gr.Column():
365
  source_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Source")
366
  repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id")
367
  safetensors_filename = gr.Textbox(label="Filename", placeholder="model.safetensors")
368
-
369
  with gr.Accordion("Advanced LoRA Settings", open=True):
370
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
371
  lora_rank = gr.Slider(minimum=4, maximum=256, step=4, value=64, label="LoRA Rank")
 
372
  architecture = gr.Dropdown(
373
  choices=[
374
  ("Auto-detect components", "auto"),
@@ -379,24 +634,25 @@ with gr.Blocks(title="FP8 + LoRA Extractor (HF ↔ ModelScope)") as demo:
379
  ("All components", "all")
380
  ],
381
  value="auto",
382
- label="Target Architecture"
 
383
  )
384
-
385
  with gr.Accordion("Authentication", open=False):
386
  hf_token = gr.Textbox(label="Hugging Face Token", type="password")
387
  modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE)
388
-
389
  with gr.Column():
390
  target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target")
391
  new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8-lora")
392
  private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False)
393
-
394
  status_output = gr.Markdown()
395
  detailed_log = gr.Textbox(label="Processing Log", interactive=False, lines=10)
396
-
397
  convert_btn = gr.Button("πŸš€ Convert & Upload", variant="primary")
398
  repo_link_output = gr.HTML()
399
-
400
  convert_btn.click(
401
  fn=process_and_upload_fp8,
402
  inputs=[
@@ -415,7 +671,7 @@ with gr.Blocks(title="FP8 + LoRA Extractor (HF ↔ ModelScope)") as demo:
415
  outputs=[repo_link_output, status_output, detailed_log],
416
  show_progress=True
417
  )
418
-
419
  gr.Examples(
420
  examples=[
421
  ["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main", "unet_diffusion_pytorch_model.safetensors", "e5m2", 64, "unet_transformer"],
@@ -425,15 +681,17 @@ with gr.Blocks(title="FP8 + LoRA Extractor (HF ↔ ModelScope)") as demo:
425
  inputs=[source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture],
426
  label="Example Conversions"
427
  )
428
-
429
  gr.Markdown("""
430
  ## πŸ’‘ Usage Tips
431
- - **Text Encoder**: Use rank 32–64 with `text_encoder` architecture
432
- - **UNet Attention**: Use `unet_transformer` with rank 64–128
433
- - **UNet Convolutions**: Use `unet_conv` with rank 16–32
434
- - **VAE**: Use `vae` with rank 16–32
435
- - **Auto Mode**: Let the tool analyze and select layers automatically
436
- - Higher ranks = better quality but larger LoRA files
 
 
437
  """)
438
 
439
  demo.launch()
 
10
  from safetensors.torch import load_file, save_file
11
  import torch
12
  import torch.nn.functional as F
13
+ import traceback
14
+ import math
15
  try:
16
  from modelscope.hub.file_download import model_file_download as ms_file_download
17
  from modelscope.hub.api import HubApi as ModelScopeApi
 
19
  except ImportError:
20
  MODELScope_AVAILABLE = False
21
 
22
+ def low_rank_decomposition(weight, rank=64, approximation_factor=0.8):
23
+ """Low-rank decomposition with controlled approximation error."""
 
 
 
 
24
  original_shape = weight.shape
25
  original_dtype = weight.dtype
26
+
27
  try:
28
+ # Handle 2D tensors (linear layers, attention)
29
  if weight.ndim == 2:
30
+ # Compute SVD
 
 
 
31
  U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
32
+
33
+ # Calculate how much variance we want to keep
34
+ total_variance = torch.sum(S ** 2)
35
+ cumulative_variance = torch.cumsum(S ** 2, dim=0)
36
+
37
+ # Find minimal rank that preserves approximation_factor of variance
38
+ minimal_rank = torch.searchsorted(cumulative_variance, approximation_factor * total_variance).item() + 1
39
+
40
+ # Use the smaller of: requested rank or minimal rank for approximation_factor
41
+ actual_rank = min(rank, len(S))
42
+
43
+ # If actual_rank is too close to full rank, reduce it to create meaningful approximation
44
+ if actual_rank > len(S) * 0.8: # If using more than 80% of full rank
45
+ actual_rank = max(min(rank // 2, len(S) // 2), 8) # Use half the requested rank
46
+
47
+ # Ensure we're actually approximating, not just reparameterizing
48
+ if actual_rank >= min(weight.shape):
49
+ # Force approximation by using lower rank
50
+ actual_rank = max(min(weight.shape) // 4, 8)
51
+
52
+ U_k = U[:, :actual_rank] @ torch.diag(torch.sqrt(S[:actual_rank]))
53
+ Vh_k = torch.diag(torch.sqrt(S[:actual_rank])) @ Vh[:actual_rank, :]
54
+
55
+ return U_k.contiguous(), Vh_k.contiguous()
56
+
57
+ # Handle 4D tensors (convolutional layers)
58
  elif weight.ndim == 4:
59
+ out_ch, in_ch, kH, kW = weight.shape
60
+
61
+ # Reshape to 2D for SVD
62
+ weight_2d = weight.view(out_ch, in_ch * kH * kW)
63
+
64
+ # Compute SVD on flattened version
65
+ U, S, Vh = torch.linalg.svd(weight_2d.float(), full_matrices=False)
66
+
67
+ # Calculate appropriate rank
68
+ total_variance = torch.sum(S ** 2)
69
+ cumulative_variance = torch.cumsum(S ** 2, dim=0)
70
+ minimal_rank = torch.searchsorted(cumulative_variance, approximation_factor * total_variance).item() + 1
71
+
72
+ # Adjust rank for convolutions - typically need lower ranks
73
+ conv_rank = min(rank // 2, len(S))
74
+ if conv_rank > len(S) * 0.7:
75
+ conv_rank = max(len(S) // 4, 8)
76
+
77
+ actual_rank = max(min(conv_rank, minimal_rank), 8)
78
+
79
+ # Decompose
80
+ U_k = U[:, :actual_rank] @ torch.diag(torch.sqrt(S[:actual_rank]))
81
+ Vh_k = torch.diag(torch.sqrt(S[:actual_rank])) @ Vh[:actual_rank, :]
82
+
83
+ # Reshape back to convolutional format
84
+ if kH == 1 and kW == 1: # 1x1 convolutions
85
+ U_k = U_k.view(out_ch, actual_rank, 1, 1)
86
+ Vh_k = Vh_k.view(actual_rank, in_ch, 1, 1)
87
+ else:
88
+ # For larger kernels, use spatial decomposition
89
+ U_k = U_k.view(out_ch, actual_rank, 1, 1)
90
+ Vh_k = Vh_k.view(actual_rank, in_ch, kH, kW)
91
+
92
+ return U_k.contiguous(), Vh_k.contiguous()
93
+
94
+ # Handle 1D tensors (biases, embeddings)
95
+ elif weight.ndim == 1:
96
+ # Don't decompose 1D tensors
97
+ return None, None
98
+
99
  except Exception as e:
100
+ print(f"Decomposition error for tensor with shape {original_shape}: {str(e)[:100]}")
101
+
102
+ return None, None
103
+
104
+ def get_architecture_specific_settings(architecture, base_rank):
105
+ """Get optimal settings for different model architectures."""
106
+ settings = {
107
+ "text_encoder": {
108
+ "rank": base_rank,
109
+ "approximation_factor": 0.95, # Text encoders need high accuracy
110
+ "min_rank": 8,
111
+ "max_rank_factor": 0.5 # Use at most 50% of full rank
112
+ },
113
+ "unet_transformer": {
114
+ "rank": base_rank,
115
+ "approximation_factor": 0.90,
116
+ "min_rank": 16,
117
+ "max_rank_factor": 0.4
118
+ },
119
+ "unet_conv": {
120
+ "rank": base_rank // 2, # Convolutions compress better
121
+ "approximation_factor": 0.85,
122
+ "min_rank": 8,
123
+ "max_rank_factor": 0.3
124
+ },
125
+ "vae": {
126
+ "rank": base_rank // 3, # VAE compresses very well
127
+ "approximation_factor": 0.80,
128
+ "min_rank": 4,
129
+ "max_rank_factor": 0.25
130
+ },
131
+ "auto": {
132
+ "rank": base_rank,
133
+ "approximation_factor": 0.90,
134
+ "min_rank": 8,
135
+ "max_rank_factor": 0.5
136
+ },
137
+ "all": {
138
+ "rank": base_rank,
139
+ "approximation_factor": 0.90,
140
+ "min_rank": 8,
141
+ "max_rank_factor": 0.5
142
+ }
143
+ }
144
+
145
+ return settings.get(architecture, settings["auto"])
146
+
147
+ def should_apply_lora(key, weight, architecture, lora_rank):
148
+ """Determine if LoRA should be applied to a specific weight based on architecture selection."""
149
+
150
+ # Skip bias terms, batchnorm, and very small tensors
151
+ if 'bias' in key or 'norm' in key.lower() or 'bn' in key.lower():
152
  return False
153
+
154
+ # Skip very small tensors
155
+ if weight.numel() < 100:
156
+ return False
157
+
158
+ # Skip 1D tensors
159
+ if weight.ndim == 1:
160
+ return False
161
+
162
+ # Architecture-specific rules
163
+ lower_key = key.lower()
164
+
165
  if architecture == "text_encoder":
166
+ # Text encoder: focus on embeddings and attention layers
167
+ return ('emb' in lower_key or 'embed' in lower_key or
168
+ 'attn' in lower_key or 'qkv' in lower_key or 'mlp' in lower_key)
169
+
170
  elif architecture == "unet_transformer":
171
+ # UNet transformers: focus on attention blocks
172
+ return ('attn' in lower_key or 'transformer' in lower_key or
173
+ 'qkv' in lower_key or 'to_out' in lower_key)
174
+
175
  elif architecture == "unet_conv":
176
+ # UNet convolutional layers
177
+ return ('conv' in lower_key or 'resnet' in lower_key or
178
+ 'downsample' in lower_key or 'upsample' in lower_key)
179
+
180
  elif architecture == "vae":
181
+ # VAE components
182
+ return ('encoder' in lower_key or 'decoder' in lower_key or
183
+ 'conv' in lower_key or 'post_quant' in lower_key)
184
+
185
+ elif architecture == "all":
186
+ # Apply to all eligible tensors
187
+ return True
188
+
189
+ elif architecture == "auto":
190
+ # Auto-detect based on tensor properties
191
+ if weight.ndim == 2 and min(weight.shape) > lora_rank // 4:
192
+ return True
193
+ if weight.ndim == 4 and (weight.shape[0] > lora_rank // 4 or weight.shape[1] > lora_rank // 4):
194
+ return True
195
+ return False
196
+
197
+ return False
198
 
199
  def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_format, lora_rank=64, architecture="auto", progress=gr.Progress()):
200
  progress(0.1, desc="Starting FP8 conversion with LoRA extraction...")
 
205
  header_json = f.read(header_size).decode('utf-8')
206
  header = json.loads(header_json)
207
  return header.get('__metadata__', {})
208
+
209
  metadata = read_safetensors_metadata(safetensors_path)
210
  progress(0.2, desc="Loaded metadata.")
211
+
212
  state_dict = load_file(safetensors_path)
213
  progress(0.4, desc="Loaded weights.")
214
+
215
+ # Architecture analysis
216
+ architecture_stats = {
217
+ 'text_encoder': 0,
218
+ 'unet_transformer': 0,
219
+ 'unet_conv': 0,
220
+ 'vae': 0,
221
+ 'other': 0
222
+ }
223
+
224
+ for key in state_dict:
225
+ lower_key = key.lower()
226
+ if 'text' in lower_key or 'emb' in lower_key:
227
+ architecture_stats['text_encoder'] += 1
228
+ elif 'attn' in lower_key or 'transformer' in lower_key:
229
+ architecture_stats['unet_transformer'] += 1
230
+ elif 'conv' in lower_key or 'resnet' in lower_key:
231
+ architecture_stats['unet_conv'] += 1
232
+ elif 'vae' in lower_key or 'encoder' in lower_key or 'decoder' in lower_key:
233
+ architecture_stats['vae'] += 1
234
+ else:
235
+ architecture_stats['other'] += 1
236
+
237
+ print("Architecture analysis:")
238
+ for arch, count in architecture_stats.items():
239
+ print(f"- {arch}: {count} layers")
240
+
241
  if fp8_format == "e5m2":
242
  fp8_dtype = torch.float8_e5m2
243
  else:
244
  fp8_dtype = torch.float8_e4m3fn
245
+
246
  sd_fp8 = {}
247
  lora_weights = {}
 
 
 
248
  lora_stats = {
249
+ 'total_layers': len(state_dict),
250
+ 'layers_analyzed': 0,
251
  'layers_eligible': 0,
252
  'layers_processed': 0,
253
  'layers_skipped': [],
254
+ 'architecture_distro': architecture_stats,
255
+ 'reconstruction_errors': []
256
  }
257
+
258
+ total = len(state_dict)
259
+ lora_keys = []
260
+
261
  for i, key in enumerate(state_dict):
262
+ progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}: {key.split('.')[-1]}")
263
  weight = state_dict[key]
264
+ lora_stats['layers_analyzed'] += 1
265
+
266
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
267
  fp8_weight = weight.to(fp8_dtype)
268
  sd_fp8[key] = fp8_weight
269
+
270
+ # Determine if we should apply LoRA
271
+ eligible_for_lora = should_apply_lora(key, weight, architecture, lora_rank)
272
+
273
+ if eligible_for_lora:
274
  lora_stats['layers_eligible'] += 1
275
+
276
  try:
277
+ # Get architecture-specific settings
278
+ arch_settings = get_architecture_specific_settings(architecture, lora_rank)
279
+
280
+ # Adjust rank based on tensor properties
281
+ if weight.ndim == 2:
282
+ max_possible_rank = min(weight.shape)
283
+ actual_rank = min(
284
+ arch_settings["rank"],
285
+ int(max_possible_rank * arch_settings["max_rank_factor"])
286
+ )
287
+ actual_rank = max(actual_rank, arch_settings["min_rank"])
288
+ elif weight.ndim == 4:
289
+ # For conv layers, use smaller rank
290
+ actual_rank = min(
291
+ arch_settings["rank"],
292
+ max(weight.shape[0], weight.shape[1]) // 4
293
+ )
294
+ actual_rank = max(actual_rank, arch_settings["min_rank"])
295
+ else:
296
+ # Skip non-2D/4D tensors for LoRA
297
+ lora_stats['layers_skipped'].append(f"{key}: unsupported ndim={weight.ndim}")
298
+ continue
299
+
300
+ if actual_rank < 4:
301
+ lora_stats['layers_skipped'].append(f"{key}: rank too small ({actual_rank})")
302
+ continue
303
+
304
+ # Perform decomposition with approximation
305
+ U, V = low_rank_decomposition(
306
+ weight,
307
+ rank=actual_rank,
308
+ approximation_factor=arch_settings["approximation_factor"]
309
+ )
310
+
311
+ if U is not None and V is not None:
312
+ # Store as half-precision
313
+ lora_weights[f"lora_A.{key}"] = U.to(torch.float16)
314
+ lora_weights[f"lora_B.{key}"] = V.to(torch.float16)
315
  lora_keys.append(key)
316
  lora_stats['layers_processed'] += 1
317
+
318
+ # Calculate and store reconstruction error
319
+ if U.ndim == 2 and V.ndim == 2:
320
+ if V.shape[0] == U.shape[1]:
321
+ reconstructed = V @ U
322
+ else:
323
+ reconstructed = U @ V
324
+ error = torch.norm(weight.float() - reconstructed.float()) / torch.norm(weight.float())
325
+ lora_stats['reconstruction_errors'].append({
326
+ 'key': key,
327
+ 'error': error.item(),
328
+ 'original_shape': list(weight.shape),
329
+ 'rank': actual_rank
330
+ })
331
  else:
332
+ lora_stats['layers_skipped'].append(f"{key}: decomposition returned None")
333
+
334
  except Exception as e:
335
+ error_msg = f"{key}: {str(e)[:100]}"
336
+ lora_stats['layers_skipped'].append(error_msg)
337
+
338
  else:
339
+ reason = "not eligible for selected architecture" if architecture != "auto" else f"ndim={weight.ndim}"
340
+ lora_stats['layers_skipped'].append(f"{key}: {reason}")
341
  else:
342
  sd_fp8[key] = weight
343
+ lora_stats['layers_skipped'].append(f"{key}: unsupported dtype {weight.dtype}")
344
+
345
+ # Add reconstruction error statistics
346
+ if lora_stats['reconstruction_errors']:
347
+ errors = [e['error'] for e in lora_stats['reconstruction_errors']]
348
+ lora_stats['avg_reconstruction_error'] = sum(errors) / len(errors) if errors else 0
349
+ lora_stats['max_reconstruction_error'] = max(errors) if errors else 0
350
+ lora_stats['min_reconstruction_error'] = min(errors) if errors else 0
351
+
352
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
353
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
354
  lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors")
355
+
356
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
357
+
358
+ # Always save LoRA file, even if empty
359
+ lora_metadata = {
360
+ "format": "pt",
361
  "lora_rank": str(lora_rank),
362
  "architecture": architecture,
363
+ "original_filename": os.path.basename(safetensors_path),
364
+ "fp8_format": fp8_format,
365
  "stats": json.dumps(lora_stats)
366
+ }
367
+
368
+ save_file(lora_weights, lora_path, metadata=lora_metadata)
369
+
370
+ # Generate detailed statistics message
371
+ stats_msg = f"""
372
+ πŸ“Š LoRA Extraction Statistics:
373
+ - Total layers analyzed: {lora_stats['layers_analyzed']}
374
+ - Layers eligible for LoRA: {lora_stats['layers_eligible']}
375
+ - Successfully processed: {lora_stats['layers_processed']}
376
+ - Architecture: {architecture}
377
+ - FP8 Format: {fp8_format.upper()}
378
+ """
379
+
380
+ if 'avg_reconstruction_error' in lora_stats:
381
+ stats_msg += f"- Avg reconstruction error: {lora_stats['avg_reconstruction_error']:.6f}\n"
382
+ stats_msg += f"- Max reconstruction error: {lora_stats['max_reconstruction_error']:.6f}\n"
383
+
384
  progress(0.9, desc="Saved FP8 and LoRA files.")
385
  progress(1.0, desc="βœ… FP8 + LoRA extraction complete!")
386
+
 
 
387
  if lora_stats['layers_processed'] == 0:
388
+ stats_msg += "\n\n⚠️ WARNING: No LoRA weights were generated. Try a different architecture selection or lower rank."
389
+ elif lora_stats.get('avg_reconstruction_error', 1) < 0.0001:
390
+ stats_msg += "\n\nℹ️ NOTE: Very low reconstruction error detected. LoRA may be reconstructing almost perfectly. Consider using lower rank for better compression."
391
+
392
+ return True, f"FP8 ({fp8_format}) and rank-{lora_rank} LoRA saved.\n{stats_msg}", lora_stats
393
 
394
  except Exception as e:
395
+ error_msg = f"Conversion error: {str(e)}\n{traceback.format_exc()}"
396
+ print(error_msg)
397
+ return False, error_msg, None
398
 
399
  def parse_hf_url(url):
400
  url = url.strip().rstrip("/")
 
472
  return None, "❌ Hugging Face token required for source.", ""
473
  if target_type == "huggingface" and not hf_token:
474
  return None, "❌ Hugging Face token required for target.", ""
475
+
476
+ # Validate lora_rank
477
  if lora_rank < 4:
478
  return None, "❌ LoRA rank must be at least 4.", ""
479
+
480
  temp_dir = None
481
  output_dir = tempfile.mkdtemp()
482
  try:
 
484
  safetensors_path, temp_dir = download_safetensors_file(
485
  source_type, repo_url, safetensors_filename, hf_token, progress
486
  )
487
+
488
  progress(0.25, desc=f"Converting to FP8 with LoRA ({architecture})...")
489
  success, msg, stats = convert_safetensors_to_fp8_with_lora(
490
  safetensors_path, output_dir, fp8_format, lora_rank, architecture, progress
491
  )
492
+
493
  if not success:
494
  return None, f"❌ Conversion failed: {msg}", ""
495
+
496
  progress(0.9, desc="Uploading...")
497
  repo_url_final = upload_to_target(
498
  target_type, new_repo_id, output_dir, fp8_format, architecture, hf_token, modelscope_token, private_repo
499
  )
500
+
501
  base_name = os.path.splitext(safetensors_filename)[0]
502
  lora_filename = f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors"
503
  fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors"
504
+
505
  readme = f"""---
506
  library_name: diffusers
507
  tags:
 
511
  - low-rank
512
  - diffusion
513
  - architecture-{architecture}
514
+ - converted-by-ai-toolkit
515
  ---
516
  # FP8 Model with Low-Rank LoRA
517
  - **Source**: `{repo_url}`
 
522
  - **LoRA File**: `{lora_filename}`
523
  - **FP8 File**: `{fp8_filename}`
524
 
525
+ ## Architecture Distribution
526
+ """
527
+
528
+ # Add architecture stats to README if available
529
+ if stats and 'architecture_distro' in stats:
530
+ readme += "\n| Component | Layer Count |\n|-----------|------------|\n"
531
+ for arch, count in stats['architecture_distro'].items():
532
+ readme += f"| {arch.replace('_', ' ').title()} | {count} |\n"
533
+
534
+ readme += f"""
535
  ## Usage (Inference)
536
  ```python
537
  from safetensors.torch import load_file
 
544
  # Reconstruct approximate original weights
545
  reconstructed = {{}}
546
  for key in fp8_state:
547
+ lora_a_key = f"lora_A.{{key}}"
548
+ lora_b_key = f"lora_B.{{key}}"
549
+
550
+ if lora_a_key in lora_state and lora_b_key in lora_state:
551
+ A = lora_state[lora_a_key].to(torch.float32)
552
+ B = lora_state[lora_b_key].to(torch.float32)
553
+
554
+ # Handle different tensor dimensions
555
  if A.ndim == 2 and B.ndim == 2:
556
  lora_weight = B @ A
557
+ elif A.ndim == 4 and B.ndim == 4:
558
+ # For convolutional LoRA
559
+ lora_weight = F.conv2d(fp8_state[key].to(torch.float32),
560
+ B, padding=1) + F.conv2d(fp8_state[key].to(torch.float32),
561
+ A, padding=1)
562
  else:
563
+ # Fallback for mixed dimension cases
564
+ lora_weight = B @ A.view(B.shape[1], -1)
565
+ if lora_weight.shape != fp8_state[key].shape:
566
+ lora_weight = lora_weight.view_as(fp8_state[key])
567
+
568
  reconstructed[key] = fp8_state[key].to(torch.float32) + lora_weight
569
  else:
570
  reconstructed[key] = fp8_state[key].to(torch.float32)
571
  ```
572
 
573
+ > **Note**: Requires PyTorch β‰₯ 2.1 for FP8 support. For best results, use the same architecture selection ({architecture}) during inference as was used during extraction.
574
  """
575
+
576
  with open(os.path.join(output_dir, "README.md"), "w") as f:
577
  f.write(readme)
578
+
579
  if target_type == "huggingface":
580
  HfApi(token=hf_token).upload_file(
581
  path_or_fileobj=os.path.join(output_dir, "README.md"),
 
584
  repo_type="model",
585
  token=hf_token
586
  )
587
+
588
  progress(1.0, desc="βœ… Done!")
589
  result_html = f"""
590
  βœ… Success!
591
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
592
+ Includes:
593
+ - FP8 model: `{fp8_filename}`
594
+ - LoRA weights: `{lora_filename}` (rank {lora_rank}, architecture: {architecture})
595
+
596
+ πŸ“Š Stats: {stats['layers_processed']}/{stats['layers_eligible']} eligible layers processed
597
  """
598
+ if 'avg_reconstruction_error' in stats:
599
+ result_html += f"<br>Avg reconstruction error: {stats['avg_reconstruction_error']:.6f}"
600
+
601
  return gr.HTML(result_html), "βœ… FP8 + LoRA upload successful!", msg
602
+
603
  except Exception as e:
604
+ error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
605
+ print(error_msg)
606
+ return None, error_msg, ""
607
+
608
  finally:
609
  if temp_dir:
610
  shutil.rmtree(temp_dir, ignore_errors=True)
 
612
 
613
  with gr.Blocks(title="FP8 + LoRA Extractor (HF ↔ ModelScope)") as demo:
614
  gr.Markdown("# πŸ”„ Advanced FP8 Pruner with Architecture-Specific LoRA Extraction")
615
+ gr.Markdown("Convert `.safetensors` β†’ **FP8** + **targeted LoRA** weights for precision recovery. Supports Hugging Face ↔ ModelScope.")
616
+
617
  with gr.Row():
618
  with gr.Column():
619
  source_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Source")
620
  repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id")
621
  safetensors_filename = gr.Textbox(label="Filename", placeholder="model.safetensors")
622
+
623
  with gr.Accordion("Advanced LoRA Settings", open=True):
624
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
625
  lora_rank = gr.Slider(minimum=4, maximum=256, step=4, value=64, label="LoRA Rank")
626
+
627
  architecture = gr.Dropdown(
628
  choices=[
629
  ("Auto-detect components", "auto"),
 
634
  ("All components", "all")
635
  ],
636
  value="auto",
637
+ label="Target Architecture",
638
+ info="Select which model components to apply LoRA to"
639
  )
640
+
641
  with gr.Accordion("Authentication", open=False):
642
  hf_token = gr.Textbox(label="Hugging Face Token", type="password")
643
  modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE)
644
+
645
  with gr.Column():
646
  target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target")
647
  new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8-lora")
648
  private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False)
649
+
650
  status_output = gr.Markdown()
651
  detailed_log = gr.Textbox(label="Processing Log", interactive=False, lines=10)
652
+
653
  convert_btn = gr.Button("πŸš€ Convert & Upload", variant="primary")
654
  repo_link_output = gr.HTML()
655
+
656
  convert_btn.click(
657
  fn=process_and_upload_fp8,
658
  inputs=[
 
671
  outputs=[repo_link_output, status_output, detailed_log],
672
  show_progress=True
673
  )
674
+
675
  gr.Examples(
676
  examples=[
677
  ["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main", "unet_diffusion_pytorch_model.safetensors", "e5m2", 64, "unet_transformer"],
 
681
  inputs=[source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture],
682
  label="Example Conversions"
683
  )
684
+
685
  gr.Markdown("""
686
  ## πŸ’‘ Usage Tips
687
+
688
+ - **For Text Encoders**: Use rank 32-64 with `text_encoder` architecture for optimal results.
689
+ - **For UNet Attention**: Use `unet_transformer` with rank 64-128 for best quality preservation.
690
+ - **For UNet Convolutions**: Use `unet_conv` with lower ranks (16-32) as convolutions compress better.
691
+ - **For VAE**: Use `vae` architecture with rank 16-32.
692
+ - **Auto Mode**: Let the tool analyze and target appropriate layers automatically.
693
+
694
+ ⚠️ **Note**: Higher ranks produce better quality but larger LoRA files. Start with lower ranks and increase if needed.
695
  """)
696
 
697
  demo.launch()