birdortyedi commited on
Commit
2a92dc2
1 Parent(s): 2529c2e

Add application file

Browse files
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.models as models
7
+
8
+ from configs.default import get_cfg_defaults
9
+ from modeling.build import build_model
10
+ from utils.data_utils import linear_scaling
11
+
12
+
13
+ url = "https://www.dropbox.com/s/uxvax5sjx5iysyl/cifr.pth?dl=0"
14
+ r = requests.get(url, stream=True)
15
+ if not os.path.exists("cifr.pth"):
16
+ with open("cifr.pth", 'wb') as f:
17
+ for data in r:
18
+ f.write(data)
19
+
20
+ cfg = get_cfg_defaults()
21
+ cfg.MODEL.CKPT = "cifr.pth"
22
+ net, _ = build_model(cfg)
23
+ net = net.eval()
24
+ vgg16 = models.vgg16(pretrained=True).features.eval()
25
+
26
+
27
+ def load_checkpoints_from_ckpt(ckpt_path):
28
+ checkpoints = torch.load(ckpt_path, map_location=torch.device('cuda'))
29
+ net.load_state_dict(checkpoints["ifr"])
30
+
31
+
32
+ load_checkpoints_from_ckpt(cfg.MODEL.CKPT)
33
+
34
+
35
+ def filter_removal(img):
36
+ arr = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
37
+ arr = torch.tensor(arr).float() / 255.
38
+ arr = linear_scaling(arr)
39
+ with torch.no_grad():
40
+ feat = vgg16(arr)
41
+ out, _ = net(arr, feat)
42
+ out = torch.clamp(out, max=1., min=0.)
43
+ return out.squeeze(0).permute(1, 2, 0).numpy()
44
+
45
+
46
+ title = "Contrastive Instagram Filter Removal (CIFR)"
47
+ description = "This is the demo for CIFR, contrastive strategy for filter removal on fashionable images on Instagram. " \
48
+ "To use it, simply upload your filtered image, or click one of the examples to load them."
49
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.07486'>Contrastive Instagram Filter Removal (CIFR)</a> | <a href='https://github.com/birdortyedi/cifr-pytorch'>Github Repo</a></p>"
50
+
51
+ gr.Interface(
52
+ filter_removal,
53
+ gr.inputs.Image(shape=(256, 256)),
54
+ gr.outputs.Image(),
55
+ title=title,
56
+ description=description,
57
+ article=article,
58
+ allow_flagging=False,
59
+ examples_per_page=17,
60
+ examples=[
61
+ ["images/examples/98_He-Fe.jpg"],
62
+ ["images/examples/2_Brannan.jpg"],
63
+ ["images/examples/12_Toaster.jpg"],
64
+ ["images/examples/18_Gingham.jpg"],
65
+ ["images/examples/11_Sutro.jpg"],
66
+ ["images/examples/9_Lo-Fi.jpg"],
67
+ ["images/examples/3_Mayfair.jpg"],
68
+ ["images/examples/4_Hudson.jpg"],
69
+ ["images/examples/5_Amaro.jpg"],
70
+ ["images/examples/6_1977.jpg"],
71
+ ["images/examples/8_Valencia.jpg"],
72
+ ["images/examples/16_Lo-Fi.jpg"],
73
+ ["images/examples/10_Nashville.jpg"],
74
+ ["images/examples/15_X-ProII.jpg"],
75
+ ["images/examples/14_Willow.jpg"],
76
+ ["images/examples/30_Perpetua.jpg"],
77
+ ["images/examples/1_Clarendon.jpg"],
78
+ ]
79
+ ).launch()
configs/__pycache__/default.cpython-37.pyc ADDED
Binary file (2.68 kB). View file
 
