Spaces:
Build error
Build error
File size: 5,355 Bytes
d03bb00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import os
import math
from timm.models.layers import trunc_normal_
from model.blocks import CBlock_ln, SwinTransformerBlock
from model.global_net import Global_pred
class Local_pred(nn.Module):
def __init__(self, dim=16, number=4, type='ccc'):
super(Local_pred, self).__init__()
# initial convolution
self.conv1 = nn.Conv2d(3, dim, 3, padding=1, groups=1)
self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# main blocks
block = CBlock_ln(dim)
block_t = SwinTransformerBlock(dim) # head number
if type =='ccc':
#blocks1, blocks2 = [block for _ in range(number)], [block for _ in range(number)]
blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
elif type =='ttt':
blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
elif type =='cct':
blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
# block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
self.mul_blocks = nn.Sequential(*blocks1, nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
self.add_blocks = nn.Sequential(*blocks2, nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
def forward(self, img):
img1 = self.relu(self.conv1(img))
mul = self.mul_blocks(img1)
add = self.add_blocks(img1)
return mul, add
# Short Cut Connection on Final Layer
class Local_pred_S(nn.Module):
def __init__(self, in_dim=3, dim=16, number=4, type='ccc'):
super(Local_pred_S, self).__init__()
# initial convolution
self.conv1 = nn.Conv2d(in_dim, dim, 3, padding=1, groups=1)
self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# main blocks
block = CBlock_ln(dim)
block_t = SwinTransformerBlock(dim) # head number
if type =='ccc':
blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
elif type =='ttt':
blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
elif type =='cct':
blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
# block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
self.mul_blocks = nn.Sequential(*blocks1)
self.add_blocks = nn.Sequential(*blocks2)
self.mul_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
self.add_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, img):
img1 = self.relu(self.conv1(img))
# short cut connection
mul = self.mul_blocks(img1) + img1
add = self.add_blocks(img1) + img1
mul = self.mul_end(mul)
add = self.add_end(add)
return mul, add
class IAT(nn.Module):
def __init__(self, in_dim=3, with_global=True, type='lol'):
super(IAT, self).__init__()
#self.local_net = Local_pred()
self.local_net = Local_pred_S(in_dim=in_dim)
self.with_global = with_global
if self.with_global:
self.global_net = Global_pred(in_channels=in_dim, type=type)
def apply_color(self, image, ccm):
shape = image.shape
image = image.view(-1, 3)
image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
image = image.view(shape)
return torch.clamp(image, 1e-8, 1.0)
def forward(self, img_low):
#print(self.with_global)
mul, add = self.local_net(img_low)
img_high = (img_low.mul(mul)).add(add)
if not self.with_global:
return img_high
else:
gamma, color = self.global_net(img_low)
b = img_high.shape[0]
img_high = img_high.permute(0, 2, 3, 1) # (B,C,H,W) -- (B,H,W,C)
img_high = torch.stack([self.apply_color(img_high[i,:,:,:], color[i,:,:])**gamma[i,:] for i in range(b)], dim=0)
img_high = img_high.permute(0, 3, 1, 2) # (B,H,W,C) -- (B,C,H,W)
return img_high
if __name__ == "__main__":
os.environ['CUDA_VISIBLE_DEVICES']='3'
img = torch.Tensor(1, 3, 400, 600)
net = IAT()
print('total parameters:', sum(param.numel() for param in net.parameters()))
_, _, high = net(img) |