AutoQuantNX / src /handlers /base_handler.py
smokxy's picture
Upload folder using huggingface_hub
9bf1d31 verified
from optimizations.quantize import ModelQuantizer
import torch
import logging
import numpy as np
from dataclasses import dataclass
from typing import Dict, Any, Optional
import json
logger = logging.getLogger(__name__)
@dataclass
class ModelMetrics:
model_sizes: Dict[str, float]
inference_times: Dict[str, float]
comparison_metrics: Dict[str, Any]
class ModelHandler:
"""Base class for handling different types of models"""
def __init__(self, model_name, model_class, quantization_type, test_text=None):
self.model_name = model_name
self.model_class = model_class
self.quantization_type = quantization_type
self.test_text = test_text
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load models
self.original_model = self._load_original_model()
self.quantized_model = self._load_quantized_model()
self.metrics: Optional[ModelMetrics] = None
def _load_original_model(self):
"""Load the original model"""
model = self.model_class.from_pretrained(self.model_name)
return model.to(self.device)
def _load_quantized_model(self):
"""Load the quantized model using ModelQuantizer"""
model = ModelQuantizer.quantize_model(
self.model_class,
self.model_name,
self.quantization_type
)
if self.quantization_type not in ["4-bit", "8-bit"]:
model = model.to(self.device)
return model
@staticmethod
def _convert_to_serializable(obj):
"""Serialization for metrics"""
if isinstance(obj, np.generic):
return obj.item()
if isinstance(obj, (np.float32, np.float64)):
return float(obj)
if isinstance(obj, (np.int32, np.int64)):
return int(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, torch.Tensor):
return obj.cpu().numpy().tolist()
if isinstance(obj, dict):
return {k: ModelHandler._convert_to_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [ModelHandler._convert_to_serializable(v) for v in obj]
return obj
def _format_metric_value(self, value):
"""Format metric value based on its type"""
if isinstance(value, (float, np.float32, np.float64)):
return f"{value:.8f}"
elif isinstance(value, (int, np.int32, np.int64)):
return str(value)
elif isinstance(value, list):
return "\n" + "\n".join([f" - {item}" for item in value])
elif isinstance(value, dict):
return "\n" + "\n".join([f" {k}: {v}" for k, v in value.items()])
else:
return str(value)
def run_inference(self, model, text):
"""Run model inference - to be implemented by subclasses"""
raise NotImplementedError
def decode_output(self, outputs):
"""Decode model outputs - to be implemented by subclasses"""
raise NotImplementedError
def compare(self):
"""Compare original and quantized models"""
try:
if self.test_text is None:
logger.warning("No test text provided. Skipping inference testing.")
return self.quantized_model
# Run inference
original_outputs, original_time = self.run_inference(self.original_model, self.test_text)
quantized_outputs, quantized_time = self.run_inference(self.quantized_model, self.test_text)
original_size = ModelQuantizer.get_model_size(self.original_model)
quantized_size = ModelQuantizer.get_model_size(self.quantized_model)
logger.info(f"Original Model Size: {original_size:.2f} MB")
logger.info(f"Quantized Model Size: {quantized_size:.2f} MB")
logger.info(f"Original Inference Time: {original_time:.4f} seconds")
logger.info(f"Quantized Inference Time: {quantized_time:.4f} seconds")
# Compare outputs
comparison_metrics = self.compare_outputs(original_outputs, quantized_outputs) or {}
for key, value in comparison_metrics.items():
comparison_metrics[key] = self._convert_to_serializable(value)
self.metrics = {
"model_sizes": {
"original": float(original_size),
"quantized": float(quantized_size)
},
"inference_times": {
"original": float(original_time),
"quantized": float(quantized_time)
},
"comparison_metrics": comparison_metrics
}
return self.quantized_model
except Exception as e:
logger.error(f"Quantization and comparison failed: {str(e)}")
raise e
def get_metrics(self) -> Dict[str, Any]:
"""Return the metrics dictionary"""
if self.metrics is None:
return {
"model_sizes": {"original": 0.0, "quantized": 0.0},
"inference_times": {"original": 0.0, "quantized": 0.0},
"comparison_metrics": {}
}
serializable_metrics = self._convert_to_serializable(self.metrics)
try:
json.dumps(serializable_metrics)
return serializable_metrics
except (TypeError, ValueError) as e:
logger.error(f"Error serializing metrics: {str(e)}")
return {
"model_sizes": {"original": 0.0, "quantized": 0.0},
"inference_times": {"original": 0.0, "quantized": 0.0},
"comparison_metrics": {}
}