codemichaeld commited on
Commit
c31eee4
Β·
verified Β·
1 Parent(s): 9f9518a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -89
app.py CHANGED
@@ -18,10 +18,7 @@ except ImportError:
18
  MODELScope_AVAILABLE = False
19
 
20
  def low_rank_decomposition(weight, rank=128):
21
- """
22
- Improved LoRA decomposition that maintains compatibility with existing merge scripts.
23
- This implementation focuses on extracting meaningful low-rank components from 2D weights.
24
- """
25
  if weight.ndim != 2:
26
  return None, None
27
 
@@ -34,10 +31,10 @@ def low_rank_decomposition(weight, rank=128):
34
 
35
  # Ensure rank doesn't exceed available singular values
36
  actual_rank = min(rank, len(S))
 
 
37
 
38
  # Create LoRA matrices using standard factorization
39
- # W β‰ˆ U[:, :r] * diag(S[:r]) * Vh[:r, :]
40
- # We split as: A = Vh[:r, :], B = U[:, :r] * diag(S[:r])
41
  A = Vh[:actual_rank, :].contiguous()
42
  B = U[:, :actual_rank] @ torch.diag(S[:actual_rank])
43
 
@@ -46,6 +43,36 @@ def low_rank_decomposition(weight, rank=128):
46
  print(f"Decomposition error: {e}")
47
  return None, None
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_format, lora_rank=128, architecture="auto", progress=gr.Progress()):
50
  progress(0.1, desc="Starting FP8 conversion with LoRA extraction...")
51
  try:
@@ -69,96 +96,133 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
69
 
70
  sd_fp8 = {}
71
  lora_weights = {}
 
72
  total = len(state_dict)
73
- lora_keys = []
74
  stats = {
75
  "total_layers": total,
76
  "eligible_layers": 0,
77
  "processed_layers": 0,
78
- "skipped_layers": []
 
 
79
  }
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  for i, key in enumerate(state_dict):
82
  progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
83
  weight = state_dict[key]
 
84
 
85
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
86
  fp8_weight = weight.to(fp8_dtype)
87
  sd_fp8[key] = fp8_weight
88
 
89
- # Apply architecture filtering
90
- lower_key = key.lower()
91
  should_process = False
92
 
93
  if architecture == "text_encoder":
94
- should_process = "text" in lower_key or "emb" in lower_key or "encoder" in lower_key
95
  elif architecture == "transformer":
96
- should_process = "attn" in lower_key or "transformer" in lower_key
97
  elif architecture == "vae":
98
- should_process = "vae" in lower_key or "decoder" in lower_key or "encoder" in lower_key
99
  elif architecture == "all":
100
  should_process = True
101
- else: # "auto" or unknown
102
  should_process = True
103
 
104
- # Only process 2D tensors that meet rank requirements and pass architecture filter
105
- if should_process and weight.ndim == 2 and min(weight.shape) > lora_rank:
106
- stats["eligible_layers"] += 1
107
- try:
108
- A, B = low_rank_decomposition(weight, rank=lora_rank)
109
- if A is not None and B is not None:
110
- lora_weights[f"lora_A.{key}"] = A
111
- lora_weights[f"lora_B.{key}"] = B
112
- lora_keys.append(key)
113
  stats["processed_layers"] += 1
114
- else:
115
- stats["skipped_layers"].append(f"{key}: decomposition failed")
116
- except Exception as e:
117
- stats["skipped_layers"].append(f"{key}: error - {str(e)}")
118
- elif should_process and weight.ndim == 2:
119
- # Handle smaller 2D tensors with reduced rank
120
- smaller_rank = min(lora_rank, min(weight.shape) // 2)
121
- if smaller_rank >= 8: # Minimum useful rank
122
  stats["eligible_layers"] += 1
123
- try:
124
- A, B = low_rank_decomposition(weight, rank=smaller_rank)
125
- if A is not None and B is not None:
126
- lora_weights[f"lora_A.{key}"] = A
127
- lora_weights[f"lora_B.{key}"] = B
128
- lora_keys.append(key)
129
- stats["processed_layers"] += 1
130
- else:
131
- stats["skipped_layers"].append(f"{key}: small tensor decomposition failed")
132
- except Exception as e:
133
- stats["skipped_layers"].append(f"{key}: small tensor error - {str(e)}")
 
 
 
 
 
 
 
 
 
 
134
  else:
135
  sd_fp8[key] = weight
136
  stats["skipped_layers"].append(f"{key}: non-float dtype")
137
 
138
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
139
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
140
- lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}.safetensors")
141
 
 
142
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
143
 
