Spaces:
Running
Running
File size: 4,756 Bytes
9bf1d31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
import torch
import onnx
import logging
from scipy.stats import spearmanr
from sklearn.metrics.pairwise import cosine_similarity
from transformers import BitsAndBytesConfig
from onnxconverter_common import float16
from onnxruntime.quantization import quantize_dynamic, QuantType
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelQuantizer:
"""Handles model quantization and comparison operations"""
@staticmethod
def quantize_model(model_class, model_name, quantization_type):
"""Quantizes a model based on specified quantization type"""
try:
if quantization_type == "4-bit":
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = model_class.from_pretrained(model_name, quantization_config=quantization_config)
elif quantization_type == "8-bit":
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = model_class.from_pretrained(model_name, quantization_config=quantization_config)
elif quantization_type == "16-bit-float":
model = model_class.from_pretrained(model_name)
model = model.to(torch.float16)
else:
raise ValueError(f"Unsupported quantization type: {quantization_type}")
return model
except Exception as e:
logger.error(f"Quantization failed: {str(e)}")
raise
@staticmethod
def get_model_size(model):
"""Calculate model size in MB"""
try:
torch.save(model.state_dict(), "temp.pth")
size = os.path.getsize("temp.pth") / (1024 * 1024)
os.remove("temp.pth")
return size
except Exception as e:
logger.error(f"Failed to get model size: {str(e)}")
raise
@staticmethod
def compare_model_outputs(original_outputs, quantized_outputs):
"""Compare outputs between original and quantized models"""
try:
if original_outputs is None or quantized_outputs is None:
return None
if hasattr(original_outputs, 'logits') and hasattr(quantized_outputs, 'logits'):
original_logits = original_outputs.logits.detach().cpu().numpy()
quantized_logits = quantized_outputs.logits.detach().cpu().numpy()
metrics = {
'mse': ((original_logits - quantized_logits) ** 2).mean(),
'spearman_corr': spearmanr(original_logits.flatten(), quantized_logits.flatten())[0],
'cosine_sim': cosine_similarity(original_logits.reshape(1, -1), quantized_logits.reshape(1, -1))[0][0]
}
return metrics
return None
except Exception as e:
logger.error(f"Output comparison failed: {str(e)}")
raise
def quantize_onnx_model(model_dir, quantization_type):
"""
Quantize ONNX model in the specified directory.
"""
logger.info(f"Quantizing ONNX model in: {model_dir}")
for filename in os.listdir(model_dir):
if filename.endswith('.onnx'):
input_model_path = os.path.join(model_dir, filename)
output_model_path = os.path.join(model_dir, f"quantized_{filename}")
try:
model = onnx.load(input_model_path)
if quantization_type == "16-bit-float":
model_fp16 = float16.convert_float_to_float16(model)
onnx.save(model_fp16, output_model_path)
elif quantization_type in ["8-bit", "16-bit-int"]:
quant_type_mapping = {
"8-bit": QuantType.QInt8,
"16-bit-int": QuantType.QInt16,
}
quantize_dynamic(
model_input=input_model_path,
model_output=output_model_path,
weight_type=quant_type_mapping[quantization_type]
)
else:
logger.error(f"Unsupported quantization type: {quantization_type}")
continue
os.remove(input_model_path)
os.rename(output_model_path, input_model_path)
logger.info(f"Quantized ONNX model saved to: {input_model_path}")
except Exception as e:
logger.error(f"Error during ONNX quantization: {str(e)}")
if os.path.exists(output_model_path):
os.remove(output_model_path) |