codemichaeld commited on
Commit
672b8b5
Β·
verified Β·
1 Parent(s): c31eee4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +304 -142
app.py CHANGED
@@ -10,6 +10,8 @@ from huggingface_hub import HfApi, hf_hub_download
10
  from safetensors.torch import load_file, save_file
11
  import torch
12
  import torch.nn.functional as F
 
 
13
  try:
14
  from modelscope.hub.file_download import model_file_download as ms_file_download
15
  from modelscope.hub.api import HubApi as ModelScopeApi
@@ -17,64 +19,192 @@ try:
17
  except ImportError:
18
  MODELScope_AVAILABLE = False
19
 
20
- def low_rank_decomposition(weight, rank=128):
21
- """Improved LoRA decomposition that maintains compatibility with existing merge scripts."""
22
- if weight.ndim != 2:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  return None, None
24
 
25
  try:
26
- # Convert to float32 for numerical stability during SVD
27
- weight_f32 = weight.float()
28
-
29
- # Perform SVD
30
- U, S, Vh = torch.linalg.svd(weight_f32, full_matrices=False)
31
-
32
- # Ensure rank doesn't exceed available singular values
33
- actual_rank = min(rank, len(S))
34
- if actual_rank < 8:
35
  return None, None
36
 
37
- # Create LoRA matrices using standard factorization
38
- A = Vh[:actual_rank, :].contiguous()
39
- B = U[:, :actual_rank] @ torch.diag(S[:actual_rank])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- return A.to(torch.float16), B.to(torch.float16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  except Exception as e:
43
- print(f"Decomposition error: {e}")
44
- return None, None
 
45
 
46
  def extract_correction_factors(original_weight, fp8_weight):
47
- """Extract per-channel/tensor correction factors instead of LoRA decomposition for VAE."""
48
  with torch.no_grad():
49
- # Convert to float32 for precision
50
  orig = original_weight.float()
51
  quant = fp8_weight.float()
52
-
53
- # Compute error (what needs to be added to FP8 to recover original)
54
  error = orig - quant
55
 
56
- # Skip if error is negligible
57
  error_norm = torch.norm(error)
58
  orig_norm = torch.norm(orig)
59
- if orig_norm > 1e-6 and error_norm / orig_norm < 0.01:
60
  return None
61
 
62
- # For 4D tensors (common in VAE), compute per-channel correction
63
  if orig.ndim == 4:
64
- # Channel dimension is typically dimension 0 (output channels)
65
- channel_dim = 0
66
  channel_mean = error.mean(dim=tuple(i for i in range(1, orig.ndim)), keepdim=True)
67
  return channel_mean.to(original_weight.dtype)
68
- # For 2D tensors, use per-row correction
69
  elif orig.ndim == 2:
70
  row_mean = error.mean(dim=1, keepdim=True)
71
  return row_mean.to(original_weight.dtype)
72
  else:
73
- # For bias/batchnorm etc., use scalar correction
74
  return error.mean().to(original_weight.dtype)
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_format, lora_rank=128, architecture="auto", progress=gr.Progress()):
77
- progress(0.1, desc="Starting FP8 conversion with LoRA extraction...")
78
  try:
79
  def read_safetensors_metadata(path):
80
  with open(path, 'rb') as f:
@@ -89,105 +219,111 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
89
  state_dict = load_file(safetensors_path)
90
  progress(0.4, desc="Loaded weights.")
91
 
92
- if fp8_format == "e5m2":
93
- fp8_dtype = torch.float8_e5m2
94
- else:
95
- fp8_dtype = torch.float8_e4m3fn
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  sd_fp8 = {}
98
  lora_weights = {}
99
  correction_factors = {}
100
- total = len(state_dict)
101
  stats = {
102
- "total_layers": total,
103
  "eligible_layers": 0,
 
104
  "processed_layers": 0,
105
  "correction_layers": 0,
106
  "skipped_layers": [],
107
- "architecture_detected": ""
 
 
108
  }
109
 
110
- # Auto-detect architecture if needed
111
- if architecture == "auto":
112
- model_keys = " ".join(state_dict.keys()).lower()
113
- if "text" in model_keys or "emb" in model_keys:
114
- architecture = "text_encoder"
115
- elif "vae" in model_keys or "encoder" in model_keys or "decoder" in model_keys:
116
- architecture = "vae"
117
- elif "attn" in model_keys or "transformer" in model_keys:
118
- architecture = "transformer"
119
- else:
120
- architecture = "all"
121
-
122
- stats["architecture_detected"] = architecture
123
- use_correction = architecture == "vae"
124
 
125
  for i, key in enumerate(state_dict):
