import torch import torch.nn.functional as F from diffusers import VQDiffusionScheduler from .test_schedulers import SchedulerCommonTest class VQDiffusionSchedulerTest(SchedulerCommonTest): scheduler_classes = (VQDiffusionScheduler,) def get_scheduler_config(self, **kwargs): config = { "num_vec_classes": 4097, "num_train_timesteps": 100, } config.update(**kwargs) return config def dummy_sample(self, num_vec_classes): batch_size = 4 height = 8 width = 8 sample = torch.randint(0, num_vec_classes, (batch_size, height * width)) return sample @property def dummy_sample_deter(self): assert False def dummy_model(self, num_vec_classes): def model(sample, t, *args): batch_size, num_latent_pixels = sample.shape logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels)) return_value = F.log_softmax(logits.double(), dim=1).float() return return_value return model def test_timesteps(self): for timesteps in [2, 5, 100, 1000]: self.check_over_configs(num_train_timesteps=timesteps) def test_num_vec_classes(self): for num_vec_classes in [5, 100, 1000, 4000]: self.check_over_configs(num_vec_classes=num_vec_classes) def test_time_indices(self): for t in [0, 50, 99]: self.check_over_forward(time_step=t) def test_add_noise_device(self): pass