from optimizations.quantize import ModelQuantizer import torch import logging import numpy as np from dataclasses import dataclass from typing import Dict, Any, Optional import json logger = logging.getLogger(__name__) @dataclass class ModelMetrics: model_sizes: Dict[str, float] inference_times: Dict[str, float] comparison_metrics: Dict[str, Any] class ModelHandler: """Base class for handling different types of models""" def __init__(self, model_name, model_class, quantization_type, test_text=None): self.model_name = model_name self.model_class = model_class self.quantization_type = quantization_type self.test_text = test_text self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load models self.original_model = self._load_original_model() self.quantized_model = self._load_quantized_model() self.metrics: Optional[ModelMetrics] = None def _load_original_model(self): """Load the original model""" model = self.model_class.from_pretrained(self.model_name) return model.to(self.device) def _load_quantized_model(self): """Load the quantized model using ModelQuantizer""" model = ModelQuantizer.quantize_model( self.model_class, self.model_name, self.quantization_type ) if self.quantization_type not in ["4-bit", "8-bit"]: model = model.to(self.device) return model @staticmethod def _convert_to_serializable(obj): """Serialization for metrics""" if isinstance(obj, np.generic): return obj.item() if isinstance(obj, (np.float32, np.float64)): return float(obj) if isinstance(obj, (np.int32, np.int64)): return int(obj) if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, torch.Tensor): return obj.cpu().numpy().tolist() if isinstance(obj, dict): return {k: ModelHandler._convert_to_serializable(v) for k, v in obj.items()} if isinstance(obj, list): return [ModelHandler._convert_to_serializable(v) for v in obj] return obj def _format_metric_value(self, value): """Format metric value based on its type""" if isinstance(value, (float, np.float32, np.float64)): return f"{value:.8f}" elif isinstance(value, (int, np.int32, np.int64)): return str(value) elif isinstance(value, list): return "\n" + "\n".join([f" - {item}" for item in value]) elif isinstance(value, dict): return "\n" + "\n".join([f" {k}: {v}" for k, v in value.items()]) else: return str(value) def run_inference(self, model, text): """Run model inference - to be implemented by subclasses""" raise NotImplementedError def decode_output(self, outputs): """Decode model outputs - to be implemented by subclasses""" raise NotImplementedError def compare(self): """Compare original and quantized models""" try: if self.test_text is None: logger.warning("No test text provided. Skipping inference testing.") return self.quantized_model # Run inference original_outputs, original_time = self.run_inference(self.original_model, self.test_text) quantized_outputs, quantized_time = self.run_inference(self.quantized_model, self.test_text) original_size = ModelQuantizer.get_model_size(self.original_model) quantized_size = ModelQuantizer.get_model_size(self.quantized_model) logger.info(f"Original Model Size: {original_size:.2f} MB") logger.info(f"Quantized Model Size: {quantized_size:.2f} MB") logger.info(f"Original Inference Time: {original_time:.4f} seconds") logger.info(f"Quantized Inference Time: {quantized_time:.4f} seconds") # Compare outputs comparison_metrics = self.compare_outputs(original_outputs, quantized_outputs) or {} for key, value in comparison_metrics.items(): comparison_metrics[key] = self._convert_to_serializable(value) self.metrics = { "model_sizes": { "original": float(original_size), "quantized": float(quantized_size) }, "inference_times": { "original": float(original_time), "quantized": float(quantized_time) }, "comparison_metrics": comparison_metrics } return self.quantized_model except Exception as e: logger.error(f"Quantization and comparison failed: {str(e)}") raise e def get_metrics(self) -> Dict[str, Any]: """Return the metrics dictionary""" if self.metrics is None: return { "model_sizes": {"original": 0.0, "quantized": 0.0}, "inference_times": {"original": 0.0, "quantized": 0.0}, "comparison_metrics": {} } serializable_metrics = self._convert_to_serializable(self.metrics) try: json.dumps(serializable_metrics) return serializable_metrics except (TypeError, ValueError) as e: logger.error(f"Error serializing metrics: {str(e)}") return { "model_sizes": {"original": 0.0, "quantized": 0.0}, "inference_times": {"original": 0.0, "quantized": 0.0}, "comparison_metrics": {} }