|
|
|
|
|
|
|
|
""" |
|
|
Comprehensive tests for the adaptive batching module. |
|
|
|
|
|
Tests cover: |
|
|
- ModelMemoryProfile dataclass |
|
|
- Memory utility functions |
|
|
- AdaptiveBatchSizeCalculator |
|
|
- BatchInfo and adaptive_batch_iterator |
|
|
- High-level API functions |
|
|
- Edge cases and error handling |
|
|
""" |
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
from unittest.mock import MagicMock, patch |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from depth_anything_3.utils.adaptive_batching import ( |
|
|
MODEL_MEMORY_PROFILES, |
|
|
AdaptiveBatchConfig, |
|
|
AdaptiveBatchSizeCalculator, |
|
|
BatchInfo, |
|
|
ModelMemoryProfile, |
|
|
adaptive_batch_iterator, |
|
|
estimate_max_batch_size, |
|
|
get_available_memory_mb, |
|
|
get_total_memory_mb, |
|
|
log_batch_plan, |
|
|
process_with_adaptive_batching, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def cpu_device(): |
|
|
"""Return CPU device.""" |
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def mock_cuda_device(): |
|
|
"""Return mock CUDA device.""" |
|
|
return torch.device("cuda:0") |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def mock_mps_device(): |
|
|
"""Return mock MPS device.""" |
|
|
return torch.device("mps") |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def default_config(): |
|
|
"""Return default adaptive batch config.""" |
|
|
return AdaptiveBatchConfig() |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def calculator_cpu(cpu_device): |
|
|
"""Return calculator for CPU.""" |
|
|
return AdaptiveBatchSizeCalculator("da3-large", cpu_device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestModelMemoryProfile: |
|
|
"""Tests for ModelMemoryProfile dataclass.""" |
|
|
|
|
|
def test_default_values(self): |
|
|
"""Test default values are set correctly.""" |
|
|
profile = ModelMemoryProfile( |
|
|
base_memory_mb=1000, |
|
|
per_image_mb_at_504=500, |
|
|
) |
|
|
assert profile.base_memory_mb == 1000 |
|
|
assert profile.per_image_mb_at_504 == 500 |
|
|
assert profile.activation_scale == 1.0 |
|
|
assert profile.safety_margin == 0.15 |
|
|
|
|
|
def test_custom_values(self): |
|
|
"""Test custom values override defaults.""" |
|
|
profile = ModelMemoryProfile( |
|
|
base_memory_mb=2000, |
|
|
per_image_mb_at_504=800, |
|
|
activation_scale=1.5, |
|
|
safety_margin=0.2, |
|
|
) |
|
|
assert profile.base_memory_mb == 2000 |
|
|
assert profile.per_image_mb_at_504 == 800 |
|
|
assert profile.activation_scale == 1.5 |
|
|
assert profile.safety_margin == 0.2 |
|
|
|
|
|
def test_all_models_have_profiles(self): |
|
|
"""Test that all expected models have memory profiles.""" |
|
|
expected_models = [ |
|
|
"da3-small", |
|
|
"da3-base", |
|
|
"da3-large", |
|
|
"da3-giant", |
|
|
"da3metric-large", |
|
|
"da3mono-large", |
|
|
"da3nested-giant-large", |
|
|
] |
|
|
for model_name in expected_models: |
|
|
assert model_name in MODEL_MEMORY_PROFILES |
|
|
profile = MODEL_MEMORY_PROFILES[model_name] |
|
|
assert profile.base_memory_mb > 0 |
|
|
assert profile.per_image_mb_at_504 > 0 |
|
|
|
|
|
def test_profiles_size_ordering(self): |
|
|
"""Test that model profiles have expected size ordering.""" |
|
|
small = MODEL_MEMORY_PROFILES["da3-small"] |
|
|
base = MODEL_MEMORY_PROFILES["da3-base"] |
|
|
large = MODEL_MEMORY_PROFILES["da3-large"] |
|
|
giant = MODEL_MEMORY_PROFILES["da3-giant"] |
|
|
|
|
|
|
|
|
assert small.base_memory_mb < base.base_memory_mb |
|
|
assert base.base_memory_mb < large.base_memory_mb |
|
|
assert large.base_memory_mb < giant.base_memory_mb |
|
|
|
|
|
|
|
|
assert small.per_image_mb_at_504 < base.per_image_mb_at_504 |
|
|
assert base.per_image_mb_at_504 < large.per_image_mb_at_504 |
|
|
assert large.per_image_mb_at_504 < giant.per_image_mb_at_504 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGetAvailableMemory: |
|
|
"""Tests for get_available_memory_mb function.""" |
|
|
|
|
|
def test_cpu_returns_infinity(self, cpu_device): |
|
|
"""CPU should return infinite memory.""" |
|
|
result = get_available_memory_mb(cpu_device) |
|
|
assert result == float("inf") |
|
|
|
|
|
@patch("torch.cuda.is_available", return_value=True) |
|
|
@patch("torch.cuda.synchronize") |
|
|
@patch("torch.cuda.get_device_properties") |
|
|
@patch("torch.cuda.memory_reserved") |
|
|
def test_cuda_memory_calculation( |
|
|
self, |
|
|
mock_reserved, |
|
|
mock_properties, |
|
|
mock_sync, |
|
|
mock_available, |
|
|
mock_cuda_device, |
|
|
): |
|
|
"""Test CUDA memory calculation.""" |
|
|
|
|
|
mock_props = MagicMock() |
|
|
mock_props.total_memory = 16 * 1024 * 1024 * 1024 |
|
|
mock_properties.return_value = mock_props |
|
|
mock_reserved.return_value = 4 * 1024 * 1024 * 1024 |
|
|
|
|
|
result = get_available_memory_mb(mock_cuda_device) |
|
|
|
|
|
|
|
|
expected = (16 - 4) * 1024 |
|
|
assert result == expected |
|
|
|
|
|
def test_mps_memory_with_env_var(self, mock_mps_device, monkeypatch): |
|
|
"""Test MPS memory respects environment variable.""" |
|
|
monkeypatch.setenv("DA3_MPS_MAX_MEMORY_GB", "16") |
|
|
|
|
|
with patch("torch.mps.current_allocated_memory", return_value=0): |
|
|
result = get_available_memory_mb(mock_mps_device) |
|
|
assert result == 16 * 1024 |
|
|
|
|
|
def test_mps_memory_default(self, mock_mps_device, monkeypatch): |
|
|
"""Test MPS memory uses default when env var not set.""" |
|
|
monkeypatch.delenv("DA3_MPS_MAX_MEMORY_GB", raising=False) |
|
|
|
|
|
with patch("torch.mps.current_allocated_memory", return_value=0): |
|
|
result = get_available_memory_mb(mock_mps_device) |
|
|
assert result == 8 * 1024 |
|
|
|
|
|
def test_mps_memory_subtracts_allocated(self, mock_mps_device, monkeypatch): |
|
|
"""Test MPS memory subtracts allocated memory.""" |
|
|
monkeypatch.setenv("DA3_MPS_MAX_MEMORY_GB", "8") |
|
|
|
|
|
allocated_bytes = 2 * 1024 * 1024 * 1024 |
|
|
with patch("torch.mps.current_allocated_memory", return_value=allocated_bytes): |
|
|
result = get_available_memory_mb(mock_mps_device) |
|
|
expected = (8 - 2) * 1024 |
|
|
assert result == expected |
|
|
|
|
|
|
|
|
class TestGetTotalMemory: |
|
|
"""Tests for get_total_memory_mb function.""" |
|
|
|
|
|
def test_cpu_returns_infinity(self, cpu_device): |
|
|
"""CPU should return infinite total memory.""" |
|
|
result = get_total_memory_mb(cpu_device) |
|
|
assert result == float("inf") |
|
|
|
|
|
@patch("torch.cuda.get_device_properties") |
|
|
def test_cuda_total_memory(self, mock_properties, mock_cuda_device): |
|
|
"""Test CUDA total memory retrieval.""" |
|
|
mock_props = MagicMock() |
|
|
mock_props.total_memory = 24 * 1024 * 1024 * 1024 |
|
|
mock_properties.return_value = mock_props |
|
|
|
|
|
result = get_total_memory_mb(mock_cuda_device) |
|
|
assert result == 24 * 1024 |
|
|
|
|
|
def test_mps_total_memory_env_var(self, mock_mps_device, monkeypatch): |
|
|
"""Test MPS total memory from environment variable.""" |
|
|
monkeypatch.setenv("DA3_MPS_MAX_MEMORY_GB", "32") |
|
|
result = get_total_memory_mb(mock_mps_device) |
|
|
assert result == 32 * 1024 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAdaptiveBatchConfig: |
|
|
"""Tests for AdaptiveBatchConfig dataclass.""" |
|
|
|
|
|
def test_default_values(self): |
|
|
"""Test default configuration values.""" |
|
|
config = AdaptiveBatchConfig() |
|
|
assert config.min_batch_size == 1 |
|
|
assert config.max_batch_size == 64 |
|
|
assert config.target_memory_utilization == 0.85 |
|
|
assert config.enable_profiling is True |
|
|
assert config.profile_warmup_batches == 2 |
|
|
|
|
|
def test_custom_values(self): |
|
|
"""Test custom configuration values.""" |
|
|
config = AdaptiveBatchConfig( |
|
|
min_batch_size=2, |
|
|
max_batch_size=32, |
|
|
target_memory_utilization=0.90, |
|
|
enable_profiling=False, |
|
|
profile_warmup_batches=5, |
|
|
) |
|
|
assert config.min_batch_size == 2 |
|
|
assert config.max_batch_size == 32 |
|
|
assert config.target_memory_utilization == 0.90 |
|
|
assert config.enable_profiling is False |
|
|
assert config.profile_warmup_batches == 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAdaptiveBatchSizeCalculator: |
|
|
"""Tests for AdaptiveBatchSizeCalculator class.""" |
|
|
|
|
|
def test_initialization_known_model(self, cpu_device): |
|
|
"""Test initialization with known model.""" |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device) |
|
|
assert calc.model_name == "da3-large" |
|
|
assert calc.device == cpu_device |
|
|
assert calc.profile == MODEL_MEMORY_PROFILES["da3-large"] |
|
|
|
|
|
def test_initialization_unknown_model_uses_fallback(self, cpu_device): |
|
|
"""Test initialization with unknown model falls back to da3-large.""" |
|
|
calc = AdaptiveBatchSizeCalculator("unknown-model", cpu_device) |
|
|
assert calc.profile == MODEL_MEMORY_PROFILES["da3-large"] |
|
|
|
|
|
def test_initialization_with_custom_config(self, cpu_device): |
|
|
"""Test initialization with custom config.""" |
|
|
config = AdaptiveBatchConfig(max_batch_size=16) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
assert calc.config.max_batch_size == 16 |
|
|
|
|
|
def test_compute_optimal_batch_size_cpu(self, cpu_device): |
|
|
"""CPU should return min(num_images, max_batch_size).""" |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device) |
|
|
|
|
|
|
|
|
result = calc.compute_optimal_batch_size(num_images=10) |
|
|
assert result == 10 |
|
|
|
|
|
|
|
|
result = calc.compute_optimal_batch_size(num_images=100) |
|
|
assert result == 64 |
|
|
|
|
|
def test_compute_optimal_batch_size_respects_min(self, cpu_device): |
|
|
"""Batch size should not go below min_batch_size.""" |
|
|
config = AdaptiveBatchConfig(min_batch_size=4) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
result = calc.compute_optimal_batch_size(num_images=2) |
|
|
|
|
|
|
|
|
assert result == 2 |
|
|
|
|
|
def test_compute_optimal_batch_size_respects_max(self, cpu_device): |
|
|
"""Batch size should not exceed max_batch_size.""" |
|
|
config = AdaptiveBatchConfig(max_batch_size=8) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
result = calc.compute_optimal_batch_size(num_images=100) |
|
|
assert result == 8 |
|
|
|
|
|
@patch("depth_anything_3.utils.adaptive_batching.get_available_memory_mb") |
|
|
def test_compute_optimal_batch_size_memory_based( |
|
|
self, mock_memory, mock_cuda_device |
|
|
): |
|
|
"""Test memory-based batch size calculation.""" |
|
|
|
|
|
mock_memory.return_value = 10000 |
|
|
|
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", mock_cuda_device) |
|
|
|
|
|
result = calc.compute_optimal_batch_size(num_images=100, process_res=504) |
|
|
|
|
|
|
|
|
assert 1 <= result <= 64 |
|
|
assert result < 100 |
|
|
|
|
|
@patch("depth_anything_3.utils.adaptive_batching.get_available_memory_mb") |
|
|
def test_compute_low_memory_returns_min(self, mock_memory, mock_cuda_device): |
|
|
"""Low memory should return min batch size.""" |
|
|
|
|
|
mock_memory.return_value = 500 |
|
|
|
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", mock_cuda_device) |
|
|
result = calc.compute_optimal_batch_size(num_images=100) |
|
|
|
|
|
assert result == 1 |
|
|
|
|
|
def test_estimate_per_image_memory_resolution_scaling(self, cpu_device): |
|
|
"""Test that memory scales quadratically with resolution.""" |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device) |
|
|
|
|
|
mem_504 = calc._estimate_per_image_memory(504) |
|
|
mem_1008 = calc._estimate_per_image_memory(1008) |
|
|
|
|
|
|
|
|
ratio = mem_1008 / mem_504 |
|
|
assert 3.5 <= ratio <= 4.5 |
|
|
|
|
|
def test_update_from_profiling_warmup(self, cpu_device): |
|
|
"""Test that warmup batches are skipped during profiling.""" |
|
|
config = AdaptiveBatchConfig(profile_warmup_batches=2) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
|
|
|
calc.update_from_profiling(batch_size=4, memory_used_mb=3000, process_res=504) |
|
|
assert calc._measured_per_image_mb is None |
|
|
|
|
|
calc.update_from_profiling(batch_size=4, memory_used_mb=3000, process_res=504) |
|
|
assert calc._measured_per_image_mb is None |
|
|
|
|
|
|
|
|
calc.update_from_profiling(batch_size=4, memory_used_mb=3000, process_res=504) |
|
|
assert calc._measured_per_image_mb is not None |
|
|
|
|
|
def test_update_from_profiling_disabled(self, cpu_device): |
|
|
"""Test that profiling can be disabled.""" |
|
|
config = AdaptiveBatchConfig(enable_profiling=False) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
for _ in range(5): |
|
|
calc.update_from_profiling(batch_size=4, memory_used_mb=3000, process_res=504) |
|
|
|
|
|
assert calc._measured_per_image_mb is None |
|
|
|
|
|
def test_update_from_profiling_ema(self, cpu_device): |
|
|
"""Test exponential moving average in profiling.""" |
|
|
config = AdaptiveBatchConfig(profile_warmup_batches=0) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
|
|
|
calc.update_from_profiling(batch_size=4, memory_used_mb=4000, process_res=504) |
|
|
first_value = calc._measured_per_image_mb |
|
|
|
|
|
|
|
|
calc.update_from_profiling(batch_size=4, memory_used_mb=5000, process_res=504) |
|
|
second_value = calc._measured_per_image_mb |
|
|
|
|
|
|
|
|
assert second_value is not None |
|
|
assert second_value != first_value |
|
|
|
|
|
def test_get_memory_estimate(self, cpu_device): |
|
|
"""Test memory estimation for batch.""" |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device) |
|
|
|
|
|
estimate = calc.get_memory_estimate(batch_size=4, process_res=504) |
|
|
|
|
|
|
|
|
expected_min = calc.profile.base_memory_mb |
|
|
assert estimate > expected_min |
|
|
assert estimate > calc.profile.base_memory_mb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBatchInfo: |
|
|
"""Tests for BatchInfo dataclass.""" |
|
|
|
|
|
def test_batch_info_creation(self): |
|
|
"""Test basic BatchInfo creation.""" |
|
|
items = ["a", "b", "c"] |
|
|
info = BatchInfo( |
|
|
batch_idx=0, |
|
|
start_idx=0, |
|
|
end_idx=3, |
|
|
items=items, |
|
|
is_last=True, |
|
|
) |
|
|
assert info.batch_idx == 0 |
|
|
assert info.start_idx == 0 |
|
|
assert info.end_idx == 3 |
|
|
assert info.items == ["a", "b", "c"] |
|
|
assert info.batch_size == 3 |
|
|
assert info.is_last is True |
|
|
|
|
|
def test_batch_size_computed_from_items(self): |
|
|
"""Test that batch_size is computed from items.""" |
|
|
info = BatchInfo( |
|
|
batch_idx=0, |
|
|
start_idx=0, |
|
|
end_idx=5, |
|
|
items=[1, 2, 3, 4, 5], |
|
|
) |
|
|
assert info.batch_size == 5 |
|
|
|
|
|
def test_empty_batch(self): |
|
|
"""Test empty batch handling.""" |
|
|
info = BatchInfo( |
|
|
batch_idx=0, |
|
|
start_idx=0, |
|
|
end_idx=0, |
|
|
items=[], |
|
|
) |
|
|
assert info.batch_size == 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAdaptiveBatchIterator: |
|
|
"""Tests for adaptive_batch_iterator function.""" |
|
|
|
|
|
def test_single_batch(self, calculator_cpu): |
|
|
"""Test single batch when all items fit.""" |
|
|
items = list(range(10)) |
|
|
batches = list(adaptive_batch_iterator(items, calculator_cpu)) |
|
|
|
|
|
assert len(batches) == 1 |
|
|
assert batches[0].items == items |
|
|
assert batches[0].is_last is True |
|
|
|
|
|
def test_multiple_batches(self, cpu_device): |
|
|
"""Test multiple batches with small max_batch_size.""" |
|
|
config = AdaptiveBatchConfig(max_batch_size=3) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
items = list(range(10)) |
|
|
batches = list(adaptive_batch_iterator(items, calc)) |
|
|
|
|
|
|
|
|
assert len(batches) == 4 |
|
|
assert batches[0].batch_size == 3 |
|
|
assert batches[-1].batch_size == 1 |
|
|
assert batches[-1].is_last is True |
|
|
|
|
|
def test_batch_indices_are_correct(self, cpu_device): |
|
|
"""Test that batch indices are sequential.""" |
|
|
config = AdaptiveBatchConfig(max_batch_size=2) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
items = list(range(6)) |
|
|
batches = list(adaptive_batch_iterator(items, calc)) |
|
|
|
|
|
for i, batch in enumerate(batches): |
|
|
assert batch.batch_idx == i |
|
|
|
|
|
def test_start_end_indices_cover_all_items(self, cpu_device): |
|
|
"""Test that batches cover all items without gaps.""" |
|
|
config = AdaptiveBatchConfig(max_batch_size=3) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
items = list(range(10)) |
|
|
batches = list(adaptive_batch_iterator(items, calc)) |
|
|
|
|
|
|
|
|
prev_end = 0 |
|
|
for batch in batches: |
|
|
assert batch.start_idx == prev_end |
|
|
assert batch.end_idx > batch.start_idx |
|
|
prev_end = batch.end_idx |
|
|
|
|
|
assert prev_end == len(items) |
|
|
|
|
|
def test_items_are_preserved(self, cpu_device): |
|
|
"""Test that all items are preserved in batches.""" |
|
|
config = AdaptiveBatchConfig(max_batch_size=4) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
original_items = ["a", "b", "c", "d", "e", "f", "g"] |
|
|
batches = list(adaptive_batch_iterator(original_items, calc)) |
|
|
|
|
|
|
|
|
collected = [] |
|
|
for batch in batches: |
|
|
collected.extend(batch.items) |
|
|
|
|
|
assert collected == original_items |
|
|
|
|
|
def test_empty_sequence(self, calculator_cpu): |
|
|
"""Test empty sequence returns no batches.""" |
|
|
batches = list(adaptive_batch_iterator([], calculator_cpu)) |
|
|
assert len(batches) == 0 |
|
|
|
|
|
def test_last_batch_flag(self, cpu_device): |
|
|
"""Test that only last batch has is_last=True.""" |
|
|
config = AdaptiveBatchConfig(max_batch_size=2) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
items = list(range(5)) |
|
|
batches = list(adaptive_batch_iterator(items, calc)) |
|
|
|
|
|
|
|
|
for batch in batches[:-1]: |
|
|
assert batch.is_last is False |
|
|
|
|
|
|
|
|
assert batches[-1].is_last is True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestProcessWithAdaptiveBatching: |
|
|
"""Tests for process_with_adaptive_batching function.""" |
|
|
|
|
|
def test_basic_processing(self, cpu_device): |
|
|
"""Test basic batch processing.""" |
|
|
items = list(range(10)) |
|
|
|
|
|
def process_fn(batch): |
|
|
return [x * 2 for x in batch] |
|
|
|
|
|
results = process_with_adaptive_batching( |
|
|
items=items, |
|
|
process_fn=process_fn, |
|
|
model_name="da3-large", |
|
|
device=cpu_device, |
|
|
) |
|
|
|
|
|
assert results == [x * 2 for x in items] |
|
|
|
|
|
def test_progress_callback(self, cpu_device): |
|
|
"""Test progress callback is called.""" |
|
|
items = list(range(10)) |
|
|
progress_calls = [] |
|
|
|
|
|
def process_fn(batch): |
|
|
return batch |
|
|
|
|
|
def progress_callback(processed, total): |
|
|
progress_calls.append((processed, total)) |
|
|
|
|
|
config = AdaptiveBatchConfig(max_batch_size=3) |
|
|
|
|
|
results = process_with_adaptive_batching( |
|
|
items=items, |
|
|
process_fn=process_fn, |
|
|
model_name="da3-large", |
|
|
device=cpu_device, |
|
|
config=config, |
|
|
progress_callback=progress_callback, |
|
|
) |
|
|
|
|
|
|
|
|
assert len(progress_calls) > 1 |
|
|
|
|
|
|
|
|
assert progress_calls[-1][0] == len(items) |
|
|
assert progress_calls[-1][1] == len(items) |
|
|
|
|
|
def test_single_result_handling(self, cpu_device): |
|
|
"""Test handling of non-list results.""" |
|
|
items = list(range(5)) |
|
|
|
|
|
def process_fn(batch): |
|
|
|
|
|
return sum(batch) |
|
|
|
|
|
results = process_with_adaptive_batching( |
|
|
items=items, |
|
|
process_fn=process_fn, |
|
|
model_name="da3-large", |
|
|
device=cpu_device, |
|
|
) |
|
|
|
|
|
|
|
|
assert isinstance(results, list) |
|
|
|
|
|
def test_empty_items(self, cpu_device): |
|
|
"""Test with empty items list.""" |
|
|
results = process_with_adaptive_batching( |
|
|
items=[], |
|
|
process_fn=lambda x: x, |
|
|
model_name="da3-large", |
|
|
device=cpu_device, |
|
|
) |
|
|
assert results == [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestEstimateMaxBatchSize: |
|
|
"""Tests for estimate_max_batch_size function.""" |
|
|
|
|
|
def test_returns_positive_integer(self, cpu_device): |
|
|
"""Test that function returns positive integer.""" |
|
|
result = estimate_max_batch_size("da3-large", cpu_device) |
|
|
assert isinstance(result, int) |
|
|
assert result > 0 |
|
|
|
|
|
def test_different_resolutions(self, cpu_device): |
|
|
"""Test that higher resolution gives lower batch size (for GPU).""" |
|
|
|
|
|
low_res = estimate_max_batch_size("da3-large", cpu_device, process_res=504) |
|
|
high_res = estimate_max_batch_size("da3-large", cpu_device, process_res=1008) |
|
|
|
|
|
|
|
|
assert low_res > 0 |
|
|
assert high_res > 0 |
|
|
|
|
|
def test_different_utilization(self, cpu_device): |
|
|
"""Test different target utilization values.""" |
|
|
low_util = estimate_max_batch_size( |
|
|
"da3-large", cpu_device, target_utilization=0.5 |
|
|
) |
|
|
high_util = estimate_max_batch_size( |
|
|
"da3-large", cpu_device, target_utilization=0.95 |
|
|
) |
|
|
|
|
|
|
|
|
assert low_util > 0 |
|
|
assert high_util > 0 |
|
|
|
|
|
|
|
|
class TestLogBatchPlan: |
|
|
"""Tests for log_batch_plan function.""" |
|
|
|
|
|
def test_log_batch_plan_runs(self, cpu_device, caplog): |
|
|
"""Test that log_batch_plan runs without error.""" |
|
|
import logging |
|
|
|
|
|
with caplog.at_level(logging.INFO): |
|
|
|
|
|
log_batch_plan( |
|
|
num_images=100, |
|
|
model_name="da3-large", |
|
|
device=cpu_device, |
|
|
process_res=504, |
|
|
) |
|
|
|
|
|
def test_log_batch_plan_different_models(self, cpu_device): |
|
|
"""Test log_batch_plan with different models.""" |
|
|
for model_name in ["da3-small", "da3-base", "da3-large", "da3-giant"]: |
|
|
|
|
|
log_batch_plan( |
|
|
num_images=50, |
|
|
model_name=model_name, |
|
|
device=cpu_device, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestIntegration: |
|
|
"""Integration tests for the adaptive batching module.""" |
|
|
|
|
|
def test_full_workflow_cpu(self, cpu_device): |
|
|
"""Test complete workflow on CPU.""" |
|
|
|
|
|
images = [f"image_{i}.jpg" for i in range(25)] |
|
|
|
|
|
|
|
|
processed_batches = [] |
|
|
|
|
|
def process_fn(batch): |
|
|
processed_batches.append(len(batch)) |
|
|
return [f"result_{item}" for item in batch] |
|
|
|
|
|
|
|
|
config = AdaptiveBatchConfig(max_batch_size=8) |
|
|
results = process_with_adaptive_batching( |
|
|
items=images, |
|
|
process_fn=process_fn, |
|
|
model_name="da3-large", |
|
|
device=cpu_device, |
|
|
config=config, |
|
|
) |
|
|
|
|
|
|
|
|
assert len(results) == len(images) |
|
|
assert all(r.startswith("result_") for r in results) |
|
|
|
|
|
|
|
|
assert sum(processed_batches) == len(images) |
|
|
assert max(processed_batches) <= 8 |
|
|
|
|
|
def test_calculator_reuse(self, cpu_device): |
|
|
"""Test that calculator can be reused across multiple iterations.""" |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device) |
|
|
|
|
|
|
|
|
batch1 = calc.compute_optimal_batch_size(num_images=100) |
|
|
|
|
|
|
|
|
batch2 = calc.compute_optimal_batch_size(num_images=50) |
|
|
|
|
|
assert batch1 == 64 |
|
|
assert batch2 == 50 |
|
|
|
|
|
def test_iterator_with_strings(self, cpu_device): |
|
|
"""Test iterator works with string items.""" |
|
|
config = AdaptiveBatchConfig(max_batch_size=3) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
items = ["path/to/image1.jpg", "path/to/image2.jpg", "path/to/image3.jpg", "path/to/image4.jpg"] |
|
|
|
|
|
batches = list(adaptive_batch_iterator(items, calc)) |
|
|
|
|
|
|
|
|
all_paths = [] |
|
|
for batch in batches: |
|
|
all_paths.extend(batch.items) |
|
|
|
|
|
assert all_paths == items |
|
|
|
|
|
def test_iterator_with_tuples(self, cpu_device): |
|
|
"""Test iterator works with tuple items.""" |
|
|
config = AdaptiveBatchConfig(max_batch_size=2) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
items = [(1, "a"), (2, "b"), (3, "c")] |
|
|
|
|
|
batches = list(adaptive_batch_iterator(items, calc)) |
|
|
|
|
|
|
|
|
all_items = [] |
|
|
for batch in batches: |
|
|
all_items.extend(batch.items) |
|
|
|
|
|
assert all_items == list(items) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestEdgeCases: |
|
|
"""Tests for edge cases and boundary conditions.""" |
|
|
|
|
|
def test_single_image(self, cpu_device): |
|
|
"""Test with single image.""" |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device) |
|
|
|
|
|
result = calc.compute_optimal_batch_size(num_images=1) |
|
|
assert result == 1 |
|
|
|
|
|
batches = list(adaptive_batch_iterator(["single"], calc)) |
|
|
assert len(batches) == 1 |
|
|
assert batches[0].items == ["single"] |
|
|
assert batches[0].is_last is True |
|
|
|
|
|
def test_exact_batch_size_multiple(self, cpu_device): |
|
|
"""Test when num_images is exact multiple of batch_size.""" |
|
|
config = AdaptiveBatchConfig(max_batch_size=5) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
items = list(range(15)) |
|
|
batches = list(adaptive_batch_iterator(items, calc)) |
|
|
|
|
|
assert len(batches) == 3 |
|
|
assert all(b.batch_size == 5 for b in batches) |
|
|
|
|
|
def test_very_large_num_images(self, cpu_device): |
|
|
"""Test with very large number of images.""" |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device) |
|
|
|
|
|
result = calc.compute_optimal_batch_size(num_images=1_000_000) |
|
|
assert result == 64 |
|
|
|
|
|
def test_zero_reserved_memory(self, cpu_device): |
|
|
"""Test with zero reserved memory.""" |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device) |
|
|
|
|
|
result = calc.compute_optimal_batch_size( |
|
|
num_images=100, |
|
|
process_res=504, |
|
|
reserved_memory_mb=0, |
|
|
) |
|
|
assert result > 0 |
|
|
|
|
|
def test_high_resolution(self, cpu_device): |
|
|
"""Test with very high resolution.""" |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device) |
|
|
|
|
|
|
|
|
result = calc.compute_optimal_batch_size( |
|
|
num_images=100, |
|
|
process_res=2160, |
|
|
) |
|
|
assert result > 0 |
|
|
|
|
|
def test_low_resolution(self, cpu_device): |
|
|
"""Test with very low resolution.""" |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device) |
|
|
|
|
|
result = calc.compute_optimal_batch_size( |
|
|
num_images=100, |
|
|
process_res=128, |
|
|
) |
|
|
assert result > 0 |
|
|
|
|
|
def test_negative_memory_edge_case(self, cpu_device): |
|
|
"""Test handling when calculations could go negative.""" |
|
|
config = AdaptiveBatchConfig( |
|
|
min_batch_size=1, |
|
|
target_memory_utilization=0.01, |
|
|
) |
|
|
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config) |
|
|
|
|
|
|
|
|
result = calc.compute_optimal_batch_size(num_images=100) |
|
|
assert result >= 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pytest.main([__file__, "-v"]) |
|
|
|