Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import torch | |
import torchvision.transforms.functional as F | |
TOKENS = 75 | |
def hook_forwards(self, root_module: torch.nn.Module): | |
for name, module in root_module.named_modules(): | |
if "attn" in name and "transformer_blocks" in name and "single_transformer_blocks" not in name and module.__class__.__name__ == "Attention": | |
module.forward = FluxTransformerBlock_hook_forward(self, module) | |
elif "attn" in name and "single_transformer_blocks" in name and module.__class__.__name__ == "Attention": | |
module.forward = FluxSingleTransformerBlock_hook_forward(self, module) | |
def FluxSingleTransformerBlock_hook_forward(self, module): | |
def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None, SR_encoder_hidden_states_list=None, SR_norm_encoder_hidden_states_list=None, SR_hidden_states_list=None, SR_norm_hidden_states_list=None): | |
flux_hidden_states=module.processor(module, hidden_states=hidden_states, image_rotary_emb=image_rotary_emb) | |
height = self.h | |
width = self.w | |
x_t = hidden_states.size()[1]-512 | |
scale = round(math.sqrt(height * width / x_t)) | |
latent_h = round(height / scale) | |
latent_w = round(width / scale) | |
ha, wa = x_t % latent_h, x_t % latent_w | |
if ha == 0: | |
latent_w = int(x_t / latent_h) | |
elif wa == 0: | |
latent_h = int(x_t / latent_w) | |
contexts_list = SR_norm_hidden_states_list | |
def single_matsepcalc(x, contexts_list, image_rotary_emb): | |
h_states = [] | |
x_t = x.size()[1]-512 | |
(latent_h,latent_w) = split_dims(x_t, height, width, self) | |
latent_out = latent_w | |
latent_in = latent_h | |
i = 0 | |
sumout = 0 | |
SR_all_out_list=[] | |
for drow in self.split_ratio: | |
v_states = [] | |
sumin = 0 | |
for dcell in drow.cols: | |
context = contexts_list[i] | |
i = i + 1 + dcell.breaks | |
SR_all_out = module.processor(module, hidden_states=context, image_rotary_emb=image_rotary_emb) | |
out = SR_all_out[:, 512 :, ...] | |
out = out.reshape(out.size()[0], latent_h, latent_w, out.size()[2]) | |
addout = 0 | |
addin = 0 | |
sumin = sumin + int(latent_in*dcell.end) - int(latent_in*dcell.start) | |
if dcell.end >= 0.999: | |
addin = sumin - latent_in | |
sumout = sumout + int(latent_out*drow.end) - int(latent_out*drow.start) | |
if drow.end >= 0.999: | |
addout = sumout - latent_out | |
out = out[:, int(latent_h*drow.start) + addout:int(latent_h*drow.end), | |
int(latent_w*dcell.start) + addin:int(latent_w*dcell.end), :] | |
v_states.append(out) | |
SR_all_out_list.append(SR_all_out) | |
output_x = torch.cat(v_states,dim = 2) | |
h_states.append(output_x) | |
output_x = torch.cat(h_states,dim = 1) | |
output_x = output_x.reshape(x.size()[0], x.size()[1]-512, x.size()[2]) | |
new_SR_all_out_list = [] | |
for SR_all_out in SR_all_out_list: | |
SR_all_out[:, 512 :, ...] = output_x | |
new_SR_all_out_list.append(SR_all_out) | |
x[:, 512 :, ...] = output_x * self.SR_delta + x[:, 512 :, ...] * (1-self.SR_delta) | |
return x, new_SR_all_out_list | |
return single_matsepcalc(flux_hidden_states, contexts_list, image_rotary_emb) | |
return forward | |
def FluxTransformerBlock_hook_forward(self, module): | |
def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None, SR_encoder_hidden_states_list=None, SR_norm_encoder_hidden_states_list=None, SR_hidden_states_list=None, SR_norm_hidden_states_list=None): | |
flux_hidden_states, flux_encoder_hidden_states = module.processor(module, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb) | |
height = self.h | |
width = self.w | |
x_t = hidden_states.size()[1] | |
scale = round(math.sqrt(height * width / x_t)) | |
latent_h = round(height / scale) | |
latent_w = round(width / scale) | |
ha, wa = x_t % latent_h, x_t % latent_w | |
if ha == 0: | |
latent_w = int(x_t / latent_h) | |
elif wa == 0: | |
latent_h = int(x_t / latent_w) | |
contexts_list = SR_norm_encoder_hidden_states_list | |
def matsepcalc(x, contexts_list, image_rotary_emb): | |
h_states = [] | |
x_t = x.size()[1] | |
(latent_h,latent_w) = split_dims(x_t, height, width, self) | |
latent_out = latent_w | |
latent_in = latent_h | |
i = 0 | |
sumout = 0 | |
SR_context_attn_output_list = [] | |
for drow in self.split_ratio: | |
v_states = [] | |
sumin = 0 | |
for dcell in drow.cols: | |
context = contexts_list[i] | |
i = i + 1 + dcell.breaks | |
out,SR_context_attn_output = module.processor(module, hidden_states=x, encoder_hidden_states=context, image_rotary_emb=image_rotary_emb) | |
out = out.reshape(out.size()[0], latent_h, latent_w, out.size()[2]) | |
addout = 0 | |
addin = 0 | |
sumin = sumin + int(latent_in*dcell.end) - int(latent_in*dcell.start) | |
if dcell.end >= 0.999: | |
addin = sumin - latent_in | |
sumout = sumout + int(latent_out*drow.end) - int(latent_out*drow.start) | |
if drow.end >= 0.999: | |
addout = sumout - latent_out | |
out = out[:, int(latent_h*drow.start) + addout:int(latent_h*drow.end), | |
int(latent_w*dcell.start) + addin:int(latent_w*dcell.end), :] | |
v_states.append(out) | |
SR_context_attn_output_list.append(SR_context_attn_output) | |
output_x = torch.cat(v_states,dim = 2) | |
h_states.append(output_x) | |
output_x = torch.cat(h_states,dim = 1) | |
output_x = output_x.reshape(x.size()[0],x.size()[1],x.size()[2]) | |
return output_x * self.SR_delta + flux_hidden_states * (1-self.SR_delta), flux_encoder_hidden_states, SR_context_attn_output_list | |
return matsepcalc(hidden_states, contexts_list, image_rotary_emb) | |
return forward | |
def split_dims(x_t, height, width, self=None): | |
"""Split an attention layer dimension to height + width. | |
The original estimate was latent_h = sqrt(hw_ratio*x_t), | |
rounding to the nearest value. However, this proved inaccurate. | |
The actual operation seems to be as follows: | |
- Divide h,w by 8, rounding DOWN. | |
- For every new layer (of 4), divide both by 2 and round UP (then back up). | |
- Multiply h*w to yield x_t. | |
There is no inverse function to this set of operations, | |
so instead we mimic them without the multiplication part using the original h+w. | |
It's worth noting that no known checkpoints follow a different system of layering, | |
but it's theoretically possible. Please report if encountered. | |
""" | |
scale = math.ceil(math.log2(math.sqrt(height * width / x_t))) | |
latent_h = repeat_div(height, scale) | |
latent_w = repeat_div(width, scale) | |
if x_t > latent_h * latent_w and hasattr(self, "nei_multi"): | |
latent_h, latent_w = self.nei_multi[1], self.nei_multi[0] | |
while latent_h * latent_w != x_t: | |
latent_h, latent_w = latent_h // 2, latent_w // 2 | |
return latent_h, latent_w | |
def repeat_div(x,y): | |
"""Imitates dimension halving common in convolution operations. | |
This is a pretty big assumption of the model, | |
but then if some model doesn't work like that it will be easy to spot. | |
""" | |
while y > 0: | |
x = math.ceil(x / 2) | |
y = y - 1 | |
return x | |
def init_forwards(self, root_module: torch.nn.Module): | |
for name, module in root_module.named_modules(): | |
if "attn" in name and "transformer_blocks" in name and "single_transformer_blocks" not in name and module.__class__.__name__ == "Attention": | |
module.forward = FluxTransformerBlock_init_forward(self, module) | |
elif "attn" in name and "single_transformer_blocks" in name and module.__class__.__name__ == "Attention": | |
module.forward = FluxSingleTransformerBlock_init_forward(self, module) | |
def FluxSingleTransformerBlock_init_forward(self, module): | |
def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None,RPG_encoder_hidden_states_list=None,RPG_norm_encoder_hidden_states_list=None,RPG_hidden_states_list=None,RPG_norm_hidden_states_list=None): | |
return module.processor(module, hidden_states=hidden_states, image_rotary_emb=image_rotary_emb) | |
return forward | |
def FluxTransformerBlock_init_forward(self, module): | |
def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None,RPG_encoder_hidden_states_list=None,RPG_norm_encoder_hidden_states_list=None,RPG_hidden_states_list=None,RPG_norm_hidden_states_list=None): | |
return module.processor(module, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb) | |
return forward |