Cosmos-Predict2-2BASD / diffusers_repo /tests /models /test_attention_processor.py
multimodalart's picture
multimodalart HF Staff
Upload 2025 files
22a452a verified
raw
history blame
4.57 kB
import tempfile
import unittest
import numpy as np
import pytest
import torch
from diffusers import DiffusionPipeline
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
from diffusers.utils.testing_utils import torch_device
class AttnAddedKVProcessorTests(unittest.TestCase):
def get_constructor_arguments(self, only_cross_attention: bool = False):
query_dim = 10
if only_cross_attention:
cross_attention_dim = 12
else:
# when only cross attention is not set, the cross attention dim must be the same as the query dim
cross_attention_dim = query_dim
return {
"query_dim": query_dim,
"cross_attention_dim": cross_attention_dim,
"heads": 2,
"dim_head": 4,
"added_kv_proj_dim": 6,
"norm_num_groups": 1,
"only_cross_attention": only_cross_attention,
"processor": AttnAddedKVProcessor(),
}
def get_forward_arguments(self, query_dim, added_kv_proj_dim):
batch_size = 2
hidden_states = torch.rand(batch_size, query_dim, 3, 2)
encoder_hidden_states = torch.rand(batch_size, 4, added_kv_proj_dim)
attention_mask = None
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
}
def test_only_cross_attention(self):
# self and cross attention
torch.manual_seed(0)
constructor_args = self.get_constructor_arguments(only_cross_attention=False)
attn = Attention(**constructor_args)
self.assertTrue(attn.to_k is not None)
self.assertTrue(attn.to_v is not None)
forward_args = self.get_forward_arguments(
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
)
self_and_cross_attn_out = attn(**forward_args)
# only self attention
torch.manual_seed(0)
constructor_args = self.get_constructor_arguments(only_cross_attention=True)
attn = Attention(**constructor_args)
self.assertTrue(attn.to_k is None)
self.assertTrue(attn.to_v is None)
forward_args = self.get_forward_arguments(
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
)
only_cross_attn_out = attn(**forward_args)
self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all())
class DeprecatedAttentionBlockTests(unittest.TestCase):
@pytest.fixture(scope="session")
def is_dist_enabled(pytestconfig):
return pytestconfig.getoption("dist") == "loadfile"
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda" and is_dist_enabled,
reason="Test currently fails on our GPU CI because of `loadfile`. Note that it only fails when the tests are distributed from `pytest ... tests/models`. If the tests are run individually, even with `loadfile` it won't fail.",
strict=True,
)
def test_conversion_when_using_device_map(self):
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
pre_conversion = pipe(
"foo",
num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images
# the initial conversion succeeds
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", device_map="balanced", safety_checker=None
)
conversion = pipe(
"foo",
num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images
with tempfile.TemporaryDirectory() as tmpdir:
# save the converted model
pipe.save_pretrained(tmpdir)
# can also load the converted weights
pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="balanced", safety_checker=None)
after_conversion = pipe(
"foo",
num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images
self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-3))
self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3))