Spaces:
Runtime error
Runtime error
| import torch as th | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import random | |
| from .nn import timestep_embedding | |
| from .unet import UNetModel | |
| from .xf import LayerNorm, Transformer, convert_module_to_f16 | |
| from timm.models.vision_transformer import PatchEmbed | |
| class Text2ImModel(nn.Module): | |
| def __init__( | |
| self, | |
| text_ctx, | |
| xf_width, | |
| xf_layers, | |
| xf_heads, | |
| xf_final_ln, | |
| model_channels, | |
| out_channels, | |
| num_res_blocks, | |
| attention_resolutions, | |
| dropout, | |
| channel_mult, | |
| use_fp16, | |
| num_heads, | |
| num_heads_upsample, | |
| num_head_channels, | |
| use_scale_shift_norm, | |
| resblock_updown, | |
| in_channels = 3, | |
| n_class = 3, | |
| image_size = 64, | |
| ): | |
| super().__init__() | |
| self.encoder = Encoder(img_size=image_size, patch_size=image_size//16, in_chans=n_class, | |
| xf_width=xf_width, xf_layers=8, xf_heads=xf_heads, model_channels=model_channels) | |
| self.in_channels = in_channels | |
| self.decoder = Text2ImUNet( | |
| in_channels, | |
| model_channels, | |
| out_channels, | |
| num_res_blocks, | |
| attention_resolutions, | |
| dropout=dropout, | |
| channel_mult=channel_mult, | |
| use_fp16=use_fp16, | |
| num_heads=num_heads, | |
| num_heads_upsample=num_heads_upsample, | |
| num_head_channels=num_head_channels, | |
| use_scale_shift_norm=use_scale_shift_norm, | |
| resblock_updown=resblock_updown, | |
| encoder_channels=xf_width | |
| ) | |
| def forward(self, xt, timesteps, ref=None, uncond_p=0.0): | |
| latent_outputs =self.encoder(ref, uncond_p) | |
| pred = self.decoder(xt, timesteps, latent_outputs) | |
| return pred | |
| class Text2ImUNet(UNetModel): | |
| def __init__( | |
| self, | |
| *args, | |
| **kwargs, | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.transformer_proj = nn.Linear(512, self.model_channels * 4) ### | |
| def forward(self, x, timesteps, latent_outputs): | |
| hs = [] | |
| emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) | |
| xf_proj, xf_out = latent_outputs["xf_proj"], latent_outputs["xf_out"] | |
| xf_proj = self.transformer_proj(xf_proj) ### | |
| emb = emb + xf_proj.to(emb) | |
| h = x.type(self.dtype) | |
| for module in self.input_blocks: | |
| h = module(h, emb, xf_out) | |
| hs.append(h) | |
| h = self.middle_block(h, emb, xf_out) | |
| for module in self.output_blocks: | |
| h = th.cat([h, hs.pop()], dim=1) | |
| h = module(h, emb, xf_out) | |
| h = h.type(x.dtype) | |
| h = self.out(h) | |
| return h | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| img_size, | |
| patch_size, | |
| in_chans, | |
| xf_width, | |
| xf_layers, | |
| xf_heads, | |
| model_channels, | |
| ): | |
| super().__init__( ) | |
| self.transformer = Transformer( | |
| xf_width, | |
| xf_layers, | |
| xf_heads, | |
| ) | |
| self.cnn = CNN(in_chans) | |
| self.final_ln = LayerNorm(xf_width) | |
| self.cls_token = nn.Parameter(th.empty(1, 1, xf_width, dtype=th.float32)) | |
| self.positional_embedding = nn.Parameter(th.empty(1, 256 + 1, xf_width, dtype=th.float32)) | |
| def forward(self, ref, uncond_p=0.0): | |
| x = self.cnn(ref) | |
| x = x.flatten(2).transpose(1, 2) | |
| x = x + self.positional_embedding[:, 1:, :] | |
| cls_token = self.cls_token + self.positional_embedding[:, :1, :] | |
| cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
| x = th.cat((x, cls_tokens), dim=1) | |
| xf_out = self.transformer(x) | |
| if self.final_ln is not None: | |
| xf_out = self.final_ln(xf_out) | |
| xf_proj = xf_out[:, -1] | |
| xf_out = xf_out[:, :-1].permute(0, 2, 1) # NLC -> NCL | |
| outputs = dict(xf_proj=xf_proj, xf_out=xf_out) | |
| return outputs | |
| class SuperResText2ImModel(Text2ImModel): | |
| """ | |
| A text2im model that performs super-resolution. | |
| Expects an extra kwarg `low_res` to condition on a low-resolution image. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| if "in_channels" in kwargs: | |
| kwargs = dict(kwargs) | |
| kwargs["in_channels"] = kwargs["in_channels"] * 2 | |
| else: | |
| # Curse you, Python. Or really, just curse positional arguments :|. | |
| args = list(args) | |
| args[1] = args[1] * 2 | |
| super().__init__(*args, **kwargs) | |
| def forward(self, x, timesteps, low_res=None, **kwargs): | |
| _, _, new_height, new_width = x.shape | |
| upsampled = F.interpolate( | |
| low_res, (new_height, new_width), mode="bilinear", align_corners=False | |
| ) | |
| # ########## | |
| # upsampled = upsampled + th.randn_like(upsampled)*0.0005*th.log(1 + 0.1* timesteps.reshape(timesteps.shape[0], 1,1,1)) | |
| # ########## | |
| x = th.cat([x, upsampled], dim=1) | |
| return super().forward(x, timesteps, **kwargs) | |
| def conv3x3(in_channels, out_channels, stride=1): | |
| return nn.Conv2d(in_channels, out_channels, kernel_size=3, | |
| stride=stride, padding=1, bias=True) | |
| def conv7x7(in_channels, out_channels, stride=1): | |
| return nn.Conv2d(in_channels, out_channels, kernel_size=7, | |
| stride=stride, padding=3, bias=True) | |
| class CNN(nn.Module): | |
| def __init__(self, in_channels=3): | |
| super(CNN, self).__init__() | |
| self.conv1 = conv7x7(in_channels, 32) #256 | |
| self.norm1 = nn.InstanceNorm2d(32, affine=True) | |
| self.LReLU1 = nn.LeakyReLU(0.2) | |
| self.conv2 = conv3x3(32, 64, 2) #128 | |
| self.norm2 = nn.InstanceNorm2d(64, affine=True) | |
| self.LReLU2 = nn.LeakyReLU(0.2) | |
| self.conv3 = conv3x3(64, 128, 2) #64 | |
| self.norm3 = nn.InstanceNorm2d(128, affine=True) | |
| self.LReLU3 = nn.LeakyReLU(0.2) | |
| self.conv4 = conv3x3(128, 256, 2) #32 | |
| self.norm4 = nn.InstanceNorm2d(256, affine=True) | |
| self.LReLU4 = nn.LeakyReLU(0.2) | |
| self.conv5 = conv3x3(256, 512, 2) #16 | |
| self.norm5 = nn.InstanceNorm2d(512, affine=True) | |
| self.LReLU5 = nn.LeakyReLU(0.2) | |
| self.conv6 = conv3x3(512, 512, 1) | |
| def forward(self, x): | |
| x = self.LReLU1(self.norm1(self.conv1(x))) | |
| x = self.LReLU2(self.norm2(self.conv2(x))) | |
| x = self.LReLU3(self.norm3(self.conv3(x))) | |
| x = self.LReLU4(self.norm4(self.conv4(x))) | |
| x = self.LReLU5(self.norm5(self.conv5(x))) | |
| x = self.conv6(x) | |
| return x |