144
- # Always save LoRA file if any weights were processed
145
  if lora_weights:
 
146
  lora_metadata = {
147
  "format": "pt",
148
  "lora_rank": str(lora_rank),
149
  "architecture": architecture,
150
- "stats": json.dumps(stats)
 
151
  }
152
  save_file(lora_weights, lora_path, metadata=lora_metadata)
153
 
154
- progress(0.9, desc="Saved FP8 and LoRA files.")
155
- progress(1.0, desc="βœ… FP8 + LoRA extraction complete!")
 
 
 
 
 
 
 
 
156
 
157
- stats_msg = f"FP8 ({fp8_format}) and rank-{lora_rank} LoRA saved.\n"
158
- stats_msg += f"Processed {stats['processed_layers']}/{stats['eligible_layers']} eligible layers."
159
 
160
- if stats['processed_layers'] == 0:
161
- stats_msg += "\n⚠️ No LoRA weights were generated. Try reducing rank or selecting a specific architecture."
 
 
 
 
 
 
 
 
162
 
163
  return True, stats_msg, stats
164
 
@@ -254,7 +318,7 @@ def process_and_upload_fp8(
254
  source_type, repo_url, safetensors_filename, hf_token, progress
255
  )
256
 
257
- progress(0.25, desc="Converting to FP8 with LoRA extraction...")
258
  success, msg, stats = convert_safetensors_to_fp8_with_lora(
259
  safetensors_path, output_dir, fp8_format, lora_rank, architecture, progress
260
  )
@@ -268,26 +332,33 @@ def process_and_upload_fp8(
268
  )
269
 
270
  base_name = os.path.splitext(safetensors_filename)[0]
271
- lora_filename = f"{base_name}-lora-r{lora_rank}.safetensors"
272
  fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors"
273
 
 
 
 
 
 
 
 
 
 
274
  readme = f"""---
275
  library_name: diffusers
276
  tags:
277
  - fp8
278
  - safetensors
279
- - lora
280
- - low-rank
281
  - diffusion
282
  - converted-by-gradio
283
  ---
284
- # FP8 Model with Low-Rank LoRA
285
  - **Source**: `{repo_url}`
286
  - **File**: `{safetensors_filename}`
287
  - **FP8 Format**: `{fp8_format.upper()}`
288
- - **LoRA Rank**: {lora_rank}
289
  - **Architecture**: {architecture}
290
- - **LoRA File**: `{lora_filename}`
 
291
  - **FP8 File**: `{fp8_filename}`
292
  ## Usage (Inference)
293
  ```python
@@ -295,18 +366,30 @@ from safetensors.torch import load_file
295
  import torch
296
  # Load FP8 model
297
  fp8_state = load_file("{fp8_filename}")
298
- lora_state = load_file("{lora_filename}")
299
- # Reconstruct approximate original weights
 
300
  reconstructed = {{}}
301
  for key in fp8_state:
302
- if f"lora_A.{{key}}" in lora_state and f"lora_B.{{key}}" in lora_state:
303
- A = lora_state[f"lora_A.{{key}}"].to(torch.float32)
304
- B = lora_state[f"lora_B.{{key}}"].to(torch.float32)
305
- lora_weight = B @ A # (out_features, rank) @ (rank, in_features) -> (out_features, in_features)
306
- fp8_weight = fp8_state[key].to(torch.float32)
307
- reconstructed[key] = fp8_weight + lora_weight
 
 
 
 
 
 
 
 
 
 
 
308
  else:
309
- reconstructed[key] = fp8_state[key].to(torch.float32)
310
  ```
311
  > Requires PyTorch β‰₯ 2.1 for FP8 support.
312
  """
@@ -327,9 +410,9 @@ for key in fp8_state:
327
  result_html = f"""
328
  βœ… Success!
329
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
330
- Includes: FP8 model + rank-{lora_rank} LoRA.
331
  """
332
- return gr.HTML(result_html), "βœ… FP8 + LoRA upload successful!", msg
333
 
334
  except Exception as e:
335
  import traceback
@@ -341,9 +424,9 @@ Includes: FP8 model + rank-{lora_rank} LoRA.
341
  shutil.rmtree(temp_dir, ignore_errors=True)
342
  shutil.rmtree(output_dir, ignore_errors=True)
