Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2023 HuggingFace Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import tempfile | |
| import unittest | |
| import torch | |
| from diffusers import UNet2DConditionModel | |
| from diffusers.training_utils import EMAModel | |
| from diffusers.utils.testing_utils import skip_mps, torch_device | |
| class EMAModelTests(unittest.TestCase): | |
| model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" | |
| batch_size = 1 | |
| prompt_length = 77 | |
| text_encoder_hidden_dim = 32 | |
| num_in_channels = 4 | |
| latent_height = latent_width = 64 | |
| generator = torch.manual_seed(0) | |
| def get_models(self, decay=0.9999): | |
| unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") | |
| unet = unet.to(torch_device) | |
| ema_unet = EMAModel(unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=unet.config) | |
| return unet, ema_unet | |
| def get_dummy_inputs(self): | |
| noisy_latents = torch.randn( | |
| self.batch_size, self.num_in_channels, self.latent_height, self.latent_width, generator=self.generator | |
| ).to(torch_device) | |
| timesteps = torch.randint(0, 1000, size=(self.batch_size,), generator=self.generator).to(torch_device) | |
| encoder_hidden_states = torch.randn( | |
| self.batch_size, self.prompt_length, self.text_encoder_hidden_dim, generator=self.generator | |
| ).to(torch_device) | |
| return noisy_latents, timesteps, encoder_hidden_states | |
| def simulate_backprop(self, unet): | |
| updated_state_dict = {} | |
| for k, param in unet.state_dict().items(): | |
| updated_param = torch.randn_like(param) + (param * torch.randn_like(param)) | |
| updated_state_dict.update({k: updated_param}) | |
| unet.load_state_dict(updated_state_dict) | |
| return unet | |
| def test_optimization_steps_updated(self): | |
| unet, ema_unet = self.get_models() | |
| # Take the first (hypothetical) EMA step. | |
| ema_unet.step(unet.parameters()) | |
| assert ema_unet.optimization_step == 1 | |
| # Take two more. | |
| for _ in range(2): | |
| ema_unet.step(unet.parameters()) | |
| assert ema_unet.optimization_step == 3 | |
| def test_shadow_params_not_updated(self): | |
| unet, ema_unet = self.get_models() | |
| # Since the `unet` is not being updated (i.e., backprop'd) | |
| # there won't be any difference between the `params` of `unet` | |
| # and `ema_unet` even if we call `ema_unet.step(unet.parameters())`. | |
| ema_unet.step(unet.parameters()) | |
| orig_params = list(unet.parameters()) | |
| for s_param, param in zip(ema_unet.shadow_params, orig_params): | |
| assert torch.allclose(s_param, param) | |
| # The above holds true even if we call `ema.step()` multiple times since | |
| # `unet` params are still not being updated. | |
| for _ in range(4): | |
| ema_unet.step(unet.parameters()) | |
| for s_param, param in zip(ema_unet.shadow_params, orig_params): | |
| assert torch.allclose(s_param, param) | |
| def test_shadow_params_updated(self): | |
| unet, ema_unet = self.get_models() | |
| # Here we simulate the parameter updates for `unet`. Since there might | |
| # be some parameters which are initialized to zero we take extra care to | |
| # initialize their values to something non-zero before the multiplication. | |
| unet_pseudo_updated_step_one = self.simulate_backprop(unet) | |
| # Take the EMA step. | |
| ema_unet.step(unet_pseudo_updated_step_one.parameters()) | |
| # Now the EMA'd parameters won't be equal to the original model parameters. | |
| orig_params = list(unet_pseudo_updated_step_one.parameters()) | |
| for s_param, param in zip(ema_unet.shadow_params, orig_params): | |
| assert ~torch.allclose(s_param, param) | |
| # Ensure this is the case when we take multiple EMA steps. | |
| for _ in range(4): | |
| ema_unet.step(unet.parameters()) | |
| for s_param, param in zip(ema_unet.shadow_params, orig_params): | |
| assert ~torch.allclose(s_param, param) | |
| def test_consecutive_shadow_params_updated(self): | |
| # If we call EMA step after a backpropagation consecutively for two times, | |
| # the shadow params from those two steps should be different. | |
| unet, ema_unet = self.get_models() | |
| # First backprop + EMA | |
| unet_step_one = self.simulate_backprop(unet) | |
| ema_unet.step(unet_step_one.parameters()) | |
| step_one_shadow_params = ema_unet.shadow_params | |
| # Second backprop + EMA | |
| unet_step_two = self.simulate_backprop(unet_step_one) | |
| ema_unet.step(unet_step_two.parameters()) | |
| step_two_shadow_params = ema_unet.shadow_params | |
| for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): | |
| assert ~torch.allclose(step_one, step_two) | |
| def test_zero_decay(self): | |
| # If there's no decay even if there are backprops, EMA steps | |
| # won't take any effect i.e., the shadow params would remain the | |
| # same. | |
| unet, ema_unet = self.get_models(decay=0.0) | |
| unet_step_one = self.simulate_backprop(unet) | |
| ema_unet.step(unet_step_one.parameters()) | |
| step_one_shadow_params = ema_unet.shadow_params | |
| unet_step_two = self.simulate_backprop(unet_step_one) | |
| ema_unet.step(unet_step_two.parameters()) | |
| step_two_shadow_params = ema_unet.shadow_params | |
| for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): | |
| assert torch.allclose(step_one, step_two) | |
| def test_serialization(self): | |
| unet, ema_unet = self.get_models() | |
| noisy_latents, timesteps, encoder_hidden_states = self.get_dummy_inputs() | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| ema_unet.save_pretrained(tmpdir) | |
| loaded_unet = UNet2DConditionModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel) | |
| loaded_unet = loaded_unet.to(unet.device) | |
| # Since no EMA step has been performed the outputs should match. | |
| output = unet(noisy_latents, timesteps, encoder_hidden_states).sample | |
| output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample | |
| assert torch.allclose(output, output_loaded, atol=1e-4) | |