configs/default.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+
3
+ _C = CN()
4
+
5
+ _C.SYSTEM = CN()
6
+ _C.SYSTEM.NUM_GPU = 2
7
+ _C.SYSTEM.NUM_WORKERS = 4
8
+
9
+ _C.WANDB = CN()
10
+ _C.WANDB.PROJECT_NAME = "contrastive-style-learning-for-ifr"
11
+ _C.WANDB.ENTITY = "vvgl-ozu"
12
+ _C.WANDB.RUN = 3
13
+ _C.WANDB.LOG_DIR = ""
14
+ _C.WANDB.NUM_ROW = 0
15
+
16
+ _C.TRAIN = CN()
17
+ _C.TRAIN.NUM_TOTAL_STEP = 200000
18
+ _C.TRAIN.START_STEP = 0
19
+ _C.TRAIN.BATCH_SIZE = 16
20
+ _C.TRAIN.SHUFFLE = True
21
+ _C.TRAIN.LOG_INTERVAL = 100
22
+ _C.TRAIN.EVAL_INTERVAL = 1000
23
+ _C.TRAIN.SAVE_INTERVAL = 1000
24
+ _C.TRAIN.SAVE_DIR = "./weights"
25
+ _C.TRAIN.RESUME = True
26
+ _C.TRAIN.VISUALIZE_INTERVAL = 100
27
+ _C.TRAIN.TUNE = False
28
+
29
+ _C.MODEL = CN()
30
+ _C.MODEL.NAME = "cifr"
31
+ _C.MODEL.IS_TRAIN = True
32
+ _C.MODEL.NUM_CLASS = 17
33
+ _C.MODEL.CKPT = ""
34
+ _C.MODEL.PRETRAINED = ""
35
+
36
+ _C.MODEL.IFR = CN()
37
+ _C.MODEL.IFR.NAME = "ContrastiveInstaFilterRemovalNetwork"
38
+ _C.MODEL.IFR.NUM_CHANNELS = 32
39
+ _C.MODEL.IFR.DESTYLER_CHANNELS = 32
40
+ _C.MODEL.IFR.SOLVER = CN()
41
+ _C.MODEL.IFR.SOLVER.LR = 2e-4
42
+ _C.MODEL.IFR.SOLVER.BETAS = (0.5, 0.999)
43
+ _C.MODEL.IFR.SOLVER.SCHEDULER = []
44
+ _C.MODEL.IFR.SOLVER.DECAY_RATE = 0.
45
+ _C.MODEL.IFR.DS_FACTOR = 4
46
+
47
+ _C.MODEL.PATCH = CN()
48
+ _C.MODEL.PATCH.NUM_CHANNELS = 256
49
+ _C.MODEL.PATCH.NUM_PATCHES = 256
50
+ _C.MODEL.PATCH.NUM_LAYERS = 6
51
+ _C.MODEL.PATCH.USE_MLP = True
52
+ _C.MODEL.PATCH.SHUFFLE_Y = True
53
+ _C.MODEL.PATCH.LR = 1e-4
54
+ _C.MODEL.PATCH.BETAS = (0.5, 0.999)
55
+ _C.MODEL.PATCH.T = 0.07
56
+
57
+ _C.MODEL.D = CN()
58
+ _C.MODEL.D.NAME = "1-ChOutputDiscriminator"
59
+ _C.MODEL.D.NUM_CHANNELS = 32
60
+ _C.MODEL.D.NUM_CRITICS = 3
61
+ _C.MODEL.D.SOLVER = CN()
62
+ _C.MODEL.D.SOLVER.LR = 1e-4
63
+ _C.MODEL.D.SOLVER.BETAS = (0.5, 0.999)
64
+ _C.MODEL.D.SOLVER.SCHEDULER = []
65
+ _C.MODEL.D.SOLVER.DECAY_RATE = 0.01
66
+
67
+ _C.ESRGAN = CN()
68
+ _C.ESRGAN.WEIGHTS = "weights/RealESRGAN_x{}plus.pth"
69
+
70
+ _C.FASHIONMASKRCNN = CN()
71
+ _C.FASHIONMASKRCNN.CFG_PATH = "configs/fashion.yaml"
72
+ _C.FASHIONMASKRCNN.WEIGHTS = "weights/fashion.pth"
73
+ _C.FASHIONMASKRCNN.SCORE_THRESH_TEST = 0.6
74
+ _C.FASHIONMASKRCNN.MIN_SIZE_TEST = 512
75
+
76
+ _C.OPTIM = CN()
77
+ _C.OPTIM.GP = 10.
78
+ _C.OPTIM.MASK = 1
79
+ _C.OPTIM.RECON = 1.4
80
+ _C.OPTIM.SEMANTIC = 1e-1
81
+ _C.OPTIM.TEXTURE = 2e-1
82
+ _C.OPTIM.ADVERSARIAL = 1e-3
83
+ _C.OPTIM.AUX = 0.5
84
+ _C.OPTIM.CONTRASTIVE = 0.1
85
+ _C.OPTIM.NLL = 1.0
86
+
87
+ _C.DATASET = CN()
88
+ _C.DATASET.NAME = "IFFI"
89
+ _C.DATASET.ROOT = "../../Downloads/IFFI-dataset/train" # "../../Downloads/IFFI-dataset/train"
90
+ _C.DATASET.TEST_ROOT = "../../Datasets/IFFI-dataset/test" # "../../Downloads/IFFI-dataset/test"
91
+ _C.DATASET.DS_TEST_ROOT = "../../Downloads/IFFI-dataset/test/" # "../../Downloads/IFFI-dataset/test"
92
+ _C.DATASET.DS_JSON_FILE = "../../Downloads/IFFI-dataset-only-orgs/instances_default.json"
93
+ _C.DATASET.SIZE = 256
94
+ _C.DATASET.CROP_SIZE = 512
95
+ _C.DATASET.MEAN = [0.5, 0.5, 0.5]
96
+ _C.DATASET.STD = [0.5, 0.5, 0.5]
97
+
98
+ _C.TEST = CN()
99
+ _C.TEST.OUTPUT_DIR = "./outputs"
100
+ _C.TEST.ABLATION = False
101
+ _C.TEST.WEIGHTS = ""
102
+ _C.TEST.BATCH_SIZE = 32
103
+ _C.TEST.IMG_ID = 52
104
+
105
+
106
+ def get_cfg_defaults():
107
+ """Get a yacs CfgNode object with default values for my_project."""
108
+ # Return a clone so that the defaults will not be altered
109
+ # This is for the "local variable" use pattern
110
+ return _C.clone()
111
+
112
+
113
+ # provide a way to import the defaults as a global singleton:
114
+ cfg = _C # users can `from config import cfg`
images/examples/10_Nashville.jpg ADDED
images/examples/11_Sutro.jpg ADDED
images/examples/12_Toaster.jpg ADDED
images/examples/14_Willow.jpg ADDED
images/examples/15_X-ProII.jpg ADDED
images/examples/16_Lo-Fi.jpg ADDED
images/examples/18_Gingham.jpg ADDED
images/examples/1_Clarendon.jpg ADDED
images/examples/2_Brannan.jpg ADDED
images/examples/30_Perpetua.jpg ADDED
images/examples/3_Mayfair.jpg ADDED
images/examples/4_Hudson.jpg ADDED
images/examples/5_Amaro.jpg ADDED
images/examples/6_1977.jpg ADDED
images/examples/8_Valencia.jpg ADDED
images/examples/98_He-Fe.jpg ADDED
images/examples/9_Lo-Fi.jpg ADDED
layers/__pycache__/blocks.cpython-37.pyc ADDED
Binary file (2.93 kB). View file
 