126
  progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
127
  weight = state_dict[key]
128
- lower_key = key.lower()
129
 
130
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
131
- fp8_weight = weight.to(fp8_dtype)
132
- sd_fp8[key] = fp8_weight
 
 
 
 
 
 
133
 
134
- # Determine if this layer should be processed based on architecture
135
- should_process = False
 
 
136
 
137
- if architecture == "text_encoder":
138
- should_process = "text" in lower_key or "emb" in lower_key or "encoder" in lower_key or "attn" in lower_key
139
- elif architecture == "transformer":
140
- should_process = "attn" in lower_key or "transformer" in lower_key or "mlp" in lower_key
141
- elif architecture == "vae":
142
- should_process = "vae" in lower_key or "decoder" in lower_key or "encoder" in lower_key or "conv" in lower_key
143
- elif architecture == "all":
144
- should_process = True
145
- else: # "auto" fallback
146
- should_process = True
147
 
148
  if should_process:
149
- if use_correction:
150
- # For VAE, use correction factors instead of LoRA
151
- corr = extract_correction_factors(weight, fp8_weight)
152
- if corr is not None:
153
- correction_factors[f"correction.{key}"] = corr
154
- stats["correction_layers"] += 1
155
- stats["processed_layers"] += 1
156
- else:
157
- # For other architectures, use LoRA
158
- stats["eligible_layers"] += 1
159
 
160
- # Handle 2D tensors with standard LoRA
161
- if weight.ndim == 2:
 
 
 
 
 
 
 
162
  try:
163
- # Adjust rank for smaller matrices
164
- adjusted_rank = lora_rank
165
- if min(weight.shape) < lora_rank:
166
- adjusted_rank = max(8, min(weight.shape) // 2)
 
167
 
168
- A, B = low_rank_decomposition(weight, rank=adjusted_rank)
169
  if A is not None and B is not None:
170
- lora_weights[f"lora_A.{key}"] = A
171
- lora_weights[f"lora_B.{key}"] = B
172
  stats["processed_layers"] += 1
173
  else:
174
  stats["skipped_layers"].append(f"{key}: decomposition failed")
175
  except Exception as e:
176
  stats["skipped_layers"].append(f"{key}: error - {str(e)}")
177
- # Skip 4D tensors for non-VAE architectures
178
- elif weight.ndim == 4:
179
- stats["skipped_layers"].append(f"{key}: 4D tensor skipped for non-VAE architecture")
180
  else:
181
  sd_fp8[key] = weight
182
  stats["skipped_layers"].append(f"{key}: non-float dtype")
183
 
 
 
 
 
 
 
184
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
185
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
186
 
187
- # Save FP8 model
188
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
189
 
190
- # Save LoRA weights if any were generated
191
  if lora_weights:
192
  lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors")
193
  lora_metadata = {
@@ -199,7 +335,6 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
199
  }
200
  save_file(lora_weights, lora_path, metadata=lora_metadata)
201
 
202
- # Save correction factors if any were generated (for VAE)
203
  if correction_factors:
204
  correction_path = os.path.join(output_dir, f"{base_name}-correction-{architecture}.safetensors")
205
  correction_metadata = {
@@ -210,24 +345,25 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
210
  }
211
  save_file(correction_factors, correction_path, metadata=correction_metadata)
212
 
213
- progress(0.9, desc="Saved FP8 and LoRA/correction files.")
214
- progress(1.0, desc="βœ… FP8 + LoRA/correction extraction complete!")
215
 
216
  stats_msg = f"FP8 ({fp8_format}) with precision recovery saved.\n"
217
- stats_msg += f"Architecture detected: {stats['architecture_detected']}\n"
 
 
218
 
219
- if use_correction:
220
  stats_msg += f"Correction factors generated for {stats['correction_layers']} layers."
221
  else:
222
- stats_msg += f"Processed {stats['processed_layers']}/{stats['eligible_layers']} eligible layers with LoRA rank {lora_rank}."
223
 
224
  if stats['processed_layers'] == 0 and stats['correction_layers'] == 0:
225
- stats_msg += "\n⚠️ No precision recovery weights were generated. Try a different architecture selection or parameters."
226
 
227
  return True, stats_msg, stats
228
 
229
  except Exception as e:
230
- import traceback
231
  error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
232
  return False, error_msg, None
233
 
@@ -336,12 +472,13 @@ def process_and_upload_fp8(
336
 
337
  # Determine which precision recovery file was generated
338
  precision_recovery_file = ""
339
- precision_recovery_type = "LoRA"
340
- if stats.get("correction_layers", 0) > 0:
341
  precision_recovery_file = f"{base_name}-correction-{architecture}.safetensors"
342
  precision_recovery_type = "Correction Factors"
343
- elif stats.get("processed_layers", 0) > 0:
344
  precision_recovery_file = f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors"
 
345
 
346
  readme = f"""---
347
  library_name: diffusers
@@ -358,42 +495,51 @@ tags:
358
  - **FP8 Format**: `{fp8_format.upper()}`
359
  - **Architecture**: {architecture}
360
  - **Precision Recovery Type**: {precision_recovery_type}
361
- - **Precision Recovery File**: `{precision_recovery_file}`
362
  - **FP8 File**: `{fp8_filename}`
 
363
  ## Usage (Inference)
364
  ```python
365
  from safetensors.torch import load_file
366
  import torch
 
367
  # Load FP8 model
368
  fp8_state = load_file("{fp8_filename}")
369
- # Load precision recovery file
370
- recovery_state = load_file("{precision_recovery_file}") if "{precision_recovery_file}" else {{}}
 
 
 
 
371
  # Reconstruct high-precision weights
372
  reconstructed = {{}}
373
  for key in fp8_state:
374
- fp8_weight = fp8_state[key].to(torch.float32)
 
 
375
  if recovery_state:
376
  # For LoRA approach
377
- if "lora_A" in recovery_state:
378
- if f"lora_A.{{key}}" in recovery_state and f"lora_B.{{key}}" in recovery_state:
379
- A = recovery_state[f"lora_A.{{key}}"].to(torch.float32)
380
- B = recovery_state[f"lora_B.{{key}}"].to(torch.float32)
381
- lora_weight = B @ A
382
- reconstructed[key] = fp8_weight + lora_weight
383
- else:
384
- reconstructed[key] = fp8_weight
385
  # For correction factor approach
386
  elif f"correction.{{key}}" in recovery_state:
387
  correction = recovery_state[f"correction.{{key}}"].to(torch.float32)
388
- reconstructed[key] = fp8_weight + correction
389
  else:
390
- reconstructed[key] = fp8_weight
391
  else:
392
- reconstructed[key] = fp8_weight
 
 
393
  ```
394
- > Requires PyTorch β‰₯ 2.1 for FP8 support.
 
 
395
  """
396
-
397
  with open(os.path.join(output_dir, "README.md"), "w") as f:
398
  f.write(readme)
399
 
@@ -407,17 +553,22 @@ for key in fp8_state:
407
  )
408
 
409
  progress(1.0, desc="βœ… Done!")
 
410
  result_html = f"""
411
  βœ… Success!
412
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
413
- Includes: FP8 model + precision recovery ({precision_recovery_type}).
 
414
  """
 
 
 
 
415
  return gr.HTML(result_html), "βœ… FP8 + precision recovery upload successful!", msg
416
 
417
  except Exception as e:
418
- import traceback
419
- error_details = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
420
- return None, error_details, ""
421
 
422
  finally:
423
  if temp_dir:
@@ -425,8 +576,8 @@ Includes: FP8 model + precision recovery ({precision_recovery_type}).
425
  shutil.rmtree(output_dir, ignore_errors=True)
426
 
427
  with gr.Blocks(title="FP8 + Precision Recovery Extractor") as demo:
428
- gr.Markdown("# πŸ”„ FP8 Pruner with Architecture-Specific Precision Recovery")
429
- gr.Markdown("Convert `.safetensors` β†’ **FP8** + **precision recovery** (LoRA or correction factors). Supports Hugging Face ↔ ModelScope.")
430
 
431
  with gr.Row():
432
  with gr.Column():
@@ -436,13 +587,15 @@ with gr.Blocks(title="FP8 + Precision Recovery Extractor") as demo:
436
 
437
  with gr.Accordion("Advanced Settings", open=True):
438
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
439
- lora_rank = gr.Slider(minimum=8, maximum=256, step=8, value=128, label="LoRA Rank (for text/transformers)")
 
440
  architecture = gr.Dropdown(
441
  choices=[
442
  ("Auto-detect architecture", "auto"),
443
  ("Text Encoder (LoRA)", "text_encoder"),
444
  ("Transformer blocks (LoRA)", "transformer"),
445
  ("VAE (Correction Factors)", "vae"),
 
446
  ("All layers (LoRA where applicable)", "all")
447
  ],
448
  value="auto",
@@ -451,11 +604,12 @@ with gr.Blocks(title="FP8 + Precision Recovery Extractor") as demo:
451
 
452
  with gr.Accordion("Authentication", open=False):
453
  hf_token = gr.Textbox(label="Hugging Face Token", type="password")
454
- modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE)
 
455
 
456
  with gr.Column():
457
  target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target")
458
- new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8")
459
  private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False)
