FreeScale / scale_attention.py
arthur-qiu
init
3adee15
raw
history blame
16.1 kB
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange, repeat
import random
def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
x_coord = torch.arange(kernel_size)
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
gaussian_1d = gaussian_1d / gaussian_1d.sum()
gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
return kernel
def gaussian_filter(latents, kernel_size=3, sigma=1.0):
channels = latents.shape[1]
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
return blurred_latents
def get_views(height, width, h_window_size=128, w_window_size=128, scale_factor=8):
height = int(height)
width = int(width)
h_window_stride = h_window_size // 2
w_window_stride = w_window_size // 2
h_window_size = int(h_window_size / scale_factor)
w_window_size = int(w_window_size / scale_factor)
h_window_stride = int(h_window_stride / scale_factor)
w_window_stride = int(w_window_stride / scale_factor)
num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
views = []
for i in range(total_num_blocks):
h_start = int((i // num_blocks_width) * h_window_stride)
h_end = h_start + h_window_size
w_start = int((i % num_blocks_width) * w_window_stride)
w_end = w_start + w_window_size
if h_end > height:
h_start = int(h_start + height - h_end)
h_end = int(height)
if w_end > width:
w_start = int(w_start + width - w_end)
w_end = int(width)
if h_start < 0:
h_end = int(h_end - h_start)
h_start = 0
if w_start < 0:
w_end = int(w_end - w_start)
w_start = 0
random_jitter = True
if random_jitter:
h_jitter_range = h_window_size // 8
w_jitter_range = w_window_size // 8
h_jitter = 0
w_jitter = 0
if (w_start != 0) and (w_end != width):
w_jitter = random.randint(-w_jitter_range, w_jitter_range)
elif (w_start == 0) and (w_end != width):
w_jitter = random.randint(-w_jitter_range, 0)
elif (w_start != 0) and (w_end == width):
w_jitter = random.randint(0, w_jitter_range)
if (h_start != 0) and (h_end != height):
h_jitter = random.randint(-h_jitter_range, h_jitter_range)
elif (h_start == 0) and (h_end != height):
h_jitter = random.randint(-h_jitter_range, 0)
elif (h_start != 0) and (h_end == height):
h_jitter = random.randint(0, h_jitter_range)
h_start += (h_jitter + h_jitter_range)
h_end += (h_jitter + h_jitter_range)
w_start += (w_jitter + w_jitter_range)
w_end += (w_jitter + w_jitter_range)
views.append((h_start, h_end, w_start, w_end))
return views
def scale_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
):
# Notice that normalization is always applied before the real computation in the following blocks.
if self.current_hw:
current_scale_num_h, current_scale_num_w = self.current_hw[0] // 1024, self.current_hw[1] // 1024
else:
current_scale_num_h, current_scale_num_w = 1, 1
# 0. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
ratio_hw = current_scale_num_h / current_scale_num_w
latent_h = int((norm_hidden_states.shape[1] * ratio_hw) ** 0.5)
latent_w = int(latent_h / ratio_hw)
scale_factor = 128 * current_scale_num_h / latent_h
if ratio_hw > 1:
sub_h = 128
sub_w = int(128 / ratio_hw)
else:
sub_h = int(128 * ratio_hw)
sub_w = 128
h_jitter_range = int(sub_h / scale_factor // 8)
w_jitter_range = int(sub_w / scale_factor // 8)
views = get_views(latent_h, latent_w, sub_h, sub_w, scale_factor = scale_factor)
current_scale_num = max(current_scale_num_h, current_scale_num_w)
global_views = [[h, w] for h in range(current_scale_num_h) for w in range(current_scale_num_w)]
if self.fast_mode:
four_window = False
fourg_window = True
else:
four_window = True
fourg_window = False
if four_window:
norm_hidden_states_ = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
norm_hidden_states_ = F.pad(norm_hidden_states_, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
value = torch.zeros_like(norm_hidden_states_)
count = torch.zeros_like(norm_hidden_states_)
for index, view in enumerate(views):
h_start, h_end, w_start, w_end = view
local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
local_output = self.attn1(
local_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))
value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
count[:, h_start:h_end, w_start:w_end, :] += 1
value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
attn_output = torch.where(count>0, value/count, value)
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
attn_output_global = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
attn_output = gaussian_local + (attn_output_global - gaussian_global)
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
elif fourg_window:
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
norm_hidden_states_ = F.pad(norm_hidden_states, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
value = torch.zeros_like(norm_hidden_states_)
count = torch.zeros_like(norm_hidden_states_)
for index, view in enumerate(views):
h_start, h_end, w_start, w_end = view
local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
local_output = self.attn1(
local_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))
value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
count[:, h_start:h_end, w_start:w_end, :] += 1
value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
attn_output = torch.where(count>0, value/count, value)
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
value = torch.zeros_like(norm_hidden_states)
count = torch.zeros_like(norm_hidden_states)
for index, global_view in enumerate(global_views):
h, w = global_view
global_states = norm_hidden_states[:, h::current_scale_num_h, w::current_scale_num_w, :]
global_states = rearrange(global_states, 'bh h w d -> bh (h w) d')
global_output = self.attn1(
global_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
global_output = rearrange(global_output, 'bh (h w) d -> bh h w d', h = int(global_output.shape[1] ** 0.5))
value[:, h::current_scale_num_h, w::current_scale_num_w, :] += global_output * 1
count[:, h::current_scale_num_h, w::current_scale_num_w, :] += 1
attn_output_global = torch.where(count>0, value/count, value)
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
attn_output = gaussian_local + (attn_output_global - gaussian_global)
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
else:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 2.5 ends
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[
self.ff(hid_slice)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
def ori_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
):
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 2.5 ends
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[
self.ff(hid_slice)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states