| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import gc |
| | import inspect |
| | import unittest |
| |
|
| | import torch |
| | from parameterized import parameterized |
| |
|
| | from diffusers import PriorTransformer |
| | from diffusers.utils.testing_utils import ( |
| | backend_empty_cache, |
| | enable_full_determinism, |
| | floats_tensor, |
| | slow, |
| | torch_all_close, |
| | torch_device, |
| | ) |
| |
|
| | from ..test_modeling_common import ModelTesterMixin |
| |
|
| |
|
| | enable_full_determinism() |
| |
|
| |
|
| | class PriorTransformerTests(ModelTesterMixin, unittest.TestCase): |
| | model_class = PriorTransformer |
| | main_input_name = "hidden_states" |
| |
|
| | @property |
| | def dummy_input(self): |
| | batch_size = 4 |
| | embedding_dim = 8 |
| | num_embeddings = 7 |
| |
|
| | hidden_states = floats_tensor((batch_size, embedding_dim)).to(torch_device) |
| |
|
| | proj_embedding = floats_tensor((batch_size, embedding_dim)).to(torch_device) |
| | encoder_hidden_states = floats_tensor((batch_size, num_embeddings, embedding_dim)).to(torch_device) |
| |
|
| | return { |
| | "hidden_states": hidden_states, |
| | "timestep": 2, |
| | "proj_embedding": proj_embedding, |
| | "encoder_hidden_states": encoder_hidden_states, |
| | } |
| |
|
| | def get_dummy_seed_input(self, seed=0): |
| | torch.manual_seed(seed) |
| | batch_size = 4 |
| | embedding_dim = 8 |
| | num_embeddings = 7 |
| |
|
| | hidden_states = torch.randn((batch_size, embedding_dim)).to(torch_device) |
| |
|
| | proj_embedding = torch.randn((batch_size, embedding_dim)).to(torch_device) |
| | encoder_hidden_states = torch.randn((batch_size, num_embeddings, embedding_dim)).to(torch_device) |
| |
|
| | return { |
| | "hidden_states": hidden_states, |
| | "timestep": 2, |
| | "proj_embedding": proj_embedding, |
| | "encoder_hidden_states": encoder_hidden_states, |
| | } |
| |
|
| | @property |
| | def input_shape(self): |
| | return (4, 8) |
| |
|
| | @property |
| | def output_shape(self): |
| | return (4, 8) |
| |
|
| | def prepare_init_args_and_inputs_for_common(self): |
| | init_dict = { |
| | "num_attention_heads": 2, |
| | "attention_head_dim": 4, |
| | "num_layers": 2, |
| | "embedding_dim": 8, |
| | "num_embeddings": 7, |
| | "additional_embeddings": 4, |
| | } |
| | inputs_dict = self.dummy_input |
| | return init_dict, inputs_dict |
| |
|
| | def test_from_pretrained_hub(self): |
| | model, loading_info = PriorTransformer.from_pretrained( |
| | "hf-internal-testing/prior-dummy", output_loading_info=True |
| | ) |
| | self.assertIsNotNone(model) |
| | self.assertEqual(len(loading_info["missing_keys"]), 0) |
| |
|
| | model.to(torch_device) |
| | hidden_states = model(**self.dummy_input)[0] |
| |
|
| | assert hidden_states is not None, "Make sure output is not None" |
| |
|
| | def test_forward_signature(self): |
| | init_dict, _ = self.prepare_init_args_and_inputs_for_common() |
| |
|
| | model = self.model_class(**init_dict) |
| | signature = inspect.signature(model.forward) |
| | |
| | arg_names = [*signature.parameters.keys()] |
| |
|
| | expected_arg_names = ["hidden_states", "timestep"] |
| | self.assertListEqual(arg_names[:2], expected_arg_names) |
| |
|
| | def test_output_pretrained(self): |
| | model = PriorTransformer.from_pretrained("hf-internal-testing/prior-dummy") |
| | model = model.to(torch_device) |
| |
|
| | if hasattr(model, "set_default_attn_processor"): |
| | model.set_default_attn_processor() |
| |
|
| | input = self.get_dummy_seed_input() |
| |
|
| | with torch.no_grad(): |
| | output = model(**input)[0] |
| |
|
| | output_slice = output[0, :5].flatten().cpu() |
| | print(output_slice) |
| |
|
| | |
| | |
| | expected_output_slice = torch.tensor([-1.3436, -0.2870, 0.7538, 0.4368, -0.0239]) |
| | self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) |
| |
|
| |
|
| | @slow |
| | class PriorTransformerIntegrationTests(unittest.TestCase): |
| | def get_dummy_seed_input(self, batch_size=1, embedding_dim=768, num_embeddings=77, seed=0): |
| | torch.manual_seed(seed) |
| |
|
| | hidden_states = torch.randn((batch_size, embedding_dim)).to(torch_device) |
| |
|
| | proj_embedding = torch.randn((batch_size, embedding_dim)).to(torch_device) |
| | encoder_hidden_states = torch.randn((batch_size, num_embeddings, embedding_dim)).to(torch_device) |
| |
|
| | return { |
| | "hidden_states": hidden_states, |
| | "timestep": 2, |
| | "proj_embedding": proj_embedding, |
| | "encoder_hidden_states": encoder_hidden_states, |
| | } |
| |
|
| | def tearDown(self): |
| | |
| | super().tearDown() |
| | gc.collect() |
| | backend_empty_cache(torch_device) |
| |
|
| | @parameterized.expand( |
| | [ |
| | |
| | [13, [-0.5861, 0.1283, -0.0931, 0.0882, 0.4476, 0.1329, -0.0498, 0.0640]], |
| | [37, [-0.4913, 0.0110, -0.0483, 0.0541, 0.4954, -0.0170, 0.0354, 0.1651]], |
| | |
| | ] |
| | ) |
| | def test_kandinsky_prior(self, seed, expected_slice): |
| | model = PriorTransformer.from_pretrained("kandinsky-community/kandinsky-2-1-prior", subfolder="prior") |
| | model.to(torch_device) |
| | input = self.get_dummy_seed_input(seed=seed) |
| |
|
| | with torch.no_grad(): |
| | sample = model(**input)[0] |
| |
|
| | assert list(sample.shape) == [1, 768] |
| |
|
| | output_slice = sample[0, :8].flatten().cpu() |
| | print(output_slice) |
| | expected_output_slice = torch.tensor(expected_slice) |
| |
|
| | assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) |
| |
|