|
|
""" |
|
|
Mixed-Precision Quantization Script for Small Language Models |
|
|
Supports selective quantization of different model components with configurable bitwidths. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig |
|
|
import argparse |
|
|
import os |
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional, Tuple |
|
|
import time |
|
|
|
|
|
class MixedPrecisionQuantizer: |
|
|
""" |
|
|
Quantizes model components with different precision levels. |
|
|
Supports more aggressive quantization for attention layers while |
|
|
preserving higher precision for FFN layers. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str, |
|
|
attention_bits: int = 4, |
|
|
ffn_bits: int = 8, |
|
|
embedding_bits: int = 8, |
|
|
output_dir: str = "./quantized_models", |
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
): |
|
|
self.model_name = model_name |
|
|
self.attention_bits = attention_bits |
|
|
self.ffn_bits = ffn_bits |
|
|
self.embedding_bits = embedding_bits |
|
|
self.output_dir = Path(output_dir) |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print(f"Initializing quantizer for {model_name}") |
|
|
print(f"Attention layers: {attention_bits}-bit") |
|
|
print(f"FFN layers: {ffn_bits}-bit") |
|
|
print(f"Embeddings: {embedding_bits}-bit") |
|
|
print(f"Device: {device}") |
|
|
|
|
|
def load_model(self) -> Tuple[nn.Module, AutoTokenizer]: |
|
|
"""Load the pretrained model and tokenizer.""" |
|
|
print(f"\nLoading model: {self.model_name}") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_name, |
|
|
torch_dtype=torch.float32, |
|
|
low_cpu_mem_usage=True, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.model_name, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
load_time = time.time() - start_time |
|
|
print(f"Model loaded in {load_time:.2f} seconds") |
|
|
|
|
|
|
|
|
param_count = sum(p.numel() for p in model.parameters()) |
|
|
param_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 ** 2) |
|
|
print(f"Parameters: {param_count:,} ({param_size_mb:.2f} MB)") |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
def quantize_linear_layer(self, layer: nn.Linear, bits: int) -> nn.Linear: |
|
|
""" |
|
|
Quantize a linear layer to specified bit width using symmetric quantization. |
|
|
""" |
|
|
if bits == 32: |
|
|
return layer |
|
|
|
|
|
weight = layer.weight.data.clone() |
|
|
|
|
|
|
|
|
qmin = -(2 ** (bits - 1)) |
|
|
qmax = 2 ** (bits - 1) - 1 |
|
|
|
|
|
|
|
|
|
|
|
max_val = torch.max(torch.abs(weight), dim=1, keepdim=True)[0] |
|
|
max_val = torch.clamp(max_val, min=1e-5) |
|
|
scale = max_val / qmax |
|
|
|
|
|
|
|
|
weight_q = torch.clamp(torch.round(weight / scale), qmin, qmax) |
|
|
weight_dq = weight_q * scale |
|
|
|
|
|
|
|
|
layer.weight.data = weight_dq.contiguous() |
|
|
|
|
|
|
|
|
layer.weight_scale = scale |
|
|
layer.quantized = True |
|
|
layer.bits = bits |
|
|
|
|
|
return layer |
|
|
|
|
|
def identify_layer_type(self, name: str, module: nn.Module) -> str: |
|
|
""" |
|
|
Identify if a layer is part of attention, FFN, embedding, or other components. |
|
|
""" |
|
|
name_lower = name.lower() |
|
|
|
|
|
|
|
|
attention_patterns = [ |
|
|
'attn', 'attention', 'q_proj', 'k_proj', 'v_proj', |
|
|
'qkv', 'query', 'key', 'value', 'o_proj', 'out_proj', |
|
|
'c_attn', 'c_proj' |
|
|
] |
|
|
|
|
|
|
|
|
ffn_patterns = [ |
|
|
'mlp', 'ffn', 'fc', 'dense', 'intermediate', |
|
|
'gate_proj', 'up_proj', 'down_proj', 'w1', 'w2', 'w3' |
|
|
] |
|
|
|
|
|
|
|
|
embedding_patterns = ['embed', 'wte', 'wpe', 'lm_head'] |
|
|
|
|
|
if any(pattern in name_lower for pattern in attention_patterns): |
|
|
return 'attention' |
|
|
elif any(pattern in name_lower for pattern in ffn_patterns): |
|
|
return 'ffn' |
|
|
elif any(pattern in name_lower for pattern in embedding_patterns): |
|
|
return 'embedding' |
|
|
else: |
|
|
return 'other' |
|
|
|
|
|
def quantize_model(self, model: nn.Module) -> Tuple[nn.Module, Dict]: |
|
|
""" |
|
|
Apply mixed-precision quantization to the model. |
|
|
""" |
|
|
print("\nApplying mixed-precision quantization...") |
|
|
start_time = time.time() |
|
|
|
|
|
stats = { |
|
|
'attention_layers': 0, |
|
|
'ffn_layers': 0, |
|
|
'embedding_layers': 0, |
|
|
'other_layers': 0, |
|
|
'total_quantized': 0 |
|
|
} |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, nn.Linear): |
|
|
layer_type = self.identify_layer_type(name, module) |
|
|
|
|
|
|
|
|
if layer_type == 'attention': |
|
|
bits = self.attention_bits |
|
|
stats['attention_layers'] += 1 |
|
|
elif layer_type == 'ffn': |
|
|
bits = self.ffn_bits |
|
|
stats['ffn_layers'] += 1 |
|
|
elif layer_type == 'embedding': |
|
|
bits = self.embedding_bits |
|
|
stats['embedding_layers'] += 1 |
|
|
else: |
|
|
bits = self.ffn_bits |
|
|
stats['other_layers'] += 1 |
|
|
|
|
|
|
|
|
self.quantize_linear_layer(module, bits) |
|
|
stats['total_quantized'] += 1 |
|
|
|
|
|
quant_time = time.time() - start_time |
|
|
print(f"\nQuantization completed in {quant_time:.2f} seconds") |
|
|
print(f"Quantized layers breakdown:") |
|
|
print(f" - Attention: {stats['attention_layers']} layers ({self.attention_bits}-bit)") |
|
|
print(f" - FFN: {stats['ffn_layers']} layers ({self.ffn_bits}-bit)") |
|
|
print(f" - Embedding: {stats['embedding_layers']} layers ({self.embedding_bits}-bit)") |
|
|
print(f" - Other: {stats['other_layers']} layers ({self.ffn_bits}-bit)") |
|
|
print(f" - Total quantized: {stats['total_quantized']} layers") |
|
|
|
|
|
return model, stats |
|
|
|
|
|
def save_quantized_model( |
|
|
self, |
|
|
model: nn.Module, |
|
|
tokenizer: AutoTokenizer, |
|
|
stats: Dict |
|
|
) -> str: |
|
|
"""Save the quantized model, tokenizer, and metadata.""" |
|
|
|
|
|
model_short_name = self.model_name.split('/')[-1] |
|
|
quant_config = f"attn{self.attention_bits}_ffn{self.ffn_bits}_emb{self.embedding_bits}" |
|
|
save_dir = self.output_dir / f"{model_short_name}_{quant_config}" |
|
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print(f"\nSaving quantized model to: {save_dir}") |
|
|
|
|
|
|
|
|
model.save_pretrained(save_dir) |
|
|
|
|
|
|
|
|
tokenizer.save_pretrained(save_dir) |
|
|
|
|
|
|
|
|
quantized_size_mb = sum( |
|
|
p.numel() * p.element_size() for p in model.parameters() |
|
|
) / (1024 ** 2) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
'original_model': self.model_name, |
|
|
'quantization_config': { |
|
|
'attention_bits': self.attention_bits, |
|
|
'ffn_bits': self.ffn_bits, |
|
|
'embedding_bits': self.embedding_bits |
|
|
}, |
|
|
'layer_stats': stats, |
|
|
'model_size_mb': quantized_size_mb, |
|
|
'quantization_timestamp': time.strftime('%Y-%m-%d %H:%M:%S') |
|
|
} |
|
|
|
|
|
with open(save_dir / 'quantization_metadata.json', 'w') as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
|
|
|
print(f"Quantized model size: {quantized_size_mb:.2f} MB") |
|
|
print(f"Metadata saved to: {save_dir / 'quantization_metadata.json'}") |
|
|
|
|
|
return str(save_dir) |
|
|
|
|
|
def run(self) -> str: |
|
|
"""Execute the full quantization pipeline.""" |
|
|
print("=" * 80) |
|
|
print("MIXED-PRECISION QUANTIZATION PIPELINE") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
model, tokenizer = self.load_model() |
|
|
|
|
|
|
|
|
quantized_model, stats = self.quantize_model(model) |
|
|
|
|
|
|
|
|
save_path = self.save_quantized_model(quantized_model, tokenizer, stats) |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("QUANTIZATION COMPLETE") |
|
|
print("=" * 80) |
|
|
print(f"Saved to: {save_path}") |
|
|
|
|
|
return save_path |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Mixed-Precision Quantization for Small Language Models" |
|
|
) |
|
|
parser.add_argument( |
|
|
'--model_name', |
|
|
type=str, |
|
|
required=True, |
|
|
help='HuggingFace model name or path' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--attention_bits', |
|
|
type=int, |
|
|
default=4, |
|
|
help='Bit width for attention layers (default: 4)' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--ffn_bits', |
|
|
type=int, |
|
|
default=8, |
|
|
help='Bit width for FFN layers (default: 8)' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--embedding_bits', |
|
|
type=int, |
|
|
default=8, |
|
|
help='Bit width for embedding layers (default: 8)' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--output_dir', |
|
|
type=str, |
|
|
default='./quantized_models', |
|
|
help='Output directory for quantized models' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--device', |
|
|
type=str, |
|
|
default='cuda' if torch.cuda.is_available() else 'cpu', |
|
|
help='Device to use (cuda/cpu)' |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
quantizer = MixedPrecisionQuantizer( |
|
|
model_name=args.model_name, |
|
|
attention_bits=args.attention_bits, |
|
|
ffn_bits=args.ffn_bits, |
|
|
embedding_bits=args.embedding_bits, |
|
|
output_dir=args.output_dir, |
|
|
device=args.device |
|
|
) |
|
|
|
|
|
|
|
|
quantizer.run() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |