Qwen3-235B-A22B-INT4-W4A16 / quantization_script.py
justinjja's picture
Create quantization_script.py
553fefa verified
raw
history blame contribute delete
2.17 kB
#!/usr/bin/env python
"""
Quantize Qwen/Qwen3-235B-A22B (MoE) to INT4-W4A16 on a CPU-only machine.
Output: Qwen3-235B-A22B-INT4-W4A16
"""
import os, warnings
import torch
from accelerate import init_empty_weights, infer_auto_device_map
from transformers import AutoModelForCausalLM
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
# --------------------------------------------------------------------
# Optional: silence CUDA warnings on machines without a GPU
os.environ["CUDA_VISIBLE_DEVICES"] = ""
warnings.filterwarnings("ignore", message="Can't initialize NVML")
model_id = "Qwen/Qwen3-235B-A22B"
output_dir = "Qwen3-235B-A22B-INT4-W4A16"
# --------------------------------------------------------------------
# 1) Build a dummy model (no weights) to infer a device map
with init_empty_weights():
dummy = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, trust_remote_code=True
)
device_map = infer_auto_device_map(
dummy, no_split_module_classes=dummy._no_split_modules
)
del dummy
# force every sub-module onto CPU
device_map = {name: "cpu" for name in device_map}
# --------------------------------------------------------------------
# 2) Load the full model weights (BF16) on CPU
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device_map,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
# --------------------------------------------------------------------
# 3) Quantization recipe — keep only router gates + lm_head in BF16
recipe = QuantizationModifier(
targets="Linear",
scheme="W4A16",
ignore=[
"lm_head",
r"re:.*\.mlp\.gate$", # router gates (tiny but accuracy-critical)
],
dampening_frac=0.1, # mitigates INT4 noise
)
# --------------------------------------------------------------------
# 4) One-shot quantization
oneshot(
model=model,
recipe=recipe,
output_dir=output_dir,
)
print(f"\n✅ Quantized model written to: {output_dir}")
print( " (router gates & lm_head remain in BF16; everything else INT4 W4A16)")