Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2024 HuggingFace Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| import gc | |
| import unittest | |
| import torch | |
| from parameterized import parameterized | |
| from diffusers import AutoencoderTiny | |
| from diffusers.utils.testing_utils import ( | |
| backend_empty_cache, | |
| enable_full_determinism, | |
| floats_tensor, | |
| load_hf_numpy, | |
| slow, | |
| torch_all_close, | |
| torch_device, | |
| ) | |
| from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin | |
| enable_full_determinism() | |
| class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): | |
| model_class = AutoencoderTiny | |
| main_input_name = "sample" | |
| base_precision = 1e-2 | |
| def get_autoencoder_tiny_config(self, block_out_channels=None): | |
| block_out_channels = (len(block_out_channels) * [32]) if block_out_channels is not None else [32, 32] | |
| init_dict = { | |
| "in_channels": 3, | |
| "out_channels": 3, | |
| "encoder_block_out_channels": block_out_channels, | |
| "decoder_block_out_channels": block_out_channels, | |
| "num_encoder_blocks": [b // min(block_out_channels) for b in block_out_channels], | |
| "num_decoder_blocks": [b // min(block_out_channels) for b in reversed(block_out_channels)], | |
| } | |
| return init_dict | |
| def dummy_input(self): | |
| batch_size = 4 | |
| num_channels = 3 | |
| sizes = (32, 32) | |
| image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) | |
| return {"sample": image} | |
| def input_shape(self): | |
| return (3, 32, 32) | |
| def output_shape(self): | |
| return (3, 32, 32) | |
| def prepare_init_args_and_inputs_for_common(self): | |
| init_dict = self.get_autoencoder_tiny_config() | |
| inputs_dict = self.dummy_input | |
| return init_dict, inputs_dict | |
| def test_enable_disable_tiling(self): | |
| pass | |
| def test_enable_disable_slicing(self): | |
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | |
| torch.manual_seed(0) | |
| model = self.model_class(**init_dict).to(torch_device) | |
| inputs_dict.update({"return_dict": False}) | |
| torch.manual_seed(0) | |
| output_without_slicing = model(**inputs_dict)[0] | |
| torch.manual_seed(0) | |
| model.enable_slicing() | |
| output_with_slicing = model(**inputs_dict)[0] | |
| self.assertLess( | |
| (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), | |
| 0.5, | |
| "VAE slicing should not affect the inference results", | |
| ) | |
| torch.manual_seed(0) | |
| model.disable_slicing() | |
| output_without_slicing_2 = model(**inputs_dict)[0] | |
| self.assertEqual( | |
| output_without_slicing.detach().cpu().numpy().all(), | |
| output_without_slicing_2.detach().cpu().numpy().all(), | |
| "Without slicing outputs should match with the outputs when slicing is manually disabled.", | |
| ) | |
| def test_outputs_equivalence(self): | |
| pass | |
| def test_forward_with_norm_groups(self): | |
| pass | |
| def test_gradient_checkpointing_is_applied(self): | |
| expected_set = {"DecoderTiny", "EncoderTiny"} | |
| super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | |
| def test_effective_gradient_checkpointing(self): | |
| if not self.model_class._supports_gradient_checkpointing: | |
| return # Skip test if model does not support gradient checkpointing | |
| # enable deterministic behavior for gradient checkpointing | |
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | |
| inputs_dict_copy = copy.deepcopy(inputs_dict) | |
| torch.manual_seed(0) | |
| model = self.model_class(**init_dict) | |
| model.to(torch_device) | |
| assert not model.is_gradient_checkpointing and model.training | |
| out = model(**inputs_dict).sample | |
| # run the backwards pass on the model. For backwards pass, for simplicity purpose, | |
| # we won't calculate the loss and rather backprop on out.sum() | |
| model.zero_grad() | |
| labels = torch.randn_like(out) | |
| loss = (out - labels).mean() | |
| loss.backward() | |
| # re-instantiate the model now enabling gradient checkpointing | |
| torch.manual_seed(0) | |
| model_2 = self.model_class(**init_dict) | |
| # clone model | |
| model_2.load_state_dict(model.state_dict()) | |
| model_2.to(torch_device) | |
| model_2.enable_gradient_checkpointing() | |
| assert model_2.is_gradient_checkpointing and model_2.training | |
| out_2 = model_2(**inputs_dict_copy).sample | |
| # run the backwards pass on the model. For backwards pass, for simplicity purpose, | |
| # we won't calculate the loss and rather backprop on out.sum() | |
| model_2.zero_grad() | |
| loss_2 = (out_2 - labels).mean() | |
| loss_2.backward() | |
| # compare the output and parameters gradients | |
| self.assertTrue((loss - loss_2).abs() < 1e-3) | |
| named_params = dict(model.named_parameters()) | |
| named_params_2 = dict(model_2.named_parameters()) | |
| for name, param in named_params.items(): | |
| if "encoder.layers" in name: | |
| continue | |
| self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=3e-2)) | |
| def test_layerwise_casting_inference(self): | |
| pass | |
| def test_layerwise_casting_memory(self): | |
| pass | |
| class AutoencoderTinyIntegrationTests(unittest.TestCase): | |
| def tearDown(self): | |
| # clean up the VRAM after each test | |
| super().tearDown() | |
| gc.collect() | |
| backend_empty_cache(torch_device) | |
| def get_file_format(self, seed, shape): | |
| return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" | |
| def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False): | |
| dtype = torch.float16 if fp16 else torch.float32 | |
| image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) | |
| return image | |
| def get_sd_vae_model(self, model_id="hf-internal-testing/taesd-diffusers", fp16=False): | |
| torch_dtype = torch.float16 if fp16 else torch.float32 | |
| model = AutoencoderTiny.from_pretrained(model_id, torch_dtype=torch_dtype) | |
| model.to(torch_device).eval() | |
| return model | |
| def test_tae_tiling(self, in_shape, out_shape): | |
| model = self.get_sd_vae_model() | |
| model.enable_tiling() | |
| with torch.no_grad(): | |
| zeros = torch.zeros(in_shape).to(torch_device) | |
| dec = model.decode(zeros).sample | |
| assert dec.shape == out_shape | |
| def test_stable_diffusion(self): | |
| model = self.get_sd_vae_model() | |
| image = self.get_sd_image(seed=33) | |
| with torch.no_grad(): | |
| sample = model(image).sample | |
| assert sample.shape == image.shape | |
| output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() | |
| expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382]) | |
| assert torch_all_close(output_slice, expected_output_slice, atol=3e-3) | |
| def test_tae_roundtrip(self, enable_tiling): | |
| # load the autoencoder | |
| model = self.get_sd_vae_model() | |
| if enable_tiling: | |
| model.enable_tiling() | |
| # make a black image with a white square in the middle, | |
| # which is large enough to split across multiple tiles | |
| image = -torch.ones(1, 3, 1024, 1024, device=torch_device) | |
| image[..., 256:768, 256:768] = 1.0 | |
| # round-trip the image through the autoencoder | |
| with torch.no_grad(): | |
| sample = model(image).sample | |
| # the autoencoder reconstruction should match original image, sorta | |
| def downscale(x): | |
| return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor) | |
| assert torch_all_close(downscale(sample), downscale(image), atol=0.125) | |