343
 
344
- with gr.Blocks(title="FP8 + LoRA Extractor (HF ↔ ModelScope)") as demo:
345
- gr.Markdown("# πŸ”„ FP8 Pruner with Enhanced Low-Rank LoRA Extraction")
346
- gr.Markdown("Convert `.safetensors` β†’ **FP8** + **high-quality LoRA** for precision recovery. Supports Hugging Face ↔ ModelScope.")
347
 
348
  with gr.Row():
349
  with gr.Column():
@@ -353,14 +436,14 @@ with gr.Blocks(title="FP8 + LoRA Extractor (HF ↔ ModelScope)") as demo:
353
 
354
  with gr.Accordion("Advanced Settings", open=True):
355
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
356
- lora_rank = gr.Slider(minimum=8, maximum=512, step=8, value=128, label="LoRA Rank")
357
  architecture = gr.Dropdown(
358
  choices=[
359
- ("Auto-detect components", "auto"),
360
- ("Text Encoder only", "text_encoder"),
361
- ("Transformer blocks only", "transformer"),
362
- ("VAE only", "vae"),
363
- ("All eligible layers", "all")
364
  ],
365
  value="auto",
366
  label="Target Architecture"
@@ -372,7 +455,7 @@ with gr.Blocks(title="FP8 + LoRA Extractor (HF ↔ ModelScope)") as demo:
372
 
373
  with gr.Column():
374
  target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target")
375
- new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8-lora")
376
  private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False)
377
 
378
  status_output = gr.Markdown()
@@ -402,25 +485,31 @@ with gr.Blocks(title="FP8 + LoRA Extractor (HF ↔ ModelScope)") as demo:
402
 
403
  gr.Examples(
404
  examples=[
405
- ["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main", "unet_diffusion_pytorch_model.safetensors", "e5m2", 128, "transformer", "huggingface"],
406
  ["huggingface", "https://huggingface.co/stabilityai/sdxl-vae", "diffusion_pytorch_model.safetensors", "e4m3fn", 64, "vae", "huggingface"],
407
- ["huggingface", "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder", "model.safetensors", "e5m2", 96, "text_encoder", "huggingface"]
408
  ],
409
  inputs=[source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture, target_type],
410
  label="Example Conversions"
411
  )
412
 
413
  gr.Markdown("""
414
- ## πŸ’‘ Usage Tips
 
 
 
 
 
 
 
 
 
 
415
 
416
- - **Higher ranks (128-256)**: Best quality recovery for important layers
417
- - **Smaller ranks (32-64)**: Good balance of quality and file size
418
- - **Architecture selection**: Focus LoRA on specific components for better results
419
- - **Text Encoder**: Use rank 96-128 for best text understanding
420
- - **Transformers**: Use rank 128-256 for maximum quality retention
421
- - **VAE**: Use rank 64-128 for good image reconstruction
422
 
423
- > **Note**: This implementation maintains compatibility with existing merge scripts while providing significantly better precision recovery through improved LoRA extraction.
 
424
  """)
425
 
426
  demo.launch()
 
18
  MODELScope_AVAILABLE = False
19
 
20
  def low_rank_decomposition(weight, rank=128):
21
+ """Improved LoRA decomposition that maintains compatibility with existing merge scripts."""
 
 
 
22
  if weight.ndim != 2:
23
  return None, None
24
 
 
31
 
32
  # Ensure rank doesn't exceed available singular values
33
  actual_rank = min(rank, len(S))
34
+ if actual_rank < 8:
35
+ return None, None
36
 
37
  # Create LoRA matrices using standard factorization
 
 
38
  A = Vh[:actual_rank, :].contiguous()
39
  B = U[:, :actual_rank] @ torch.diag(S[:actual_rank])
40
 
 
43
  print(f"Decomposition error: {e}")
44
  return None, None
45
 
