Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Here we reproduce DAAM, but for Flux DiT models. This is effectively a visualization of the cross attention | |
layers of a Flux model. | |
""" | |
from torch import nn | |
import torch | |
import einops | |
from concept_attention.image_generator import FluxGenerator | |
from concept_attention.segmentation import SegmentationAbstractClass | |
class DAAM(nn.Module): | |
def __init__( | |
self, | |
model_name: str = "flux-schnell", | |
device: str = "cuda", | |
offload: bool = True, | |
): | |
""" | |
Initialize the DAAM model. | |
""" | |
super(DAAM, self).__init__() | |
# Load up the flux generator | |
self.generator = FluxGenerator( | |
model_name=model_name, | |
device=device, | |
offload=offload, | |
) | |
# Unpack the tokenizer | |
self.tokenizer = self.generator.t5.tokenizer | |
def __call__( | |
self, | |
prompt, | |
seed=4, | |
num_steps=4, | |
timesteps=None, | |
layers=None | |
): | |
""" | |
Generate cross attention heatmap visualizations. | |
Args: | |
- prompt: str, the prompt to generate the visualizations for | |
- seed: int, the seed to use for the visualization | |
Returns: | |
- attention_maps: torch.Tensor, the attention maps for the prompt | |
- tokens: list[str], the tokens in the prompt | |
- image: torch.Tensor, the image generated by the | |
""" | |
if timesteps is None: | |
timesteps = list(range(num_steps)) | |
if layers is None: | |
layers = list(range(19)) | |
# Run the tokenizer and get list of the tokens | |
token_strings = self.tokenizer.tokenize(prompt) | |
# Run the image generator | |
image = self.generator.generate_image( | |
width=1024, | |
height=1024, | |
num_steps=num_steps, | |
guidance=0.0, | |
seed=seed, | |
prompt=prompt, | |
concepts=token_strings | |
) | |
# Pull out and average the attention maps | |
cross_attention_maps = [] | |
for double_block in self.generator.model.double_blocks: | |
cross_attention_map = torch.stack( | |
double_block.cross_attention_maps | |
).squeeze(1) | |
# Clear out the layer (always same) | |
double_block.clear_cached_vectors() | |
# Append to the list | |
cross_attention_maps.append(cross_attention_map) | |
# Stack layers | |
cross_attention_maps = torch.stack(cross_attention_maps).to(torch.float32) | |
# Pull out the desired timesteps | |
cross_attention_maps = cross_attention_maps[:, timesteps] | |
# Pull out the desired layers | |
cross_attention_maps = cross_attention_maps[layers] | |
# Average over layers and time | |
attention_maps = einops.reduce( | |
cross_attention_maps, | |
"layers time concepts height width -> concepts height width", | |
reduction="mean" | |
) | |
# Pull out only token length attention maps | |
attention_maps = attention_maps[:len(token_strings)] | |
return attention_maps, token_strings, image | |