Spaces:
Running
Running
from optimizations.quantize import ModelQuantizer | |
import torch | |
import logging | |
import numpy as np | |
from dataclasses import dataclass | |
from typing import Dict, Any, Optional | |
import json | |
logger = logging.getLogger(__name__) | |
class ModelMetrics: | |
model_sizes: Dict[str, float] | |
inference_times: Dict[str, float] | |
comparison_metrics: Dict[str, Any] | |
class ModelHandler: | |
"""Base class for handling different types of models""" | |
def __init__(self, model_name, model_class, quantization_type, test_text=None): | |
self.model_name = model_name | |
self.model_class = model_class | |
self.quantization_type = quantization_type | |
self.test_text = test_text | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load models | |
self.original_model = self._load_original_model() | |
self.quantized_model = self._load_quantized_model() | |
self.metrics: Optional[ModelMetrics] = None | |
def _load_original_model(self): | |
"""Load the original model""" | |
model = self.model_class.from_pretrained(self.model_name) | |
return model.to(self.device) | |
def _load_quantized_model(self): | |
"""Load the quantized model using ModelQuantizer""" | |
model = ModelQuantizer.quantize_model( | |
self.model_class, | |
self.model_name, | |
self.quantization_type | |
) | |
if self.quantization_type not in ["4-bit", "8-bit"]: | |
model = model.to(self.device) | |
return model | |
def _convert_to_serializable(obj): | |
"""Serialization for metrics""" | |
if isinstance(obj, np.generic): | |
return obj.item() | |
if isinstance(obj, (np.float32, np.float64)): | |
return float(obj) | |
if isinstance(obj, (np.int32, np.int64)): | |
return int(obj) | |
if isinstance(obj, np.ndarray): | |
return obj.tolist() | |
if isinstance(obj, torch.Tensor): | |
return obj.cpu().numpy().tolist() | |
if isinstance(obj, dict): | |
return {k: ModelHandler._convert_to_serializable(v) for k, v in obj.items()} | |
if isinstance(obj, list): | |
return [ModelHandler._convert_to_serializable(v) for v in obj] | |
return obj | |
def _format_metric_value(self, value): | |
"""Format metric value based on its type""" | |
if isinstance(value, (float, np.float32, np.float64)): | |
return f"{value:.8f}" | |
elif isinstance(value, (int, np.int32, np.int64)): | |
return str(value) | |
elif isinstance(value, list): | |
return "\n" + "\n".join([f" - {item}" for item in value]) | |
elif isinstance(value, dict): | |
return "\n" + "\n".join([f" {k}: {v}" for k, v in value.items()]) | |
else: | |
return str(value) | |
def run_inference(self, model, text): | |
"""Run model inference - to be implemented by subclasses""" | |
raise NotImplementedError | |
def decode_output(self, outputs): | |
"""Decode model outputs - to be implemented by subclasses""" | |
raise NotImplementedError | |
def compare(self): | |
"""Compare original and quantized models""" | |
try: | |
if self.test_text is None: | |
logger.warning("No test text provided. Skipping inference testing.") | |
return self.quantized_model | |
# Run inference | |
original_outputs, original_time = self.run_inference(self.original_model, self.test_text) | |
quantized_outputs, quantized_time = self.run_inference(self.quantized_model, self.test_text) | |
original_size = ModelQuantizer.get_model_size(self.original_model) | |
quantized_size = ModelQuantizer.get_model_size(self.quantized_model) | |
logger.info(f"Original Model Size: {original_size:.2f} MB") | |
logger.info(f"Quantized Model Size: {quantized_size:.2f} MB") | |
logger.info(f"Original Inference Time: {original_time:.4f} seconds") | |
logger.info(f"Quantized Inference Time: {quantized_time:.4f} seconds") | |
# Compare outputs | |
comparison_metrics = self.compare_outputs(original_outputs, quantized_outputs) or {} | |
for key, value in comparison_metrics.items(): | |
comparison_metrics[key] = self._convert_to_serializable(value) | |
self.metrics = { | |
"model_sizes": { | |
"original": float(original_size), | |
"quantized": float(quantized_size) | |
}, | |
"inference_times": { | |
"original": float(original_time), | |
"quantized": float(quantized_time) | |
}, | |
"comparison_metrics": comparison_metrics | |
} | |
return self.quantized_model | |
except Exception as e: | |
logger.error(f"Quantization and comparison failed: {str(e)}") | |
raise e | |
def get_metrics(self) -> Dict[str, Any]: | |
"""Return the metrics dictionary""" | |
if self.metrics is None: | |
return { | |
"model_sizes": {"original": 0.0, "quantized": 0.0}, | |
"inference_times": {"original": 0.0, "quantized": 0.0}, | |
"comparison_metrics": {} | |
} | |
serializable_metrics = self._convert_to_serializable(self.metrics) | |
try: | |
json.dumps(serializable_metrics) | |
return serializable_metrics | |
except (TypeError, ValueError) as e: | |
logger.error(f"Error serializing metrics: {str(e)}") | |
return { | |
"model_sizes": {"original": 0.0, "quantized": 0.0}, | |
"inference_times": {"original": 0.0, "quantized": 0.0}, | |
"comparison_metrics": {} | |
} | |