| |
| """ |
| Weight-only INT8 quantization β no calibration, no forward passes needed. |
| Uses torchao int8_weight_only which packs weights instantly. |
| Then re-exports to ExecuTorch XNNPACK .pte. |
| """ |
|
|
| import os, sys, time, gc, torch |
| sys.path.insert(0, ".") |
|
|
| MODEL_DIR = "./models/LightOnOCR-2-1B" |
| FIXED_H, FIXED_W = 1120, 1540 |
|
|
|
|
| def quantize_vision(orig): |
| from export_vision import build_vision_module |
| from torchao.quantization import quantize_, int8_weight_only |
| from torch.export import export |
| from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig |
| from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
|
|
| print("\n=== VISION ENCODER (INT8 weight-only) ===") |
| vision = build_vision_module(orig) |
| vision = vision.to("cpu").to(torch.float32).eval() |
| print(f" Params: {sum(p.numel() for p in vision.parameters())/1e6:.1f}M") |
|
|
| |
| print(" Applying int8_weight_only...") |
| t0 = time.time() |
| quantize_(vision, int8_weight_only()) |
| print(f" Quantization took {time.time()-t0:.1f}s") |
|
|
| |
| print(" torch.export...") |
| example = (torch.randn(1, 3, FIXED_H, FIXED_W),) |
| t0 = time.time() |
| ep = export(vision, example) |
| print(f" Export took {time.time()-t0:.1f}s") |
|
|
| |
| print(" XNNPACK lowering...") |
| t0 = time.time() |
| edge = to_edge_transform_and_lower( |
| ep, |
| compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| partitioner=[XnnpackPartitioner()] |
| ) |
| et = edge.to_executorch() |
| print(f" Lowering took {time.time()-t0:.1f}s") |
|
|
| path = "vision_encoder_int8.pte" |
| with open(path, "wb") as f: |
| f.write(et.buffer) |
| print(f" β
{path}: {os.path.getsize(path)/1024/1024:.1f} MB") |
| del vision, ep, edge, et; gc.collect() |
| return path |
|
|
|
|
| def quantize_decoder(orig): |
| import export_decoder as ed |
| from export_decoder import build_decoder_module |
| from torchao.quantization import quantize_, int8_weight_only |
| from torch.export import export |
| from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig |
| from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
|
|
| print("\n=== TEXT DECODER (INT8 weight-only) ===") |
| decoder = build_decoder_module(orig) |
| decoder = decoder.to("cpu").to(torch.float32).eval() |
| print(f" Params: {sum(p.numel() for p in decoder.parameters())/1e6:.1f}M") |
|
|
| |
| print(" Applying int8_weight_only...") |
| t0 = time.time() |
| quantize_(decoder, int8_weight_only()) |
| print(f" Quantization took {time.time()-t0:.1f}s") |
|
|
| |
| print(" torch.export...") |
| kv = ed.create_empty_kv_caches(1, torch.float32, "cpu") |
| example = ( |
| torch.ones(1, 8, dtype=torch.long), |
| ed.create_causal_mask(8, ed.MAX_SEQ_LEN, torch.float32), |
| torch.arange(8).unsqueeze(0), |
| torch.arange(8), |
| *kv, |
| ) |
| t0 = time.time() |
| ep = export(decoder, example) |
| print(f" Export took {time.time()-t0:.1f}s") |
|
|
| |
| print(" XNNPACK lowering...") |
| t0 = time.time() |
| edge = to_edge_transform_and_lower( |
| ep, |
| compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| partitioner=[XnnpackPartitioner()] |
| ) |
| et = edge.to_executorch() |
| print(f" Lowering took {time.time()-t0:.1f}s") |
|
|
| path = "text_decoder_int8.pte" |
| with open(path, "wb") as f: |
| f.write(et.buffer) |
| print(f" β
{path}: {os.path.getsize(path)/1024/1024:.1f} MB") |
| del decoder, ep, edge, et; gc.collect() |
| return path |
|
|
|
|
| def main(): |
| from export_vision import load_original_model |
|
|
| print("LightOnOCR INT8 Weight-Only Quantization") |
| print("No calibration needed β weights quantized instantly\n") |
|
|
| print("Loading model...") |
| orig = load_original_model() |
|
|
| vis_path = quantize_vision(orig) |
| dec_path = quantize_decoder(orig) |
| del orig; gc.collect() |
|
|
| print("\n=== RESULTS ===") |
| for fp32, int8 in [("vision_encoder.pte", vis_path), |
| ("text_decoder_4096.pte", dec_path)]: |
| if os.path.exists(fp32) and os.path.exists(int8): |
| orig_mb = os.path.getsize(fp32) / 1024 / 1024 |
| quant_mb = os.path.getsize(int8) / 1024 / 1024 |
| ratio = quant_mb / orig_mb * 100 |
| print(f" {fp32}: {orig_mb:.0f} MB β {int8}: {quant_mb:.0f} MB ({ratio:.0f}%)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|