460
 
461
  status_output = gr.Markdown()
@@ -485,31 +639,39 @@ with gr.Blocks(title="FP8 + Precision Recovery Extractor") as demo:
485
 
486
  gr.Examples(
487
  examples=[
488
- ["huggingface", "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder", "model.safetensors", "e5m2", 96, "text_encoder", "huggingface"],
489
- ["huggingface", "https://huggingface.co/stabilityai/sdxl-vae", "diffusion_pytorch_model.safetensors", "e4m3fn", 64, "vae", "huggingface"],
490
- ["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main", "unet_diffusion_pytorch_model.safetensors", "e5m2", 128, "transformer", "huggingface"]
 
 
 
491
  ],
492
- inputs=[source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture, target_type],
493
  label="Example Conversions"
494
  )
495
 
496
  gr.Markdown("""
497
- ## πŸ’‘ Architecture-Specific Precision Recovery
 
 
498
 
499
- This tool automatically selects the best precision recovery method based on architecture:
 
 
 
500
 
501
- - **Text Encoder & Transformers**: Uses **LoRA decomposition** (best for attention layers)
502
- - Higher ranks (96-128) recommended for text encoders
503
- - Medium ranks (64-128) for transformers
504
 
505
- - **VAE**: Uses **per-channel correction factors** (better for convolutional layers)
506
- - No rank parameter needed - automatically computes channel-wise corrections
507
- - Works with 4D convolutional weights that LoRA cannot handle well
 
508
 
509
- - **Auto-detect**: Analyzes model structure to select appropriate method
510
 
511
- > **Note**: VAE models typically contain 4D convolutional weights that don't work well with standard LoRA.
512
- > The correction factor approach used for VAE matches the successful method from the attached file.
 
513
  """)
514
 
515
  demo.launch()
 
10
  from safetensors.torch import load_file, save_file
11
  import torch
12
  import torch.nn.functional as F
13
+ import traceback
14
+ import math
15
  try:
16
  from modelscope.hub.file_download import model_file_download as ms_file_download
17
  from modelscope.hub.api import HubApi as ModelScopeApi
 
19
  except ImportError:
20
  MODELScope_AVAILABLE = False
21
 
22
+ def get_fp8_dtype(fp8_format):
23
+ """Get torch FP8 dtype."""
24
+ if fp8_format == "e5m2":
25
+ return torch.float8_e5m2
26
+ else:
27
+ return torch.float8_e4m3fn
28
+
29
+ def quantize_and_get_error(weight, fp8_dtype):
30
+ """Quantize weight to FP8 and return both quantized weight and error."""
31
+ weight_fp8 = weight.to(fp8_dtype)
32
+ weight_dequantized = weight_fp8.to(weight.dtype)
33
+ error = weight - weight_dequantized
34
+ return weight_fp8, error
35
+
36
+ def low_rank_decomposition_error(error_tensor, rank=32, min_error_threshold=1e-6):
37
+ """Decompose error tensor with proper rank reduction."""
38
+ if error_tensor.ndim not in [2, 4]:
39
  return None, None
40
 
41
  try:
42
+ # Calculate error magnitude
43
+ error_norm = torch.norm(error_tensor.float())
44
+ if error_norm < min_error_threshold:
 
 
 
 
 
 
45
  return None, None
46
 