46
+ def extract_correction_factors(original_weight, fp8_weight):
47
+ """Extract per-channel/tensor correction factors instead of LoRA decomposition for VAE."""
48
+ with torch.no_grad():
49
+ # Convert to float32 for precision
50
+ orig = original_weight.float()
51
+ quant = fp8_weight.float()
52
+
53
+ # Compute error (what needs to be added to FP8 to recover original)
54
+ error = orig - quant
55
+
56
+ # Skip if error is negligible
57
+ error_norm = torch.norm(error)
58
+ orig_norm = torch.norm(orig)
59
+ if orig_norm > 1e-6 and error_norm / orig_norm < 0.01:
60
+ return None
61
+
62
+ # For 4D tensors (common in VAE), compute per-channel correction
63
+ if orig.ndim == 4:
64
+ # Channel dimension is typically dimension 0 (output channels)
65
+ channel_dim = 0
66
+ channel_mean = error.mean(dim=tuple(i for i in range(1, orig.ndim)), keepdim=True)
67
+ return channel_mean.to(original_weight.dtype)
68
+ # For 2D tensors, use per-row correction
69
+ elif orig.ndim == 2:
70
+ row_mean = error.mean(dim=1, keepdim=True)
71
+ return row_mean.to(original_weight.dtype)
72
+ else:
73
+ # For bias/batchnorm etc., use scalar correction
74
+ return error.mean().to(original_weight.dtype)
75
+
76
  def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_format, lora_rank=128, architecture="auto", progress=gr.Progress()):
77
  progress(0.1, desc="Starting FP8 conversion with LoRA extraction...")
78
  try:
 
96
 
97
  sd_fp8 = {}
98
  lora_weights = {}
99
+ correction_factors = {}
100
  total = len(state_dict)
 
101
  stats = {
102
  "total_layers": total,
103
  "eligible_layers": 0,
104
  "processed_layers": 0,
105
+ "correction_layers": 0,
106
+ "skipped_layers": [],
107
+ "architecture_detected": ""
108
  }
109
 
110
+ # Auto-detect architecture if needed
111
+ if architecture == "auto":
112
+ model_keys = " ".join(state_dict.keys()).lower()
113
+ if "text" in model_keys or "emb" in model_keys:
114
+ architecture = "text_encoder"
115
+ elif "vae" in model_keys or "encoder" in model_keys or "decoder" in model_keys:
116
+ architecture = "vae"
117
+ elif "attn" in model_keys or "transformer" in model_keys:
118
+ architecture = "transformer"
119
+ else:
120
+ architecture = "all"
121
+
122
+ stats["architecture_detected"] = architecture
123
+ use_correction = architecture == "vae"
124
+
125
  for i, key in enumerate(state_dict):
126
  progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
127
  weight = state_dict[key]
128
+ lower_key = key.lower()
129
 
130
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
131
  fp8_weight = weight.to(fp8_dtype)
132
  sd_fp8[key] = fp8_weight
133
 
134
+ # Determine if this layer should be processed based on architecture
 
135
  should_process = False
136
 
137
  if architecture == "text_encoder":
138
+ should_process = "text" in lower_key or "emb" in lower_key or "encoder" in lower_key or "attn" in lower_key
139
  elif architecture == "transformer":
140
+ should_process = "attn" in lower_key or "transformer" in lower_key or "mlp" in lower_key
141
  elif architecture == "vae":
142
+ should_process = "vae" in lower_key or "decoder" in lower_key or "encoder" in lower_key or "conv" in lower_key
143
  elif architecture == "all":
144
  should_process = True
145
+ else: # "auto" fallback
146
  should_process = True
147
 
148
+ if should_process:
149
+ if use_correction:
150
+ # For VAE, use correction factors instead of LoRA
151
+ corr = extract_correction_factors(weight, fp8_weight)
152
+ if corr is not None:
153
+ correction_factors[f"correction.{key}"] = corr
154
+ stats["correction_layers"] += 1
 
 
155
  stats["processed_layers"] += 1
156
+ else:
157
+ # For other architectures, use LoRA
 
 
 
 
 
 
158
  stats["eligible_layers"] += 1
