Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import einops | |
| import numpy as np | |
| import models | |
| from modules.common_ckpt import Linear, Conv2d, AttnBlock, ResBlock, LayerNorm2d | |
| #from modules.common_ckpt import AttnBlock, | |
| from einops import rearrange | |
| import torch.fft as fft | |
| from modules.speed_util import checkpoint | |
| def batched_linear_mm(x, wb): | |
| # x: (B, N, D1); wb: (B, D1 + 1, D2) or (D1 + 1, D2) | |
| one = torch.ones(*x.shape[:-1], 1, device=x.device) | |
| return torch.matmul(torch.cat([x, one], dim=-1), wb) | |
| def make_coord_grid(shape, range, device=None): | |
| """ | |
| Args: | |
| shape: tuple | |
| range: [minv, maxv] or [[minv_1, maxv_1], ..., [minv_d, maxv_d]] for each dim | |
| Returns: | |
| grid: shape (*shape, ) | |
| """ | |
| l_lst = [] | |
| for i, s in enumerate(shape): | |
| l = (0.5 + torch.arange(s, device=device)) / s | |
| if isinstance(range[0], list) or isinstance(range[0], tuple): | |
| minv, maxv = range[i] | |
| else: | |
| minv, maxv = range | |
| l = minv + (maxv - minv) * l | |
| l_lst.append(l) | |
| grid = torch.meshgrid(*l_lst, indexing='ij') | |
| grid = torch.stack(grid, dim=-1) | |
| return grid | |
| def init_wb(shape): | |
| weight = torch.empty(shape[1], shape[0] - 1) | |
| nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) | |
| bias = torch.empty(shape[1], 1) | |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) | |
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | |
| nn.init.uniform_(bias, -bound, bound) | |
| return torch.cat([weight, bias], dim=1).t().detach() | |
| def init_wb_rewrite(shape): | |
| weight = torch.empty(shape[1], shape[0] - 1) | |
| torch.nn.init.xavier_uniform_(weight) | |
| bias = torch.empty(shape[1], 1) | |
| torch.nn.init.xavier_uniform_(bias) | |
| return torch.cat([weight, bias], dim=1).t().detach() | |
| class HypoMlp(nn.Module): | |
| def __init__(self, depth, in_dim, out_dim, hidden_dim, use_pe, pe_dim, out_bias=0, pe_sigma=1024): | |
| super().__init__() | |
| self.use_pe = use_pe | |
| self.pe_dim = pe_dim | |
| self.pe_sigma = pe_sigma | |
| self.depth = depth | |
| self.param_shapes = dict() | |
| if use_pe: | |
| last_dim = in_dim * pe_dim | |
| else: | |
| last_dim = in_dim | |
| for i in range(depth): # for each layer the weight | |
| cur_dim = hidden_dim if i < depth - 1 else out_dim | |
| self.param_shapes[f'wb{i}'] = (last_dim + 1, cur_dim) | |
| last_dim = cur_dim | |
| self.relu = nn.ReLU() | |
| self.params = None | |
| self.out_bias = out_bias | |
| def set_params(self, params): | |
| self.params = params | |
| def convert_posenc(self, x): | |
| w = torch.exp(torch.linspace(0, np.log(self.pe_sigma), self.pe_dim // 2, device=x.device)) | |
| x = torch.matmul(x.unsqueeze(-1), w.unsqueeze(0)).view(*x.shape[:-1], -1) | |
| x = torch.cat([torch.cos(np.pi * x), torch.sin(np.pi * x)], dim=-1) | |
| return x | |
| def forward(self, x): | |
| B, query_shape = x.shape[0], x.shape[1: -1] | |
| x = x.view(B, -1, x.shape[-1]) | |
| if self.use_pe: | |
| x = self.convert_posenc(x) | |
| #print('in line 79 after pos embedding', x.shape) | |
| for i in range(self.depth): | |
| x = batched_linear_mm(x, self.params[f'wb{i}']) | |
| if i < self.depth - 1: | |
| x = self.relu(x) | |
| else: | |
| x = x + self.out_bias | |
| x = x.view(B, *query_shape, -1) | |
| return x | |
| class Attention(nn.Module): | |
| def __init__(self, dim, n_head, head_dim, dropout=0.): | |
| super().__init__() | |
| self.n_head = n_head | |
| inner_dim = n_head * head_dim | |
| self.to_q = nn.Sequential( | |
| nn.SiLU(), | |
| Linear(dim, inner_dim )) | |
| self.to_kv = nn.Sequential( | |
| nn.SiLU(), | |
| Linear(dim, inner_dim * 2)) | |
| self.scale = head_dim ** -0.5 | |
| # self.to_out = nn.Sequential( | |
| # Linear(inner_dim, dim), | |
| # nn.Dropout(dropout), | |
| # ) | |
| def forward(self, fr, to=None): | |
| if to is None: | |
| to = fr | |
| q = self.to_q(fr) | |
| k, v = self.to_kv(to).chunk(2, dim=-1) | |
| q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.n_head), [q, k, v]) | |
| dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale | |
| attn = F.softmax(dots, dim=-1) # b h n n | |
| out = torch.matmul(attn, v) | |
| out = einops.rearrange(out, 'b h n d -> b n (h d)') | |
| return out | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, ff_dim, dropout=0.): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| Linear(dim, ff_dim), | |
| nn.GELU(), | |
| #GlobalResponseNorm(ff_dim), | |
| nn.Dropout(dropout), | |
| Linear(ff_dim, dim) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.fn = fn | |
| def forward(self, x): | |
| return self.fn(self.norm(x)) | |
| #TransInr(ind=2048, ch=256, n_head=16, head_dim=16, n_groups=64, f_dim=256, time_dim=self.c_r, t_conds = []) | |
| class TransformerEncoder(nn.Module): | |
| def __init__(self, dim, depth, n_head, head_dim, ff_dim, dropout=0.): | |
| super().__init__() | |
| self.layers = nn.ModuleList() | |
| for _ in range(depth): | |
| self.layers.append(nn.ModuleList([ | |
| PreNorm(dim, Attention(dim, n_head, head_dim, dropout=dropout)), | |
| PreNorm(dim, FeedForward(dim, ff_dim, dropout=dropout)), | |
| ])) | |
| def forward(self, x): | |
| for norm_attn, norm_ff in self.layers: | |
| x = x + norm_attn(x) | |
| x = x + norm_ff(x) | |
| return x | |
| class ImgrecTokenizer(nn.Module): | |
| def __init__(self, input_size=32*32, patch_size=1, dim=768, padding=0, img_channels=16): | |
| super().__init__() | |
| if isinstance(patch_size, int): | |
| patch_size = (patch_size, patch_size) | |
| if isinstance(padding, int): | |
| padding = (padding, padding) | |
| self.patch_size = patch_size | |
| self.padding = padding | |
| self.prefc = nn.Linear(patch_size[0] * patch_size[1] * img_channels, dim) | |
| self.posemb = nn.Parameter(torch.randn(input_size, dim)) | |
| def forward(self, x): | |
| #print(x.shape) | |
| p = self.patch_size | |
| x = F.unfold(x, p, stride=p, padding=self.padding) # (B, C * p * p, L) | |
| #print('in line 185 after unfoding', x.shape) | |
| x = x.permute(0, 2, 1).contiguous() | |
| ttt = self.prefc(x) | |
| x = self.prefc(x) + self.posemb[:x.shape[1]].unsqueeze(0) | |
| return x | |
| class SpatialAttention(nn.Module): | |
| def __init__(self, kernel_size=7): | |
| super(SpatialAttention, self).__init__() | |
| self.conv1 = Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| avg_out = torch.mean(x, dim=1, keepdim=True) | |
| max_out, _ = torch.max(x, dim=1, keepdim=True) | |
| x = torch.cat([avg_out, max_out], dim=1) | |
| x = self.conv1(x) | |
| return self.sigmoid(x) | |
| class TimestepBlock_res(nn.Module): | |
| def __init__(self, c, c_timestep, conds=['sca']): | |
| super().__init__() | |
| self.mapper = Linear(c_timestep, c * 2) | |
| self.conds = conds | |
| for cname in conds: | |
| setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) | |
| def forward(self, x, t): | |
| #print(x.shape, t.shape, self.conds, 'in line 269') | |
| t = t.chunk(len(self.conds) + 1, dim=1) | |
| a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) | |
| for i, c in enumerate(self.conds): | |
| ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) | |
| a, b = a + ac, b + bc | |
| return x * (1 + a) + b | |
| def zero_module(module): | |
| """ | |
| Zero out the parameters of a module and return it. | |
| """ | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| class ScaleNormalize_res(nn.Module): | |
| def __init__(self, c, scale_c, conds=['sca']): | |
| super().__init__() | |
| self.c_r = scale_c | |
| self.mapping = TimestepBlock_res(c, scale_c, conds=conds) | |
| self.t_conds = conds | |
| self.alpha = nn.Conv2d(c, c, kernel_size=1) | |
| self.gamma = nn.Conv2d(c, c, kernel_size=1) | |
| self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) | |
| def gen_r_embedding(self, r, max_positions=10000): | |
| r = r * max_positions | |
| half_dim = self.c_r // 2 | |
| emb = math.log(max_positions) / (half_dim - 1) | |
| emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() | |
| emb = r[:, None] * emb[None, :] | |
| emb = torch.cat([emb.sin(), emb.cos()], dim=1) | |
| if self.c_r % 2 == 1: # zero pad | |
| emb = nn.functional.pad(emb, (0, 1), mode='constant') | |
| return emb | |
| def forward(self, x, std_size=24*24): | |
| scale_val = math.sqrt(math.log(x.shape[-2] * x.shape[-1], std_size)) | |
| scale_val = torch.ones(x.shape[0]).to(x.device)*scale_val | |
| scale_val_f = self.gen_r_embedding(scale_val) | |
| for c in self.t_conds: | |
| t_cond = torch.zeros_like(scale_val) | |
| scale_val_f = torch.cat([scale_val_f, self.gen_r_embedding(t_cond)], dim=1) | |
| f = self.mapping(x, scale_val_f) | |
| return f + x | |
| class TransInr_withnorm(nn.Module): | |
| def __init__(self, ind=2048, ch=16, n_head=12, head_dim=64, n_groups=64, f_dim=768, time_dim=2048, t_conds=[]): | |
| super().__init__() | |
| self.input_layer= nn.Conv2d(ind, ch, 1) | |
| self.tokenizer = ImgrecTokenizer(dim=ch, img_channels=ch) | |
| #self.hyponet = HypoMlp(depth=12, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128) | |
| #self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=12, n_head=n_head, head_dim=f_dim // n_head, ff_dim=3*f_dim, ) | |
| self.hyponet = HypoMlp(depth=2, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128) | |
| self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=1, n_head=n_head, head_dim=f_dim // n_head, ff_dim=f_dim) | |
| #self.transformer_encoder = TransInr( ch=ch, n_head=16, head_dim=16, n_groups=64, f_dim=ch, time_dim=time_dim, t_conds = []) | |
| self.base_params = nn.ParameterDict() | |
| n_wtokens = 0 | |
| self.wtoken_postfc = nn.ModuleDict() | |
| self.wtoken_rng = dict() | |
| for name, shape in self.hyponet.param_shapes.items(): | |
| self.base_params[name] = nn.Parameter(init_wb(shape)) | |
| g = min(n_groups, shape[1]) | |
| assert shape[1] % g == 0 | |
| self.wtoken_postfc[name] = nn.Sequential( | |
| nn.LayerNorm(f_dim), | |
| nn.Linear(f_dim, shape[0] - 1), | |
| ) | |
| self.wtoken_rng[name] = (n_wtokens, n_wtokens + g) | |
| n_wtokens += g | |
| self.wtokens = nn.Parameter(torch.randn(n_wtokens, f_dim)) | |
| self.output_layer= nn.Conv2d(ch, ind, 1) | |
| self.mapp_t = TimestepBlock_res( ind, time_dim, conds = t_conds) | |
| self.hr_norm = ScaleNormalize_res(ind, 64, conds=[]) | |
| self.normalize_final = nn.Sequential( | |
| LayerNorm2d(ind, elementwise_affine=False, eps=1e-6), | |
| ) | |
| self.toout = nn.Sequential( | |
| Linear( ind*2, ind // 4), | |
| nn.GELU(), | |
| Linear( ind // 4, ind) | |
| ) | |
| self.apply(self._init_weights) | |
| mask = torch.zeros((1, 1, 32, 32)) | |
| h, w = 32, 32 | |
| center_h, center_w = h // 2, w // 2 | |
| low_freq_h, low_freq_w = h // 4, w // 4 | |
| mask[:, :, center_h-low_freq_h:center_h+low_freq_h, center_w-low_freq_w:center_w+low_freq_w] = 1 | |
| self.mask = mask | |
| def _init_weights(self, m): | |
| if isinstance(m, (nn.Conv2d, nn.Linear)): | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| #nn.init.constant_(self.last.weight, 0) | |
| def adain(self, feature_a, feature_b): | |
| norm_mean = torch.mean(feature_a, dim=(2, 3), keepdim=True) | |
| norm_std = torch.std(feature_a, dim=(2, 3), keepdim=True) | |
| #feature_a = F.interpolate(feature_a, feature_b.shape[2:]) | |
| feature_b = (feature_b - feature_b.mean(dim=(2, 3), keepdim=True)) / (1e-8 + feature_b.std(dim=(2, 3), keepdim=True)) * norm_std + norm_mean | |
| return feature_b | |
| def forward(self, target_shape, target, dtokens, t_emb): | |
| #print(target.shape, dtokens.shape, 'in line 290') | |
| hlr, wlr = dtokens.shape[2:] | |
| original = dtokens | |
| dtokens = self.input_layer(dtokens) | |
| dtokens = self.tokenizer(dtokens) | |
| B = dtokens.shape[0] | |
| wtokens = einops.repeat(self.wtokens, 'n d -> b n d', b=B) | |
| #print(wtokens.shape, dtokens.shape) | |
| trans_out = self.transformer_encoder(torch.cat([dtokens, wtokens], dim=1)) | |
| trans_out = trans_out[:, -len(self.wtokens):, :] | |
| params = dict() | |
| for name, shape in self.hyponet.param_shapes.items(): | |
| wb = einops.repeat(self.base_params[name], 'n m -> b n m', b=B) | |
| w, b = wb[:, :-1, :], wb[:, -1:, :] | |
| l, r = self.wtoken_rng[name] | |
| x = self.wtoken_postfc[name](trans_out[:, l: r, :]) | |
| x = x.transpose(-1, -2) # (B, shape[0] - 1, g) | |
| w = F.normalize(w * x.repeat(1, 1, w.shape[2] // x.shape[2]), dim=1) | |
| wb = torch.cat([w, b], dim=1) | |
| params[name] = wb | |
| coord = make_coord_grid(target_shape[2:], (-1, 1), device=dtokens.device) | |
| coord = einops.repeat(coord, 'h w d -> b h w d', b=dtokens.shape[0]) | |
| self.hyponet.set_params(params) | |
| ori_up = F.interpolate(original.float(), target_shape[2:]) | |
| hr_rec = self.output_layer(rearrange(self.hyponet(coord), 'b h w c -> b c h w')) + ori_up | |
| #print(hr_rec.shape, target.shape, torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1).shape, 'in line 537') | |
| output = self.toout(torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
| #print(output.shape, 'in line 540') | |
| #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)* 0.3 | |
| output = self.mapp_t(output, t_emb) | |
| output = self.normalize_final(output) | |
| output = self.hr_norm(output) | |
| #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
| #output = self.mapp_t(output, t_emb) | |
| #output = self.weight(output) * output | |
| return output | |
| class LayerNorm2d(nn.LayerNorm): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward(self, x): | |
| return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
| class GlobalResponseNorm(nn.Module): | |
| "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) | |
| self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) | |
| def forward(self, x): | |
| Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) | |
| Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) | |
| return self.gamma * (x * Nx) + self.beta + x | |
| if __name__ == '__main__': | |
| #ef __init__(self, ch, n_head, head_dim, n_groups): | |
| trans_inr = TransInr(16, 24, 32, 64).cuda() | |
| input = torch.randn((1, 16, 24, 24)).cuda() | |
| source = torch.randn((1, 16, 16, 16)).cuda() | |
| t = torch.randn((1, 128)).cuda() | |
| output, hr = trans_inr(input, t, source) | |
| total_up = sum([ param.nelement() for param in trans_inr.parameters()]) | |
| print(output.shape, hr.shape, total_up /1e6 ) | |