Spaces:
Runtime error
Runtime error
AlekseyKorshuk
commited on
Commit
β’
1cae80b
1
Parent(s):
61020b4
First commit
Browse files- README.md +13 -13
- app.py +79 -0
- config.yaml +98 -0
- configs/default.py +88 -0
- images/examples/10_Nashville.jpg +0 -0
- images/examples/11_Sutro.jpg +0 -0
- images/examples/12_Toaster.jpg +0 -0
- images/examples/14_Willow.jpg +0 -0
- images/examples/15_X-ProII.jpg +0 -0
- images/examples/16_Lo-Fi.jpg +0 -0
- images/examples/18_Gingham.jpg +0 -0
- images/examples/1_Clarendon.jpg +0 -0
- images/examples/2_Brannan.jpg +0 -0
- images/examples/30_Perpetua.jpg +0 -0
- images/examples/3_Mayfair.jpg +0 -0
- images/examples/4_Hudson.jpg +0 -0
- images/examples/5_Amaro.jpg +0 -0
- images/examples/6_1977.jpg +0 -0
- images/examples/8_Valencia.jpg +0 -0
- images/examples/98_He-Fe.jpg +0 -0
- images/examples/9_Lo-Fi.jpg +0 -0
- modeling/base.py +60 -0
- modeling/benchmark.py +62 -0
- modeling/build.py +19 -0
- modeling/ifrnet.py +166 -0
- modules/blocks.py +93 -0
- modules/normalization.py +16 -0
- requirements.txt +12 -0
- utils/data_utils.py +6 -0
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: Instagram Filter Removal
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
@@ -10,28 +10,28 @@ pinned: false
|
|
10 |
|
11 |
# Configuration
|
12 |
|
13 |
-
`title`: _string_
|
14 |
Display title for the Space
|
15 |
|
16 |
-
`emoji`: _string_
|
17 |
Space emoji (emoji-only character allowed)
|
18 |
|
19 |
-
`colorFrom`: _string_
|
20 |
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
21 |
|
22 |
-
`colorTo`: _string_
|
23 |
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
24 |
|
25 |
-
`sdk`: _string_
|
26 |
-
Can be either `gradio
|
27 |
|
28 |
-
`sdk_version` : _string_
|
29 |
Only applicable for `streamlit` SDK.
|
30 |
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
31 |
|
32 |
-
`app_file`: _string_
|
33 |
-
Path to your main application file (which contains either `gradio` or `streamlit` Python code
|
34 |
Path is relative to the root of the repository.
|
35 |
|
36 |
-
`pinned`: _boolean_
|
37 |
Whether the Space stays on top of your list.
|
1 |
---
|
2 |
title: Instagram Filter Removal
|
3 |
+
emoji: π
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
10 |
|
11 |
# Configuration
|
12 |
|
13 |
+
`title`: _string_
|
14 |
Display title for the Space
|
15 |
|
16 |
+
`emoji`: _string_
|
17 |
Space emoji (emoji-only character allowed)
|
18 |
|
19 |
+
`colorFrom`: _string_
|
20 |
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
21 |
|
22 |
+
`colorTo`: _string_
|
23 |
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
24 |
|
25 |
+
`sdk`: _string_
|
26 |
+
Can be either `gradio` or `streamlit`
|
27 |
|
28 |
+
`sdk_version` : _string_
|
29 |
Only applicable for `streamlit` SDK.
|
30 |
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
31 |
|
32 |
+
`app_file`: _string_
|
33 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
34 |
Path is relative to the root of the repository.
|
35 |
|
36 |
+
`pinned`: _boolean_
|
37 |
Whether the Space stays on top of your list.
|
app.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision.models as models
|
7 |
+
|
8 |
+
from configs.default import get_cfg_defaults
|
9 |
+
from modeling.build import build_model
|
10 |
+
from utils.data_utils import linear_scaling
|
11 |
+
|
12 |
+
|
13 |
+
url = "https://www.dropbox.com/s/y97z812sxa1kvrg/ifrnet.pth?dl=1"
|
14 |
+
r = requests.get(url, stream=True)
|
15 |
+
if not os.path.exists("ifrnet.pth"):
|
16 |
+
with open("ifrnet.pth", 'wb') as f:
|
17 |
+
for data in r:
|
18 |
+
f.write(data)
|
19 |
+
|
20 |
+
cfg = get_cfg_defaults()
|
21 |
+
cfg.MODEL.CKPT = "ifrnet.pth"
|
22 |
+
net, _ = build_model(cfg)
|
23 |
+
net = net.eval()
|
24 |
+
vgg16 = models.vgg16(pretrained=True).features.eval()
|
25 |
+
|
26 |
+
|
27 |
+
def load_checkpoints_from_ckpt(ckpt_path):
|
28 |
+
checkpoints = torch.load(ckpt_path, map_location=torch.device('cpu'))
|
29 |
+
net.load_state_dict(checkpoints["ifr"])
|
30 |
+
|
31 |
+
|
32 |
+
load_checkpoints_from_ckpt(cfg.MODEL.CKPT)
|
33 |
+
|
34 |
+
|
35 |
+
def filter_removal(img):
|
36 |
+
arr = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
37 |
+
arr = torch.tensor(arr).float() / 255.
|
38 |
+
arr = linear_scaling(arr)
|
39 |
+
with torch.no_grad():
|
40 |
+
feat = vgg16(arr)
|
41 |
+
out, _ = net(arr, feat)
|
42 |
+
out = torch.clamp(out, max=1., min=0.)
|
43 |
+
return out.squeeze(0).permute(1, 2, 0).numpy()
|
44 |
+
|
45 |
+
|
46 |
+
title = "Instagram Filter Removal on Fashionable Images"
|
47 |
+
description = "This is the demo for IFRNet, filter removal on fashionable images on Instagram. " \
|
48 |
+
"To use it, simply upload your filtered image, or click one of the examples to load them."
|
49 |
+
article = "<p style='text-align: center'><a href='https://openaccess.thecvf.com/content/CVPR2021W/NTIRE/papers/Kinli_Instagram_Filter_Removal_on_Fashionable_Images_CVPRW_2021_paper.pdf'>Paper</a> | <a href='https://github.com/birdortyedi/instagram-filter-removal-pytorch'>Github</a></p>"
|
50 |
+
|
51 |
+
gr.Interface(
|
52 |
+
filter_removal,
|
53 |
+
gr.inputs.Image(shape=(256, 256)),
|
54 |
+
gr.outputs.Image(),
|
55 |
+
title=title,
|
56 |
+
description=description,
|
57 |
+
article=article,
|
58 |
+
allow_flagging=False,
|
59 |
+
examples_per_page=17,
|
60 |
+
examples=[
|
61 |
+
["images/examples/98_He-Fe.jpg"],
|
62 |
+
["images/examples/2_Brannan.jpg"],
|
63 |
+
["images/examples/12_Toaster.jpg"],
|
64 |
+
["images/examples/18_Gingham.jpg"],
|
65 |
+
["images/examples/11_Sutro.jpg"],
|
66 |
+
["images/examples/9_Lo-Fi.jpg"],
|
67 |
+
["images/examples/3_Mayfair.jpg"],
|
68 |
+
["images/examples/4_Hudson.jpg"],
|
69 |
+
["images/examples/5_Amaro.jpg"],
|
70 |
+
["images/examples/6_1977.jpg"],
|
71 |
+
["images/examples/8_Valencia.jpg"],
|
72 |
+
["images/examples/16_Lo-Fi.jpg"],
|
73 |
+
["images/examples/10_Nashville.jpg"],
|
74 |
+
["images/examples/15_X-ProII.jpg"],
|
75 |
+
["images/examples/14_Willow.jpg"],
|
76 |
+
["images/examples/30_Perpetua.jpg"],
|
77 |
+
["images/examples/1_Clarendon.jpg"],
|
78 |
+
]
|
79 |
+
).launch()
|
config.yaml
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wandb_version: 1
|
2 |
+
|
3 |
+
DATASET:
|
4 |
+
desc: null
|
5 |
+
value:
|
6 |
+
MEAN:
|
7 |
+
- 0.5
|
8 |
+
- 0.5
|
9 |
+
- 0.5
|
10 |
+
NAME: IFFI
|
11 |
+
ROOT: ../../Downloads/IFFI-dataset/train
|
12 |
+
SIZE: 256
|
13 |
+
STD:
|
14 |
+
- 0.5
|
15 |
+
- 0.5
|
16 |
+
- 0.5
|
17 |
+
TEST_ROOT: ../../Downloads/IFFI-dataset/test
|
18 |
+
MODEL:
|
19 |
+
desc: null
|
20 |
+
value:
|
21 |
+
D:
|
22 |
+
NAME: 1-ChOutputDiscriminator
|
23 |
+
NUM_CHANNELS: 32
|
24 |
+
NUM_CRITICS: 5
|
25 |
+
SOLVER:
|
26 |
+
BETAS:
|
27 |
+
- 0.5
|
28 |
+
- 0.9
|
29 |
+
DECAY_RATE: 0.5
|
30 |
+
LR: 0.001
|
31 |
+
SCHEDULER: []
|
32 |
+
IFR:
|
33 |
+
DESTYLER_CHANNELS: 32
|
34 |
+
NAME: InstaFilterRemovalNetwork
|
35 |
+
NUM_CHANNELS: 32
|
36 |
+
SOLVER:
|
37 |
+
BETAS:
|
38 |
+
- 0.5
|
39 |
+
- 0.9
|
40 |
+
DECAY_RATE: 0
|
41 |
+
LR: 0.0002
|
42 |
+
SCHEDULER: []
|
43 |
+
IS_TRAIN: true
|
44 |
+
NAME: ifrnet
|
45 |
+
NUM_CLASS: 17
|
46 |
+
OPTIM:
|
47 |
+
desc: null
|
48 |
+
value:
|
49 |
+
ADVERSARIAL: 0.001
|
50 |
+
AUX: 0.5
|
51 |
+
GP: 10
|
52 |
+
MASK: 1
|
53 |
+
RECON: 1.4
|
54 |
+
SEMANTIC: 0.0001
|
55 |
+
TEXTURE: 0.001
|
56 |
+
SYSTEM:
|
57 |
+
desc: null
|
58 |
+
value:
|
59 |
+
NUM_GPU: 2
|
60 |
+
NUM_WORKERS: 4
|
61 |
+
TEST:
|
62 |
+
desc: null
|
63 |
+
value:
|
64 |
+
ABLATION: false
|
65 |
+
BATCH_SIZE: 64
|
66 |
+
IMG_ID: 52
|
67 |
+
OUTPUT_DIR: ./outputs
|
68 |
+
WEIGHTS: ''
|
69 |
+
TRAIN:
|
70 |
+
desc: null
|
71 |
+
value:
|
72 |
+
BATCH_SIZE: 8
|
73 |
+
IS_TRAIN: true
|
74 |
+
LOG_INTERVAL: 100
|
75 |
+
NUM_TOTAL_STEP: 120000
|
76 |
+
RESUME: true
|
77 |
+
SAVE_DIR: ./weights
|
78 |
+
SAVE_INTERVAL: 1000
|
79 |
+
SHUFFLE: true
|
80 |
+
START_STEP: 0
|
81 |
+
TUNE: false
|
82 |
+
VISUALIZE_INTERVAL: 100
|
83 |
+
WANDB:
|
84 |
+
desc: null
|
85 |
+
value:
|
86 |
+
ENTITY: vvgl-ozu
|
87 |
+
LOG_DIR: ./logs/ifrnet_IFFI_120000step_8bs_0.0002lr_2gpu_9run
|
88 |
+
NUM_ROW: 0
|
89 |
+
PROJECT_NAME: instagram-filter-removal
|
90 |
+
RUN: 9
|
91 |
+
_wandb:
|
92 |
+
desc: null
|
93 |
+
value:
|
94 |
+
cli_version: 0.9.1
|
95 |
+
framework: torch
|
96 |
+
is_jupyter_run: false
|
97 |
+
is_kaggle_kernel: false
|
98 |
+
python_version: 3.6.9
|
configs/default.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode as CN
|
2 |
+
|
3 |
+
_C = CN()
|
4 |
+
|
5 |
+
_C.SYSTEM = CN()
|
6 |
+
_C.SYSTEM.NUM_GPU = 2
|
7 |
+
_C.SYSTEM.NUM_WORKERS = 4
|
8 |
+
|
9 |
+
_C.WANDB = CN()
|
10 |
+
_C.WANDB.PROJECT_NAME = "instagram-filter-removal"
|
11 |
+
_C.WANDB.ENTITY = "vvgl-ozu"
|
12 |
+
_C.WANDB.RUN = 12
|
13 |
+
_C.WANDB.LOG_DIR = ""
|
14 |
+
_C.WANDB.NUM_ROW = 0
|
15 |
+
|
16 |
+
_C.TRAIN = CN()
|
17 |
+
_C.TRAIN.NUM_TOTAL_STEP = 120000
|
18 |
+
_C.TRAIN.START_STEP = 0
|
19 |
+
_C.TRAIN.BATCH_SIZE = 8
|
20 |
+
_C.TRAIN.SHUFFLE = True
|
21 |
+
_C.TRAIN.LOG_INTERVAL = 100
|
22 |
+
_C.TRAIN.SAVE_INTERVAL = 5000
|
23 |
+
_C.TRAIN.SAVE_DIR = "./weights"
|
24 |
+
_C.TRAIN.RESUME = True
|
25 |
+
_C.TRAIN.VISUALIZE_INTERVAL = 100
|
26 |
+
_C.TRAIN.TUNE = False
|
27 |
+
|
28 |
+
_C.MODEL = CN()
|
29 |
+
_C.MODEL.NAME = "ifr-no-aux"
|
30 |
+
_C.MODEL.IS_TRAIN = True
|
31 |
+
_C.MODEL.NUM_CLASS = 17
|
32 |
+
_C.MODEL.CKPT = ""
|
33 |
+
|
34 |
+
_C.MODEL.IFR = CN()
|
35 |
+
_C.MODEL.IFR.NAME = "InstaFilterRemovalNetwork"
|
36 |
+
_C.MODEL.IFR.NUM_CHANNELS = 32
|
37 |
+
_C.MODEL.IFR.DESTYLER_CHANNELS = 32
|
38 |
+
_C.MODEL.IFR.SOLVER = CN()
|
39 |
+
_C.MODEL.IFR.SOLVER.LR = 2e-4
|
40 |
+
_C.MODEL.IFR.SOLVER.BETAS = (0.5, 0.9)
|
41 |
+
_C.MODEL.IFR.SOLVER.SCHEDULER = []
|
42 |
+
_C.MODEL.IFR.SOLVER.DECAY_RATE = 0.
|
43 |
+
|
44 |
+
_C.MODEL.D = CN()
|
45 |
+
_C.MODEL.D.NAME = "1-ChOutputDiscriminator"
|
46 |
+
_C.MODEL.D.NUM_CHANNELS = 32
|
47 |
+
_C.MODEL.D.NUM_CRITICS = 5
|
48 |
+
_C.MODEL.D.SOLVER = CN()
|
49 |
+
_C.MODEL.D.SOLVER.LR = 1e-3
|
50 |
+
_C.MODEL.D.SOLVER.BETAS = (0.5, 0.9)
|
51 |
+
_C.MODEL.D.SOLVER.SCHEDULER = []
|
52 |
+
_C.MODEL.D.SOLVER.DECAY_RATE = 0.5
|
53 |
+
|
54 |
+
_C.OPTIM = CN()
|
55 |
+
_C.OPTIM.GP = 10
|
56 |
+
_C.OPTIM.MASK = 1
|
57 |
+
_C.OPTIM.RECON = 1.4
|
58 |
+
_C.OPTIM.SEMANTIC = 1e-4
|
59 |
+
_C.OPTIM.TEXTURE = 1e-3
|
60 |
+
_C.OPTIM.ADVERSARIAL = 1e-3
|
61 |
+
_C.OPTIM.AUX = 0.5
|
62 |
+
|
63 |
+
_C.DATASET = CN()
|
64 |
+
_C.DATASET.NAME = "IFFI" # "IFFI" # "DIV2K?" #
|
65 |
+
_C.DATASET.ROOT = "../../Datasets/IFFI-dataset/train" # "../../Datasets/IFFI-dataset" # "/media/birdortyedi/e5042b8f-ca5e-4a22-ac68-7e69ff648bc4/IFFI-dataset"
|
66 |
+
_C.DATASET.TEST_ROOT = "../../Datasets/IFFI-dataset"
|
67 |
+
_C.DATASET.SIZE = 256
|
68 |
+
_C.DATASET.CROP_SIZE = 512
|
69 |
+
_C.DATASET.MEAN = [0.5, 0.5, 0.5]
|
70 |
+
_C.DATASET.STD = [0.5, 0.5, 0.5]
|
71 |
+
|
72 |
+
_C.TEST = CN()
|
73 |
+
_C.TEST.OUTPUT_DIR = "./outputs"
|
74 |
+
_C.TEST.ABLATION = False
|
75 |
+
_C.TEST.WEIGHTS = ""
|
76 |
+
_C.TEST.BATCH_SIZE = 64
|
77 |
+
_C.TEST.IMG_ID = 52
|
78 |
+
|
79 |
+
|
80 |
+
def get_cfg_defaults():
|
81 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
82 |
+
# Return a clone so that the defaults will not be altered
|
83 |
+
# This is for the "local variable" use pattern
|
84 |
+
return _C.clone()
|
85 |
+
|
86 |
+
|
87 |
+
# provide a way to import the defaults as a global singleton:
|
88 |
+
cfg = _C # users can `from config import cfg`
|
images/examples/10_Nashville.jpg
ADDED
images/examples/11_Sutro.jpg
ADDED
images/examples/12_Toaster.jpg
ADDED
images/examples/14_Willow.jpg
ADDED
images/examples/15_X-ProII.jpg
ADDED
images/examples/16_Lo-Fi.jpg
ADDED
images/examples/18_Gingham.jpg
ADDED
images/examples/1_Clarendon.jpg
ADDED
images/examples/2_Brannan.jpg
ADDED
images/examples/30_Perpetua.jpg
ADDED
images/examples/3_Mayfair.jpg
ADDED
images/examples/4_Hudson.jpg
ADDED
images/examples/5_Amaro.jpg
ADDED
images/examples/6_1977.jpg
ADDED
images/examples/8_Valencia.jpg
ADDED
images/examples/98_He-Fe.jpg
ADDED
images/examples/9_Lo-Fi.jpg
ADDED
modeling/base.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
class BaseNetwork(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super(BaseNetwork, self).__init__()
|
7 |
+
|
8 |
+
def forward(self, x, y):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def print_network(self):
|
12 |
+
if isinstance(self, list):
|
13 |
+
self = self[0]
|
14 |
+
num_params = 0
|
15 |
+
for param in self.parameters():
|
16 |
+
num_params += param.numel()
|
17 |
+
print('Network [%s] was created. Total number of parameters: %.1f million. '
|
18 |
+
'To see the architecture, do print(network).'
|
19 |
+
% (type(self).__name__, num_params / 1000000))
|
20 |
+
|
21 |
+
def set_requires_grad(self, requires_grad=False):
|
22 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
23 |
+
Parameters:
|
24 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
25 |
+
"""
|
26 |
+
for param in self.parameters():
|
27 |
+
param.requires_grad = requires_grad
|
28 |
+
|
29 |
+
def init_weights(self, init_type='xavier', gain=0.02):
|
30 |
+
def init_func(m):
|
31 |
+
classname = m.__class__.__name__
|
32 |
+
if classname.find('BatchNorm2d') != -1:
|
33 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
34 |
+
nn.init.normal_(m.weight.data, 1.0, gain)
|
35 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
36 |
+
nn.init.constant_(m.bias.data, 0.0)
|
37 |
+
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
38 |
+
if init_type == 'normal':
|
39 |
+
nn.init.normal_(m.weight.data, 0.0, gain)
|
40 |
+
elif init_type == 'xavier':
|
41 |
+
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
42 |
+
elif init_type == 'xavier_uniform':
|
43 |
+
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
|
44 |
+
elif init_type == 'kaiming':
|
45 |
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
46 |
+
elif init_type == 'orthogonal':
|
47 |
+
nn.init.orthogonal_(m.weight.data, gain=gain)
|
48 |
+
elif init_type == 'none': # uses pytorch's default init method
|
49 |
+
m.reset_parameters()
|
50 |
+
else:
|
51 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
52 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
53 |
+
nn.init.constant_(m.bias.data, 0.0)
|
54 |
+
|
55 |
+
self.apply(init_func)
|
56 |
+
|
57 |
+
# propagate to children
|
58 |
+
for m in self.children():
|
59 |
+
if hasattr(m, 'init_weights'):
|
60 |
+
m.init_weights(init_type, gain)
|
modeling/benchmark.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from modeling.base import BaseNetwork
|
5 |
+
from modeling.ifrnet import Flatten
|
6 |
+
from modules.blocks import DestyleResBlock, Destyler, ResBlock
|
7 |
+
|
8 |
+
|
9 |
+
class UNet(BaseNetwork):
|
10 |
+
def __init__(self, base_n_channels):
|
11 |
+
super(UNet, self).__init__()
|
12 |
+
|
13 |
+
self.ds_res1 = ResBlock(channels_in=3, channels_out=base_n_channels, kernel_size=5, stride=1, padding=2)
|
14 |
+
self.ds_res2 = ResBlock(channels_in=base_n_channels, channels_out=base_n_channels * 2, kernel_size=3, stride=2, padding=1)
|
15 |
+
self.ds_res3 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
|
16 |
+
self.ds_res4 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 4, kernel_size=3, stride=2, padding=1)
|
17 |
+
self.ds_res5 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
|
18 |
+
self.ds_res6 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 8, kernel_size=3, stride=2, padding=1)
|
19 |
+
|
20 |
+
self.upsample = nn.UpsamplingNearest2d(scale_factor=2.0)
|
21 |
+
|
22 |
+
self.res1 = ResBlock(channels_in=base_n_channels * 8, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
|
23 |
+
self.res2 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
|
24 |
+
self.res3 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
|
25 |
+
self.res4 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
|
26 |
+
self.res5 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels, kernel_size=3, stride=1, padding=1)
|
27 |
+
|
28 |
+
self.conv1 = nn.Conv2d(base_n_channels, 3, kernel_size=3, stride=1, padding=1)
|
29 |
+
|
30 |
+
self.init_weights(init_type="normal", gain=0.02)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
out = self.ds_res1(x)
|
34 |
+
out = self.ds_res2(out)
|
35 |
+
out = self.ds_res3(out)
|
36 |
+
out = self.ds_res4(out)
|
37 |
+
out = self.ds_res5(out)
|
38 |
+
aux = self.ds_res6(out)
|
39 |
+
|
40 |
+
out = self.upsample(aux)
|
41 |
+
out = self.res1(out)
|
42 |
+
out = self.res2(out)
|
43 |
+
out = self.upsample(out)
|
44 |
+
out = self.res3(out)
|
45 |
+
out = self.res4(out)
|
46 |
+
out = self.upsample(out)
|
47 |
+
out = self.res5(out)
|
48 |
+
out = self.conv1(out)
|
49 |
+
|
50 |
+
return out, aux
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
import torchvision
|
55 |
+
x = torch.rand((2, 3, 256, 256)).cuda()
|
56 |
+
unet = UNet(32, 32).cuda()
|
57 |
+
vgg16 = torchvision.models.vgg16(pretrained=True).features.eval().cuda()
|
58 |
+
with torch.no_grad():
|
59 |
+
vgg_feat = vgg16(x)
|
60 |
+
out = unet(x, vgg_feat)
|
61 |
+
|
62 |
+
print(out.size())
|
modeling/build.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modeling.ifrnet import IFRNet, Discriminator, PatchDiscriminator, MLP
|
2 |
+
from modeling.benchmark import UNet
|
3 |
+
|
4 |
+
|
5 |
+
def build_model(args):
|
6 |
+
if args.MODEL.NAME.lower() == "ifrnet":
|
7 |
+
net = IFRNet(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS)
|
8 |
+
mlp = MLP(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, num_class=args.MODEL.NUM_CLASS)
|
9 |
+
elif args.MODEL.NAME.lower() == "ifr-no-aux":
|
10 |
+
net = IFRNet(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS)
|
11 |
+
mlp = None
|
12 |
+
else:
|
13 |
+
raise NotImplementedError
|
14 |
+
return net, mlp
|
15 |
+
|
16 |
+
|
17 |
+
def build_discriminators(args):
|
18 |
+
return Discriminator(base_n_channels=args.MODEL.D.NUM_CHANNELS), PatchDiscriminator(base_n_channels=args.MODEL.D.NUM_CHANNELS)
|
19 |
+
|
modeling/ifrnet.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn.utils import spectral_norm
|
4 |
+
|
5 |
+
from modeling.base import BaseNetwork
|
6 |
+
from modules.blocks import DestyleResBlock, Destyler, ResBlock
|
7 |
+
|
8 |
+
|
9 |
+
class IFRNet(BaseNetwork):
|
10 |
+
def __init__(self, base_n_channels, destyler_n_channels):
|
11 |
+
super(IFRNet, self).__init__()
|
12 |
+
self.destyler = Destyler(in_features=32768, num_features=destyler_n_channels) # from vgg features
|
13 |
+
|
14 |
+
self.ds_fc1 = nn.Linear(destyler_n_channels, base_n_channels * 2)
|
15 |
+
self.ds_res1 = DestyleResBlock(channels_in=3, channels_out=base_n_channels, kernel_size=5, stride=1, padding=2)
|
16 |
+
self.ds_fc2 = nn.Linear(destyler_n_channels, base_n_channels * 4)
|
17 |
+
self.ds_res2 = DestyleResBlock(channels_in=base_n_channels, channels_out=base_n_channels * 2, kernel_size=3, stride=2, padding=1)
|
18 |
+
self.ds_fc3 = nn.Linear(destyler_n_channels, base_n_channels * 4)
|
19 |
+
self.ds_res3 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
|
20 |
+
self.ds_fc4 = nn.Linear(destyler_n_channels, base_n_channels * 8)
|
21 |
+
self.ds_res4 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 4, kernel_size=3, stride=2, padding=1)
|
22 |
+
self.ds_fc5 = nn.Linear(destyler_n_channels, base_n_channels * 8)
|
23 |
+
self.ds_res5 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
|
24 |
+
self.ds_fc6 = nn.Linear(destyler_n_channels, base_n_channels * 16)
|
25 |
+
self.ds_res6 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 8, kernel_size=3, stride=2, padding=1)
|
26 |
+
|
27 |
+
self.upsample = nn.UpsamplingNearest2d(scale_factor=2.0)
|
28 |
+
|
29 |
+
self.res1 = ResBlock(channels_in=base_n_channels * 8, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
|
30 |
+
self.res2 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
|
31 |
+
self.res3 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
|
32 |
+
self.res4 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
|
33 |
+
self.res5 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels, kernel_size=3, stride=1, padding=1)
|
34 |
+
|
35 |
+
self.conv1 = nn.Conv2d(base_n_channels, 3, kernel_size=3, stride=1, padding=1)
|
36 |
+
|
37 |
+
self.init_weights(init_type="normal", gain=0.02)
|
38 |
+
|
39 |
+
def forward(self, x, vgg_feat):
|
40 |
+
b_size, ch, h, w = vgg_feat.size()
|
41 |
+
vgg_feat = vgg_feat.view(b_size, ch * h * w)
|
42 |
+
vgg_feat = self.destyler(vgg_feat)
|
43 |
+
|
44 |
+
out = self.ds_res1(x, self.ds_fc1(vgg_feat))
|
45 |
+
out = self.ds_res2(out, self.ds_fc2(vgg_feat))
|
46 |
+
out = self.ds_res3(out, self.ds_fc3(vgg_feat))
|
47 |
+
out = self.ds_res4(out, self.ds_fc4(vgg_feat))
|
48 |
+
out = self.ds_res5(out, self.ds_fc5(vgg_feat))
|
49 |
+
aux = self.ds_res6(out, self.ds_fc6(vgg_feat))
|
50 |
+
|
51 |
+
out = self.upsample(aux)
|
52 |
+
out = self.res1(out)
|
53 |
+
out = self.res2(out)
|
54 |
+
out = self.upsample(out)
|
55 |
+
out = self.res3(out)
|
56 |
+
out = self.res4(out)
|
57 |
+
out = self.upsample(out)
|
58 |
+
out = self.res5(out)
|
59 |
+
out = self.conv1(out)
|
60 |
+
|
61 |
+
return out, aux
|
62 |
+
|
63 |
+
|
64 |
+
class MLP(nn.Module):
|
65 |
+
def __init__(self, base_n_channels, num_class=14):
|
66 |
+
super(MLP, self).__init__()
|
67 |
+
self.aux_classifier = nn.Sequential(
|
68 |
+
nn.Conv2d(base_n_channels * 8, base_n_channels * 4, kernel_size=3, stride=1, padding=1),
|
69 |
+
nn.MaxPool2d(2),
|
70 |
+
nn.Conv2d(base_n_channels * 4, base_n_channels * 2, kernel_size=3, stride=1, padding=1),
|
71 |
+
nn.MaxPool2d(2),
|
72 |
+
# nn.Conv2d(base_n_channels * 2, base_n_channels * 1, kernel_size=3, stride=1, padding=1),
|
73 |
+
# nn.MaxPool2d(2),
|
74 |
+
Flatten(),
|
75 |
+
nn.Linear(base_n_channels * 8 * 8 * 2, num_class),
|
76 |
+
# nn.Softmax(dim=-1)
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
return self.aux_classifier(x)
|
81 |
+
|
82 |
+
|
83 |
+
class Flatten(nn.Module):
|
84 |
+
def forward(self, input):
|
85 |
+
"""
|
86 |
+
Note that input.size(0) is usually the batch size.
|
87 |
+
So what it does is that given any input with input.size(0) # of batches,
|
88 |
+
will flatten to be 1 * nb_elements.
|
89 |
+
"""
|
90 |
+
batch_size = input.size(0)
|
91 |
+
out = input.view(batch_size, -1)
|
92 |
+
return out # (batch_size, *size)
|
93 |
+
|
94 |
+
|
95 |
+
class Discriminator(BaseNetwork):
|
96 |
+
def __init__(self, base_n_channels):
|
97 |
+
"""
|
98 |
+
img_size : (int, int, int)
|
99 |
+
Height and width must be powers of 2. E.g. (32, 32, 1) or
|
100 |
+
(64, 128, 3). Last number indicates number of channels, e.g. 1 for
|
101 |
+
grayscale or 3 for RGB
|
102 |
+
"""
|
103 |
+
super(Discriminator, self).__init__()
|
104 |
+
|
105 |
+
self.image_to_features = nn.Sequential(
|
106 |
+
spectral_norm(nn.Conv2d(3, base_n_channels, 5, 2, 2)),
|
107 |
+
nn.LeakyReLU(0.2, inplace=True),
|
108 |
+
spectral_norm(nn.Conv2d(base_n_channels, 2 * base_n_channels, 5, 2, 2)),
|
109 |
+
nn.LeakyReLU(0.2, inplace=True),
|
110 |
+
spectral_norm(nn.Conv2d(2 * base_n_channels, 2 * base_n_channels, 5, 2, 2)),
|
111 |
+
nn.LeakyReLU(0.2, inplace=True),
|
112 |
+
spectral_norm(nn.Conv2d(2 * base_n_channels, 4 * base_n_channels, 5, 2, 2)),
|
113 |
+
nn.LeakyReLU(0.2, inplace=True),
|
114 |
+
# spectral_norm(nn.Conv2d(4 * base_n_channels, 4 * base_n_channels, 5, 2, 2)),
|
115 |
+
# nn.LeakyReLU(0.2, inplace=True),
|
116 |
+
spectral_norm(nn.Conv2d(4 * base_n_channels, 8 * base_n_channels, 5, 1, 1)),
|
117 |
+
nn.LeakyReLU(0.2, inplace=True),
|
118 |
+
)
|
119 |
+
|
120 |
+
output_size = 8 * base_n_channels * 3 * 3
|
121 |
+
self.features_to_prob = nn.Sequential(
|
122 |
+
spectral_norm(nn.Conv2d(8 * base_n_channels, 2 * base_n_channels, 5, 2, 1)),
|
123 |
+
Flatten(),
|
124 |
+
nn.Linear(output_size, 1)
|
125 |
+
)
|
126 |
+
|
127 |
+
self.init_weights(init_type="normal", gain=0.02)
|
128 |
+
|
129 |
+
def forward(self, input_data):
|
130 |
+
x = self.image_to_features(input_data)
|
131 |
+
return self.features_to_prob(x)
|
132 |
+
|
133 |
+
|
134 |
+
class PatchDiscriminator(Discriminator):
|
135 |
+
def __init__(self, base_n_channels):
|
136 |
+
super(PatchDiscriminator, self).__init__(base_n_channels)
|
137 |
+
|
138 |
+
self.features_to_prob = nn.Sequential(
|
139 |
+
spectral_norm(nn.Conv2d(8 * base_n_channels, 1, 1)),
|
140 |
+
Flatten()
|
141 |
+
)
|
142 |
+
|
143 |
+
def forward(self, input_data):
|
144 |
+
x = self.image_to_features(input_data)
|
145 |
+
return self.features_to_prob(x)
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == '__main__':
|
149 |
+
import torchvision
|
150 |
+
ifrnet = IFRNet(32, 128).cuda()
|
151 |
+
x = torch.rand((2, 3, 256, 256)).cuda()
|
152 |
+
vgg16 = torchvision.models.vgg16(pretrained=True).features.eval().cuda()
|
153 |
+
with torch.no_grad():
|
154 |
+
vgg_feat = vgg16(x)
|
155 |
+
output, aux_out = ifrnet(x, vgg_feat)
|
156 |
+
print(output.size())
|
157 |
+
print(aux_out.size())
|
158 |
+
|
159 |
+
disc = Discriminator(32).cuda()
|
160 |
+
d_out = disc(output)
|
161 |
+
print(d_out.size())
|
162 |
+
|
163 |
+
patch_disc = PatchDiscriminator(32).cuda()
|
164 |
+
p_d_out = patch_disc(output)
|
165 |
+
print(p_d_out.size())
|
166 |
+
|
modules/blocks.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
from modules.normalization import AdaIN
|
4 |
+
|
5 |
+
|
6 |
+
class DestyleResBlock(nn.Module):
|
7 |
+
def __init__(self, channels_out, kernel_size, channels_in=None, stride=1, dilation=1, padding=1, use_dropout=False):
|
8 |
+
super(DestyleResBlock, self).__init__()
|
9 |
+
|
10 |
+
# uses 1x1 convolutions for downsampling
|
11 |
+
if not channels_in or channels_in == channels_out:
|
12 |
+
channels_in = channels_out
|
13 |
+
self.projection = None
|
14 |
+
else:
|
15 |
+
self.projection = nn.Conv2d(channels_in, channels_out, kernel_size=1, stride=stride, dilation=1)
|
16 |
+
self.use_dropout = use_dropout
|
17 |
+
|
18 |
+
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
|
19 |
+
self.lrelu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
20 |
+
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
|
21 |
+
self.adain = AdaIN()
|
22 |
+
if self.use_dropout:
|
23 |
+
self.dropout = nn.Dropout()
|
24 |
+
self.lrelu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
25 |
+
|
26 |
+
def forward(self, x, feat):
|
27 |
+
residual = x
|
28 |
+
out = self.conv1(x)
|
29 |
+
out = self.lrelu1(out)
|
30 |
+
out = self.conv2(out)
|
31 |
+
_, _, h, w = out.size()
|
32 |
+
out = self.adain(out, feat)
|
33 |
+
if self.use_dropout:
|
34 |
+
out = self.dropout(out)
|
35 |
+
if self.projection:
|
36 |
+
residual = self.projection(x)
|
37 |
+
out = out + residual
|
38 |
+
out = self.lrelu2(out)
|
39 |
+
return out
|
40 |
+
|
41 |
+
|
42 |
+
class ResBlock(nn.Module):
|
43 |
+
def __init__(self, channels_out, kernel_size, channels_in=None, stride=1, dilation=1, padding=1, use_dropout=False):
|
44 |
+
super(ResBlock, self).__init__()
|
45 |
+
|
46 |
+
# uses 1x1 convolutions for downsampling
|
47 |
+
if not channels_in or channels_in == channels_out:
|
48 |
+
channels_in = channels_out
|
49 |
+
self.projection = None
|
50 |
+
else:
|
51 |
+
self.projection = nn.Conv2d(channels_in, channels_out, kernel_size=1, stride=stride, dilation=1)
|
52 |
+
self.use_dropout = use_dropout
|
53 |
+
|
54 |
+
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
|
55 |
+
self.lrelu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
56 |
+
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
|
57 |
+
self.n2 = nn.BatchNorm2d(channels_out)
|
58 |
+
if self.use_dropout:
|
59 |
+
self.dropout = nn.Dropout()
|
60 |
+
self.lrelu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
residual = x
|
64 |
+
out = self.conv1(x)
|
65 |
+
out = self.lrelu1(out)
|
66 |
+
out = self.conv2(out)
|
67 |
+
# out = self.n2(out)
|
68 |
+
if self.use_dropout:
|
69 |
+
out = self.dropout(out)
|
70 |
+
if self.projection:
|
71 |
+
residual = self.projection(x)
|
72 |
+
out = out + residual
|
73 |
+
out = self.lrelu2(out)
|
74 |
+
return out
|
75 |
+
|
76 |
+
|
77 |
+
class Destyler(nn.Module):
|
78 |
+
def __init__(self, in_features, num_features):
|
79 |
+
super(Destyler, self).__init__()
|
80 |
+
self.fc1 = nn.Linear(in_features, num_features)
|
81 |
+
self.fc2 = nn.Linear(num_features, num_features)
|
82 |
+
self.fc3 = nn.Linear(num_features, num_features)
|
83 |
+
self.fc4 = nn.Linear(num_features, num_features)
|
84 |
+
self.fc5 = nn.Linear(num_features, num_features)
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
x = self.fc1(x)
|
88 |
+
x = self.fc2(x)
|
89 |
+
x = self.fc3(x)
|
90 |
+
x = self.fc4(x)
|
91 |
+
x = self.fc5(x)
|
92 |
+
return x
|
93 |
+
|
modules/normalization.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class AdaIN(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
def forward(self, x, y):
|
10 |
+
ch = y.size(1)
|
11 |
+
sigma, mu = torch.split(y.unsqueeze(-1).unsqueeze(-1), [ch // 2, ch // 2], dim=1)
|
12 |
+
|
13 |
+
x_mu = x.mean(dim=[2, 3], keepdim=True)
|
14 |
+
x_sigma = x.std(dim=[2, 3], keepdim=True)
|
15 |
+
|
16 |
+
return sigma * ((x - x_mu) / x_sigma) + mu
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy>=1.17.0
|
2 |
+
requests>=2.25.1
|
3 |
+
torchvision>=0.6.0
|
4 |
+
yacs>=0.1.7
|
5 |
+
kornia>=0.3.1
|
6 |
+
matplotlib>=3.3.4
|
7 |
+
torch>=1.5.0
|
8 |
+
glog>=0.3.1
|
9 |
+
gradio>=1.6.4
|
10 |
+
seaborn>=0.11.0
|
11 |
+
Pillow>=8.2.0
|
12 |
+
scikit_learn>=0.24.1
|
utils/data_utils.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def linear_scaling(x):
|
2 |
+
return (x * 255.) / 127.5 - 1.
|
3 |
+
|
4 |
+
|
5 |
+
def linear_unscaling(x):
|
6 |
+
return (x + 1.) * 127.5 / 255.
|