159
+
160
+ # Handle 2D tensors with standard LoRA
161
+ if weight.ndim == 2:
162
+ try:
163
+ # Adjust rank for smaller matrices
164
+ adjusted_rank = lora_rank
165
+ if min(weight.shape) < lora_rank:
166
+ adjusted_rank = max(8, min(weight.shape) // 2)
167
+
168
+ A, B = low_rank_decomposition(weight, rank=adjusted_rank)
169
+ if A is not None and B is not None:
170
+ lora_weights[f"lora_A.{key}"] = A
171
+ lora_weights[f"lora_B.{key}"] = B
172
+ stats["processed_layers"] += 1
173
+ else:
174
+ stats["skipped_layers"].append(f"{key}: decomposition failed")
175
+ except Exception as e:
176
+ stats["skipped_layers"].append(f"{key}: error - {str(e)}")
177
+ # Skip 4D tensors for non-VAE architectures
178
+ elif weight.ndim == 4:
179
+ stats["skipped_layers"].append(f"{key}: 4D tensor skipped for non-VAE architecture")
180
  else:
181
  sd_fp8[key] = weight
182
  stats["skipped_layers"].append(f"{key}: non-float dtype")
183
 
184
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
185
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
 
186
 
187
+ # Save FP8 model
188
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
189
 
190
+ # Save LoRA weights if any were generated
191
  if lora_weights:
192
+ lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors")
193
  lora_metadata = {
194
  "format": "pt",
195
  "lora_rank": str(lora_rank),
196
  "architecture": architecture,
197
+ "stats": json.dumps(stats),
198
+ "method": "lora"
199
  }
200
  save_file(lora_weights, lora_path, metadata=lora_metadata)
201
 
202
+ # Save correction factors if any were generated (for VAE)
203
+ if correction_factors:
204
+ correction_path = os.path.join(output_dir, f"{base_name}-correction-{architecture}.safetensors")
205
+ correction_metadata = {
206
+ "format": "pt",
207
+ "architecture": architecture,
208
+ "stats": json.dumps(stats),
209
+ "method": "correction"
210
+ }
211
+ save_file(correction_factors, correction_path, metadata=correction_metadata)
212
 
213
+ progress(0.9, desc="Saved FP8 and LoRA/correction files.")
214
+ progress(1.0, desc="βœ… FP8 + LoRA/correction extraction complete!")
215
 
216
+ stats_msg = f"FP8 ({fp8_format}) with precision recovery saved.\n"
217
+ stats_msg += f"Architecture detected: {stats['architecture_detected']}\n"
218
+
219
+ if use_correction:
220
+ stats_msg += f"Correction factors generated for {stats['correction_layers']} layers."
221
+ else:
222
+ stats_msg += f"Processed {stats['processed_layers']}/{stats['eligible_layers']} eligible layers with LoRA rank {lora_rank}."
223
+
224
+ if stats['processed_layers'] == 0 and stats['correction_layers'] == 0:
225
+ stats_msg += "\n⚠️ No precision recovery weights were generated. Try a different architecture selection or parameters."
226
 
227
  return True, stats_msg, stats
228
 
 
318
  source_type, repo_url, safetensors_filename, hf_token, progress
319
  )
320
 
321
+ progress(0.25, desc="Converting to FP8 with precision recovery...")
322
  success, msg, stats = convert_safetensors_to_fp8_with_lora(
323
  safetensors_path, output_dir, fp8_format, lora_rank, architecture, progress
324
  )
 
332
  )
333
 
334
  base_name = os.path.splitext(safetensors_filename)[0]
 
335
  fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors"
336
 
337
+ # Determine which precision recovery file was generated
338
+ precision_recovery_file = ""
339
+ precision_recovery_type = "LoRA"
340
+ if stats.get("correction_layers", 0) > 0:
341
+ precision_recovery_file = f"{base_name}-correction-{architecture}.safetensors"
342
+ precision_recovery_type = "Correction Factors"
343
+ elif stats.get("processed_layers", 0) > 0:
344
+ precision_recovery_file = f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors"
345
+
346
  readme = f"""---
347
  library_name: diffusers
348
  tags:
349
  - fp8
350
  - safetensors
351
+ - precision-recovery
 
352
  - diffusion
353
  - converted-by-gradio
354
  ---
355
+ # FP8 Model with Precision Recovery
356
  - **Source**: `{repo_url}`
357
  - **File**: `{safetensors_filename}`
358
  - **FP8 Format**: `{fp8_format.upper()}`
 
359
  - **Architecture**: {architecture}
360
+ - **Precision Recovery Type**: {precision_recovery_type}
361
+ - **Precision Recovery File**: `{precision_recovery_file}`
362
  - **FP8 File**: `{fp8_filename}`
363
  ## Usage (Inference)
364
  ```python
 
366
  import torch
367
  # Load FP8 model
368
  fp8_state = load_file("{fp8_filename}")
369
+ # Load precision recovery file
370
+ recovery_state = load_file("{precision_recovery_file}") if "{precision_recovery_file}" else {{}}
371
+ # Reconstruct high-precision weights
372
  reconstructed = {{}}
373
  for key in fp8_state:
374
+ fp8_weight = fp8_state[key].to(torch.float32)
375
+ if recovery_state:
376
+ # For LoRA approach
377
+ if "lora_A" in recovery_state:
378
+ if f"lora_A.{{key}}" in recovery_state and f"lora_B.{{key}}" in recovery_state:
379
+ A = recovery_state[f"lora_A.{{key}}"].to(torch.float32)
380
+ B = recovery_state[f"lora_B.{{key}}"].to(torch.float32)
381
+ lora_weight = B @ A
382
+ reconstructed[key] = fp8_weight + lora_weight
383
+ else:
384
+ reconstructed[key] = fp8_weight
385
+ # For correction factor approach
386
+ elif f"correction.{{key}}" in recovery_state:
387
+ correction = recovery_state[f"correction.{{key}}"].to(torch.float32)
388
+ reconstructed[key] = fp8_weight + correction
389
+ else:
390
+ reconstructed[key] = fp8_weight
391
  else:
392
+ reconstructed[key] = fp8_weight
393
  ```
394
  > Requires PyTorch β‰₯ 2.1 for FP8 support.
395
  """
 
