| """Tests for external cache provider API.""" |
|
|
| import importlib.util |
| import pytest |
| from typing import Optional |
|
|
|
|
| def _torch_available() -> bool: |
| """Check if PyTorch is available.""" |
| return importlib.util.find_spec("torch") is not None |
|
|
|
|
| from comfy_execution.cache_provider import ( |
| CacheProvider, |
| CacheContext, |
| CacheValue, |
| register_cache_provider, |
| unregister_cache_provider, |
| _get_cache_providers, |
| _has_cache_providers, |
| _clear_cache_providers, |
| _serialize_cache_key, |
| _contains_self_unequal, |
| _estimate_value_size, |
| _canonicalize, |
| ) |
|
|
|
|
| class TestCanonicalize: |
| """Test _canonicalize function for deterministic ordering.""" |
|
|
| def test_frozenset_ordering_is_deterministic(self): |
| """Frozensets should produce consistent canonical form regardless of iteration order.""" |
| |
| fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)]) |
| fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)]) |
|
|
| result1 = _canonicalize(fs1) |
| result2 = _canonicalize(fs2) |
|
|
| assert result1 == result2 |
|
|
| def test_nested_frozenset_ordering(self): |
| """Nested frozensets should also be deterministically ordered.""" |
| inner1 = frozenset([1, 2, 3]) |
| inner2 = frozenset([3, 2, 1]) |
|
|
| fs1 = frozenset([("key", inner1)]) |
| fs2 = frozenset([("key", inner2)]) |
|
|
| result1 = _canonicalize(fs1) |
| result2 = _canonicalize(fs2) |
|
|
| assert result1 == result2 |
|
|
| def test_dict_ordering(self): |
| """Dicts should be sorted by key.""" |
| d1 = {"z": 1, "a": 2, "m": 3} |
| d2 = {"a": 2, "m": 3, "z": 1} |
|
|
| result1 = _canonicalize(d1) |
| result2 = _canonicalize(d2) |
|
|
| assert result1 == result2 |
|
|
| def test_tuple_preserved(self): |
| """Tuples should be marked and preserved.""" |
| t = (1, 2, 3) |
| result = _canonicalize(t) |
|
|
| assert result[0] == "__tuple__" |
|
|
| def test_list_preserved(self): |
| """Lists should be recursively canonicalized.""" |
| lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])] |
| result = _canonicalize(lst) |
|
|
| |
| assert "__dict__" in result[0] |
| |
| assert result[1][0] == "__frozenset__" |
|
|
| def test_primitives_include_type(self): |
| """Primitive types should include type name for disambiguation.""" |
| assert _canonicalize(42) == ("int", 42) |
| assert _canonicalize(3.14) == ("float", 3.14) |
| assert _canonicalize("hello") == ("str", "hello") |
| assert _canonicalize(True) == ("bool", True) |
| assert _canonicalize(None) == ("NoneType", None) |
|
|
| def test_int_and_str_distinguished(self): |
| """int 7 and str '7' must produce different canonical forms.""" |
| assert _canonicalize(7) != _canonicalize("7") |
|
|
| def test_bytes_converted(self): |
| """Bytes should be converted to hex string.""" |
| b = b"\x00\xff" |
| result = _canonicalize(b) |
|
|
| assert result[0] == "__bytes__" |
| assert result[1] == "00ff" |
|
|
| def test_set_ordering(self): |
| """Sets should be sorted like frozensets.""" |
| s1 = {3, 1, 2} |
| s2 = {1, 2, 3} |
|
|
| result1 = _canonicalize(s1) |
| result2 = _canonicalize(s2) |
|
|
| assert result1 == result2 |
| assert result1[0] == "__set__" |
|
|
| def test_unknown_type_raises(self): |
| """Unknown types should raise ValueError (fail-closed).""" |
| class CustomObj: |
| pass |
| with pytest.raises(ValueError): |
| _canonicalize(CustomObj()) |
|
|
| def test_object_with_value_attr_raises(self): |
| """Objects with .value attribute (Unhashable-like) should raise ValueError.""" |
| class FakeUnhashable: |
| def __init__(self): |
| self.value = float('nan') |
| with pytest.raises(ValueError): |
| _canonicalize(FakeUnhashable()) |
|
|
|
|
| class TestSerializeCacheKey: |
| """Test _serialize_cache_key for deterministic hashing.""" |
|
|
| def test_same_content_same_hash(self): |
| """Same content should produce same hash.""" |
| key1 = frozenset([("node_1", frozenset([("input", "value")]))]) |
| key2 = frozenset([("node_1", frozenset([("input", "value")]))]) |
|
|
| hash1 = _serialize_cache_key(key1) |
| hash2 = _serialize_cache_key(key2) |
|
|
| assert hash1 == hash2 |
|
|
| def test_different_content_different_hash(self): |
| """Different content should produce different hash.""" |
| key1 = frozenset([("node_1", "value_a")]) |
| key2 = frozenset([("node_1", "value_b")]) |
|
|
| hash1 = _serialize_cache_key(key1) |
| hash2 = _serialize_cache_key(key2) |
|
|
| assert hash1 != hash2 |
|
|
| def test_returns_hex_string(self): |
| """Should return hex string (SHA256 hex digest).""" |
| key = frozenset([("test", 123)]) |
| result = _serialize_cache_key(key) |
|
|
| assert isinstance(result, str) |
| assert len(result) == 64 |
|
|
| def test_complex_nested_structure(self): |
| """Complex nested structures should hash deterministically.""" |
| |
| |
| key = frozenset([ |
| ("node_1", frozenset([ |
| ("input_a", ("tuple", "value")), |
| ("input_b", frozenset([("nested", "dict")])), |
| ])), |
| ("node_2", frozenset([ |
| ("param", 42), |
| ])), |
| ]) |
|
|
| |
| hash1 = _serialize_cache_key(key) |
| hash2 = _serialize_cache_key(key) |
|
|
| assert hash1 == hash2 |
|
|
| def test_dict_in_cache_key(self): |
| """Dicts passed directly to _serialize_cache_key should work.""" |
| key = {"node_1": {"input": "value"}, "node_2": 42} |
|
|
| hash1 = _serialize_cache_key(key) |
| hash2 = _serialize_cache_key(key) |
|
|
| assert hash1 == hash2 |
| assert isinstance(hash1, str) |
| assert len(hash1) == 64 |
|
|
| def test_unknown_type_returns_none(self): |
| """Non-cacheable types should return None (fail-closed).""" |
| class CustomObj: |
| pass |
| assert _serialize_cache_key(CustomObj()) is None |
|
|
|
|
| class TestContainsSelfUnequal: |
| """Test _contains_self_unequal utility function.""" |
|
|
| def test_nan_float_detected(self): |
| """NaN floats should be detected (not equal to itself).""" |
| assert _contains_self_unequal(float('nan')) is True |
|
|
| def test_regular_float_not_detected(self): |
| """Regular floats are equal to themselves.""" |
| assert _contains_self_unequal(3.14) is False |
| assert _contains_self_unequal(0.0) is False |
| assert _contains_self_unequal(-1.5) is False |
|
|
| def test_infinity_not_detected(self): |
| """Infinity is equal to itself.""" |
| assert _contains_self_unequal(float('inf')) is False |
| assert _contains_self_unequal(float('-inf')) is False |
|
|
| def test_nan_in_list(self): |
| """NaN in list should be detected.""" |
| assert _contains_self_unequal([1, 2, float('nan'), 4]) is True |
| assert _contains_self_unequal([1, 2, 3, 4]) is False |
|
|
| def test_nan_in_tuple(self): |
| """NaN in tuple should be detected.""" |
| assert _contains_self_unequal((1, float('nan'))) is True |
| assert _contains_self_unequal((1, 2, 3)) is False |
|
|
| def test_nan_in_frozenset(self): |
| """NaN in frozenset should be detected.""" |
| assert _contains_self_unequal(frozenset([1, float('nan')])) is True |
| assert _contains_self_unequal(frozenset([1, 2, 3])) is False |
|
|
| def test_nan_in_dict_value(self): |
| """NaN in dict value should be detected.""" |
| assert _contains_self_unequal({"key": float('nan')}) is True |
| assert _contains_self_unequal({"key": 42}) is False |
|
|
| def test_nan_in_nested_structure(self): |
| """NaN in deeply nested structure should be detected.""" |
| nested = {"level1": [{"level2": (1, 2, float('nan'))}]} |
| assert _contains_self_unequal(nested) is True |
|
|
| def test_non_numeric_types(self): |
| """Non-numeric types should not be self-unequal.""" |
| assert _contains_self_unequal("string") is False |
| assert _contains_self_unequal(None) is False |
| assert _contains_self_unequal(True) is False |
|
|
| def test_object_with_nan_value_attribute(self): |
| """Objects wrapping NaN in .value should be detected.""" |
| class NanWrapper: |
| def __init__(self): |
| self.value = float('nan') |
| assert _contains_self_unequal(NanWrapper()) is True |
|
|
| def test_custom_self_unequal_object(self): |
| """Custom objects where not (x == x) should be detected.""" |
| class NeverEqual: |
| def __eq__(self, other): |
| return False |
| assert _contains_self_unequal(NeverEqual()) is True |
|
|
|
|
| class TestEstimateValueSize: |
| """Test _estimate_value_size utility function.""" |
|
|
| def test_empty_outputs(self): |
| """Empty outputs should have zero size.""" |
| value = CacheValue(outputs=[]) |
| assert _estimate_value_size(value) == 0 |
|
|
| @pytest.mark.skipif( |
| not _torch_available(), |
| reason="PyTorch not available" |
| ) |
| def test_tensor_size_estimation(self): |
| """Tensor size should be estimated correctly.""" |
| import torch |
|
|
| |
| tensor = torch.zeros(1000, dtype=torch.float32) |
| value = CacheValue(outputs=[[tensor]]) |
|
|
| size = _estimate_value_size(value) |
| assert size == 4000 |
|
|
| @pytest.mark.skipif( |
| not _torch_available(), |
| reason="PyTorch not available" |
| ) |
| def test_nested_tensor_in_dict(self): |
| """Tensors nested in dicts should be counted.""" |
| import torch |
|
|
| tensor = torch.zeros(100, dtype=torch.float32) |
| value = CacheValue(outputs=[[{"samples": tensor}]]) |
|
|
| size = _estimate_value_size(value) |
| assert size == 400 |
|
|
|
|
| class TestProviderRegistry: |
| """Test cache provider registration and retrieval.""" |
|
|
| def setup_method(self): |
| """Clear providers before each test.""" |
| _clear_cache_providers() |
|
|
| def teardown_method(self): |
| """Clear providers after each test.""" |
| _clear_cache_providers() |
|
|
| def test_register_provider(self): |
| """Provider should be registered successfully.""" |
| provider = MockCacheProvider() |
| register_cache_provider(provider) |
|
|
| assert _has_cache_providers() is True |
| providers = _get_cache_providers() |
| assert len(providers) == 1 |
| assert providers[0] is provider |
|
|
| def test_unregister_provider(self): |
| """Provider should be unregistered successfully.""" |
| provider = MockCacheProvider() |
| register_cache_provider(provider) |
| unregister_cache_provider(provider) |
|
|
| assert _has_cache_providers() is False |
|
|
| def test_multiple_providers(self): |
| """Multiple providers can be registered.""" |
| provider1 = MockCacheProvider() |
| provider2 = MockCacheProvider() |
|
|
| register_cache_provider(provider1) |
| register_cache_provider(provider2) |
|
|
| providers = _get_cache_providers() |
| assert len(providers) == 2 |
|
|
| def test_duplicate_registration_ignored(self): |
| """Registering same provider twice should be ignored.""" |
| provider = MockCacheProvider() |
|
|
| register_cache_provider(provider) |
| register_cache_provider(provider) |
|
|
| providers = _get_cache_providers() |
| assert len(providers) == 1 |
|
|
| def test_clear_providers(self): |
| """_clear_cache_providers should remove all providers.""" |
| provider1 = MockCacheProvider() |
| provider2 = MockCacheProvider() |
|
|
| register_cache_provider(provider1) |
| register_cache_provider(provider2) |
| _clear_cache_providers() |
|
|
| assert _has_cache_providers() is False |
| assert len(_get_cache_providers()) == 0 |
|
|
|
|
| class TestCacheContext: |
| """Test CacheContext dataclass.""" |
|
|
| def test_context_creation(self): |
| """CacheContext should be created with all fields.""" |
| context = CacheContext( |
| node_id="node-456", |
| class_type="KSampler", |
| cache_key_hash="a" * 64, |
| ) |
|
|
| assert context.node_id == "node-456" |
| assert context.class_type == "KSampler" |
| assert context.cache_key_hash == "a" * 64 |
|
|
|
|
| class TestCacheValue: |
| """Test CacheValue dataclass.""" |
|
|
| def test_value_creation(self): |
| """CacheValue should be created with outputs.""" |
| outputs = [[{"samples": "tensor_data"}]] |
| value = CacheValue(outputs=outputs) |
|
|
| assert value.outputs == outputs |
|
|
|
|
| class MockCacheProvider(CacheProvider): |
| """Mock cache provider for testing.""" |
|
|
| def __init__(self): |
| self.lookups = [] |
| self.stores = [] |
|
|
| async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: |
| self.lookups.append(context) |
| return None |
|
|
| async def on_store(self, context: CacheContext, value: CacheValue) -> None: |
| self.stores.append((context, value)) |
|
|