RAG-Diffusion / cross_attention.py
znchen
Add application file
8fb99cf
import math
import torch
import torchvision.transforms.functional as F
TOKENS = 75
def hook_forwards(self, root_module: torch.nn.Module):
for name, module in root_module.named_modules():
if "attn" in name and "transformer_blocks" in name and "single_transformer_blocks" not in name and module.__class__.__name__ == "Attention":
module.forward = FluxTransformerBlock_hook_forward(self, module)
elif "attn" in name and "single_transformer_blocks" in name and module.__class__.__name__ == "Attention":
module.forward = FluxSingleTransformerBlock_hook_forward(self, module)
def FluxSingleTransformerBlock_hook_forward(self, module):
def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None, SR_encoder_hidden_states_list=None, SR_norm_encoder_hidden_states_list=None, SR_hidden_states_list=None, SR_norm_hidden_states_list=None):
flux_hidden_states=module.processor(module, hidden_states=hidden_states, image_rotary_emb=image_rotary_emb)
height = self.h
width = self.w
x_t = hidden_states.size()[1]-512
scale = round(math.sqrt(height * width / x_t))
latent_h = round(height / scale)
latent_w = round(width / scale)
ha, wa = x_t % latent_h, x_t % latent_w
if ha == 0:
latent_w = int(x_t / latent_h)
elif wa == 0:
latent_h = int(x_t / latent_w)
contexts_list = SR_norm_hidden_states_list
def single_matsepcalc(x, contexts_list, image_rotary_emb):
h_states = []
x_t = x.size()[1]-512
(latent_h,latent_w) = split_dims(x_t, height, width, self)
latent_out = latent_w
latent_in = latent_h
i = 0
sumout = 0
SR_all_out_list=[]
for drow in self.split_ratio:
v_states = []
sumin = 0
for dcell in drow.cols:
context = contexts_list[i]
i = i + 1 + dcell.breaks
SR_all_out = module.processor(module, hidden_states=context, image_rotary_emb=image_rotary_emb)
out = SR_all_out[:, 512 :, ...]
out = out.reshape(out.size()[0], latent_h, latent_w, out.size()[2])
addout = 0
addin = 0
sumin = sumin + int(latent_in*dcell.end) - int(latent_in*dcell.start)
if dcell.end >= 0.999:
addin = sumin - latent_in
sumout = sumout + int(latent_out*drow.end) - int(latent_out*drow.start)
if drow.end >= 0.999:
addout = sumout - latent_out
out = out[:, int(latent_h*drow.start) + addout:int(latent_h*drow.end),
int(latent_w*dcell.start) + addin:int(latent_w*dcell.end), :]
v_states.append(out)
SR_all_out_list.append(SR_all_out)
output_x = torch.cat(v_states,dim = 2)
h_states.append(output_x)
output_x = torch.cat(h_states,dim = 1)
output_x = output_x.reshape(x.size()[0], x.size()[1]-512, x.size()[2])
new_SR_all_out_list = []
for SR_all_out in SR_all_out_list:
SR_all_out[:, 512 :, ...] = output_x
new_SR_all_out_list.append(SR_all_out)
x[:, 512 :, ...] = output_x * self.SR_delta + x[:, 512 :, ...] * (1-self.SR_delta)
return x, new_SR_all_out_list
return single_matsepcalc(flux_hidden_states, contexts_list, image_rotary_emb)
return forward
def FluxTransformerBlock_hook_forward(self, module):
def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None, SR_encoder_hidden_states_list=None, SR_norm_encoder_hidden_states_list=None, SR_hidden_states_list=None, SR_norm_hidden_states_list=None):
flux_hidden_states, flux_encoder_hidden_states = module.processor(module, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb)
height = self.h
width = self.w
x_t = hidden_states.size()[1]
scale = round(math.sqrt(height * width / x_t))
latent_h = round(height / scale)
latent_w = round(width / scale)
ha, wa = x_t % latent_h, x_t % latent_w
if ha == 0:
latent_w = int(x_t / latent_h)
elif wa == 0:
latent_h = int(x_t / latent_w)
contexts_list = SR_norm_encoder_hidden_states_list
def matsepcalc(x, contexts_list, image_rotary_emb):
h_states = []
x_t = x.size()[1]
(latent_h,latent_w) = split_dims(x_t, height, width, self)
latent_out = latent_w
latent_in = latent_h
i = 0
sumout = 0
SR_context_attn_output_list = []
for drow in self.split_ratio:
v_states = []
sumin = 0
for dcell in drow.cols:
context = contexts_list[i]
i = i + 1 + dcell.breaks
out,SR_context_attn_output = module.processor(module, hidden_states=x, encoder_hidden_states=context, image_rotary_emb=image_rotary_emb)
out = out.reshape(out.size()[0], latent_h, latent_w, out.size()[2])
addout = 0
addin = 0
sumin = sumin + int(latent_in*dcell.end) - int(latent_in*dcell.start)
if dcell.end >= 0.999:
addin = sumin - latent_in
sumout = sumout + int(latent_out*drow.end) - int(latent_out*drow.start)
if drow.end >= 0.999:
addout = sumout - latent_out
out = out[:, int(latent_h*drow.start) + addout:int(latent_h*drow.end),
int(latent_w*dcell.start) + addin:int(latent_w*dcell.end), :]
v_states.append(out)
SR_context_attn_output_list.append(SR_context_attn_output)
output_x = torch.cat(v_states,dim = 2)
h_states.append(output_x)
output_x = torch.cat(h_states,dim = 1)
output_x = output_x.reshape(x.size()[0],x.size()[1],x.size()[2])
return output_x * self.SR_delta + flux_hidden_states * (1-self.SR_delta), flux_encoder_hidden_states, SR_context_attn_output_list
return matsepcalc(hidden_states, contexts_list, image_rotary_emb)
return forward
def split_dims(x_t, height, width, self=None):
"""Split an attention layer dimension to height + width.
The original estimate was latent_h = sqrt(hw_ratio*x_t),
rounding to the nearest value. However, this proved inaccurate.
The actual operation seems to be as follows:
- Divide h,w by 8, rounding DOWN.
- For every new layer (of 4), divide both by 2 and round UP (then back up).
- Multiply h*w to yield x_t.
There is no inverse function to this set of operations,
so instead we mimic them without the multiplication part using the original h+w.
It's worth noting that no known checkpoints follow a different system of layering,
but it's theoretically possible. Please report if encountered.
"""
scale = math.ceil(math.log2(math.sqrt(height * width / x_t)))
latent_h = repeat_div(height, scale)
latent_w = repeat_div(width, scale)
if x_t > latent_h * latent_w and hasattr(self, "nei_multi"):
latent_h, latent_w = self.nei_multi[1], self.nei_multi[0]
while latent_h * latent_w != x_t:
latent_h, latent_w = latent_h // 2, latent_w // 2
return latent_h, latent_w
def repeat_div(x,y):
"""Imitates dimension halving common in convolution operations.
This is a pretty big assumption of the model,
but then if some model doesn't work like that it will be easy to spot.
"""
while y > 0:
x = math.ceil(x / 2)
y = y - 1
return x
def init_forwards(self, root_module: torch.nn.Module):
for name, module in root_module.named_modules():
if "attn" in name and "transformer_blocks" in name and "single_transformer_blocks" not in name and module.__class__.__name__ == "Attention":
module.forward = FluxTransformerBlock_init_forward(self, module)
elif "attn" in name and "single_transformer_blocks" in name and module.__class__.__name__ == "Attention":
module.forward = FluxSingleTransformerBlock_init_forward(self, module)
def FluxSingleTransformerBlock_init_forward(self, module):
def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None,RPG_encoder_hidden_states_list=None,RPG_norm_encoder_hidden_states_list=None,RPG_hidden_states_list=None,RPG_norm_hidden_states_list=None):
return module.processor(module, hidden_states=hidden_states, image_rotary_emb=image_rotary_emb)
return forward
def FluxTransformerBlock_init_forward(self, module):
def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None,RPG_encoder_hidden_states_list=None,RPG_norm_encoder_hidden_states_list=None,RPG_hidden_states_list=None,RPG_norm_hidden_states_list=None):
return module.processor(module, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb)
return forward