Spaces:
Running
Running
File size: 7,339 Bytes
680cb9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import torch
import torch.nn as nn
from .base_color import *
class SIGGRAPHGenerator(BaseColor):
def __init__(self, norm_layer=nn.BatchNorm2d, classes=529):
super(SIGGRAPHGenerator, self).__init__()
# Conv1
model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),]
model1+=[nn.ReLU(True),]
model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),]
model1+=[nn.ReLU(True),]
model1+=[norm_layer(64),]
# add a subsampling operation
# Conv2
model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
model2+=[nn.ReLU(True),]
model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
model2+=[nn.ReLU(True),]
model2+=[norm_layer(128),]
# add a subsampling layer operation
# Conv3
model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
model3+=[nn.ReLU(True),]
model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
model3+=[nn.ReLU(True),]
model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
model3+=[nn.ReLU(True),]
model3+=[norm_layer(256),]
# add a subsampling layer operation
# Conv4
model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
model4+=[nn.ReLU(True),]
model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
model4+=[nn.ReLU(True),]
model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
model4+=[nn.ReLU(True),]
model4+=[norm_layer(512),]
# Conv5
model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
model5+=[nn.ReLU(True),]
model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
model5+=[nn.ReLU(True),]
model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
model5+=[nn.ReLU(True),]
model5+=[norm_layer(512),]
# Conv6
model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
model6+=[nn.ReLU(True),]
model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
model6+=[nn.ReLU(True),]
model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
model6+=[nn.ReLU(True),]
model6+=[norm_layer(512),]
# Conv7
model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
model7+=[nn.ReLU(True),]
model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
model7+=[nn.ReLU(True),]
model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
model7+=[nn.ReLU(True),]
model7+=[norm_layer(512),]
# Conv7
model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)]
model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
model8=[nn.ReLU(True),]
model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
model8+=[nn.ReLU(True),]
model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
model8+=[nn.ReLU(True),]
model8+=[norm_layer(256),]
# Conv9
model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),]
model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
# add the two feature maps above
model9=[nn.ReLU(True),]
model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
model9+=[nn.ReLU(True),]
model9+=[norm_layer(128),]
# Conv10
model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),]
model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
# add the two feature maps above
model10=[nn.ReLU(True),]
model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),]
model10+=[nn.LeakyReLU(negative_slope=.2),]
# classification output
model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
# regression output
model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
model_out+=[nn.Tanh()]
self.model1 = nn.Sequential(*model1)
self.model2 = nn.Sequential(*model2)
self.model3 = nn.Sequential(*model3)
self.model4 = nn.Sequential(*model4)
self.model5 = nn.Sequential(*model5)
self.model6 = nn.Sequential(*model6)
self.model7 = nn.Sequential(*model7)
self.model8up = nn.Sequential(*model8up)
self.model8 = nn.Sequential(*model8)
self.model9up = nn.Sequential(*model9up)
self.model9 = nn.Sequential(*model9)
self.model10up = nn.Sequential(*model10up)
self.model10 = nn.Sequential(*model10)
self.model3short8 = nn.Sequential(*model3short8)
self.model2short9 = nn.Sequential(*model2short9)
self.model1short10 = nn.Sequential(*model1short10)
self.model_class = nn.Sequential(*model_class)
self.model_out = nn.Sequential(*model_out)
self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),])
self.softmax = nn.Sequential(*[nn.Softmax(dim=1),])
def forward(self, input_A, input_B=None, mask_B=None):
if(input_B is None):
input_B = torch.cat((input_A*0, input_A*0), dim=1)
if(mask_B is None):
mask_B = input_A*0
conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1))
conv2_2 = self.model2(conv1_2[:,:,::2,::2])
conv3_3 = self.model3(conv2_2[:,:,::2,::2])
conv4_3 = self.model4(conv3_3[:,:,::2,::2])
conv5_3 = self.model5(conv4_3)
conv6_3 = self.model6(conv5_3)
conv7_3 = self.model7(conv6_3)
conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
conv8_3 = self.model8(conv8_up)
conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
conv9_3 = self.model9(conv9_up)
conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
conv10_2 = self.model10(conv10_up)
out_reg = self.model_out(conv10_2)
conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
conv9_3 = self.model9(conv9_up)
conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
conv10_2 = self.model10(conv10_up)
out_reg = self.model_out(conv10_2)
return self.unnormalize_ab(out_reg)
def siggraph17(pretrained=True):
model = SIGGRAPHGenerator()
if(pretrained):
import torch.utils.model_zoo as model_zoo
model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
return model
|