410
  result_html = f"""
411
  βœ… Success!
412
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
413
+ Includes: FP8 model + precision recovery ({precision_recovery_type}).
414
  """
415
+ return gr.HTML(result_html), "βœ… FP8 + precision recovery upload successful!", msg
416
 
417
  except Exception as e:
418
  import traceback
 
424
  shutil.rmtree(temp_dir, ignore_errors=True)
425
  shutil.rmtree(output_dir, ignore_errors=True)
426
 
427
+ with gr.Blocks(title="FP8 + Precision Recovery Extractor") as demo:
428
+ gr.Markdown("# πŸ”„ FP8 Pruner with Architecture-Specific Precision Recovery")
429
+ gr.Markdown("Convert `.safetensors` β†’ **FP8** + **precision recovery** (LoRA or correction factors). Supports Hugging Face ↔ ModelScope.")
430
 
431
  with gr.Row():
432
  with gr.Column():
 
436
 
437
  with gr.Accordion("Advanced Settings", open=True):
438
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
439
+ lora_rank = gr.Slider(minimum=8, maximum=256, step=8, value=128, label="LoRA Rank (for text/transformers)")
440
  architecture = gr.Dropdown(
441
  choices=[
442
+ ("Auto-detect architecture", "auto"),
443
+ ("Text Encoder (LoRA)", "text_encoder"),
444
+ ("Transformer blocks (LoRA)", "transformer"),
445
+ ("VAE (Correction Factors)", "vae"),
446
+ ("All layers (LoRA where applicable)", "all")
447
  ],
448
  value="auto",
449
  label="Target Architecture"
 
455
 
456
  with gr.Column():
457
  target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target")
458
+ new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8")
459
  private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False)
460
 
461
  status_output = gr.Markdown()
 
485
 
486
  gr.Examples(
487
  examples=[
488
+ ["huggingface", "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder", "model.safetensors", "e5m2", 96, "text_encoder", "huggingface"],
489
  ["huggingface", "https://huggingface.co/stabilityai/sdxl-vae", "diffusion_pytorch_model.safetensors", "e4m3fn", 64, "vae", "huggingface"],
490
+ ["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main", "unet_diffusion_pytorch_model.safetensors", "e5m2", 128, "transformer", "huggingface"]
491
  ],
492
  inputs=[source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture, target_type],
493
  label="Example Conversions"
494
  )
495
 
496
  gr.Markdown("""
497
+ ## πŸ’‘ Architecture-Specific Precision Recovery
498
+
499
+ This tool automatically selects the best precision recovery method based on architecture:
500
+
501
+ - **Text Encoder & Transformers**: Uses **LoRA decomposition** (best for attention layers)
502
+ - Higher ranks (96-128) recommended for text encoders
503
+ - Medium ranks (64-128) for transformers
504
+
505
+ - **VAE**: Uses **per-channel correction factors** (better for convolutional layers)
506
+ - No rank parameter needed - automatically computes channel-wise corrections
507
+ - Works with 4D convolutional weights that LoRA cannot handle well
508
 
509
+ - **Auto-detect**: Analyzes model structure to select appropriate method
 
 
 
 
 
510
 
511
+ > **Note**: VAE models typically contain 4D convolutional weights that don't work well with standard LoRA.
512
+ > The correction factor approach used for VAE matches the successful method from the attached file.
513
  """)
514
 
515
  demo.launch()