47
+ # For 2D tensors (linear layers)
48
+ if error_tensor.ndim == 2:
49
+ U, S, Vh = torch.linalg.svd(error_tensor.float(), full_matrices=False)
50
+
51
+ # Calculate rank based on variance explained (keep 95% of error)
52
+ total_variance = torch.sum(S ** 2)
53
+ cumulative = torch.cumsum(S ** 2, dim=0)
54
+ keep_components = torch.sum(cumulative <= 0.95 * total_variance).item() + 1
55
+
56
+ # Limit rank to much smaller than original
57
+ max_rank = min(error_tensor.shape)
58
+ actual_rank = min(rank, keep_components, max_rank // 2)
59
+
60
+ if actual_rank < 2:
61
+ return None, None
62
+
63
+ A = Vh[:actual_rank, :].contiguous()
64
+ B = U[:, :actual_rank] @ torch.diag(S[:actual_rank]).contiguous()
65
+
66
+ return A, B
67
 
68
+ # For 4D convolutions
69
+ elif error_tensor.ndim == 4:
70
+ out_ch, in_ch, kH, kW = error_tensor.shape
71
+
72
+ # Reshape to 2D for decomposition
73
+ error_2d = error_tensor.view(out_ch, in_ch * kH * kW)
74
+ U, S, Vh = torch.linalg.svd(error_2d.float(), full_matrices=False)
75
+
76
+ # Calculate rank based on variance explained (90% for conv)
77
+ total_variance = torch.sum(S ** 2)
78
+ cumulative = torch.cumsum(S ** 2, dim=0)
79
+ keep_components = torch.sum(cumulative <= 0.90 * total_variance).item() + 1
80
+
81
+ # Use even lower rank for conv
82
+ max_rank = min(error_2d.shape)
83
+ actual_rank = min(rank // 2, keep_components, max_rank // 4)
84
+
85
+ if actual_rank < 2:
86
+ return None, None
87
+
88
+ A = Vh[:actual_rank, :].contiguous()
89
+ B = U[:, :actual_rank] @ torch.diag(S[:actual_rank]).contiguous()
90
+
91
+ # Reshape back for convolutional format
92
+ if kH == 1 and kW == 1:
93
+ B = B.view(out_ch, actual_rank, 1, 1)
94
+ A = A.view(actual_rank, in_ch, 1, 1)
95
+ else:
96
+ B = B.view(out_ch, actual_rank, 1, 1)
97
+ A = A.view(actual_rank, in_ch, kH, kW)
98
+
99
+ return A, B
100
+
101
  except Exception as e:
102
+ print(f"Error decomposition failed: {e}")
103
+
104
+ return None, None
105
 
106
  def extract_correction_factors(original_weight, fp8_weight):
107
+ """Extract simple correction factors for VAE."""
108
  with torch.no_grad():
 
109
  orig = original_weight.float()
110
  quant = fp8_weight.float()
 
 
111
  error = orig - quant
112
 
 
113
  error_norm = torch.norm(error)
114
  orig_norm = torch.norm(orig)
115
+ if orig_norm > 1e-6 and error_norm / orig_norm < 0.001:
116
  return None
117
 
118
+ # For 4D tensors (VAE), compute per-channel correction
119
  if orig.ndim == 4:
 
 
120
  channel_mean = error.mean(dim=tuple(i for i in range(1, orig.ndim)), keepdim=True)
121
  return channel_mean.to(original_weight.dtype)
 
122
  elif orig.ndim == 2:
123
  row_mean = error.mean(dim=1, keepdim=True)
124
  return row_mean.to(original_weight.dtype)
125
  else:
 
126
  return error.mean().to(original_weight.dtype)
127
 
128
+ def get_architecture_settings(architecture, base_rank):
129
+ """Get optimal settings for different architectures."""
130
+ settings = {
131
+ "text_encoder": {
132
+ "rank": base_rank,
133
+ "error_threshold": 5e-5,
134
+ "min_rank": 8,
135
+ "max_rank_factor": 0.4,
136
+ "method": "lora"
137
+ },
138
+ "transformer": {
139
+ "rank": base_rank,
140
+ "error_threshold": 1e-5,
141
+ "min_rank": 12,
142
+ "max_rank_factor": 0.35,
143
+ "method": "lora"
144
+ },
145
+ "vae": {
146
+ "rank": base_rank // 2,
147
+ "error_threshold": 1e-4,
148
+ "min_rank": 4,
149
+ "max_rank_factor": 0.3,
150
+ "method": "correction"
151
+ },
152
+ "unet_conv": {
153
+ "rank": base_rank // 3,
154
+ "error_threshold": 2e-5,
155
+ "min_rank": 8,
156
+ "max_rank_factor": 0.25,
157
+ "method": "lora"
158
+ },
159
+ "auto": {
160
+ "rank": base_rank,
161
+ "error_threshold": 1e-5,
162
+ "min_rank": 8,
163
+ "max_rank_factor": 0.3,
164
+ "method": "lora"
165
+ },
166
+ "all": {
167
+ "rank": base_rank,
168
+ "error_threshold": 1e-5,
169
+ "min_rank": 8,
170
+ "max_rank_factor": 0.3,
171
+ "method": "lora"
172
+ }
173
+ }
174
+
175
+ return settings.get(architecture, settings["auto"])
176
+
177
+ def should_process_layer(key, weight, architecture):
178
+ """Determine if layer should be processed for LoRA/correction."""
179
+ lower_key = key.lower()
180
+
181
+ # Skip biases and normalization layers
182
+ if 'bias' in key or 'norm' in key.lower() or 'bn' in key.lower():
183
+ return False
184
+
185
+ if weight.numel() < 100:
186
+ return False
187
+
188
+ # Architecture-specific filtering
189
+ if architecture == "text_encoder":
190
+ return ('text' in lower_key or 'emb' in lower_key or
191
+ 'encoder' in lower_key or 'attn' in lower_key)
192
+ elif architecture == "transformer":
193
+ return ('attn' in lower_key or 'transformer' in lower_key or
194
+ 'mlp' in lower_key or 'to_out' in lower_key)
195
+ elif architecture == "vae":
196
+ return ('vae' in lower_key or 'encoder' in lower_key or
197
+ 'decoder' in lower_key or 'conv' in lower_key)
198
+ elif architecture == "unet_conv":
199
+ return ('conv' in lower_key or 'resnet' in lower_key or
200
+ 'downsample' in lower_key or 'upsample' in lower_key)
201
+ elif architecture in ["all", "auto"]:
202
+ return True
203
+
204
+ return False
205
+
206
  def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_format, lora_rank=128, architecture="auto", progress=gr.Progress()):
207
+ progress(0.1, desc="Starting FP8 conversion with error recovery...")
208
  try:
209
  def read_safetensors_metadata(path):
210
  with open(path, 'rb') as f:
 
219
  state_dict = load_file(safetensors_path)
220
  progress(0.4, desc="Loaded weights.")
221
 
222
+ # Auto-detect architecture if needed
223
+ if architecture == "auto":
224
+ model_keys = " ".join(state_dict.keys()).lower()
225
+ if "vae" in model_keys or ("encoder" in model_keys and "decoder" in model_keys):
226
+ architecture = "vae"
227
+ elif "text" in model_keys or "emb" in model_keys:
228
+ architecture = "text_encoder"
229
+ elif "attn" in model_keys or "transformer" in model_keys:
230
+ architecture = "transformer"
231
+ elif "conv" in model_keys or "resnet" in model_keys:
232
+ architecture = "unet_conv"
233
+ else:
234
+ architecture = "all"
235
+
236
+ settings = get_architecture_settings(architecture, lora_rank)
237
+ fp8_dtype = get_fp8_dtype(fp8_format)
238
 
239
  sd_fp8 = {}
240
  lora_weights = {}
241
  correction_factors = {}
 
242
  stats = {
243
+ "total_layers": len(state_dict),
244
  "eligible_layers": 0,
245
+ "layers_with_error": 0,
246
  "processed_layers": 0,
247
  "correction_layers": 0,
248
  "skipped_layers": [],
249
+ "architecture": architecture,
250
+ "method": settings["method"],
251
+ "error_magnitudes": []
252
  }
253
 
254
+ total = len(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  for i, key in enumerate(state_dict):
257
  progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
258
  weight = state_dict[key]
 
259
 
260
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
261
+ # Quantize to FP8 and calculate error
262
+ weight_fp8, error = quantize_and_get_error(weight, fp8_dtype)
263
+ sd_fp8[key] = weight_fp8
264
+
265
+ # Calculate error magnitude
266
+ error_norm = torch.norm(error.float())
267
+ weight_norm = torch.norm(weight.float())
268
+ relative_error = (error_norm / weight_norm).item() if weight_norm > 0 else 0
269
 
270
+ stats["error_magnitudes"].append({
271
+ "key": key,
272
+ "relative_error": relative_error
273
+ })
274
 
275
+ # Check if layer should be processed
276
+ should_process = should_process_layer(key, weight, architecture)
 
 
 
 
 
 
 
 
277
 
278
  if should_process:
279
+ stats["eligible_layers"] += 1
280
+
281
+ # Only process if error is significant
282
+ if relative_error > settings["error_threshold"]:
283
+ stats["layers_with_error"] += 1
 
 
 
 
 
284
 
285
+ if settings["method"] == "correction":
286
+ # Use correction factors for VAE
287
+ correction = extract_correction_factors(weight, weight_fp8)
288
+ if correction is not None:
289
+ correction_factors[f"correction.{key}"] = correction
290
+ stats["correction_layers"] += 1
291
+ stats["processed_layers"] += 1
292
+ else:
293
+ # Use LoRA decomposition for other architectures
294
  try:
295
+ A, B = low_rank_decomposition_error(
296
+ error,
297
+ rank=settings["rank"],
298
+ min_error_threshold=settings["error_threshold"]
299
+ )
300
 
 
301
  if A is not None and B is not None:
302
+ lora_weights[f"lora_A.{key}"] = A.to(torch.float16)
303
+ lora_weights[f"lora_B.{key}"] = B.to(torch.float16)
304
  stats["processed_layers"] += 1
305
  else:
306
  stats["skipped_layers"].append(f"{key}: decomposition failed")
307
  except Exception as e:
308
  stats["skipped_layers"].append(f"{key}: error - {str(e)}")
309
+ else:
310
+ stats["skipped_layers"].append(f"{key}: error too small ({relative_error:.6f})")
 
311
  else:
312
  sd_fp8[key] = weight
313
  stats["skipped_layers"].append(f"{key}: non-float dtype")
314
 
315
+ # Calculate average error
316
+ if stats["error_magnitudes"]:
317
+ errors = [e["relative_error"] for e in stats["error_magnitudes"]]
318
+ stats["avg_error"] = sum(errors) / len(errors) if errors else 0
319
+ stats["max_error"] = max(errors) if errors else 0
320
+
321
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
322
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
323
 
 
324
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
325
 
326
+ # Save precision recovery weights
327
  if lora_weights:
328
  lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors")
329
  lora_metadata = {
 
335
  }
336
  save_file(lora_weights, lora_path, metadata=lora_metadata)
337
 
 
338
  if correction_factors:
339
  correction_path = os.path.join(output_dir, f"{base_name}-correction-{architecture}.safetensors")
340
  correction_metadata = {
 
345
  }
346
  save_file(correction_factors, correction_path, metadata=correction_metadata)
347
 
348
+ progress(0.9, desc="Saved FP8 and precision recovery files.")
349
+ progress(1.0, desc="βœ… FP8 + precision recovery extraction complete!")
350
 
351
  stats_msg = f"FP8 ({fp8_format}) with precision recovery saved.\n"
352
+ stats_msg += f"Architecture: {architecture}\n"
353
+ stats_msg += f"Method: {settings['method']}\n"
354
+ stats_msg += f"Average quantization error: {stats.get('avg_error', 0):.6f}\n"
355
 
356
+ if settings["method"] == "correction":
357
  stats_msg += f"Correction factors generated for {stats['correction_layers']} layers."
358
  else:
359
+ stats_msg += f"LoRA generated for {stats['processed_layers']}/{stats['eligible_layers']} eligible layers (rank {lora_rank})."
360
 
361
  if stats['processed_layers'] == 0 and stats['correction_layers'] == 0:
362
+ stats_msg += "\n⚠️ No precision recovery weights were generated. FP8 quantization error may be too small."
363
 
364
  return True, stats_msg, stats
365
 
366
  except Exception as e:
 
367
  error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
368
  return False, error_msg, None
369
 
 
472
 
473
  # Determine which precision recovery file was generated
474
  precision_recovery_file = ""
475
+ precision_recovery_type = ""
476
+ if stats.get("method") == "correction" and stats.get("correction_layers", 0) > 0:
477
  precision_recovery_file = f"{base_name}-correction-{architecture}.safetensors"
478
  precision_recovery_type = "Correction Factors"
479
+ elif stats.get("method") == "lora" and stats.get("processed_layers", 0) > 0:
480
  precision_recovery_file = f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors"
481
+ precision_recovery_type = "LoRA"
482
 
483
  readme = f"""---
484
  library_name: diffusers
 
495
  - **FP8 Format**: `{fp8_format.upper()}`
496
  - **Architecture**: {architecture}
497
  - **Precision Recovery Type**: {precision_recovery_type}
498
+ - **Precision Recovery File**: `{precision_recovery_file}` if available
499
  - **FP8 File**: `{fp8_filename}`
500
+
501
  ## Usage (Inference)
502
  ```python
503
  from safetensors.torch import load_file
504
  import torch
505
+
506
  # Load FP8 model
507
  fp8_state = load_file("{fp8_filename}")
508
+
509
+ # Load precision recovery file if available
510
+ recovery_state = {{}}
511
+ if "{precision_recovery_file}":
512
+ recovery_state = load_file("{precision_recovery_file}")
513
+
514
  # Reconstruct high-precision weights
515
  reconstructed = {{}}
516
  for key in fp8_state:
517
+ # Dequantize FP8 to target precision
518
+ fp_weight = fp8_state[key].to(torch.float32)
519
+
520
  if recovery_state:
521
  # For LoRA approach
522
+ if f"lora_A.{{key}}" in recovery_state and f"lora_B.{{key}}" in recovery_state:
523
+ A = recovery_state[f"lora_A.{{key}}"].to(torch.float32)
524
+ B = recovery_state[f"lora_B.{{key}}"].to(torch.float32)
525
+ error_correction = B @ A
526
+ reconstructed[key] = fp_weight + error_correction
 
 
 
527
  # For correction factor approach
528
  elif f"correction.{{key}}" in recovery_state:
529
  correction = recovery_state[f"correction.{{key}}"].to(torch.float32)
530
+ reconstructed[key] = fp_weight + correction
531
  else:
532
+ reconstructed[key] = fp_weight
533
  else:
534
+ reconstructed[key] = fp_weight
535
+
536
+ print("Model reconstructed with FP8 error recovery")
537
  ```
538
+
539
+ > **Note**: This precision recovery targets FP8 quantization errors.
540
+ > Average quantization error: {stats.get('avg_error', 0):.6f}
541
  """
542
+
543
  with open(os.path.join(output_dir, "README.md"), "w") as f:
544
  f.write(readme)
545
 
 
553
  )
554
 
555
  progress(1.0, desc="βœ… Done!")
556
+
557
  result_html = f"""
558
  βœ… Success!
559
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
560
+ Includes: FP8 model + precision recovery ({precision_recovery_type}).
561
+ Average quantization error: {stats.get('avg_error', 0):.6f}
562
  """
563
+
564
+ if stats['processed_layers'] > 0 or stats['correction_layers'] > 0:
565
+ result_html += f"<br>Precision recovery applied to {stats['processed_layers'] + stats['correction_layers']} layers."
566
+
567
  return gr.HTML(result_html), "βœ… FP8 + precision recovery upload successful!", msg
568
 
569
  except Exception as e:
570
+ error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
571
+ return None, error_msg, ""
 
572
 
573
  finally:
574
  if temp_dir:
 
576
  shutil.rmtree(output_dir, ignore_errors=True)
577
 
578
  with gr.Blocks(title="FP8 + Precision Recovery Extractor") as demo:
579
+ gr.Markdown("# πŸ”„ FP8 Converter with Architecture-Specific Precision Recovery")
580
+ gr.Markdown("Convert models to **FP8** with **error-based precision recovery**.")
581
 
582
  with gr.Row():
583
  with gr.Column():
 
587
 
588
  with gr.Accordion("Advanced Settings", open=True):
589
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
590
+ lora_rank = gr.Slider(minimum=8, maximum=256, step=8, value=128,
591
+ label="LoRA Rank (for text/transformers)")
592
  architecture = gr.Dropdown(
593
  choices=[
594
  ("Auto-detect architecture", "auto"),
595
  ("Text Encoder (LoRA)", "text_encoder"),
596
  ("Transformer blocks (LoRA)", "transformer"),
597
  ("VAE (Correction Factors)", "vae"),
598
+ ("UNet Convolutions (LoRA)", "unet_conv"),
599
  ("All layers (LoRA where applicable)", "all")
600
  ],
601
  value="auto",
 
604
 
605
  with gr.Accordion("Authentication", open=False):
606
  hf_token = gr.Textbox(label="Hugging Face Token", type="password")
607
+ modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password",
608
+ visible=MODELScope_AVAILABLE)
609
 
610
  with gr.Column():
611
  target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target")
612
+ new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8-precision")
613
  private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False)