layers/__pycache__/normalization.cpython-37.pyc ADDED
Binary file (859 Bytes). View file
 
layers/blocks.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from layers.normalization import AdaIN
4
+
5
+
6
+ class DestyleResBlock(nn.Module):
7
+ def __init__(self, channels_out, kernel_size, channels_in=None, stride=1, dilation=1, padding=1, use_dropout=False):
8
+ super(DestyleResBlock, self).__init__()
9
+
10
+ # uses 1x1 convolutions for downsampling
11
+ if not channels_in or channels_in == channels_out:
12
+ channels_in = channels_out
13
+ self.projection = None
14
+ else:
15
+ self.projection = nn.Conv2d(channels_in, channels_out, kernel_size=1, stride=stride, dilation=1)
16
+ self.use_dropout = use_dropout
17
+
18
+ self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
19
+ self.lrelu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
20
+ self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
21
+ self.adain = AdaIN()
22
+ if self.use_dropout:
23
+ self.dropout = nn.Dropout()
24
+ self.lrelu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
25
+
26
+ def forward(self, x, feat):
27
+ residual = x
28
+ out = self.conv1(x)
29
+ out = self.lrelu1(out)
30
+ out = self.conv2(out)
31
+ _, _, h, w = out.size()
32
+ out = self.adain(out, feat)
33
+ if self.use_dropout:
34
+ out = self.dropout(out)
35
+ if self.projection:
36
+ residual = self.projection(x)
37
+ out = out + residual
38
+ out = self.lrelu2(out)
39
+ return out
40
+
41
+
42
+ class ResBlock(nn.Module):
43
+ def __init__(self, channels_out, kernel_size, channels_in=None, stride=1, dilation=1, padding=1, use_dropout=False):
44
+ super(ResBlock, self).__init__()
45
+
46
+ # uses 1x1 convolutions for downsampling
47
+ if not channels_in or channels_in == channels_out:
48
+ channels_in = channels_out
49
+ self.projection = None
50
+ else:
51
+ self.projection = nn.Conv2d(channels_in, channels_out, kernel_size=1, stride=stride, dilation=1)
52
+ self.use_dropout = use_dropout
53
+
54
+ self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
55
+ self.lrelu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
56
+ self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
57
+ self.n2 = nn.BatchNorm2d(channels_out)
58
+ if self.use_dropout:
59
+ self.dropout = nn.Dropout()
60
+ self.lrelu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
61
+
62
+ def forward(self, x):
63
+ residual = x
64
+ out = self.conv1(x)
65
+ out = self.lrelu1(out)
66
+ out = self.conv2(out)
67
+ # out = self.n2(out)
68
+ if self.use_dropout:
69
+ out = self.dropout(out)
70
+ if self.projection:
71
+ residual = self.projection(x)
72
+ out = out + residual
73
+ out = self.lrelu2(out)
74
+ return out
75
+
76
+
77
+ class Destyler(nn.Module):
78
+ def __init__(self, in_features, num_features):
79
+ super(Destyler, self).__init__()
80
+ self.fc1 = nn.Linear(in_features, num_features)
81
+ self.fc2 = nn.Linear(num_features, num_features)
82
+ self.fc3 = nn.Linear(num_features, num_features)
83
+ self.fc4 = nn.Linear(num_features, num_features)
84
+ self.fc5 = nn.Linear(num_features, num_features)
85
+
86
+ def forward(self, x):
87
+ x = self.fc1(x)
88
+ x = self.fc2(x)
89
+ x = self.fc3(x)
90
+ x = self.fc4(x)
91
+ x = self.fc5(x)
92
+ return x
93
+
layers/normalization.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class AdaIN(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x, y):
10
+ ch = y.size(1)
11
+ sigma, mu = torch.split(y.unsqueeze(-1).unsqueeze(-1), [ch // 2, ch // 2], dim=1)
12
+
13
+ x_mu = x.mean(dim=[2, 3], keepdim=True)
14
+ x_sigma = x.std(dim=[2, 3], keepdim=True)
15
+
16
+ return sigma * ((x - x_mu) / x_sigma) + mu
modeling/__pycache__/arch.cpython-37.pyc ADDED
Binary file (8.89 kB). View file
 
modeling/__pycache__/base.cpython-37.pyc ADDED
Binary file (2.66 kB). View file
 
modeling/__pycache__/build.cpython-37.pyc ADDED
Binary file (1.22 kB). View file
 
modeling/arch.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn.utils import spectral_norm
4
+
5
+ from modeling.base import BaseNetwork
6
+ from layers.blocks import DestyleResBlock, Destyler, ResBlock
7
+
8
+
9
+ class IFRNet(BaseNetwork):
10
+ def __init__(self, base_n_channels, destyler_n_channels):
11
+ super(IFRNet, self).__init__()
12
+ self.destyler = Destyler(in_features=32768, num_features=destyler_n_channels) # from vgg features
13
+
14
+ self.ds_fc1 = nn.Linear(destyler_n_channels, base_n_channels * 2)
15
+ self.ds_res1 = DestyleResBlock(channels_in=3, channels_out=base_n_channels, kernel_size=5, stride=1, padding=2)
16
+ self.ds_fc2 = nn.Linear(destyler_n_channels, base_n_channels * 4)
17
+ self.ds_res2 = DestyleResBlock(channels_in=base_n_channels, channels_out=base_n_channels * 2, kernel_size=3, stride=2, padding=1)
18
+ self.ds_fc3 = nn.Linear(destyler_n_channels, base_n_channels * 4)
19
+ self.ds_res3 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
20
+ self.ds_fc4 = nn.Linear(destyler_n_channels, base_n_channels * 8)
21
+ self.ds_res4 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 4, kernel_size=3, stride=2, padding=1)
22
+ self.ds_fc5 = nn.Linear(destyler_n_channels, base_n_channels * 8)
23
+ self.ds_res5 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
24
+ self.ds_fc6 = nn.Linear(destyler_n_channels, base_n_channels * 16)
25
+ self.ds_res6 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 8, kernel_size=3, stride=2, padding=1)
26
+
27
+ self.upsample = nn.UpsamplingNearest2d(scale_factor=2.0)
28
+
29
+ self.res1 = ResBlock(channels_in=base_n_channels * 8, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
30
+ self.res2 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
31
+ self.res3 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
32
+ self.res4 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
33
+ self.res5 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels, kernel_size=3, stride=1, padding=1)
34
+
35
+ self.conv1 = nn.Conv2d(base_n_channels, 3, kernel_size=3, stride=1, padding=1)
36
+
37
+ self.init_weights(init_type="normal", gain=0.02)
38
+
39
+ def forward(self, x, vgg_feat):
40
+ b_size, ch, h, w = vgg_feat.size()
41
+ vgg_feat = vgg_feat.view(b_size, ch * h * w)
42
+ vgg_feat = self.destyler(vgg_feat)
43
+
44
+ out = self.ds_res1(x, self.ds_fc1(vgg_feat))
45
+ out = self.ds_res2(out, self.ds_fc2(vgg_feat))
46
+ out = self.ds_res3(out, self.ds_fc3(vgg_feat))
47
+ out = self.ds_res4(out, self.ds_fc4(vgg_feat))
48
+ out = self.ds_res5(out, self.ds_fc5(vgg_feat))
49
+ aux = self.ds_res6(out, self.ds_fc6(vgg_feat))
50
+
51
+ out = self.upsample(aux)
52
+ out = self.res1(out)
53
+ out = self.res2(out)
54
+ out = self.upsample(out)
55
+ out = self.res3(out)
56
+ out = self.res4(out)
57
+ out = self.upsample(out)
58
+ out = self.res5(out)
59
+ out = self.conv1(out)
60
+
61
+ return out, aux
62
+
63
+
64
+ class CIFR_Encoder(IFRNet):
65
+ def __init__(self, base_n_channels, destyler_n_channels):
66
+ super(CIFR_Encoder, self).__init__(base_n_channels, destyler_n_channels)
67
+
68
+ def forward(self, x, vgg_feat):
69
+ b_size, ch, h, w = vgg_feat.size()
70
+ vgg_feat = vgg_feat.view(b_size, ch * h * w)
71
+ vgg_feat = self.destyler(vgg_feat)
72
+
73
+ feat1 = self.ds_res1(x, self.ds_fc1(vgg_feat))
74
+ feat2 = self.ds_res2(feat1, self.ds_fc2(vgg_feat))
75
+ feat3 = self.ds_res3(feat2, self.ds_fc3(vgg_feat))
76
+ feat4 = self.ds_res4(feat3, self.ds_fc4(vgg_feat))
77
+ feat5 = self.ds_res5(feat4, self.ds_fc5(vgg_feat))
78
+ feat6 = self.ds_res6(feat5, self.ds_fc6(vgg_feat))
79
+
80
+ feats = [feat1, feat2, feat3, feat4, feat5, feat6]
81
+
82
+ out = self.upsample(feat6)
83
+ out = self.res1(out)
84
+ out = self.res2(out)
85
+ out = self.upsample(out)
86
+ out = self.res3(out)
87
+ out = self.res4(out)
88
+ out = self.upsample(out)
89
+ out = self.res5(out)
90
+ out = self.conv1(out)
91
+
92
+ return out, feats
93
+
94
+
95
+ class Normalize(nn.Module):
96
+ def __init__(self, power=2):
97
+ super(Normalize, self).__init__()
98
+ self.power = power
99
+
100
+ def forward(self, x):
101
+ norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
102
+ out = x.div(norm + 1e-7)
103
+ return out
104
+
105
+
106
+ class PatchSampleF(BaseNetwork):
107
+ def __init__(self, base_n_channels, style_or_content, use_mlp=False, nc=256):
108
+ # potential issues: currently, we use the same patch_ids for multiple images in the batch
109
+ super(PatchSampleF, self).__init__()
110
+ self.is_content = True if style_or_content == "content" else False
111
+ self.l2norm = Normalize(2)
112
+ self.use_mlp = use_mlp
113
+ self.nc = nc # hard-coded
114
+
115
+ self.mlp_0 = nn.Sequential(*[nn.Linear(base_n_channels, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
116
+ self.mlp_1 = nn.Sequential(*[nn.Linear(base_n_channels * 2, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
117
+ self.mlp_2 = nn.Sequential(*[nn.Linear(base_n_channels * 2, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
118
+ self.mlp_3 = nn.Sequential(*[nn.Linear(base_n_channels * 4, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
119
+ self.mlp_4 = nn.Sequential(*[nn.Linear(base_n_channels * 4, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
120
+ self.mlp_5 = nn.Sequential(*[nn.Linear(base_n_channels * 8, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
121
+ self.init_weights(init_type="normal", gain=0.02)
122
+
123
+ @staticmethod
124
+ def gram_matrix(x):
125
+ # a, b, c, d = x.size() # a=batch size(=1)
126
+ a, b = x.size()
127
+ # b=number of feature maps
128
+ # (c,d)=dimensions of a f. map (N=c*d)
129
+
130
+ # features = x.view(a * b, c * d) # resise F_XL into \hat F_XL
131
+
132
+ G = torch.mm(x, x.t()) # compute the gram product
133
+
134
+ # we 'normalize' the values of the gram matrix
135
+ # by dividing by the number of element in each feature maps.
136
+ return G.div(a * b)
137
+
138
+ def forward(self, feats, num_patches=64, patch_ids=None):
139
+ return_ids = []
140
+ return_feats = []
141
+
142
+ for feat_id, feat in enumerate(feats):
143
+ B, C, H, W = feat.shape
144
+ feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
145
+ if num_patches > 0:
146
+ if patch_ids is not None:
147
+ patch_id = patch_ids[feat_id]
148
+ else:
149
+ patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
150
+ patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
151
+ x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
152
+ else:
153
+ x_sample = feat_reshape
154
+ patch_id = []
155
+ if self.use_mlp:
156
+ mlp = getattr(self, 'mlp_%d' % feat_id)
157
+ x_sample = mlp(x_sample)
158
+ if not self.is_content:
159
+ x_sample = self.gram_matrix(x_sample)
160
+ return_ids.append(patch_id)
161
+ x_sample = self.l2norm(x_sample)
162
+
163
+ if num_patches == 0:
164
+ x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
165
+ return_feats.append(x_sample)
166
+ return return_feats, return_ids
167
+
168
+
169
+ class MLP(nn.Module):
170
+ def __init__(self, base_n_channels, out_features=14):
171
+ super(MLP, self).__init__()
172
+ self.aux_classifier = nn.Sequential(
173
+ nn.Conv2d(base_n_channels * 8, base_n_channels * 4, kernel_size=3, stride=1, padding=1),
174
+ nn.MaxPool2d(2),
175
+ nn.Conv2d(base_n_channels * 4, base_n_channels * 2, kernel_size=3, stride=1, padding=1),
176
+ nn.MaxPool2d(2),
177
+ # nn.Conv2d(base_n_channels * 2, base_n_channels * 1, kernel_size=3, stride=1, padding=1),
178
+ # nn.MaxPool2d(2),
179
+ Flatten(),
180
+ nn.Linear(base_n_channels * 8 * 8 * 2, out_features),
181
+ # nn.Softmax(dim=-1)
182
+ )
183
+
184
+ def forward(self, x):
185
+ return self.aux_classifier(x)
186
+
187
+
188
+ class Flatten(nn.Module):
189
+ def forward(self, input):
190
+ """
191
+ Note that input.size(0) is usually the batch size.
192
+ So what it does is that given any input with input.size(0) # of batches,
193
+ will flatten to be 1 * nb_elements.
194
+ """
195
+ batch_size = input.size(0)
196
+ out = input.view(batch_size, -1)
197
+ return out # (batch_size, *size)
198
+
199
+
200
+ class Discriminator(BaseNetwork):
201
+ def __init__(self, base_n_channels):
202
+ """
203
+ img_size : (int, int, int)
204
+ Height and width must be powers of 2. E.g. (32, 32, 1) or
205
+ (64, 128, 3). Last number indicates number of channels, e.g. 1 for
206
+ grayscale or 3 for RGB
207
+ """
208
+ super(Discriminator, self).__init__()
209
+
210
+ self.image_to_features = nn.Sequential(
211
+ spectral_norm(nn.Conv2d(3, base_n_channels, 5, 2, 2)),
212
+ nn.LeakyReLU(0.2, inplace=True),
213
+ spectral_norm(nn.Conv2d(base_n_channels, 2 * base_n_channels, 5, 2, 2)),
214
+ nn.LeakyReLU(0.2, inplace=True),
215
+ spectral_norm(nn.Conv2d(2 * base_n_channels, 2 * base_n_channels, 5, 2, 2)),
216
+ nn.LeakyReLU(0.2, inplace=True),
217
+ spectral_norm(nn.Conv2d(2 * base_n_channels, 4 * base_n_channels, 5, 2, 2)),
218
+ nn.LeakyReLU(0.2, inplace=True),
219
+ # spectral_norm(nn.Conv2d(4 * base_n_channels, 4 * base_n_channels, 5, 2, 2)),
220
+ # nn.LeakyReLU(0.2, inplace=True),
221
+ spectral_norm(nn.Conv2d(4 * base_n_channels, 8 * base_n_channels, 5, 1, 1)),
222
+ nn.LeakyReLU(0.2, inplace=True),
223
+ )
224
+
225
+ output_size = 8 * base_n_channels * 3 * 3
226
+ self.features_to_prob = nn.Sequential(
227
+ spectral_norm(nn.Conv2d(8 * base_n_channels, 2 * base_n_channels, 5, 2, 1)),
228
+ Flatten(),
229
+ nn.Linear(output_size, 1)
230
+ )
231
+
232
+ self.init_weights(init_type="normal", gain=0.02)
233
+
234
+ def forward(self, input_data):
235
+ x = self.image_to_features(input_data)
236
+ return self.features_to_prob(x)
237
+
238
+
239
+ class PatchDiscriminator(Discriminator):
240
+ def __init__(self, base_n_channels):
241
+ super(PatchDiscriminator, self).__init__(base_n_channels)
242
+
243
+ self.features_to_prob = nn.Sequential(
244
+ spectral_norm(nn.Conv2d(8 * base_n_channels, 1, 1)),
245
+ Flatten()
246
+ )
247
+
248
+ def forward(self, input_data):
249
+ x = self.image_to_features(input_data)
250
+ return self.features_to_prob(x)
251
+
252
+
253
+ if __name__ == '__main__':
254
+ import torchvision
255
+ ifrnet = CIFR_Encoder(32, 128).cuda()
256
+ x = torch.rand((2, 3, 256, 256)).cuda()
257
+ vgg16 = torchvision.models.vgg16(pretrained=True).features.eval().cuda()
258
+ with torch.no_grad():
259
+ vgg_feat = vgg16(x)
260
+ output, feats = ifrnet(x, vgg_feat)
261
+ print(output.size())
262
+ for i, feat in enumerate(feats):
263
+ print(i, feat.size())
264
+
265
+ disc = Discriminator(32).cuda()
266
+ d_out = disc(output)
267
+ print(d_out.size())
268
+
269
+ patch_disc = PatchDiscriminator(32).cuda()
270
+ p_d_out = patch_disc(output)
271
+ print(p_d_out.size())
272
+
modeling/base.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class BaseNetwork(nn.Module):
5
+ def __init__(self):
6
+ super(BaseNetwork, self).__init__()
7
+
8
+ def forward(self, x, y):
9
+ pass
10
+
11
+ def print_network(self):
12
+ if isinstance(self, list):
13
+ self = self[0]
14
+ num_params = 0
15
+ for param in self.parameters():
16
+ num_params += param.numel()
17
+ print('Network [%s] was created. Total number of parameters: %.1f million. '
18
+ 'To see the architecture, do print(network).'
19
+ % (type(self).__name__, num_params / 1000000))
20
+
21
+ def set_requires_grad(self, requires_grad=False):
22
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
23
+ Parameters:
24
+ requires_grad (bool) -- whether the networks require gradients or not
25
+ """
26
+ for param in self.parameters():
27
+ param.requires_grad = requires_grad
28
+
29
+ def init_weights(self, init_type='xavier', gain=0.02):
30
+ def init_func(m):
31
+ classname = m.__class__.__name__
32
+ if classname.find('BatchNorm2d') != -1:
33
+ if hasattr(m, 'weight') and m.weight is not None:
34
+ nn.init.normal_(m.weight.data, 1.0, gain)
35
+ if hasattr(m, 'bias') and m.bias is not None:
36
+ nn.init.constant_(m.bias.data, 0.0)
37
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
38
+ if init_type == 'normal':
39
+ nn.init.normal_(m.weight.data, 0.0, gain)
40
+ elif init_type == 'xavier':
41
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
42
+ elif init_type == 'xavier_uniform':
43
+ nn.init.xavier_uniform_(m.weight.data, gain=1.0)
44
+ elif init_type == 'kaiming':
45
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
46
+ elif init_type == 'orthogonal':
47
+ nn.init.orthogonal_(m.weight.data, gain=gain)
48
+ elif init_type == 'none': # uses pytorch's default init method
49
+ m.reset_parameters()
50
+ else:
51
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
52
+ if hasattr(m, 'bias') and m.bias is not None:
53
+ nn.init.constant_(m.bias.data, 0.0)
54
+
55
+ self.apply(init_func)
56
+
57
+ # propagate to children
58
+ for m in self.children():
59
+ if hasattr(m, 'init_weights'):
60
+ m.init_weights(init_type, gain)
modeling/build.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modeling.arch import IFRNet, CIFR_Encoder, Discriminator, PatchDiscriminator, MLP, PatchSampleF
2
+
3
+
4
+ def build_model(args):
5
+ if args.MODEL.NAME.lower() == "ifrnet":
6
+ net = IFRNet(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS)
7
+ mlp = MLP(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, out_features=args.MODEL.NUM_CLASS)
8
+ elif args.MODEL.NAME.lower() == "cifr":
9
+ net = CIFR_Encoder(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS)
10
+ mlp = None
11
+ elif args.MODEL.NAME.lower() == "ifr-no-aux":
12
+ net = IFRNet(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS)
13
+ mlp = None
14
+ else:
15
+ raise NotImplementedError
16
+ return net, mlp
17
+
18
+
19
+ def build_discriminators(args):
20
+ return Discriminator(base_n_channels=args.MODEL.D.NUM_CHANNELS), PatchDiscriminator(base_n_channels=args.MODEL.D.NUM_CHANNELS)
21
+
22
+
23
+ def build_patch_sampler(args):
24
+ return PatchSampleF(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, style_or_content="content", use_mlp=args.MODEL.PATCH.USE_MLP, nc=args.MODEL.PATCH.NUM_CHANNELS), \
25
+ PatchSampleF(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, style_or_content="style", use_mlp=args.MODEL.PATCH.USE_MLP, nc=args.MODEL.PATCH.NUM_CHANNELS)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==2.9.4
2
+ numpy==1.21.2
3
+ requests==2.27.1
4
+ torch==1.10.1
5
+ torchvision==0.11.2
6
+ yacs==0.1.8
utils/__pycache__/data_utils.cpython-37.pyc ADDED
Binary file (1.07 kB). View file
 
utils/data_utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+
4
+ def linear_scaling(x):
5
+ return (x * 255.) / 127.5 - 1.
6
+
7
+
8
+ def linear_unscaling(x):
9
+ return (x + 1.) * 127.5 / 255.
10
+
11
+
12
+ def read_json(path):
13
+ """
14
+ :param path (str or os.Path): JSON file path.
15
+ :return: (Dict): the data in the JSON file.
16
+ """
17
+ with open(path) as f:
18
+ data = json.load(f)
19
+ return data
20
+
21
+
22
+ def write_json(path, datagroup):
23
+ """
24
+ :param path (str or os.Path): File path for the output JSON file.
25
+ :param datagroup (Dict): The data which should be dump to the JSON file.
26
+ :return: void.
27
+ """
28
+ with open(path, "w+") as f:
29
+ json.dump(datagroup, f)