|
|
""" |
|
|
Tests for model management functionality. |
|
|
""" |
|
|
|
|
|
import pytest |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from unittest.mock import Mock, patch, MagicMock |
|
|
import json |
|
|
|
|
|
from models import ( |
|
|
ModelRegistry, |
|
|
ModelInfo, |
|
|
ModelStatus, |
|
|
ModelTask, |
|
|
ModelFramework, |
|
|
ModelDownloader, |
|
|
ModelLoader, |
|
|
ModelOptimizer |
|
|
) |
|
|
|
|
|
|
|
|
class TestModelRegistry: |
|
|
"""Test model registry functionality.""" |
|
|
|
|
|
@pytest.fixture |
|
|
def registry(self): |
|
|
"""Create a test registry.""" |
|
|
temp_dir = tempfile.mkdtemp() |
|
|
return ModelRegistry(models_dir=Path(temp_dir)) |
|
|
|
|
|
def test_registry_initialization(self, registry): |
|
|
"""Test registry initialization.""" |
|
|
assert registry is not None |
|
|
assert len(registry.models) > 0 |
|
|
assert registry.models_dir.exists() |
|
|
|
|
|
def test_register_model(self, registry): |
|
|
"""Test registering a new model.""" |
|
|
model = ModelInfo( |
|
|
model_id="test-model", |
|
|
name="Test Model", |
|
|
version="1.0", |
|
|
task=ModelTask.SEGMENTATION, |
|
|
framework=ModelFramework.PYTORCH, |
|
|
url="http://example.com/model.pth", |
|
|
filename="test.pth", |
|
|
file_size=1000000 |
|
|
) |
|
|
|
|
|
success = registry.register_model(model) |
|
|
assert success == True |
|
|
assert "test-model" in registry.models |
|
|
|
|
|
def test_get_model(self, registry): |
|
|
"""Test getting a model by ID.""" |
|
|
model = registry.get_model("rmbg-1.4") |
|
|
assert model is not None |
|
|
assert model.model_id == "rmbg-1.4" |
|
|
assert model.task == ModelTask.SEGMENTATION |
|
|
|
|
|
def test_list_models_by_task(self, registry): |
|
|
"""Test listing models by task.""" |
|
|
segmentation_models = registry.list_models(task=ModelTask.SEGMENTATION) |
|
|
assert len(segmentation_models) > 0 |
|
|
assert all(m.task == ModelTask.SEGMENTATION for m in segmentation_models) |
|
|
|
|
|
def test_list_models_by_framework(self, registry): |
|
|
"""Test listing models by framework.""" |
|
|
pytorch_models = registry.list_models(framework=ModelFramework.PYTORCH) |
|
|
onnx_models = registry.list_models(framework=ModelFramework.ONNX) |
|
|
|
|
|
assert all(m.framework == ModelFramework.PYTORCH for m in pytorch_models) |
|
|
assert all(m.framework == ModelFramework.ONNX for m in onnx_models) |
|
|
|
|
|
def test_get_best_model(self, registry): |
|
|
"""Test getting best model for a task.""" |
|
|
|
|
|
best_accuracy = registry.get_best_model( |
|
|
ModelTask.SEGMENTATION, |
|
|
prefer_speed=False |
|
|
) |
|
|
assert best_accuracy is not None |
|
|
|
|
|
|
|
|
best_speed = registry.get_best_model( |
|
|
ModelTask.SEGMENTATION, |
|
|
prefer_speed=True |
|
|
) |
|
|
assert best_speed is not None |
|
|
|
|
|
def test_update_model_usage(self, registry): |
|
|
"""Test updating model usage statistics.""" |
|
|
model_id = "rmbg-1.4" |
|
|
initial_count = registry.models[model_id].use_count |
|
|
|
|
|
registry.update_model_usage(model_id) |
|
|
|
|
|
assert registry.models[model_id].use_count == initial_count + 1 |
|
|
assert registry.models[model_id].last_used is not None |
|
|
|
|
|
def test_get_total_size(self, registry): |
|
|
"""Test calculating total model size.""" |
|
|
total_size = registry.get_total_size() |
|
|
assert total_size > 0 |
|
|
|
|
|
|
|
|
available_size = registry.get_total_size(status=ModelStatus.AVAILABLE) |
|
|
assert available_size == 0 |
|
|
|
|
|
def test_export_registry(self, registry, temp_dir): |
|
|
"""Test exporting registry to file.""" |
|
|
export_path = temp_dir / "registry_export.json" |
|
|
registry.export_registry(export_path) |
|
|
|
|
|
assert export_path.exists() |
|
|
|
|
|
with open(export_path) as f: |
|
|
data = json.load(f) |
|
|
assert "models" in data |
|
|
assert len(data["models"]) > 0 |
|
|
|
|
|
|
|
|
class TestModelDownloader: |
|
|
"""Test model downloading functionality.""" |
|
|
|
|
|
@pytest.fixture |
|
|
def downloader(self, mock_registry): |
|
|
"""Create a test downloader.""" |
|
|
return ModelDownloader(mock_registry) |
|
|
|
|
|
@patch('requests.get') |
|
|
def test_download_model(self, mock_get, downloader): |
|
|
"""Test downloading a model.""" |
|
|
|
|
|
mock_response = MagicMock() |
|
|
mock_response.headers = {'content-length': '1000000'} |
|
|
mock_response.iter_content = MagicMock( |
|
|
return_value=[b'data' * 1000] |
|
|
) |
|
|
mock_response.raise_for_status = MagicMock() |
|
|
mock_get.return_value = mock_response |
|
|
|
|
|
|
|
|
success = downloader.download_model("test-model", force=True) |
|
|
|
|
|
assert mock_get.called |
|
|
|
|
|
|
|
|
def test_download_progress_tracking(self, downloader): |
|
|
"""Test download progress tracking.""" |
|
|
progress_values = [] |
|
|
|
|
|
def progress_callback(progress): |
|
|
progress_values.append(progress.progress) |
|
|
|
|
|
|
|
|
with patch.object(downloader, '_download_model_task', return_value=True): |
|
|
downloader.download_model( |
|
|
"test-model", |
|
|
progress_callback=progress_callback |
|
|
) |
|
|
|
|
|
assert "test-model" in downloader.downloads |
|
|
|
|
|
def test_cancel_download(self, downloader): |
|
|
"""Test cancelling a download.""" |
|
|
|
|
|
downloader.downloads["test-model"] = Mock() |
|
|
downloader._stop_events["test-model"] = Mock() |
|
|
|
|
|
success = downloader.cancel_download("test-model") |
|
|
|
|
|
assert success == True |
|
|
assert downloader._stop_events["test-model"].set.called |
|
|
|
|
|
def test_download_with_resume(self, downloader, temp_dir): |
|
|
"""Test download with resume support.""" |
|
|
|
|
|
partial_file = temp_dir / "test.pth.part" |
|
|
partial_file.write_bytes(b"partial_data") |
|
|
|
|
|
|
|
|
assert partial_file.exists() |
|
|
assert partial_file.stat().st_size > 0 |
|
|
|
|
|
|
|
|
class TestModelLoader: |
|
|
"""Test model loading functionality.""" |
|
|
|
|
|
@pytest.fixture |
|
|
def loader(self, mock_registry): |
|
|
"""Create a test loader.""" |
|
|
return ModelLoader(mock_registry, device='cpu') |
|
|
|
|
|
def test_loader_initialization(self, loader): |
|
|
"""Test loader initialization.""" |
|
|
assert loader is not None |
|
|
assert loader.device == 'cpu' |
|
|
assert loader.max_memory_bytes > 0 |
|
|
|
|
|
@patch('torch.load') |
|
|
def test_load_pytorch_model(self, mock_torch_load, loader): |
|
|
"""Test loading a PyTorch model.""" |
|
|
mock_model = MagicMock() |
|
|
mock_torch_load.return_value = mock_model |
|
|
|
|
|
|
|
|
model_info = ModelInfo( |
|
|
model_id="test-pytorch", |
|
|
name="Test PyTorch Model", |
|
|
version="1.0", |
|
|
task=ModelTask.SEGMENTATION, |
|
|
framework=ModelFramework.PYTORCH, |
|
|
url="", |
|
|
filename="model.pth", |
|
|
local_path="/tmp/model.pth", |
|
|
status=ModelStatus.AVAILABLE |
|
|
) |
|
|
|
|
|
loader.registry.get_model = Mock(return_value=model_info) |
|
|
|
|
|
with patch.object(Path, 'exists', return_value=True): |
|
|
loaded = loader.load_model("test-pytorch") |
|
|
|
|
|
|
|
|
assert mock_torch_load.called |
|
|
|
|
|
def test_memory_management(self, loader): |
|
|
"""Test memory management during model loading.""" |
|
|
|
|
|
for i in range(5): |
|
|
loader.loaded_models[f"model_{i}"] = Mock( |
|
|
memory_usage=100 * 1024 * 1024 |
|
|
) |
|
|
|
|
|
loader.current_memory_usage = 500 * 1024 * 1024 |
|
|
|
|
|
|
|
|
loader._free_memory(200 * 1024 * 1024) |
|
|
|
|
|
|
|
|
assert len(loader.loaded_models) < 5 |
|
|
|
|
|
def test_unload_model(self, loader): |
|
|
"""Test unloading a model.""" |
|
|
|
|
|
loader.loaded_models["test"] = Mock( |
|
|
model=Mock(), |
|
|
memory_usage=100 * 1024 * 1024 |
|
|
) |
|
|
loader.current_memory_usage = 100 * 1024 * 1024 |
|
|
|
|
|
success = loader.unload_model("test") |
|
|
|
|
|
assert success == True |
|
|
assert "test" not in loader.loaded_models |
|
|
assert loader.current_memory_usage == 0 |
|
|
|
|
|
def test_get_memory_usage(self, loader): |
|
|
"""Test getting memory usage statistics.""" |
|
|
|
|
|
loader.loaded_models["model1"] = Mock(memory_usage=100 * 1024 * 1024) |
|
|
loader.loaded_models["model2"] = Mock(memory_usage=200 * 1024 * 1024) |
|
|
loader.current_memory_usage = 300 * 1024 * 1024 |
|
|
|
|
|
usage = loader.get_memory_usage() |
|
|
|
|
|
assert usage["current_usage_mb"] == 300 |
|
|
assert usage["loaded_models"] == 2 |
|
|
assert "model1" in usage["models"] |
|
|
assert "model2" in usage["models"] |
|
|
|
|
|
|
|
|
class TestModelOptimizer: |
|
|
"""Test model optimization functionality.""" |
|
|
|
|
|
@pytest.fixture |
|
|
def optimizer(self, mock_registry): |
|
|
"""Create a test optimizer.""" |
|
|
loader = ModelLoader(mock_registry, device='cpu') |
|
|
return ModelOptimizer(loader) |
|
|
|
|
|
@patch('torch.quantization.quantize_dynamic') |
|
|
def test_quantize_pytorch_model(self, mock_quantize, optimizer): |
|
|
"""Test PyTorch model quantization.""" |
|
|
|
|
|
mock_model = MagicMock() |
|
|
mock_quantize.return_value = mock_model |
|
|
|
|
|
loaded = Mock( |
|
|
model_id="test", |
|
|
model=mock_model, |
|
|
framework=ModelFramework.PYTORCH, |
|
|
metadata={'input_size': (1, 3, 512, 512)} |
|
|
) |
|
|
|
|
|
with patch.object(optimizer, '_get_model_size', return_value=1000000): |
|
|
with patch.object(optimizer, '_benchmark_model', return_value=0.1): |
|
|
result = optimizer._quantize_pytorch( |
|
|
loaded, |
|
|
Path("/tmp"), |
|
|
"dynamic" |
|
|
) |
|
|
|
|
|
assert mock_quantize.called |
|
|
|
|
|
|
|
|
def test_optimization_result(self, optimizer): |
|
|
"""Test optimization result structure.""" |
|
|
from models.optimizer import OptimizationResult |
|
|
|
|
|
result = OptimizationResult( |
|
|
original_size_mb=100, |
|
|
optimized_size_mb=25, |
|
|
compression_ratio=4.0, |
|
|
original_speed_ms=100, |
|
|
optimized_speed_ms=50, |
|
|
speedup=2.0, |
|
|
accuracy_loss=0.01, |
|
|
optimization_time=10.0, |
|
|
output_path="/tmp/optimized.pth" |
|
|
) |
|
|
|
|
|
assert result.compression_ratio == 4.0 |
|
|
assert result.speedup == 2.0 |
|
|
assert result.accuracy_loss == 0.01 |
|
|
|
|
|
|
|
|
class TestModelIntegration: |
|
|
"""Integration tests for model management.""" |
|
|
|
|
|
@pytest.mark.integration |
|
|
@pytest.mark.slow |
|
|
def test_model_registry_persistence(self, temp_dir): |
|
|
"""Test registry persistence across instances.""" |
|
|
|
|
|
registry1 = ModelRegistry(models_dir=temp_dir) |
|
|
|
|
|
test_model = ModelInfo( |
|
|
model_id="persistence-test", |
|
|
name="Persistence Test", |
|
|
version="1.0", |
|
|
task=ModelTask.SEGMENTATION, |
|
|
framework=ModelFramework.PYTORCH, |
|
|
url="http://example.com/model.pth", |
|
|
filename="persist.pth" |
|
|
) |
|
|
|
|
|
registry1.register_model(test_model) |
|
|
|
|
|
|
|
|
registry2 = ModelRegistry(models_dir=temp_dir) |
|
|
|
|
|
|
|
|
loaded_model = registry2.get_model("persistence-test") |
|
|
assert loaded_model is not None |
|
|
assert loaded_model.name == "Persistence Test" |
|
|
|
|
|
@pytest.mark.integration |
|
|
def test_model_manager_workflow(self): |
|
|
"""Test complete model manager workflow.""" |
|
|
from models import create_model_manager |
|
|
|
|
|
manager = create_model_manager() |
|
|
|
|
|
|
|
|
stats = manager.get_stats() |
|
|
assert "registry" in stats |
|
|
assert stats["registry"]["total_models"] > 0 |
|
|
|
|
|
|
|
|
with patch.object(manager.loader, 'load_model', return_value=Mock()): |
|
|
benchmarks = manager.benchmark() |
|
|
|
|
|
assert isinstance(benchmarks, dict) |