Spaces:
Running
Running
import os | |
import torch | |
import onnx | |
import logging | |
from scipy.stats import spearmanr | |
from sklearn.metrics.pairwise import cosine_similarity | |
from transformers import BitsAndBytesConfig | |
from onnxconverter_common import float16 | |
from onnxruntime.quantization import quantize_dynamic, QuantType | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ModelQuantizer: | |
"""Handles model quantization and comparison operations""" | |
def quantize_model(model_class, model_name, quantization_type): | |
"""Quantizes a model based on specified quantization type""" | |
try: | |
if quantization_type == "4-bit": | |
quantization_config = BitsAndBytesConfig(load_in_4bit=True) | |
model = model_class.from_pretrained(model_name, quantization_config=quantization_config) | |
elif quantization_type == "8-bit": | |
quantization_config = BitsAndBytesConfig(load_in_8bit=True) | |
model = model_class.from_pretrained(model_name, quantization_config=quantization_config) | |
elif quantization_type == "16-bit-float": | |
model = model_class.from_pretrained(model_name) | |
model = model.to(torch.float16) | |
else: | |
raise ValueError(f"Unsupported quantization type: {quantization_type}") | |
return model | |
except Exception as e: | |
logger.error(f"Quantization failed: {str(e)}") | |
raise | |
def get_model_size(model): | |
"""Calculate model size in MB""" | |
try: | |
torch.save(model.state_dict(), "temp.pth") | |
size = os.path.getsize("temp.pth") / (1024 * 1024) | |
os.remove("temp.pth") | |
return size | |
except Exception as e: | |
logger.error(f"Failed to get model size: {str(e)}") | |
raise | |
def compare_model_outputs(original_outputs, quantized_outputs): | |
"""Compare outputs between original and quantized models""" | |
try: | |
if original_outputs is None or quantized_outputs is None: | |
return None | |
if hasattr(original_outputs, 'logits') and hasattr(quantized_outputs, 'logits'): | |
original_logits = original_outputs.logits.detach().cpu().numpy() | |
quantized_logits = quantized_outputs.logits.detach().cpu().numpy() | |
metrics = { | |
'mse': ((original_logits - quantized_logits) ** 2).mean(), | |
'spearman_corr': spearmanr(original_logits.flatten(), quantized_logits.flatten())[0], | |
'cosine_sim': cosine_similarity(original_logits.reshape(1, -1), quantized_logits.reshape(1, -1))[0][0] | |
} | |
return metrics | |
return None | |
except Exception as e: | |
logger.error(f"Output comparison failed: {str(e)}") | |
raise | |
def quantize_onnx_model(model_dir, quantization_type): | |
""" | |
Quantize ONNX model in the specified directory. | |
""" | |
logger.info(f"Quantizing ONNX model in: {model_dir}") | |
for filename in os.listdir(model_dir): | |
if filename.endswith('.onnx'): | |
input_model_path = os.path.join(model_dir, filename) | |
output_model_path = os.path.join(model_dir, f"quantized_{filename}") | |
try: | |
model = onnx.load(input_model_path) | |
if quantization_type == "16-bit-float": | |
model_fp16 = float16.convert_float_to_float16(model) | |
onnx.save(model_fp16, output_model_path) | |
elif quantization_type in ["8-bit", "16-bit-int"]: | |
quant_type_mapping = { | |
"8-bit": QuantType.QInt8, | |
"16-bit-int": QuantType.QInt16, | |
} | |
quantize_dynamic( | |
model_input=input_model_path, | |
model_output=output_model_path, | |
weight_type=quant_type_mapping[quantization_type] | |
) | |
else: | |
logger.error(f"Unsupported quantization type: {quantization_type}") | |
continue | |
os.remove(input_model_path) | |
os.rename(output_model_path, input_model_path) | |
logger.info(f"Quantized ONNX model saved to: {input_model_path}") | |
except Exception as e: | |
logger.error(f"Error during ONNX quantization: {str(e)}") | |
if os.path.exists(output_model_path): | |
os.remove(output_model_path) |