phazei commited on
Commit
c1f87d4
·
1 Parent(s): 0873e19

Update conversion script and include more triple weights to fp8

Browse files
convert_safetensors_to_fp8.py CHANGED
@@ -1,192 +1,231 @@
1
- import torch
2
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import argparse
4
- from safetensors.torch import save_file
5
- from safetensors import safe_open
6
- from collections import OrderedDict
7
- from tqdm import tqdm
8
- import gc
9
 
10
- def should_convert_to_fp8(tensor_name: str) -> bool:
11
- """
12
- Conservative FP8 conversion policy:
13
- - Only convert .weight tensors (not biases)
14
- - Only convert transformer block layers
15
- - Skip normalization layers (precision sensitive)
16
- """
17
- if not tensor_name.endswith(".weight"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  return False
19
- if not "blocks." in tensor_name:
20
  return False
21
- if "cross_attn" in tensor_name or \
22
- "ffn" in tensor_name or \
23
- "self_attn" in tensor_name or \
24
- "linear" in tensor_name: # Added "linear" for broader coverage
25
- if ".norm_k.weight" in tensor_name or \
26
- ".norm_q.weight" in tensor_name or \
27
- ".norm.weight" in tensor_name:
28
- return False
29
- return True
 
 
30
  return False
31
 
32
- def convert_safetensors_to_fp8(input_path: str, fp8_variant: str = "e4m3fn", device: str = "cuda"):
 
 
 
 
 
 
 
 
33
  """
34
- Convert a single SafeTensors file to FP8, saving in the same directory
35
- with fp8_e*** appended to the filename.
36
-
37
- Args:
38
- input_path: Path to input .safetensors file
39
- fp8_variant: "e4m3fn" or "e5m2"
40
- device: Device to use for conversion ("cuda" or "cpu")
41
  """
42
- if not os.path.exists(input_path):
43
- raise FileNotFoundError(f"Input file not found: {input_path}")
44
-
45
- if not input_path.endswith('.safetensors'):
46
- raise ValueError("Input file must be a .safetensors file")
47
-
48
- # Determine target dtype
49
- if fp8_variant == "e5m2":
50
- target_dtype = torch.float8_e5m2
51
- elif fp8_variant == "e4m3fn":
52
- target_dtype = torch.float8_e4m3fn
53
- else:
54
- raise ValueError(f"Unsupported FP8 variant: {fp8_variant}. Use 'e4m3fn' or 'e5m2'")
55
-
56
- # Generate output path
57
- input_dir = os.path.dirname(input_path)
58
- input_filename = os.path.basename(input_path)
59
- name_without_ext = os.path.splitext(input_filename)[0]
60
- output_filename = f"{name_without_ext}_fp8_{fp8_variant}.safetensors"
61
- output_path = os.path.join(input_dir, output_filename)
62
-
63
- print(f"Converting: {input_path}")
64
- print(f"Output: {output_path}")
65
- print(f"Target dtype: {target_dtype}")
66
- print(f"Device: {device}")
67
-
68
- # Check if output already exists
69
- if os.path.exists(output_path):
70
- response = input(f"Output file {output_path} already exists. Overwrite? (y/N): ")
71
- if response.lower() != 'y':
72
- print("Conversion cancelled.")
73
- return
74
-
75
- converted_state_dict = OrderedDict()
76
- conversion_stats = {"converted": 0, "skipped": 0, "total": 0}
77
-
78
- try:
79
- # Load and process tensors
80
- with safe_open(input_path, framework="pt", device="cpu") as f:
81
- tensor_names = list(f.keys())
82
- conversion_stats["total"] = len(tensor_names)
83
-
84
- print(f"Processing {len(tensor_names)} tensors...")
85
-
86
- for tensor_name in tqdm(tensor_names, desc="Converting tensors"):
87
- original_tensor = f.get_tensor(tensor_name)
88
-
89
- if should_convert_to_fp8(tensor_name):
90
- # Convert to FP8
91
- converted_tensor = original_tensor.to(device).to(target_dtype).to("cpu")
92
- converted_state_dict[tensor_name] = converted_tensor
93
- conversion_stats["converted"] += 1
94
-
95
- # Clean up GPU memory if using CUDA
96
- if device == "cuda" and torch.cuda.is_available():
97
- del converted_tensor
98
- else:
99
- # Keep original precision
100
- converted_state_dict[tensor_name] = original_tensor.to("cpu")
101
- conversion_stats["skipped"] += 1
102
-
103
- # Save converted model
104
- print(f"Saving converted model to: {output_path}")
105
- save_file(converted_state_dict, output_path)
106
-
107
- # Print conversion statistics
108
- print(f"\nConversion complete!")
109
- print(f"Total tensors: {conversion_stats['total']}")
110
- print(f"Converted to FP8: {conversion_stats['converted']}")
111
- print(f"Kept original precision: {conversion_stats['skipped']}")
112
- print(f"Conversion rate: {conversion_stats['converted']/conversion_stats['total']*100:.1f}%")
113
-
114
- # Calculate file sizes
115
- input_size = os.path.getsize(input_path) / (1024**3) # GB
116
- output_size = os.path.getsize(output_path) / (1024**3) # GB
117
- size_reduction = (1 - output_size/input_size) * 100
118
-
119
- print(f"\nFile size comparison:")
120
- print(f"Original: {input_size:.2f} GB")
121
- print(f"Converted: {output_size:.2f} GB")
122
- print(f"Size reduction: {size_reduction:.1f}%")
123
-
124
- except Exception as e:
125
- print(f"Error during conversion: {e}")
126
- if os.path.exists(output_path):
127
- print(f"Removing incomplete output file: {output_path}")
128
- os.remove(output_path)
129
- raise
130
-
131
- finally:
132
- # Clean up memory
133
- if 'converted_state_dict' in locals():
134
- del converted_state_dict
135
- if 'original_tensor' in locals():
136
- del original_tensor
137
- gc.collect()
138
- if torch.cuda.is_available():
139
- torch.cuda.empty_cache()
140
 
141
  def main():
142
- parser = argparse.ArgumentParser(
143
- description="Convert SafeTensors model to FP8 precision",
144
- formatter_class=argparse.RawDescriptionHelpFormatter,
145
- epilog="""
146
- Examples:
147
- python convert_safetensors_to_fp8.py model.safetensors
148
- python convert_safetensors_to_fp8.py model.safetensors --variant e5m2
149
- python convert_safetensors_to_fp8.py model.safetensors --device cpu
150
- """
151
- )
152
-
153
- parser.add_argument(
154
- "input_file",
155
- help="Path to input .safetensors file"
156
- )
157
-
158
- parser.add_argument(
159
- "--variant",
160
- choices=["e4m3fn", "e5m2"],
161
- default="e4m3fn",
162
- help="FP8 variant to use (default: e4m3fn)"
163
- )
164
-
165
- parser.add_argument(
166
- "--device",
167
- choices=["cuda", "cpu"],
168
- default="cuda" if torch.cuda.is_available() else "cpu",
169
- help="Device to use for conversion"
170
  )
171
-
172
- args = parser.parse_args()
173
-
174
- print(f"=== SafeTensors to FP8 Converter ===")
175
- print(f"PyTorch version: {torch.__version__}")
176
- print(f"CUDA available: {torch.cuda.is_available()}")
177
- if torch.cuda.is_available():
178
- print(f"CUDA device: {torch.cuda.get_device_name()}")
179
- print()
180
-
181
- try:
182
- convert_safetensors_to_fp8(
183
- input_path=args.input_file,
184
- fp8_variant=args.variant,
185
- device=args.device
186
- )
187
- except Exception as e:
188
- print(f"Conversion failed: {e}")
189
- exit(1)
190
 
191
  if __name__ == "__main__":
192
- main()
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Mixed-FP8 safetensors converter for Hunyuan-Foley checkpoints.
4
+
5
+ - Converts selected .weight tensors to FP8 storage (E5M2 by default on pre-Hopper).
6
+ - Keeps math in FP16/BF16; this is a storage-only change in the file.
7
+ - Honors existing FP8 tensors in the input unless --recode-fp8 is set.
8
+ - Skips norms, biases, visual_proj.*, final_layer.* by design.
9
+ - Optional --aggressive converts modulation linears too.
10
+
11
+ USAGE (simple):
12
+ python convert_fp8.py in.safetensors [out.safetensors] # out is optional
13
+
14
+ USAGE (flags):
15
+ python convert_fp8.py in.safetensors out.safetensors --fp8 auto --aggressive
16
+
17
+ Notes:
18
+ - “auto” picks FP8_E5M2 on SM < 90 (e.g., 3090), else FP8_E4M3FN.
19
+ - You can force a format: --fp8 e5m2 | e4m3fn
20
+ - Dry run: add --dry to print what would change without writing.
21
+ """
22
+
23
  import argparse
24
+ import re
25
+ from typing import Dict, Tuple
26
+ from pathlib import Path
 
 
27
 
28
+ import torch
29
+ from safetensors.torch import load_file, save_file
30
+
31
+
32
+ # --------------------------- Policy (names) ---------------------------
33
+
34
+ # Skip norms/bias and sensitive endpoints explicitly
35
+ _DENY_SUBSTRINGS = (
36
+ ".bias", ".norm", "q_norm.", "k_norm.",
37
+ "final_layer.", "visual_proj.",
38
+ )
39
+
40
+ # Allowed patterns target this architecture’s large linears
41
+ _ALLOW_PATTERNS = tuple(re.compile(p) for p in (
42
+ # Single-stream blocks
43
+ r"^single_blocks\.\d+\.linear1\.weight$",
44
+ r"^single_blocks\.\d+\.linear2\.w[123]\.weight$", # w1/w2/w3
45
+ r"^single_blocks\.\d+\.linear_qkv\.weight$",
46
+ r"^single_blocks\.\d+\.modulation\.linear\.weight$", # gated by --aggressive
47
+
48
+ # Triple-stream blocks: MLPs (dominant size)
49
+ r"^triple_blocks\.\d+\.audio_mlp\.fc[12]\.weight$",
50
+ r"^triple_blocks\.\d+\.v_cond_mlp\.fc[12]\.weight$",
51
+
52
+ # Triple-stream blocks: attention projections
53
+ r"^triple_blocks\.\d+\.(audio_self_attn_qkv|v_cond_attn_qkv|text_cross_kv)\.weight$",
54
+ r"^triple_blocks\.\d+\.(audio_self_proj|v_cond_self_proj)\.weight$",
55
+
56
+ # r"^triple_blocks\.\d+\.(audio_cross_q|v_cond_cross_q)\.weight$",
57
+ # r"^triple_blocks\.\d+\.(audio_cross_proj|v_cond_cross_proj)\.weight$",
58
+
59
+ # Triple-stream blocks: modulation linears (gated)
60
+ r"^triple_blocks\.\d+\.(audio_mod|v_cond_mod)\.linear\.weight$",
61
+ ))
62
+
63
+
64
+ # --------------------------- Helpers ---------------------------
65
+
66
+ def default_out_path(in_path: str, tgt_dtype: torch.dtype) -> str:
67
+ """<in>_fp8_<e5m2|e4m3fn>.safetensors (idempotent if already suffixed)."""
68
+ suffix = "e5m2" if tgt_dtype == torch.float8_e5m2 else "e4m3fn"
69
+ p = Path(in_path)
70
+ stem = re.sub(r"_fp8_e(5m2|4m3fn)$", "", p.stem) # strip prior suffix
71
+ ext = p.suffix or ".safetensors"
72
+ return str(p.with_name(f"{stem}_fp8_{suffix}{ext}"))
73
+
74
+
75
+ def pick_fp8_dtype(fp8_mode: str) -> torch.dtype:
76
+ """Pick target FP8 dtype."""
77
+ m = fp8_mode.lower()
78
+ if m == "e5m2":
79
+ return torch.float8_e5m2
80
+ if m == "e4m3fn":
81
+ return torch.float8_e4m3fn
82
+ # auto
83
+ try:
84
+ major, _ = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
85
+ except Exception:
86
+ major = 0
87
+ return torch.float8_e5m2 if major < 9 else torch.float8_e4m3fn
88
+
89
+
90
+ def bytes_of(t: torch.Tensor) -> int:
91
+ """Size in bytes (FP8=1 byte/elt)."""
92
+ if t.dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
93
+ return t.numel() * 1
94
+ return t.numel() * t.element_size()
95
+
96
+
97
+ def human_gb(nbytes: int) -> float:
98
+ return nbytes / (1024 ** 3)
99
+
100
+
101
+ def _is_denied(name: str) -> bool:
102
+ return any(tok in name for tok in _DENY_SUBSTRINGS)
103
+
104
+
105
+ def should_convert_to_fp8(name: str, aggressive: bool) -> bool:
106
+ """Match names for conversion, with modulation linears gated by --aggressive."""
107
+ if not name.endswith(".weight"):
108
  return False
109
+ if _is_denied(name):
110
  return False
111
+
112
+ for pat in _ALLOW_PATTERNS:
113
+ if pat.search(name):
114
+ # Gate modulation linears (single/triple) behind --aggressive
115
+ if (
116
+ ".modulation.linear.weight" in name
117
+ or ".audio_mod.linear.weight" in name
118
+ or ".v_cond_mod.linear.weight" in name
119
+ ):
120
+ return aggressive
121
+ return True
122
  return False
123
 
124
+
125
+ # --------------------------- Core ---------------------------
126
+
127
+ def convert_state_dict(
128
+ sd: Dict[str, torch.Tensor],
129
+ fp8_mode: str = "auto",
130
+ aggressive: bool = False,
131
+ recode_fp8: bool = False,
132
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]:
133
  """
134
+ Convert selected weights to FP8 storage according to the policy.
135
+ Honors existing FP8 unless recode_fp8=True.
136
+ Returns (new_sd, stats) with byte counts.
 
 
 
 
137
  """
138
+ tgt_dtype = pick_fp8_dtype(fp8_mode)
139
+ out: Dict[str, torch.Tensor] = {}
140
+ stats = {
141
+ "total_before": 0,
142
+ "total_after": 0,
143
+ "converted_count": 0,
144
+ "kept_fp8_count": 0,
145
+ "skipped_count": 0,
146
+ }
147
+
148
+ for name, tensor in sd.items():
149
+ before = bytes_of(tensor)
150
+ stats["total_before"] += before
151
+
152
+ # Respect existing FP8 unless asked to recode
153
+ if tensor.dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
154
+ if recode_fp8:
155
+ out[name] = tensor.to(dtype=tgt_dtype)
156
+ stats["converted_count"] += 1
157
+ else:
158
+ out[name] = tensor
159
+ stats["kept_fp8_count"] += 1
160
+ stats["total_after"] += bytes_of(out[name])
161
+ continue
162
+
163
+ # Decide conversion
164
+ if should_convert_to_fp8(name, aggressive):
165
+ out[name] = tensor.to(dtype=tgt_dtype)
166
+ stats["converted_count"] += 1
167
+ else:
168
+ out[name] = tensor
169
+ stats["skipped_count"] += 1
170
+
171
+ stats["total_after"] += bytes_of(out[name])
172
+
173
+ return out, stats
174
+
175
+
176
+ # --------------------------- CLI ---------------------------
177
+
178
+ def parse_args() -> argparse.Namespace:
179
+ p = argparse.ArgumentParser(description="Convert selected weights in a safetensors file to FP8 storage.")
180
+ p.add_argument("in_path", help="Input .safetensors")
181
+ p.add_argument("out_path", nargs="?", help="Output .safetensors (optional)")
182
+ p.add_argument("--fp8", choices=["auto", "e5m2", "e4m3fn"], default="auto",
183
+ help='Target FP8 storage dtype: "auto" (default), "e5m2", or "e4m3fn"')
184
+ p.add_argument("--aggressive", action="store_true",
185
+ help="Also convert modulation linears (audio_mod/v_cond_mod + single modulation.linear).")
186
+ p.add_argument("--recode-fp8", action="store_true",
187
+ help="Re-encode existing FP8 tensors to the chosen target dtype.")
188
+ p.add_argument("--dry", action="store_true",
189
+ help="Dry run: report only; do not write output file.")
190
+ return p.parse_args()
191
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  def main():
194
+ args = parse_args()
195
+
196
+ print(f"[load] {args.in_path}")
197
+ sd = load_file(args.in_path)
198
+
199
+ tgt = pick_fp8_dtype(args.fp8)
200
+ if not args.out_path:
201
+ args.out_path = default_out_path(args.in_path, tgt)
202
+ print(f"[auto-out] {args.out_path}")
203
+
204
+ print(f"[policy] fp8_mode={args.fp8} -> {str(tgt).replace('torch.','')}, "
205
+ f"aggressive={args.aggressive}, recode_fp8={args.recode_fp8}")
206
+
207
+ new_sd, stats = convert_state_dict(
208
+ sd,
209
+ fp8_mode=args.fp8,
210
+ aggressive=args.aggressive,
211
+ recode_fp8=args.recode_fp8,
 
 
 
 
 
 
 
 
 
 
212
  )
213
+
214
+ saved = stats["total_before"] - stats["total_after"]
215
+ print(f"[stats] tensors: {len(sd)}")
216
+ print(f"[stats] converted: {stats['converted_count']} | kept_fp8: {stats['kept_fp8_count']} "
217
+ f"| skipped: {stats['skipped_count']}")
218
+ print(f"[bytes] before={human_gb(stats['total_before']):.3f} GiB | "
219
+ f"after={human_gb(stats['total_after']):.3f} GiB | saved={human_gb(saved):.3f} GiB")
220
+
221
+ if args.dry:
222
+ print("[dry] no file written")
223
+ return
224
+
225
+ print(f"[save] {args.out_path}")
226
+ save_file(new_sd, args.out_path)
227
+ print("[done]")
228
+
 
 
 
229
 
230
  if __name__ == "__main__":
231
+ main()
fp8info.txt ADDED
The diff for this file is too large to render. See raw diff
 
hunyuanvideo_foley_fp8_e4m3fn.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ce2af6afbe910197e48020261d5ce2af30267b3d15e2f8f385570cc4049e3934
3
- size 6318689840
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b2fa56b9d9bd0c89f3d7e486f9f00032f247cf10dd860cc6d7f0b734bca8a31
3
+ size 5341941120
hunyuanvideo_foley_fp8_e5m2.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:43983270051d1b0bb3078f97040fd7bf84ffbee607d79ea1559c21a64741e5fc
3
- size 6318689840
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fa9d76e614aff32cf089aa8cf249b18547670c4be656e73a51804caec0f7963
3
+ size 5341941120