AlekseyKorshuk commited on
Commit
1cae80b
β€’
1 Parent(s): 61020b4

First commit

Browse files
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Instagram Filter Removal
3
- emoji: 🐨
4
- colorFrom: purple
5
- colorTo: pink
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
@@ -10,28 +10,28 @@ pinned: false
10
 
11
  # Configuration
12
 
13
- `title`: _string_
14
  Display title for the Space
15
 
16
- `emoji`: _string_
17
  Space emoji (emoji-only character allowed)
18
 
19
- `colorFrom`: _string_
20
  Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
 
22
- `colorTo`: _string_
23
  Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
 
25
- `sdk`: _string_
26
- Can be either `gradio`, `streamlit`, or `static`
27
 
28
- `sdk_version` : _string_
29
  Only applicable for `streamlit` SDK.
30
  See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
 
32
- `app_file`: _string_
33
- Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
34
  Path is relative to the root of the repository.
35
 
36
- `pinned`: _boolean_
37
  Whether the Space stays on top of your list.
1
  ---
2
  title: Instagram Filter Removal
3
+ emoji: πŸ‘€
4
+ colorFrom: gray
5
+ colorTo: green
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
10
 
11
  # Configuration
12
 
13
+ `title`: _string_
14
  Display title for the Space
15
 
16
+ `emoji`: _string_
17
  Space emoji (emoji-only character allowed)
18
 
19
+ `colorFrom`: _string_
20
  Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
 
22
+ `colorTo`: _string_
23
  Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
 
25
+ `sdk`: _string_
26
+ Can be either `gradio` or `streamlit`
27
 
28
+ `sdk_version` : _string_
29
  Only applicable for `streamlit` SDK.
30
  See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
 
32
+ `app_file`: _string_
33
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
34
  Path is relative to the root of the repository.
35
 
36
+ `pinned`: _boolean_
37
  Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.models as models
