| |
| """ |
| Phase 6: INT8 Weight-Only Quantization for All Modules |
| ======================================================= |
| Applies torchao int8_weight_only quantization to each module, |
| re-exports to torch.export, and lowers to ExecuTorch .pte. |
| |
| int8_weight_only is INSTANT β no calibration data needed. |
| """ |
|
|
| import sys |
| import os |
| import copy |
| import time |
| import gc |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| MODEL_PATH = os.path.expanduser("~/Documents/Qwen3-TTS/models/1.7B-Base") |
| VENV_SITE = os.path.expanduser("~/Documents/Qwen3-TTS/.venv/lib/python3.10/site-packages") |
| QWEN_TTS_SRC = os.path.expanduser("~/Documents/Qwen3-TTS") |
| OUTPUT_DIR = os.path.expanduser("~/Documents/Qwen3-TTS-ExecuTorch/exported") |
|
|
| if VENV_SITE not in sys.path: |
| sys.path.insert(0, VENV_SITE) |
| if QWEN_TTS_SRC not in sys.path: |
| sys.path.insert(0, QWEN_TTS_SRC) |
|
|
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| from torchao.quantization import quantize_, int8_weight_only |
| from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig |
| from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
|
|
| print("=" * 70) |
| print("PHASE 6: INT8 Weight-Only Quantization") |
| print("=" * 70) |
|
|
|
|
| def export_and_lower_int8(module, example_args, name, output_dir): |
| """Quantize, export, and lower a module to INT8 .pte.""" |
| |
| print(f" Applying int8_weight_only quantization...") |
| t0 = time.time() |
| quantize_(module, int8_weight_only()) |
| print(f" Quantized in {time.time() - t0:.1f}s") |
|
|
| |
| print(f" Running torch.export...") |
| t0 = time.time() |
| exported = torch.export.export(module, example_args, strict=False) |
| print(f" Exported in {time.time() - t0:.1f}s ({len(exported.graph.nodes)} nodes)") |
|
|
| |
| print(f" Lowering to ExecuTorch .pte...") |
| t0 = time.time() |
| edge = to_edge_transform_and_lower( |
| exported, |
| compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| partitioner=[XnnpackPartitioner()], |
| ) |
| et_program = edge.to_executorch() |
|
|
| pte_path = os.path.join(output_dir, f"{name}_int8.pte") |
| with open(pte_path, "wb") as f: |
| f.write(et_program.buffer) |
|
|
| pte_size = os.path.getsize(pte_path) / 1e6 |
| print(f" Saved: {pte_path} ({pte_size:.1f} MB)") |
| print(f" Lowered in {time.time() - t0:.1f}s") |
| return pte_size |
|
|
|
|
| |
|
|
| print("\n[0/4] Loading base model...") |
| from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig |
| from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration |
|
|
| config = Qwen3TTSConfig.from_pretrained(MODEL_PATH) |
| model = Qwen3TTSForConditionalGeneration.from_pretrained( |
| MODEL_PATH, config=config, dtype=torch.float32, |
| attn_implementation="sdpa", device_map="cpu", |
| ) |
| model.eval() |
| print(" Model loaded.") |
|
|
| results = {} |
|
|
| |
| |
| |
|
|
| print("\n[1/4] Speaker Encoder INT8") |
|
|
| |
| class _ExplicitPadConv1d(nn.Module): |
| def __init__(self, original_conv, pad_left, pad_right, pad_mode): |
| super().__init__() |
| self.conv = nn.Conv1d( |
| in_channels=original_conv.in_channels, out_channels=original_conv.out_channels, |
| kernel_size=original_conv.kernel_size[0], stride=original_conv.stride[0], |
| padding=0, dilation=original_conv.dilation[0], groups=original_conv.groups, |
| bias=original_conv.bias is not None) |
| self.conv.weight = original_conv.weight |
| if original_conv.bias is not None: |
| self.conv.bias = original_conv.bias |
| self.pad_left = pad_left |
| self.pad_right = pad_right |
| self.pad_mode = pad_mode |
|
|
| def forward(self, x): |
| if self.pad_left > 0 or self.pad_right > 0: |
| x = F.pad(x, (self.pad_left, self.pad_right), mode=self.pad_mode) |
| return self.conv(x) |
|
|
|
|
| class SpeakerEncoderForExport_Q(nn.Module): |
| def __init__(self, original_encoder): |
| super().__init__() |
| self.encoder = copy.deepcopy(original_encoder) |
| self._fix_conv_padding(self.encoder) |
|
|
| def _fix_conv_padding(self, module): |
| for name, child in module.named_children(): |
| if isinstance(child, nn.Conv1d) and child.padding == 'same': |
| k = child.kernel_size[0] |
| d = child.dilation[0] |
| pad_total = d * (k - 1) |
| new_conv = _ExplicitPadConv1d(child, pad_total // 2, pad_total - pad_total // 2, child.padding_mode) |
| setattr(module, name, new_conv) |
| else: |
| self._fix_conv_padding(child) |
|
|
| def forward(self, mel_input): |
| return self.encoder(mel_input) |
|
|
|
|
| FIXED_MEL_FRAMES = 469 |
| se = SpeakerEncoderForExport_Q(model.speaker_encoder) |
| se.eval() |
| se_args = (torch.randn(1, FIXED_MEL_FRAMES, 128),) |
| fp32_size = os.path.getsize(os.path.join(OUTPUT_DIR, "speaker_encoder.pte")) / 1e6 |
|
|
| try: |
| int8_size = export_and_lower_int8(se, se_args, "speaker_encoder", OUTPUT_DIR) |
| results["speaker_encoder"] = {"fp32": fp32_size, "int8": int8_size} |
| except Exception as e: |
| print(f" FAILED: {e}") |
| results["speaker_encoder"] = {"fp32": fp32_size, "int8": None, "error": str(e)} |
|
|
| del se; gc.collect() |
|
|
| |
| |
| |
|
|
| print("\n[2/4] Talker INT8") |
|
|
| |
| |
| |
| import importlib.util |
| spec = importlib.util.spec_from_file_location( |
| "export_talker_mod", |
| os.path.join(os.path.dirname(os.path.abspath(__file__)), "export_talker.py") |
| ) |
| |
| |
| |
|
|
| |
| |
| |
|
|
| MAX_SEQ_LEN = 2048; NUM_LAYERS = 28; NUM_KV_HEADS = 8; HEAD_DIM = 128 |
| NUM_HEADS = 16; HIDDEN_SIZE = 2048; CODEC_VOCAB = 3072 |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
| def forward(self, x): |
| dtype = x.dtype; x = x.float() |
| v = x.pow(2).mean(-1, keepdim=True) |
| return (self.weight * (x * torch.rsqrt(v + self.eps))).to(dtype) |
|
|
| def rotate_half(x): |
| x1 = x[..., :x.shape[-1]//2]; x2 = x[..., x.shape[-1]//2:] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
| class TalkerAttnQ(nn.Module): |
| def __init__(self, orig, layer_idx): |
| super().__init__() |
| self.layer_idx = layer_idx; self.head_dim = HEAD_DIM |
| self.num_heads = NUM_HEADS; self.num_kv_heads = NUM_KV_HEADS |
| self.num_kv_groups = NUM_HEADS // NUM_KV_HEADS; self.scaling = HEAD_DIM**-0.5 |
| self.q_proj = copy.deepcopy(orig.q_proj); self.k_proj = copy.deepcopy(orig.k_proj) |
| self.v_proj = copy.deepcopy(orig.v_proj); self.o_proj = copy.deepcopy(orig.o_proj) |
| self.q_norm = RMSNorm(HEAD_DIM); self.q_norm.weight = copy.deepcopy(orig.q_norm.weight) |
| self.k_norm = RMSNorm(HEAD_DIM); self.k_norm.weight = copy.deepcopy(orig.k_norm.weight) |
|
|
| def forward(self, h, cos, sin, cp, kc, vc, am): |
| B, S, _ = h.shape |
| q = self.q_norm(self.q_proj(h).view(B,S,self.num_heads,HEAD_DIM)).transpose(1,2) |
| k = self.k_norm(self.k_proj(h).view(B,S,self.num_kv_heads,HEAD_DIM)).transpose(1,2) |
| v = self.v_proj(h).view(B,S,self.num_kv_heads,HEAD_DIM).transpose(1,2) |
| q = q*cos + rotate_half(q)*sin; k = k*cos + rotate_half(k)*sin |
| kc = kc.clone(); vc = vc.clone() |
| kc[:,:,cp,:] = k; vc[:,:,cp,:] = v |
| ke = kc.unsqueeze(2).repeat(1,1,self.num_kv_groups,1,1).reshape(B,self.num_heads,MAX_SEQ_LEN,HEAD_DIM) |
| ve = vc.unsqueeze(2).repeat(1,1,self.num_kv_groups,1,1).reshape(B,self.num_heads,MAX_SEQ_LEN,HEAD_DIM) |
| o = F.scaled_dot_product_attention(q, ke, ve, attn_mask=am, scale=self.scaling) |
| return self.o_proj(o.transpose(1,2).reshape(B,S,-1)), kc, vc |
|
|
| class TalkerLayerQ(nn.Module): |
| def __init__(self, orig, i): |
| super().__init__() |
| self.attn = TalkerAttnQ(orig.self_attn, i) |
| self.gate_proj = copy.deepcopy(orig.mlp.gate_proj) |
| self.up_proj = copy.deepcopy(orig.mlp.up_proj) |
| self.down_proj = copy.deepcopy(orig.mlp.down_proj) |
| self.n1 = RMSNorm(HIDDEN_SIZE); self.n1.weight = copy.deepcopy(orig.input_layernorm.weight) |
| self.n2 = RMSNorm(HIDDEN_SIZE); self.n2.weight = copy.deepcopy(orig.post_attention_layernorm.weight) |
|
|
| def forward(self, h, cos, sin, cp, kc, vc, am): |
| r = h; a, kc, vc = self.attn(self.n1(h), cos, sin, cp, kc, vc, am); h = r + a |
| r = h; x = self.n2(h); h = r + self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
| return h, kc, vc |
|
|
| class TalkerQ(nn.Module): |
| def __init__(self, orig): |
| super().__init__() |
| self.layers = nn.ModuleList([TalkerLayerQ(l, i) for i, l in enumerate(orig.model.layers)]) |
| self.norm = RMSNorm(HIDDEN_SIZE); self.norm.weight = copy.deepcopy(orig.model.norm.weight) |
| self.codec_head = copy.deepcopy(orig.codec_head) |
| self.register_buffer("inv_freq", orig.model.rotary_emb.inv_freq.clone()) |
| self.rope_scaling = getattr(orig.model.rotary_emb, 'attention_scaling', 1.0) |
|
|
| def forward(self, ie, pid, cp, am, *kv): |
| pos = pid[0].float() |
| freqs = pos.unsqueeze(-1) * self.inv_freq.float().unsqueeze(0).unsqueeze(0) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| cos = (emb.cos() * self.rope_scaling).to(ie.dtype).unsqueeze(1) |
| sin = (emb.sin() * self.rope_scaling).to(ie.dtype).unsqueeze(1) |
| h = ie; ukv = [] |
| for i, layer in enumerate(self.layers): |
| h, nk, nv = layer(h, cos, sin, cp, kv[i*2], kv[i*2+1], am) |
| ukv.append(nk); ukv.append(nv) |
| return (self.codec_head(self.norm(h)), *ukv) |
|
|
| t_mod = TalkerQ(model.talker); t_mod.eval() |
| sl = 10 |
| cm = torch.full((1,1,sl,MAX_SEQ_LEN), float('-inf')) |
| for i in range(sl): cm[:,:,i,:i+1] = 0.0 |
| t_args = ( |
| torch.randn(1,sl,HIDDEN_SIZE), |
| torch.arange(sl).unsqueeze(0).unsqueeze(0).repeat(3,1,1), |
| torch.arange(sl), cm, |
| *[torch.zeros(1,NUM_KV_HEADS,MAX_SEQ_LEN,HEAD_DIM) for _ in range(NUM_LAYERS*2)] |
| ) |
|
|
| fp32_size = os.path.getsize(os.path.join(OUTPUT_DIR, "talker_prefill.pte")) / 1e6 |
| try: |
| int8_size = export_and_lower_int8(t_mod, t_args, "talker", OUTPUT_DIR) |
| results["talker"] = {"fp32": fp32_size, "int8": int8_size} |
| except Exception as e: |
| print(f" FAILED: {e}") |
| results["talker"] = {"fp32": fp32_size, "int8": None, "error": str(e)} |
|
|
| del t_mod; gc.collect() |
|
|
| |
| |
| |
|
|
| print("\n[3/4] Code Predictor INT8") |
|
|
| CP_MAX = 17; CPL = 5; CPKV = 8; CPHD = 128; CPH = 16; CPHS = 1024; THD = 2048 |
|
|
| class CPAttnQ(nn.Module): |
| def __init__(self, orig, i): |
| super().__init__() |
| self.q_proj = copy.deepcopy(orig.q_proj); self.k_proj = copy.deepcopy(orig.k_proj) |
| self.v_proj = copy.deepcopy(orig.v_proj); self.o_proj = copy.deepcopy(orig.o_proj) |
| self.q_norm = RMSNorm(CPHD); self.q_norm.weight = copy.deepcopy(orig.q_norm.weight) |
| self.k_norm = RMSNorm(CPHD); self.k_norm.weight = copy.deepcopy(orig.k_norm.weight) |
| self.g = CPH // CPKV |
|
|
| def forward(self, h, cos, sin, cp, kc, vc, am): |
| B,S,_ = h.shape |
| q = self.q_norm(self.q_proj(h).view(B,S,CPH,CPHD)).transpose(1,2) |
| k = self.k_norm(self.k_proj(h).view(B,S,CPKV,CPHD)).transpose(1,2) |
| v = self.v_proj(h).view(B,S,CPKV,CPHD).transpose(1,2) |
| q = q*cos + rotate_half(q)*sin; k = k*cos + rotate_half(k)*sin |
| kc = kc.clone(); vc = vc.clone(); kc[:,:,cp,:] = k; vc[:,:,cp,:] = v |
| ke = kc.unsqueeze(2).repeat(1,1,self.g,1,1).reshape(B,CPH,CP_MAX,CPHD) |
| ve = vc.unsqueeze(2).repeat(1,1,self.g,1,1).reshape(B,CPH,CP_MAX,CPHD) |
| o = F.scaled_dot_product_attention(q,ke,ve,attn_mask=am,scale=CPHD**-0.5) |
| return self.o_proj(o.transpose(1,2).reshape(B,S,-1)), kc, vc |
|
|
| class CPLayerQ(nn.Module): |
| def __init__(self, orig, i): |
| super().__init__() |
| self.attn = CPAttnQ(orig.self_attn, i) |
| self.gp = copy.deepcopy(orig.mlp.gate_proj) |
| self.up = copy.deepcopy(orig.mlp.up_proj) |
| self.dp = copy.deepcopy(orig.mlp.down_proj) |
| self.n1 = RMSNorm(CPHS); self.n1.weight = copy.deepcopy(orig.input_layernorm.weight) |
| self.n2 = RMSNorm(CPHS); self.n2.weight = copy.deepcopy(orig.post_attention_layernorm.weight) |
|
|
| def forward(self, h, cos, sin, cp, kc, vc, am): |
| r=h; a,kc,vc = self.attn(self.n1(h),cos,sin,cp,kc,vc,am); h=r+a |
| r=h; x=self.n2(h); h=r+self.dp(F.silu(self.gp(x))*self.up(x)) |
| return h, kc, vc |
|
|
| class CPQ(nn.Module): |
| def __init__(self, orig): |
| super().__init__() |
| self.layers = nn.ModuleList([CPLayerQ(l,i) for i,l in enumerate(orig.model.layers)]) |
| self.norm = RMSNorm(CPHS); self.norm.weight = copy.deepcopy(orig.model.norm.weight) |
| self.proj = copy.deepcopy(orig.small_to_mtp_projection) |
| self.register_buffer("inv_freq", orig.model.rotary_emb.inv_freq.clone()) |
| self.rs = getattr(orig.model.rotary_emb, 'attention_scaling', 1.0) |
|
|
| def forward(self, ie, pid, cp, am, *kv): |
| h = self.proj(ie) |
| pos = pid.float() |
| freqs = pos.unsqueeze(-1)*self.inv_freq.float().unsqueeze(0).unsqueeze(0) |
| emb = torch.cat([freqs,freqs],dim=-1) |
| cos = (emb.cos()*self.rs).to(h.dtype).unsqueeze(1) |
| sin = (emb.sin()*self.rs).to(h.dtype).unsqueeze(1) |
| ukv = [] |
| for i, l in enumerate(self.layers): |
| h,nk,nv = l(h,cos,sin,cp,kv[i*2],kv[i*2+1],am); ukv.append(nk); ukv.append(nv) |
| return (self.norm(h), *ukv) |
|
|
| cp_mod = CPQ(model.talker.code_predictor); cp_mod.eval() |
| csl = 2 |
| ccm = torch.full((1,1,csl,CP_MAX), float('-inf')) |
| for i in range(csl): ccm[:,:,i,:i+1] = 0.0 |
| cp_args = ( |
| torch.randn(1,csl,THD), torch.arange(csl).unsqueeze(0), torch.arange(csl), ccm, |
| *[torch.zeros(1,CPKV,CP_MAX,CPHD) for _ in range(CPL*2)] |
| ) |
|
|
| fp32_size = os.path.getsize(os.path.join(OUTPUT_DIR, "code_predictor.pte")) / 1e6 |
| try: |
| int8_size = export_and_lower_int8(cp_mod, cp_args, "code_predictor", OUTPUT_DIR) |
| results["code_predictor"] = {"fp32": fp32_size, "int8": int8_size} |
| except Exception as e: |
| print(f" FAILED: {e}") |
| results["code_predictor"] = {"fp32": fp32_size, "int8": None, "error": str(e)} |
|
|
| del cp_mod; gc.collect() |
|
|
| |
| |
| |
|
|
| print("\n[4/4] Vocoder INT8") |
|
|
| class VocQ(nn.Module): |
| def __init__(self, dec): |
| super().__init__() |
| self.decoder = copy.deepcopy(dec) |
| def forward(self, codes): |
| return self.decoder(codes) |
|
|
| v_mod = VocQ(model.speech_tokenizer.model.decoder); v_mod.eval() |
| v_args = (torch.randint(0, 2048, (1, 16, 50)),) |
|
|
| fp32_size = os.path.getsize(os.path.join(OUTPUT_DIR, "vocoder.pte")) / 1e6 |
| try: |
| int8_size = export_and_lower_int8(v_mod, v_args, "vocoder", OUTPUT_DIR) |
| results["vocoder"] = {"fp32": fp32_size, "int8": int8_size} |
| except Exception as e: |
| print(f" FAILED: {e}") |
| results["vocoder"] = {"fp32": fp32_size, "int8": None, "error": str(e)} |
|
|
| del v_mod; gc.collect() |
|
|
| |
|
|
| print("\n" + "=" * 70) |
| print("QUANTIZATION SUMMARY") |
| print("=" * 70) |
| print(f"\n{'Module':25s} {'FP32 (MB)':>12s} {'INT8 (MB)':>12s} {'Reduction':>10s}") |
| print("-" * 60) |
|
|
| total_fp32 = 0; total_int8 = 0 |
| for name, r in results.items(): |
| fp32 = r.get("fp32", 0) or 0 |
| int8 = r.get("int8") |
| total_fp32 += fp32 |
| if int8 is not None: |
| total_int8 += int8 |
| red = f"{fp32/int8:.1f}x" if int8 > 0 else "β" |
| else: |
| red = f"FAILED: {r.get('error','')[:40]}" |
| int8 = 0 |
| print(f" {name:23s} {fp32:10.1f} {int8:10.1f} {red}") |
|
|
| print("-" * 60) |
| ovr = f"{total_fp32/total_int8:.1f}x" if total_int8 > 0 else "N/A" |
| print(f" {'TOTAL':23s} {total_fp32:10.1f} {total_int8:10.1f} {ovr}") |
| print("\nPhase 6 complete!") |
|
|