Spaces:
Runtime error
Runtime error
| # Copyright 2024 HuggingFace Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import gc | |
| import unittest | |
| import torch | |
| from diffusers.hooks import HookRegistry, ModelHook | |
| from diffusers.training_utils import free_memory | |
| from diffusers.utils.logging import get_logger | |
| from diffusers.utils.testing_utils import CaptureLogger, torch_device | |
| logger = get_logger(__name__) # pylint: disable=invalid-name | |
| class DummyBlock(torch.nn.Module): | |
| def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: | |
| super().__init__() | |
| self.proj_in = torch.nn.Linear(in_features, hidden_features) | |
| self.activation = torch.nn.ReLU() | |
| self.proj_out = torch.nn.Linear(hidden_features, out_features) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.proj_in(x) | |
| x = self.activation(x) | |
| x = self.proj_out(x) | |
| return x | |
| class DummyModel(torch.nn.Module): | |
| def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None: | |
| super().__init__() | |
| self.linear_1 = torch.nn.Linear(in_features, hidden_features) | |
| self.activation = torch.nn.ReLU() | |
| self.blocks = torch.nn.ModuleList( | |
| [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] | |
| ) | |
| self.linear_2 = torch.nn.Linear(hidden_features, out_features) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.linear_1(x) | |
| x = self.activation(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.linear_2(x) | |
| return x | |
| class AddHook(ModelHook): | |
| def __init__(self, value: int): | |
| super().__init__() | |
| self.value = value | |
| def pre_forward(self, module: torch.nn.Module, *args, **kwargs): | |
| logger.debug("AddHook pre_forward") | |
| args = ((x + self.value) if torch.is_tensor(x) else x for x in args) | |
| return args, kwargs | |
| def post_forward(self, module, output): | |
| logger.debug("AddHook post_forward") | |
| return output | |
| class MultiplyHook(ModelHook): | |
| def __init__(self, value: int): | |
| super().__init__() | |
| self.value = value | |
| def pre_forward(self, module, *args, **kwargs): | |
| logger.debug("MultiplyHook pre_forward") | |
| args = ((x * self.value) if torch.is_tensor(x) else x for x in args) | |
| return args, kwargs | |
| def post_forward(self, module, output): | |
| logger.debug("MultiplyHook post_forward") | |
| return output | |
| def __repr__(self): | |
| return f"MultiplyHook(value={self.value})" | |
| class StatefulAddHook(ModelHook): | |
| _is_stateful = True | |
| def __init__(self, value: int): | |
| super().__init__() | |
| self.value = value | |
| self.increment = 0 | |
| def pre_forward(self, module, *args, **kwargs): | |
| logger.debug("StatefulAddHook pre_forward") | |
| add_value = self.value + self.increment | |
| self.increment += 1 | |
| args = ((x + add_value) if torch.is_tensor(x) else x for x in args) | |
| return args, kwargs | |
| def reset_state(self, module): | |
| self.increment = 0 | |
| class SkipLayerHook(ModelHook): | |
| def __init__(self, skip_layer: bool): | |
| super().__init__() | |
| self.skip_layer = skip_layer | |
| def pre_forward(self, module, *args, **kwargs): | |
| logger.debug("SkipLayerHook pre_forward") | |
| return args, kwargs | |
| def new_forward(self, module, *args, **kwargs): | |
| logger.debug("SkipLayerHook new_forward") | |
| if self.skip_layer: | |
| return args[0] | |
| return self.fn_ref.original_forward(*args, **kwargs) | |
| def post_forward(self, module, output): | |
| logger.debug("SkipLayerHook post_forward") | |
| return output | |
| class HookTests(unittest.TestCase): | |
| in_features = 4 | |
| hidden_features = 8 | |
| out_features = 4 | |
| num_layers = 2 | |
| def setUp(self): | |
| params = self.get_module_parameters() | |
| self.model = DummyModel(**params) | |
| self.model.to(torch_device) | |
| def tearDown(self): | |
| super().tearDown() | |
| del self.model | |
| gc.collect() | |
| free_memory() | |
| def get_module_parameters(self): | |
| return { | |
| "in_features": self.in_features, | |
| "hidden_features": self.hidden_features, | |
| "out_features": self.out_features, | |
| "num_layers": self.num_layers, | |
| } | |
| def get_generator(self): | |
| return torch.manual_seed(0) | |
| def test_hook_registry(self): | |
| registry = HookRegistry.check_if_exists_or_initialize(self.model) | |
| registry.register_hook(AddHook(1), "add_hook") | |
| registry.register_hook(MultiplyHook(2), "multiply_hook") | |
| registry_repr = repr(registry) | |
| expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)" | |
| self.assertEqual(len(registry.hooks), 2) | |
| self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"]) | |
| self.assertEqual(registry_repr, expected_repr) | |
| registry.remove_hook("add_hook") | |
| self.assertEqual(len(registry.hooks), 1) | |
| self.assertEqual(registry._hook_order, ["multiply_hook"]) | |
| def test_stateful_hook(self): | |
| registry = HookRegistry.check_if_exists_or_initialize(self.model) | |
| registry.register_hook(StatefulAddHook(1), "stateful_add_hook") | |
| self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0) | |
| input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) | |
| num_repeats = 3 | |
| for i in range(num_repeats): | |
| result = self.model(input) | |
| if i == 0: | |
| output1 = result | |
| self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats) | |
| registry.reset_stateful_hooks() | |
| output2 = self.model(input) | |
| self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1) | |
| self.assertTrue(torch.allclose(output1, output2)) | |
| def test_inference(self): | |
| registry = HookRegistry.check_if_exists_or_initialize(self.model) | |
| registry.register_hook(AddHook(1), "add_hook") | |
| registry.register_hook(MultiplyHook(2), "multiply_hook") | |
| input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) | |
| output1 = self.model(input).mean().detach().cpu().item() | |
| registry.remove_hook("multiply_hook") | |
| new_input = input * 2 | |
| output2 = self.model(new_input).mean().detach().cpu().item() | |
| registry.remove_hook("add_hook") | |
| new_input = input * 2 + 1 | |
| output3 = self.model(new_input).mean().detach().cpu().item() | |
| self.assertAlmostEqual(output1, output2, places=5) | |
| self.assertAlmostEqual(output1, output3, places=5) | |
| def test_skip_layer_hook(self): | |
| registry = HookRegistry.check_if_exists_or_initialize(self.model) | |
| registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") | |
| input = torch.zeros(1, 4, device=torch_device) | |
| output = self.model(input).mean().detach().cpu().item() | |
| self.assertEqual(output, 0.0) | |
| registry.remove_hook("skip_layer_hook") | |
| registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook") | |
| output = self.model(input).mean().detach().cpu().item() | |
| self.assertNotEqual(output, 0.0) | |
| def test_skip_layer_internal_block(self): | |
| registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1) | |
| input = torch.zeros(1, 4, device=torch_device) | |
| registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") | |
| with self.assertRaises(RuntimeError) as cm: | |
| self.model(input).mean().detach().cpu().item() | |
| self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception)) | |
| registry.remove_hook("skip_layer_hook") | |
| output = self.model(input).mean().detach().cpu().item() | |
| self.assertNotEqual(output, 0.0) | |
| registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1]) | |
| registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") | |
| output = self.model(input).mean().detach().cpu().item() | |
| self.assertNotEqual(output, 0.0) | |
| def test_invocation_order_stateful_first(self): | |
| registry = HookRegistry.check_if_exists_or_initialize(self.model) | |
| registry.register_hook(StatefulAddHook(1), "add_hook") | |
| registry.register_hook(AddHook(2), "add_hook_2") | |
| registry.register_hook(MultiplyHook(3), "multiply_hook") | |
| input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) | |
| logger = get_logger(__name__) | |
| logger.setLevel("DEBUG") | |
| with CaptureLogger(logger) as cap_logger: | |
| self.model(input) | |
| output = cap_logger.out.replace(" ", "").replace("\n", "") | |
| expected_invocation_order_log = ( | |
| ( | |
| "MultiplyHook pre_forward\n" | |
| "AddHook pre_forward\n" | |
| "StatefulAddHook pre_forward\n" | |
| "AddHook post_forward\n" | |
| "MultiplyHook post_forward\n" | |
| ) | |
| .replace(" ", "") | |
| .replace("\n", "") | |
| ) | |
| self.assertEqual(output, expected_invocation_order_log) | |
| registry.remove_hook("add_hook") | |
| with CaptureLogger(logger) as cap_logger: | |
| self.model(input) | |
| output = cap_logger.out.replace(" ", "").replace("\n", "") | |
| expected_invocation_order_log = ( | |
| ("MultiplyHook pre_forward\nAddHook pre_forward\nAddHook post_forward\nMultiplyHook post_forward\n") | |
| .replace(" ", "") | |
| .replace("\n", "") | |
| ) | |
| self.assertEqual(output, expected_invocation_order_log) | |
| def test_invocation_order_stateful_middle(self): | |
| registry = HookRegistry.check_if_exists_or_initialize(self.model) | |
| registry.register_hook(AddHook(2), "add_hook") | |
| registry.register_hook(StatefulAddHook(1), "add_hook_2") | |
| registry.register_hook(MultiplyHook(3), "multiply_hook") | |
| input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) | |
| logger = get_logger(__name__) | |
| logger.setLevel("DEBUG") | |
| with CaptureLogger(logger) as cap_logger: | |
| self.model(input) | |
| output = cap_logger.out.replace(" ", "").replace("\n", "") | |
| expected_invocation_order_log = ( | |
| ( | |
| "MultiplyHook pre_forward\n" | |
| "StatefulAddHook pre_forward\n" | |
| "AddHook pre_forward\n" | |
| "AddHook post_forward\n" | |
| "MultiplyHook post_forward\n" | |
| ) | |
| .replace(" ", "") | |
| .replace("\n", "") | |
| ) | |
| self.assertEqual(output, expected_invocation_order_log) | |
| registry.remove_hook("add_hook") | |
| with CaptureLogger(logger) as cap_logger: | |
| self.model(input) | |
| output = cap_logger.out.replace(" ", "").replace("\n", "") | |
| expected_invocation_order_log = ( | |
| ("MultiplyHook pre_forward\nStatefulAddHook pre_forward\nMultiplyHook post_forward\n") | |
| .replace(" ", "") | |
| .replace("\n", "") | |
| ) | |
| self.assertEqual(output, expected_invocation_order_log) | |
| registry.remove_hook("add_hook_2") | |
| with CaptureLogger(logger) as cap_logger: | |
| self.model(input) | |
| output = cap_logger.out.replace(" ", "").replace("\n", "") | |
| expected_invocation_order_log = ( | |
| ("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "") | |
| ) | |
| self.assertEqual(output, expected_invocation_order_log) | |
| def test_invocation_order_stateful_last(self): | |
| registry = HookRegistry.check_if_exists_or_initialize(self.model) | |
| registry.register_hook(AddHook(1), "add_hook") | |
| registry.register_hook(MultiplyHook(2), "multiply_hook") | |
| registry.register_hook(StatefulAddHook(3), "add_hook_2") | |
| input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) | |
| logger = get_logger(__name__) | |
| logger.setLevel("DEBUG") | |
| with CaptureLogger(logger) as cap_logger: | |
| self.model(input) | |
| output = cap_logger.out.replace(" ", "").replace("\n", "") | |
| expected_invocation_order_log = ( | |
| ( | |
| "StatefulAddHook pre_forward\n" | |
| "MultiplyHook pre_forward\n" | |
| "AddHook pre_forward\n" | |
| "AddHook post_forward\n" | |
| "MultiplyHook post_forward\n" | |
| ) | |
| .replace(" ", "") | |
| .replace("\n", "") | |
| ) | |
| self.assertEqual(output, expected_invocation_order_log) | |
| registry.remove_hook("add_hook") | |
| with CaptureLogger(logger) as cap_logger: | |
| self.model(input) | |
| output = cap_logger.out.replace(" ", "").replace("\n", "") | |
| expected_invocation_order_log = ( | |
| ("StatefulAddHook pre_forward\nMultiplyHook pre_forward\nMultiplyHook post_forward\n") | |
| .replace(" ", "") | |
| .replace("\n", "") | |
| ) | |
| self.assertEqual(output, expected_invocation_order_log) | |