7
+
8
+ from configs.default import get_cfg_defaults
9
+ from modeling.build import build_model
10
+ from utils.data_utils import linear_scaling
11
+
12
+
13
+ url = "https://www.dropbox.com/s/y97z812sxa1kvrg/ifrnet.pth?dl=1"
14
+ r = requests.get(url, stream=True)
15
+ if not os.path.exists("ifrnet.pth"):
16
+ with open("ifrnet.pth", 'wb') as f:
17
+ for data in r:
18
+ f.write(data)
19
+
20
+ cfg = get_cfg_defaults()
21
+ cfg.MODEL.CKPT = "ifrnet.pth"
22
+ net, _ = build_model(cfg)
23
+ net = net.eval()
24
+ vgg16 = models.vgg16(pretrained=True).features.eval()
25
+
26
+
27
+ def load_checkpoints_from_ckpt(ckpt_path):
28
+ checkpoints = torch.load(ckpt_path, map_location=torch.device('cpu'))
29
+ net.load_state_dict(checkpoints["ifr"])
30
+
31
+
32
+ load_checkpoints_from_ckpt(cfg.MODEL.CKPT)
33
+
34
+
35
+ def filter_removal(img):
36
+ arr = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
37
+ arr = torch.tensor(arr).float() / 255.
38
+ arr = linear_scaling(arr)
39
+ with torch.no_grad():
40
+ feat = vgg16(arr)
41
+ out, _ = net(arr, feat)
42
+ out = torch.clamp(out, max=1., min=0.)
43
+ return out.squeeze(0).permute(1, 2, 0).numpy()
44
+
45
+
46
+ title = "Instagram Filter Removal on Fashionable Images"
47
+ description = "This is the demo for IFRNet, filter removal on fashionable images on Instagram. " \
48
+ "To use it, simply upload your filtered image, or click one of the examples to load them."
49
+ article = "<p style='text-align: center'><a href='https://openaccess.thecvf.com/content/CVPR2021W/NTIRE/papers/Kinli_Instagram_Filter_Removal_on_Fashionable_Images_CVPRW_2021_paper.pdf'>Paper</a> | <a href='https://github.com/birdortyedi/instagram-filter-removal-pytorch'>Github</a></p>"
50
+
51
+ gr.Interface(
52
+ filter_removal,
53
+ gr.inputs.Image(shape=(256, 256)),
54
+ gr.outputs.Image(),
55
+ title=title,
56
+ description=description,
57
+ article=article,
58
+ allow_flagging=False,
59
+ examples_per_page=17,
60
+ examples=[
61
+ ["images/examples/98_He-Fe.jpg"],
62
+ ["images/examples/2_Brannan.jpg"],
63
+ ["images/examples/12_Toaster.jpg"],
64
+ ["images/examples/18_Gingham.jpg"],
65
+ ["images/examples/11_Sutro.jpg"],
66
+ ["images/examples/9_Lo-Fi.jpg"],
67
+ ["images/examples/3_Mayfair.jpg"],
68
+ ["images/examples/4_Hudson.jpg"],
69
+ ["images/examples/5_Amaro.jpg"],
70
+ ["images/examples/6_1977.jpg"],
71
+ ["images/examples/8_Valencia.jpg"],
72
+ ["images/examples/16_Lo-Fi.jpg"],
73
+ ["images/examples/10_Nashville.jpg"],
74
+ ["images/examples/15_X-ProII.jpg"],
75
+ ["images/examples/14_Willow.jpg"],
76
+ ["images/examples/30_Perpetua.jpg"],
77
+ ["images/examples/1_Clarendon.jpg"],
78
+ ]
79
+ ).launch()
config.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ DATASET:
4
+ desc: null
5
+ value:
6
+ MEAN:
7
+ - 0.5
8
+ - 0.5
9
+ - 0.5
10
+ NAME: IFFI
11
+ ROOT: ../../Downloads/IFFI-dataset/train
12
+ SIZE: 256
13
+ STD:
14
+ - 0.5
15
+ - 0.5
16
+ - 0.5
17
+ TEST_ROOT: ../../Downloads/IFFI-dataset/test
18
+ MODEL:
19
+ desc: null
20
+ value:
21
+ D:
22
+ NAME: 1-ChOutputDiscriminator
23
+ NUM_CHANNELS: 32
24
+ NUM_CRITICS: 5
25
+ SOLVER:
26
+ BETAS:
27
+ - 0.5
28
+ - 0.9
29
+ DECAY_RATE: 0.5
30
+ LR: 0.001
31
+ SCHEDULER: []
32
+ IFR:
33
+ DESTYLER_CHANNELS: 32
34
+ NAME: InstaFilterRemovalNetwork
35
+ NUM_CHANNELS: 32
36
+ SOLVER:
37
+ BETAS:
38
+ - 0.5
39
+ - 0.9
40
+ DECAY_RATE: 0
41
+ LR: 0.0002
42
+ SCHEDULER: []
43
+ IS_TRAIN: true
44
+ NAME: ifrnet
45
+ NUM_CLASS: 17
46
+ OPTIM:
47
+ desc: null
48
+ value:
49
+ ADVERSARIAL: 0.001
50
+ AUX: 0.5
51
+ GP: 10
52
+ MASK: 1
53
+ RECON: 1.4
54
+ SEMANTIC: 0.0001
55
+ TEXTURE: 0.001
56
+ SYSTEM:
57
+ desc: null
58
+ value:
59
+ NUM_GPU: 2
60
+ NUM_WORKERS: 4
61
+ TEST:
62
+ desc: null
63
+ value:
64
+ ABLATION: false
65
+ BATCH_SIZE: 64
66
+ IMG_ID: 52
67
+ OUTPUT_DIR: ./outputs
68
+ WEIGHTS: ''
69
+ TRAIN:
70
+ desc: null
71
+ value:
72
+ BATCH_SIZE: 8
73
+ IS_TRAIN: true
74
+ LOG_INTERVAL: 100
75
+ NUM_TOTAL_STEP: 120000
76
+ RESUME: true
77
+ SAVE_DIR: ./weights
78
+ SAVE_INTERVAL: 1000
79
+ SHUFFLE: true
80
+ START_STEP: 0
81
+ TUNE: false
82
+ VISUALIZE_INTERVAL: 100
83
+ WANDB:
84
+ desc: null
85
+ value:
86
+ ENTITY: vvgl-ozu
87
+ LOG_DIR: ./logs/ifrnet_IFFI_120000step_8bs_0.0002lr_2gpu_9run
88
+ NUM_ROW: 0
89
+ PROJECT_NAME: instagram-filter-removal
90
+ RUN: 9
91
+ _wandb:
92
+ desc: null
93
+ value:
94
+ cli_version: 0.9.1
95
+ framework: torch
96
+ is_jupyter_run: false
97
+ is_kaggle_kernel: false
98
+ python_version: 3.6.9
configs/default.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+
3
+ _C = CN()
4
+
5
+ _C.SYSTEM = CN()
6
+ _C.SYSTEM.NUM_GPU = 2
7
+ _C.SYSTEM.NUM_WORKERS = 4
8
+
9
+ _C.WANDB = CN()
10
+ _C.WANDB.PROJECT_NAME = "instagram-filter-removal"
11
+ _C.WANDB.ENTITY = "vvgl-ozu"
12
+ _C.WANDB.RUN = 12
13
+ _C.WANDB.LOG_DIR = ""
14
+ _C.WANDB.NUM_ROW = 0
15
+
16
+ _C.TRAIN = CN()
17
+ _C.TRAIN.NUM_TOTAL_STEP = 120000
18
+ _C.TRAIN.START_STEP = 0
19
+ _C.TRAIN.BATCH_SIZE = 8
20
+ _C.TRAIN.SHUFFLE = True
21
+ _C.TRAIN.LOG_INTERVAL = 100
22
+ _C.TRAIN.SAVE_INTERVAL = 5000
23
+ _C.TRAIN.SAVE_DIR = "./weights"
24
+ _C.TRAIN.RESUME = True
25
+ _C.TRAIN.VISUALIZE_INTERVAL = 100
26
+ _C.TRAIN.TUNE = False
27
+
28
+ _C.MODEL = CN()
29
+ _C.MODEL.NAME = "ifr-no-aux"
30
+ _C.MODEL.IS_TRAIN = True
31
+ _C.MODEL.NUM_CLASS = 17
32
+ _C.MODEL.CKPT = ""
33
+
34
+ _C.MODEL.IFR = CN()
35
+ _C.MODEL.IFR.NAME = "InstaFilterRemovalNetwork"
36
+ _C.MODEL.IFR.NUM_CHANNELS = 32
37
+ _C.MODEL.IFR.DESTYLER_CHANNELS = 32
38
+ _C.MODEL.IFR.SOLVER = CN()
39
+ _C.MODEL.IFR.SOLVER.LR = 2e-4
40
+ _C.MODEL.IFR.SOLVER.BETAS = (0.5, 0.9)
41
+ _C.MODEL.IFR.SOLVER.SCHEDULER = []
42
+ _C.MODEL.IFR.SOLVER.DECAY_RATE = 0.
43
+
44
+ _C.MODEL.D = CN()
45
+ _C.MODEL.D.NAME = "1-ChOutputDiscriminator"
46
+ _C.MODEL.D.NUM_CHANNELS = 32
47
+ _C.MODEL.D.NUM_CRITICS = 5
48
+ _C.MODEL.D.SOLVER = CN()
49
+ _C.MODEL.D.SOLVER.LR = 1e-3
50
+ _C.MODEL.D.SOLVER.BETAS = (0.5, 0.9)
51
+ _C.MODEL.D.SOLVER.SCHEDULER = []
52
+ _C.MODEL.D.SOLVER.DECAY_RATE = 0.5
53
+
54
+ _C.OPTIM = CN()
55
+ _C.OPTIM.GP = 10
56
+ _C.OPTIM.MASK = 1
57
+ _C.OPTIM.RECON = 1.4
58
+ _C.OPTIM.SEMANTIC = 1e-4
59
+ _C.OPTIM.TEXTURE = 1e-3
60
+ _C.OPTIM.ADVERSARIAL = 1e-3
61
+ _C.OPTIM.AUX = 0.5
62
+
63
+ _C.DATASET = CN()
64
+ _C.DATASET.NAME = "IFFI" # "IFFI" # "DIV2K?" #
65
+ _C.DATASET.ROOT = "../../Datasets/IFFI-dataset/train" # "../../Datasets/IFFI-dataset" # "/media/birdortyedi/e5042b8f-ca5e-4a22-ac68-7e69ff648bc4/IFFI-dataset"
66
+ _C.DATASET.TEST_ROOT = "../../Datasets/IFFI-dataset"
67
+ _C.DATASET.SIZE = 256
68
+ _C.DATASET.CROP_SIZE = 512
69
+ _C.DATASET.MEAN = [0.5, 0.5, 0.5]
70
+ _C.DATASET.STD = [0.5, 0.5, 0.5]
71
+
72
+ _C.TEST = CN()
73
+ _C.TEST.OUTPUT_DIR = "./outputs"
74
+ _C.TEST.ABLATION = False
75
+ _C.TEST.WEIGHTS = ""
76
+ _C.TEST.BATCH_SIZE = 64
77
+ _C.TEST.IMG_ID = 52
78
+
79
+
80
+ def get_cfg_defaults():
81
+ """Get a yacs CfgNode object with default values for my_project."""
82
+ # Return a clone so that the defaults will not be altered
83
+ # This is for the "local variable" use pattern
84
+ return _C.clone()
85
+
86
+
87
+ # provide a way to import the defaults as a global singleton:
88
+ cfg = _C # users can `from config import cfg`
images/examples/10_Nashville.jpg ADDED
images/examples/11_Sutro.jpg ADDED
images/examples/12_Toaster.jpg ADDED
images/examples/14_Willow.jpg ADDED
images/examples/15_X-ProII.jpg ADDED
images/examples/16_Lo-Fi.jpg ADDED
images/examples/18_Gingham.jpg ADDED
images/examples/1_Clarendon.jpg ADDED
images/examples/2_Brannan.jpg ADDED
images/examples/30_Perpetua.jpg ADDED
images/examples/3_Mayfair.jpg ADDED
images/examples/4_Hudson.jpg ADDED
images/examples/5_Amaro.jpg ADDED
images/examples/6_1977.jpg ADDED
images/examples/8_Valencia.jpg ADDED
images/examples/98_He-Fe.jpg ADDED
images/examples/9_Lo-Fi.jpg ADDED
modeling/base.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class BaseNetwork(nn.Module):
5
+ def __init__(self):
6
+ super(BaseNetwork, self).__init__()
7
+
8
+ def forward(self, x, y):
9
+ pass
10
+
11
+ def print_network(self):
12
+ if isinstance(self, list):
13
+ self = self[0]
14
+ num_params = 0
15
+ for param in self.parameters():
16
+ num_params += param.numel()
17
+ print('Network [%s] was created. Total number of parameters: %.1f million. '
18
+ 'To see the architecture, do print(network).'
19
+ % (type(self).__name__, num_params / 1000000))
20
+
21
+ def set_requires_grad(self, requires_grad=False):
22
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
23
+ Parameters:
24
+ requires_grad (bool) -- whether the networks require gradients or not
25
+ """
26
+ for param in self.parameters():
27
+ param.requires_grad = requires_grad
28
+
29
+ def init_weights(self, init_type='xavier', gain=0.02):
30
+ def init_func(m):
31
+ classname = m.__class__.__name__
32
+ if classname.find('BatchNorm2d') != -1:
33
+ if hasattr(m, 'weight') and m.weight is not None:
34
+ nn.init.normal_(m.weight.data, 1.0, gain)
35
+ if hasattr(m, 'bias') and m.bias is not None:
36
+ nn.init.constant_(m.bias.data, 0.0)
37
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
38
+ if init_type == 'normal':
39
+ nn.init.normal_(m.weight.data, 0.0, gain)
40
+ elif init_type == 'xavier':
41
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
42
+ elif init_type == 'xavier_uniform':
43
+ nn.init.xavier_uniform_(m.weight.data, gain=1.0)
44
+ elif init_type == 'kaiming':
45
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
46
+ elif init_type == 'orthogonal':
47
+ nn.init.orthogonal_(m.weight.data, gain=gain)
48
+ elif init_type == 'none': # uses pytorch's default init method
49
+ m.reset_parameters()
50
+ else:
51
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
52
+ if hasattr(m, 'bias') and m.bias is not None:
53
+ nn.init.constant_(m.bias.data, 0.0)
54
+
55
+ self.apply(init_func)
56
+
57
+ # propagate to children
58
+ for m in self.children():
59
+ if hasattr(m, 'init_weights'):
60
+ m.init_weights(init_type, gain)
modeling/benchmark.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from modeling.base import BaseNetwork
5
+ from modeling.ifrnet import Flatten
6
+ from modules.blocks import DestyleResBlock, Destyler, ResBlock
7
+
8
+
9
+ class UNet(BaseNetwork):
10
+ def __init__(self, base_n_channels):
11
+ super(UNet, self).__init__()
12
+
13
+ self.ds_res1 = ResBlock(channels_in=3, channels_out=base_n_channels, kernel_size=5, stride=1, padding=2)
14
+ self.ds_res2 = ResBlock(channels_in=base_n_channels, channels_out=base_n_channels * 2, kernel_size=3, stride=2, padding=1)
15
+ self.ds_res3 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
16
+ self.ds_res4 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 4, kernel_size=3, stride=2, padding=1)
17
+ self.ds_res5 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
18
+ self.ds_res6 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 8, kernel_size=3, stride=2, padding=1)
19
+
20
+ self.upsample = nn.UpsamplingNearest2d(scale_factor=2.0)
21
+
22
+ self.res1 = ResBlock(channels_in=base_n_channels * 8, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
23
+ self.res2 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
24
+ self.res3 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
25
+ self.res4 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
26
+ self.res5 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels, kernel_size=3, stride=1, padding=1)
27
+
28
+ self.conv1 = nn.Conv2d(base_n_channels, 3, kernel_size=3, stride=1, padding=1)
29
+
30
+ self.init_weights(init_type="normal", gain=0.02)
31
+
32
+ def forward(self, x):
33
+ out = self.ds_res1(x)
34
+ out = self.ds_res2(out)
35
+ out = self.ds_res3(out)
36
+ out = self.ds_res4(out)
37
+ out = self.ds_res5(out)
38
+ aux = self.ds_res6(out)
39
+
40
+ out = self.upsample(aux)
41
+ out = self.res1(out)
42
+ out = self.res2(out)
43
+ out = self.upsample(out)
44
+ out = self.res3(out)
45
+ out = self.res4(out)
46
+ out = self.upsample(out)
47
+ out = self.res5(out)
48
+ out = self.conv1(out)
49
+
50
+ return out, aux
51
+
52
+
53
+ if __name__ == '__main__':
54
+ import torchvision
55
+ x = torch.rand((2, 3, 256, 256)).cuda()
56
+ unet = UNet(32, 32).cuda()
57
+ vgg16 = torchvision.models.vgg16(pretrained=True).features.eval().cuda()
58
+ with torch.no_grad():
59
+ vgg_feat = vgg16(x)
60
+ out = unet(x, vgg_feat)
61
+
62
+ print(out.size())
modeling/build.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modeling.ifrnet import IFRNet, Discriminator, PatchDiscriminator, MLP
2
+ from modeling.benchmark import UNet
3
+
4
+
5
+ def build_model(args):
6
+ if args.MODEL.NAME.lower() == "ifrnet":
7
+ net = IFRNet(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS)
8
+ mlp = MLP(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, num_class=args.MODEL.NUM_CLASS)
9
+ elif args.MODEL.NAME.lower() == "ifr-no-aux":
10
+ net = IFRNet(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS)
11
+ mlp = None
12
+ else:
13
+ raise NotImplementedError
14
+ return net, mlp
15
+
16
+
17
+ def build_discriminators(args):
18
+ return Discriminator(base_n_channels=args.MODEL.D.NUM_CHANNELS), PatchDiscriminator(base_n_channels=args.MODEL.D.NUM_CHANNELS)
19
+
modeling/ifrnet.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn.utils import spectral_norm
4
+
5
+ from modeling.base import BaseNetwork
6
+ from modules.blocks import DestyleResBlock, Destyler, ResBlock
7
+
8
+
9
+ class IFRNet(BaseNetwork):
10
+ def __init__(self, base_n_channels, destyler_n_channels):
11
+ super(IFRNet, self).__init__()
12
+ self.destyler = Destyler(in_features=32768, num_features=destyler_n_channels) # from vgg features
13
+
14
+ self.ds_fc1 = nn.Linear(destyler_n_channels, base_n_channels * 2)
15
+ self.ds_res1 = DestyleResBlock(channels_in=3, channels_out=base_n_channels, kernel_size=5, stride=1, padding=2)
16
+ self.ds_fc2 = nn.Linear(destyler_n_channels, base_n_channels * 4)
17
+ self.ds_res2 = DestyleResBlock(channels_in=base_n_channels, channels_out=base_n_channels * 2, kernel_size=3, stride=2, padding=1)
18
+ self.ds_fc3 = nn.Linear(destyler_n_channels, base_n_channels * 4)
19
+ self.ds_res3 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
20
+ self.ds_fc4 = nn.Linear(destyler_n_channels, base_n_channels * 8)
21
+ self.ds_res4 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 4, kernel_size=3, stride=2, padding=1)
22
+ self.ds_fc5 = nn.Linear(destyler_n_channels, base_n_channels * 8)
23
+ self.ds_res5 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
24
+ self.ds_fc6 = nn.Linear(destyler_n_channels, base_n_channels * 16)
25
+ self.ds_res6 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 8, kernel_size=3, stride=2, padding=1)
26
+
27
+ self.upsample = nn.UpsamplingNearest2d(scale_factor=2.0)
28
+
29
+ self.res1 = ResBlock(channels_in=base_n_channels * 8, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
30
+ self.res2 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
31
+ self.res3 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
32
+ self.res4 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
33
+ self.res5 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels, kernel_size=3, stride=1, padding=1)
34
+
35
+ self.conv1 = nn.Conv2d(base_n_channels, 3, kernel_size=3, stride=1, padding=1)
36
+
37
+ self.init_weights(init_type="normal", gain=0.02)
38
+
39
+ def forward(self, x, vgg_feat):
40
+ b_size, ch, h, w = vgg_feat.size()
41
+ vgg_feat = vgg_feat.view(b_size, ch * h * w)
42
+ vgg_feat = self.destyler(vgg_feat)
43
+
44
+ out = self.ds_res1(x, self.ds_fc1(vgg_feat))
45
+ out = self.ds_res2(out, self.ds_fc2(vgg_feat))
46
+ out = self.ds_res3(out, self.ds_fc3(vgg_feat))
47
+ out = self.ds_res4(out, self.ds_fc4(vgg_feat))
48
+ out = self.ds_res5(out, self.ds_fc5(vgg_feat))
49
+ aux = self.ds_res6(out, self.ds_fc6(vgg_feat))
50
+
51
+ out = self.upsample(aux)
52
+ out = self.res1(out)
53
+ out = self.res2(out)
54
+ out = self.upsample(out)
55
+ out = self.res3(out)
56
+ out = self.res4(out)
57
+ out = self.upsample(out)
58
+ out = self.res5(out)
59
+ out = self.conv1(out)
60
+
61
+ return out, aux
62
+
63
+
64
+ class MLP(nn.Module):
65
+ def __init__(self, base_n_channels, num_class=14):
66
+ super(MLP, self).__init__()
67
+ self.aux_classifier = nn.Sequential(
68
+ nn.Conv2d(base_n_channels * 8, base_n_channels * 4, kernel_size=3, stride=1, padding=1),
69
+ nn.MaxPool2d(2),
70
+ nn.Conv2d(base_n_channels * 4, base_n_channels * 2, kernel_size=3, stride=1, padding=1),
71
+ nn.MaxPool2d(2),
72
+ # nn.Conv2d(base_n_channels * 2, base_n_channels * 1, kernel_size=3, stride=1, padding=1),
73
+ # nn.MaxPool2d(2),
74
+ Flatten(),
75
+ nn.Linear(base_n_channels * 8 * 8 * 2, num_class),
76
+ # nn.Softmax(dim=-1)
77
+ )
78
+
79
+ def forward(self, x):
80
+ return self.aux_classifier(x)
81
+
82
+
83
+ class Flatten(nn.Module):
84
+ def forward(self, input):
85
+ """
86
+ Note that input.size(0) is usually the batch size.
87
+ So what it does is that given any input with input.size(0) # of batches,
88
+ will flatten to be 1 * nb_elements.
89
+ """
90
+ batch_size = input.size(0)
91
+ out = input.view(batch_size, -1)
92
+ return out # (batch_size, *size)
93
+
94
+
95
+ class Discriminator(BaseNetwork):
96
+ def __init__(self, base_n_channels):
97
+ """
98
+ img_size : (int, int, int)
99
+ Height and width must be powers of 2. E.g. (32, 32, 1) or
100
+ (64, 128, 3). Last number indicates number of channels, e.g. 1 for
101
+ grayscale or 3 for RGB
102
+ """
103
+ super(Discriminator, self).__init__()
104
+
105
+ self.image_to_features = nn.Sequential(
106
+ spectral_norm(nn.Conv2d(3, base_n_channels, 5, 2, 2)),
107
+ nn.LeakyReLU(0.2, inplace=True),
108
+ spectral_norm(nn.Conv2d(base_n_channels, 2 * base_n_channels, 5, 2, 2)),
109
+ nn.LeakyReLU(0.2, inplace=True),
110
+ spectral_norm(nn.Conv2d(2 * base_n_channels, 2 * base_n_channels, 5, 2, 2)),
111
+ nn.LeakyReLU(0.2, inplace=True),
112
+ spectral_norm(nn.Conv2d(2 * base_n_channels, 4 * base_n_channels, 5, 2, 2)),
113
+ nn.LeakyReLU(0.2, inplace=True),
114
+ # spectral_norm(nn.Conv2d(4 * base_n_channels, 4 * base_n_channels, 5, 2, 2)),
115
+ # nn.LeakyReLU(0.2, inplace=True),
116
+ spectral_norm(nn.Conv2d(4 * base_n_channels, 8 * base_n_channels, 5, 1, 1)),
117
+ nn.LeakyReLU(0.2, inplace=True),
118
+ )
119
+
120
+ output_size = 8 * base_n_channels * 3 * 3
121
+ self.features_to_prob = nn.Sequential(
122
+ spectral_norm(nn.Conv2d(8 * base_n_channels, 2 * base_n_channels, 5, 2, 1)),
123
+ Flatten(),
124
+ nn.Linear(output_size, 1)
125
+ )
126
+
127
+ self.init_weights(init_type="normal", gain=0.02)
128
+
129
+ def forward(self, input_data):
130
+ x = self.image_to_features(input_data)
131
+ return self.features_to_prob(x)
132
+
133
+
134
+ class PatchDiscriminator(Discriminator):
135
+ def __init__(self, base_n_channels):
136
+ super(PatchDiscriminator, self).__init__(base_n_channels)
137
+
138
+ self.features_to_prob = nn.Sequential(
139
+ spectral_norm(nn.Conv2d(8 * base_n_channels, 1, 1)),
140
+ Flatten()
141
+ )
142
+
143
+ def forward(self, input_data):
144
+ x = self.image_to_features(input_data)
145
+ return self.features_to_prob(x)
146
+
147
+
148
+ if __name__ == '__main__':
149
+ import torchvision
150
+ ifrnet = IFRNet(32, 128).cuda()
151
+ x = torch.rand((2, 3, 256, 256)).cuda()
152
+ vgg16 = torchvision.models.vgg16(pretrained=True).features.eval().cuda()
153
+ with torch.no_grad():
154
+ vgg_feat = vgg16(x)
155
+ output, aux_out = ifrnet(x, vgg_feat)
156
+ print(output.size())
157
+ print(aux_out.size())
158
+
159
+ disc = Discriminator(32).cuda()
160
+ d_out = disc(output)
161
+ print(d_out.size())
162
+
163
+ patch_disc = PatchDiscriminator(32).cuda()
164
+ p_d_out = patch_disc(output)
165
+ print(p_d_out.size())
166
+
modules/blocks.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from modules.normalization import AdaIN
4
+
5
+
6
+ class DestyleResBlock(nn.Module):
7
+ def __init__(self, channels_out, kernel_size, channels_in=None, stride=1, dilation=1, padding=1, use_dropout=False):
8
+ super(DestyleResBlock, self).__init__()
9
+
10
+ # uses 1x1 convolutions for downsampling
11
+ if not channels_in or channels_in == channels_out:
12
+ channels_in = channels_out
13
+ self.projection = None
14
+ else:
15
+ self.projection = nn.Conv2d(channels_in, channels_out, kernel_size=1, stride=stride, dilation=1)
16
+ self.use_dropout = use_dropout
17
+
18
+ self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
19
+ self.lrelu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
20
+ self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
21
+ self.adain = AdaIN()
22
+ if self.use_dropout:
23
+ self.dropout = nn.Dropout()
24
+ self.lrelu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
25
+
26
+ def forward(self, x, feat):
27
+ residual = x
28
+ out = self.conv1(x)
29
+ out = self.lrelu1(out)
30
+ out = self.conv2(out)
31
+ _, _, h, w = out.size()
32
+ out = self.adain(out, feat)
33
+ if self.use_dropout:
34
+ out = self.dropout(out)
35
+ if self.projection:
36
+ residual = self.projection(x)
37
+ out = out + residual
38
+ out = self.lrelu2(out)
39
+ return out
40
+
41
+
42
+ class ResBlock(nn.Module):
43
+ def __init__(self, channels_out, kernel_size, channels_in=None, stride=1, dilation=1, padding=1, use_dropout=False):
44
+ super(ResBlock, self).__init__()
45
+
46
+ # uses 1x1 convolutions for downsampling
47
+ if not channels_in or channels_in == channels_out:
48
+ channels_in = channels_out
49
+ self.projection = None
50
+ else:
51
+ self.projection = nn.Conv2d(channels_in, channels_out, kernel_size=1, stride=stride, dilation=1)
52
+ self.use_dropout = use_dropout
53
+
54
+ self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
55
+ self.lrelu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
56
+ self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
57
+ self.n2 = nn.BatchNorm2d(channels_out)
58
+ if self.use_dropout:
59
+ self.dropout = nn.Dropout()
60
+ self.lrelu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
61
+
62
+ def forward(self, x):
63
+ residual = x
64
+ out = self.conv1(x)
65
+ out = self.lrelu1(out)
66
+ out = self.conv2(out)
67
+ # out = self.n2(out)
68
+ if self.use_dropout:
69
+ out = self.dropout(out)
70
+ if self.projection:
71
+ residual = self.projection(x)
72
+ out = out + residual
73
+ out = self.lrelu2(out)
74
+ return out
75
+
76
+
77
+ class Destyler(nn.Module):
78
+ def __init__(self, in_features, num_features):
79
+ super(Destyler, self).__init__()
80
+ self.fc1 = nn.Linear(in_features, num_features)
81
+ self.fc2 = nn.Linear(num_features, num_features)
82
+ self.fc3 = nn.Linear(num_features, num_features)
83
+ self.fc4 = nn.Linear(num_features, num_features)
84
+ self.fc5 = nn.Linear(num_features, num_features)
85
+
86
+ def forward(self, x):
87
+ x = self.fc1(x)
88
+ x = self.fc2(x)
89
+ x = self.fc3(x)
90
+ x = self.fc4(x)
91
+ x = self.fc5(x)
92
+ return x
93
+
modules/normalization.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class AdaIN(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x, y):
10
+ ch = y.size(1)
11
+ sigma, mu = torch.split(y.unsqueeze(-1).unsqueeze(-1), [ch // 2, ch // 2], dim=1)
12
+
13
+ x_mu = x.mean(dim=[2, 3], keepdim=True)
14
+ x_sigma = x.std(dim=[2, 3], keepdim=True)
15
+
16
+ return sigma * ((x - x_mu) / x_sigma) + mu
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.17.0
2
+ requests>=2.25.1
3
+ torchvision>=0.6.0
4
+ yacs>=0.1.7
5
+ kornia>=0.3.1
6
+ matplotlib>=3.3.4
7
+ torch>=1.5.0
8
+ glog>=0.3.1
9
+ gradio>=1.6.4
10
+ seaborn>=0.11.0
11
+ Pillow>=8.2.0
12
+ scikit_learn>=0.24.1
utils/data_utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ def linear_scaling(x):
2
+ return (x * 255.) / 127.5 - 1.
3
+
4
+
5
+ def linear_unscaling(x):
6
+ return (x + 1.) * 127.5 / 255.