614
 
615
  status_output = gr.Markdown()
 
639
 
640
  gr.Examples(
641
  examples=[
642
+ ["huggingface", "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder",
643
+ "model.safetensors", "e5m2", 96, "text_encoder"],
644
+ ["huggingface", "https://huggingface.co/stabilityai/sdxl-vae",
645
+ "diffusion_pytorch_model.safetensors", "e4m3fn", 64, "vae"],
646
+ ["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main",
647
+ "unet_diffusion_pytorch_model.safetensors", "e5m2", 128, "transformer"]
648
  ],
649
+ inputs=[source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture],
650
  label="Example Conversions"
651
  )
652
 
653
  gr.Markdown("""
654
+ ## 🎯 What This Tool Does
655
+
656
+ Unlike traditional LoRA fine-tuning, this tool:
657
 
658
+ 1. **Quantizes** the model to FP8 (loses precision)
659
+ 2. **Measures** the quantization error for each weight
660
+ 3. **Extracts recovery weights** that specifically recover this error
661
+ 4. **Only applies** recovery where error is significant (>0.001%)
662
 
663
+ ## πŸ’‘ Recommended Settings
 
 
664
 
665
+ - **Text Encoders**: rank 64-96 (text is sensitive)
666
+ - **Transformers**: rank 96-128
667
+ - **VAE**: Uses correction factors (no rank needed)
668
+ - **UNet Convolutions**: rank 32-64
669
 
670
+ ## ⚠️ Important Notes
671
 
672
+ - This recovers **FP8 quantization errors**, not fine-tuning changes
673
+ - If FP8 error is tiny (<0.0001%), recovery may not be generated
674
+ - Higher rank β‰  better for error recovery (use recommended ranges)
675
  """)
676
 
677
  demo.launch()