Spaces:
Build error
Build error
# NAI compatible | |
import torch | |
class HypernetworkModule(torch.nn.Module): | |
def __init__(self, dim, multiplier=1.0): | |
super().__init__() | |
linear1 = torch.nn.Linear(dim, dim * 2) | |
linear2 = torch.nn.Linear(dim * 2, dim) | |
linear1.weight.data.normal_(mean=0.0, std=0.01) | |
linear1.bias.data.zero_() | |
linear2.weight.data.normal_(mean=0.0, std=0.01) | |
linear2.bias.data.zero_() | |
linears = [linear1, linear2] | |
self.linear = torch.nn.Sequential(*linears) | |
self.multiplier = multiplier | |
def forward(self, x): | |
return x + self.linear(x) * self.multiplier | |
class Hypernetwork(torch.nn.Module): | |
enable_sizes = [320, 640, 768, 1280] | |
# return self.modules[Hypernetwork.enable_sizes.index(size)] | |
def __init__(self, multiplier=1.0) -> None: | |
super().__init__() | |
self.modules = [] | |
for size in Hypernetwork.enable_sizes: | |
self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier))) | |
self.register_module(f"{size}_0", self.modules[-1][0]) | |
self.register_module(f"{size}_1", self.modules[-1][1]) | |
def apply_to_stable_diffusion(self, text_encoder, vae, unet): | |
blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks | |
for block in blocks: | |
for subblk in block: | |
if 'SpatialTransformer' in str(type(subblk)): | |
for tf_block in subblk.transformer_blocks: | |
for attn in [tf_block.attn1, tf_block.attn2]: | |
size = attn.context_dim | |
if size in Hypernetwork.enable_sizes: | |
attn.hypernetwork = self | |
else: | |
attn.hypernetwork = None | |
def apply_to_diffusers(self, text_encoder, vae, unet): | |
blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks | |
for block in blocks: | |
if hasattr(block, 'attentions'): | |
for subblk in block.attentions: | |
if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~ | |
for tf_block in subblk.transformer_blocks: | |
for attn in [tf_block.attn1, tf_block.attn2]: | |
size = attn.to_k.in_features | |
if size in Hypernetwork.enable_sizes: | |
attn.hypernetwork = self | |
else: | |
attn.hypernetwork = None | |
return True # TODO error checking | |
def forward(self, x, context): | |
size = context.shape[-1] | |
assert size in Hypernetwork.enable_sizes | |
module = self.modules[Hypernetwork.enable_sizes.index(size)] | |
return module[0].forward(context), module[1].forward(context) | |
def load_from_state_dict(self, state_dict): | |
# old ver to new ver | |
changes = { | |
'linear1.bias': 'linear.0.bias', | |
'linear1.weight': 'linear.0.weight', | |
'linear2.bias': 'linear.1.bias', | |
'linear2.weight': 'linear.1.weight', | |
} | |
for key_from, key_to in changes.items(): | |
if key_from in state_dict: | |
state_dict[key_to] = state_dict[key_from] | |
del state_dict[key_from] | |
for size, sd in state_dict.items(): | |
if type(size) == int: | |
self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True) | |
self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True) | |
return True | |
def get_state_dict(self): | |
state_dict = {} | |
for i, size in enumerate(Hypernetwork.enable_sizes): | |
sd0 = self.modules[i][0].state_dict() | |
sd1 = self.modules[i][1].state_dict() | |
state_dict[size] = [sd0, sd1] | |
return state_dict | |