Spaces:
Runtime error
Runtime error
birdortyedi
commited on
Commit
•
2a92dc2
1
Parent(s):
2529c2e
Add application file
Browse files- app.py +79 -0
- configs/__pycache__/default.cpython-37.pyc +0 -0
- configs/default.py +114 -0
- images/examples/10_Nashville.jpg +0 -0
- images/examples/11_Sutro.jpg +0 -0
- images/examples/12_Toaster.jpg +0 -0
- images/examples/14_Willow.jpg +0 -0
- images/examples/15_X-ProII.jpg +0 -0
- images/examples/16_Lo-Fi.jpg +0 -0
- images/examples/18_Gingham.jpg +0 -0
- images/examples/1_Clarendon.jpg +0 -0
- images/examples/2_Brannan.jpg +0 -0
- images/examples/30_Perpetua.jpg +0 -0
- images/examples/3_Mayfair.jpg +0 -0
- images/examples/4_Hudson.jpg +0 -0
- images/examples/5_Amaro.jpg +0 -0
- images/examples/6_1977.jpg +0 -0
- images/examples/8_Valencia.jpg +0 -0
- images/examples/98_He-Fe.jpg +0 -0
- images/examples/9_Lo-Fi.jpg +0 -0
- layers/__pycache__/blocks.cpython-37.pyc +0 -0
- layers/__pycache__/normalization.cpython-37.pyc +0 -0
- layers/blocks.py +93 -0
- layers/normalization.py +16 -0
- modeling/__pycache__/arch.cpython-37.pyc +0 -0
- modeling/__pycache__/base.cpython-37.pyc +0 -0
- modeling/__pycache__/build.cpython-37.pyc +0 -0
- modeling/arch.py +272 -0
- modeling/base.py +60 -0
- modeling/build.py +25 -0
- requirements.txt +6 -0
- utils/__pycache__/data_utils.cpython-37.pyc +0 -0
- utils/data_utils.py +29 -0
app.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision.models as models
|
7 |
+
|
8 |
+
from configs.default import get_cfg_defaults
|
9 |
+
from modeling.build import build_model
|
10 |
+
from utils.data_utils import linear_scaling
|
11 |
+
|
12 |
+
|
13 |
+
url = "https://www.dropbox.com/s/uxvax5sjx5iysyl/cifr.pth?dl=0"
|
14 |
+
r = requests.get(url, stream=True)
|
15 |
+
if not os.path.exists("cifr.pth"):
|
16 |
+
with open("cifr.pth", 'wb') as f:
|
17 |
+
for data in r:
|
18 |
+
f.write(data)
|
19 |
+
|
20 |
+
cfg = get_cfg_defaults()
|
21 |
+
cfg.MODEL.CKPT = "cifr.pth"
|
22 |
+
net, _ = build_model(cfg)
|
23 |
+
net = net.eval()
|
24 |
+
vgg16 = models.vgg16(pretrained=True).features.eval()
|
25 |
+
|
26 |
+
|
27 |
+
def load_checkpoints_from_ckpt(ckpt_path):
|
28 |
+
checkpoints = torch.load(ckpt_path, map_location=torch.device('cuda'))
|
29 |
+
net.load_state_dict(checkpoints["ifr"])
|
30 |
+
|
31 |
+
|
32 |
+
load_checkpoints_from_ckpt(cfg.MODEL.CKPT)
|
33 |
+
|
34 |
+
|
35 |
+
def filter_removal(img):
|
36 |
+
arr = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
|
37 |
+
arr = torch.tensor(arr).float() / 255.
|
38 |
+
arr = linear_scaling(arr)
|
39 |
+
with torch.no_grad():
|
40 |
+
feat = vgg16(arr)
|
41 |
+
out, _ = net(arr, feat)
|
42 |
+
out = torch.clamp(out, max=1., min=0.)
|
43 |
+
return out.squeeze(0).permute(1, 2, 0).numpy()
|
44 |
+
|
45 |
+
|
46 |
+
title = "Contrastive Instagram Filter Removal (CIFR)"
|
47 |
+
description = "This is the demo for CIFR, contrastive strategy for filter removal on fashionable images on Instagram. " \
|
48 |
+
"To use it, simply upload your filtered image, or click one of the examples to load them."
|
49 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.07486'>Contrastive Instagram Filter Removal (CIFR)</a> | <a href='https://github.com/birdortyedi/cifr-pytorch'>Github Repo</a></p>"
|
50 |
+
|
51 |
+
gr.Interface(
|
52 |
+
filter_removal,
|
53 |
+
gr.inputs.Image(shape=(256, 256)),
|
54 |
+
gr.outputs.Image(),
|
55 |
+
title=title,
|
56 |
+
description=description,
|
57 |
+
article=article,
|
58 |
+
allow_flagging=False,
|
59 |
+
examples_per_page=17,
|
60 |
+
examples=[
|
61 |
+
["images/examples/98_He-Fe.jpg"],
|
62 |
+
["images/examples/2_Brannan.jpg"],
|
63 |
+
["images/examples/12_Toaster.jpg"],
|
64 |
+
["images/examples/18_Gingham.jpg"],
|
65 |
+
["images/examples/11_Sutro.jpg"],
|
66 |
+
["images/examples/9_Lo-Fi.jpg"],
|
67 |
+
["images/examples/3_Mayfair.jpg"],
|
68 |
+
["images/examples/4_Hudson.jpg"],
|
69 |
+
["images/examples/5_Amaro.jpg"],
|
70 |
+
["images/examples/6_1977.jpg"],
|
71 |
+
["images/examples/8_Valencia.jpg"],
|
72 |
+
["images/examples/16_Lo-Fi.jpg"],
|
73 |
+
["images/examples/10_Nashville.jpg"],
|
74 |
+
["images/examples/15_X-ProII.jpg"],
|
75 |
+
["images/examples/14_Willow.jpg"],
|
76 |
+
["images/examples/30_Perpetua.jpg"],
|
77 |
+
["images/examples/1_Clarendon.jpg"],
|
78 |
+
]
|
79 |
+
).launch()
|
configs/__pycache__/default.cpython-37.pyc
ADDED
Binary file (2.68 kB). View file
|
|
configs/default.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode as CN
|
2 |
+
|
3 |
+
_C = CN()
|
4 |
+
|
5 |
+
_C.SYSTEM = CN()
|
6 |
+
_C.SYSTEM.NUM_GPU = 2
|
7 |
+
_C.SYSTEM.NUM_WORKERS = 4
|
8 |
+
|
9 |
+
_C.WANDB = CN()
|
10 |
+
_C.WANDB.PROJECT_NAME = "contrastive-style-learning-for-ifr"
|
11 |
+
_C.WANDB.ENTITY = "vvgl-ozu"
|
12 |
+
_C.WANDB.RUN = 3
|
13 |
+
_C.WANDB.LOG_DIR = ""
|
14 |
+
_C.WANDB.NUM_ROW = 0
|
15 |
+
|
16 |
+
_C.TRAIN = CN()
|
17 |
+
_C.TRAIN.NUM_TOTAL_STEP = 200000
|
18 |
+
_C.TRAIN.START_STEP = 0
|
19 |
+
_C.TRAIN.BATCH_SIZE = 16
|
20 |
+
_C.TRAIN.SHUFFLE = True
|
21 |
+
_C.TRAIN.LOG_INTERVAL = 100
|
22 |
+
_C.TRAIN.EVAL_INTERVAL = 1000
|
23 |
+
_C.TRAIN.SAVE_INTERVAL = 1000
|
24 |
+
_C.TRAIN.SAVE_DIR = "./weights"
|
25 |
+
_C.TRAIN.RESUME = True
|
26 |
+
_C.TRAIN.VISUALIZE_INTERVAL = 100
|
27 |
+
_C.TRAIN.TUNE = False
|
28 |
+
|
29 |
+
_C.MODEL = CN()
|
30 |
+
_C.MODEL.NAME = "cifr"
|
31 |
+
_C.MODEL.IS_TRAIN = True
|
32 |
+
_C.MODEL.NUM_CLASS = 17
|
33 |
+
_C.MODEL.CKPT = ""
|
34 |
+
_C.MODEL.PRETRAINED = ""
|
35 |
+
|
36 |
+
_C.MODEL.IFR = CN()
|
37 |
+
_C.MODEL.IFR.NAME = "ContrastiveInstaFilterRemovalNetwork"
|
38 |
+
_C.MODEL.IFR.NUM_CHANNELS = 32
|
39 |
+
_C.MODEL.IFR.DESTYLER_CHANNELS = 32
|
40 |
+
_C.MODEL.IFR.SOLVER = CN()
|
41 |
+
_C.MODEL.IFR.SOLVER.LR = 2e-4
|
42 |
+
_C.MODEL.IFR.SOLVER.BETAS = (0.5, 0.999)
|
43 |
+
_C.MODEL.IFR.SOLVER.SCHEDULER = []
|
44 |
+
_C.MODEL.IFR.SOLVER.DECAY_RATE = 0.
|
45 |
+
_C.MODEL.IFR.DS_FACTOR = 4
|
46 |
+
|
47 |
+
_C.MODEL.PATCH = CN()
|
48 |
+
_C.MODEL.PATCH.NUM_CHANNELS = 256
|
49 |
+
_C.MODEL.PATCH.NUM_PATCHES = 256
|
50 |
+
_C.MODEL.PATCH.NUM_LAYERS = 6
|
51 |
+
_C.MODEL.PATCH.USE_MLP = True
|
52 |
+
_C.MODEL.PATCH.SHUFFLE_Y = True
|
53 |
+
_C.MODEL.PATCH.LR = 1e-4
|
54 |
+
_C.MODEL.PATCH.BETAS = (0.5, 0.999)
|
55 |
+
_C.MODEL.PATCH.T = 0.07
|
56 |
+
|
57 |
+
_C.MODEL.D = CN()
|
58 |
+
_C.MODEL.D.NAME = "1-ChOutputDiscriminator"
|
59 |
+
_C.MODEL.D.NUM_CHANNELS = 32
|
60 |
+
_C.MODEL.D.NUM_CRITICS = 3
|
61 |
+
_C.MODEL.D.SOLVER = CN()
|
62 |
+
_C.MODEL.D.SOLVER.LR = 1e-4
|
63 |
+
_C.MODEL.D.SOLVER.BETAS = (0.5, 0.999)
|
64 |
+
_C.MODEL.D.SOLVER.SCHEDULER = []
|
65 |
+
_C.MODEL.D.SOLVER.DECAY_RATE = 0.01
|
66 |
+
|
67 |
+
_C.ESRGAN = CN()
|
68 |
+
_C.ESRGAN.WEIGHTS = "weights/RealESRGAN_x{}plus.pth"
|
69 |
+
|
70 |
+
_C.FASHIONMASKRCNN = CN()
|
71 |
+
_C.FASHIONMASKRCNN.CFG_PATH = "configs/fashion.yaml"
|
72 |
+
_C.FASHIONMASKRCNN.WEIGHTS = "weights/fashion.pth"
|
73 |
+
_C.FASHIONMASKRCNN.SCORE_THRESH_TEST = 0.6
|
74 |
+
_C.FASHIONMASKRCNN.MIN_SIZE_TEST = 512
|
75 |
+
|
76 |
+
_C.OPTIM = CN()
|
77 |
+
_C.OPTIM.GP = 10.
|
78 |
+
_C.OPTIM.MASK = 1
|
79 |
+
_C.OPTIM.RECON = 1.4
|
80 |
+
_C.OPTIM.SEMANTIC = 1e-1
|
81 |
+
_C.OPTIM.TEXTURE = 2e-1
|
82 |
+
_C.OPTIM.ADVERSARIAL = 1e-3
|
83 |
+
_C.OPTIM.AUX = 0.5
|
84 |
+
_C.OPTIM.CONTRASTIVE = 0.1
|
85 |
+
_C.OPTIM.NLL = 1.0
|
86 |
+
|
87 |
+
_C.DATASET = CN()
|
88 |
+
_C.DATASET.NAME = "IFFI"
|
89 |
+
_C.DATASET.ROOT = "../../Downloads/IFFI-dataset/train" # "../../Downloads/IFFI-dataset/train"
|
90 |
+
_C.DATASET.TEST_ROOT = "../../Datasets/IFFI-dataset/test" # "../../Downloads/IFFI-dataset/test"
|
91 |
+
_C.DATASET.DS_TEST_ROOT = "../../Downloads/IFFI-dataset/test/" # "../../Downloads/IFFI-dataset/test"
|
92 |
+
_C.DATASET.DS_JSON_FILE = "../../Downloads/IFFI-dataset-only-orgs/instances_default.json"
|
93 |
+
_C.DATASET.SIZE = 256
|
94 |
+
_C.DATASET.CROP_SIZE = 512
|
95 |
+
_C.DATASET.MEAN = [0.5, 0.5, 0.5]
|
96 |
+
_C.DATASET.STD = [0.5, 0.5, 0.5]
|
97 |
+
|
98 |
+
_C.TEST = CN()
|
99 |
+
_C.TEST.OUTPUT_DIR = "./outputs"
|
100 |
+
_C.TEST.ABLATION = False
|
101 |
+
_C.TEST.WEIGHTS = ""
|
102 |
+
_C.TEST.BATCH_SIZE = 32
|
103 |
+
_C.TEST.IMG_ID = 52
|
104 |
+
|
105 |
+
|
106 |
+
def get_cfg_defaults():
|
107 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
108 |
+
# Return a clone so that the defaults will not be altered
|
109 |
+
# This is for the "local variable" use pattern
|
110 |
+
return _C.clone()
|
111 |
+
|
112 |
+
|
113 |
+
# provide a way to import the defaults as a global singleton:
|
114 |
+
cfg = _C # users can `from config import cfg`
|
images/examples/10_Nashville.jpg
ADDED
images/examples/11_Sutro.jpg
ADDED
images/examples/12_Toaster.jpg
ADDED
images/examples/14_Willow.jpg
ADDED
images/examples/15_X-ProII.jpg
ADDED
images/examples/16_Lo-Fi.jpg
ADDED
images/examples/18_Gingham.jpg
ADDED
images/examples/1_Clarendon.jpg
ADDED
images/examples/2_Brannan.jpg
ADDED
images/examples/30_Perpetua.jpg
ADDED
images/examples/3_Mayfair.jpg
ADDED
images/examples/4_Hudson.jpg
ADDED
images/examples/5_Amaro.jpg
ADDED
images/examples/6_1977.jpg
ADDED
images/examples/8_Valencia.jpg
ADDED
images/examples/98_He-Fe.jpg
ADDED
images/examples/9_Lo-Fi.jpg
ADDED
layers/__pycache__/blocks.cpython-37.pyc
ADDED
Binary file (2.93 kB). View file
|
|
layers/__pycache__/normalization.cpython-37.pyc
ADDED
Binary file (859 Bytes). View file
|
|
layers/blocks.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
from layers.normalization import AdaIN
|
4 |
+
|
5 |
+
|
6 |
+
class DestyleResBlock(nn.Module):
|
7 |
+
def __init__(self, channels_out, kernel_size, channels_in=None, stride=1, dilation=1, padding=1, use_dropout=False):
|
8 |
+
super(DestyleResBlock, self).__init__()
|
9 |
+
|
10 |
+
# uses 1x1 convolutions for downsampling
|
11 |
+
if not channels_in or channels_in == channels_out:
|
12 |
+
channels_in = channels_out
|
13 |
+
self.projection = None
|
14 |
+
else:
|
15 |
+
self.projection = nn.Conv2d(channels_in, channels_out, kernel_size=1, stride=stride, dilation=1)
|
16 |
+
self.use_dropout = use_dropout
|
17 |
+
|
18 |
+
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
|
19 |
+
self.lrelu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
20 |
+
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
|
21 |
+
self.adain = AdaIN()
|
22 |
+
if self.use_dropout:
|
23 |
+
self.dropout = nn.Dropout()
|
24 |
+
self.lrelu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
25 |
+
|
26 |
+
def forward(self, x, feat):
|
27 |
+
residual = x
|
28 |
+
out = self.conv1(x)
|
29 |
+
out = self.lrelu1(out)
|
30 |
+
out = self.conv2(out)
|
31 |
+
_, _, h, w = out.size()
|
32 |
+
out = self.adain(out, feat)
|
33 |
+
if self.use_dropout:
|
34 |
+
out = self.dropout(out)
|
35 |
+
if self.projection:
|
36 |
+
residual = self.projection(x)
|
37 |
+
out = out + residual
|
38 |
+
out = self.lrelu2(out)
|
39 |
+
return out
|
40 |
+
|
41 |
+
|
42 |
+
class ResBlock(nn.Module):
|
43 |
+
def __init__(self, channels_out, kernel_size, channels_in=None, stride=1, dilation=1, padding=1, use_dropout=False):
|
44 |
+
super(ResBlock, self).__init__()
|
45 |
+
|
46 |
+
# uses 1x1 convolutions for downsampling
|
47 |
+
if not channels_in or channels_in == channels_out:
|
48 |
+
channels_in = channels_out
|
49 |
+
self.projection = None
|
50 |
+
else:
|
51 |
+
self.projection = nn.Conv2d(channels_in, channels_out, kernel_size=1, stride=stride, dilation=1)
|
52 |
+
self.use_dropout = use_dropout
|
53 |
+
|
54 |
+
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
|
55 |
+
self.lrelu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
56 |
+
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
|
57 |
+
self.n2 = nn.BatchNorm2d(channels_out)
|
58 |
+
if self.use_dropout:
|
59 |
+
self.dropout = nn.Dropout()
|
60 |
+
self.lrelu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
residual = x
|
64 |
+
out = self.conv1(x)
|
65 |
+
out = self.lrelu1(out)
|
66 |
+
out = self.conv2(out)
|
67 |
+
# out = self.n2(out)
|
68 |
+
if self.use_dropout:
|
69 |
+
out = self.dropout(out)
|
70 |
+
if self.projection:
|
71 |
+
residual = self.projection(x)
|
72 |
+
out = out + residual
|
73 |
+
out = self.lrelu2(out)
|
74 |
+
return out
|
75 |
+
|
76 |
+
|
77 |
+
class Destyler(nn.Module):
|
78 |
+
def __init__(self, in_features, num_features):
|
79 |
+
super(Destyler, self).__init__()
|
80 |
+
self.fc1 = nn.Linear(in_features, num_features)
|
81 |
+
self.fc2 = nn.Linear(num_features, num_features)
|
82 |
+
self.fc3 = nn.Linear(num_features, num_features)
|
83 |
+
self.fc4 = nn.Linear(num_features, num_features)
|
84 |
+
self.fc5 = nn.Linear(num_features, num_features)
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
x = self.fc1(x)
|
88 |
+
x = self.fc2(x)
|
89 |
+
x = self.fc3(x)
|
90 |
+
x = self.fc4(x)
|
91 |
+
x = self.fc5(x)
|
92 |
+
return x
|
93 |
+
|
layers/normalization.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class AdaIN(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
def forward(self, x, y):
|
10 |
+
ch = y.size(1)
|
11 |
+
sigma, mu = torch.split(y.unsqueeze(-1).unsqueeze(-1), [ch // 2, ch // 2], dim=1)
|
12 |
+
|
13 |
+
x_mu = x.mean(dim=[2, 3], keepdim=True)
|
14 |
+
x_sigma = x.std(dim=[2, 3], keepdim=True)
|
15 |
+
|
16 |
+
return sigma * ((x - x_mu) / x_sigma) + mu
|
modeling/__pycache__/arch.cpython-37.pyc
ADDED
Binary file (8.89 kB). View file
|
|
modeling/__pycache__/base.cpython-37.pyc
ADDED
Binary file (2.66 kB). View file
|
|
modeling/__pycache__/build.cpython-37.pyc
ADDED
Binary file (1.22 kB). View file
|
|
modeling/arch.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn.utils import spectral_norm
|
4 |
+
|
5 |
+
from modeling.base import BaseNetwork
|
6 |
+
from layers.blocks import DestyleResBlock, Destyler, ResBlock
|
7 |
+
|
8 |
+
|
9 |
+
class IFRNet(BaseNetwork):
|
10 |
+
def __init__(self, base_n_channels, destyler_n_channels):
|
11 |
+
super(IFRNet, self).__init__()
|
12 |
+
self.destyler = Destyler(in_features=32768, num_features=destyler_n_channels) # from vgg features
|
13 |
+
|
14 |
+
self.ds_fc1 = nn.Linear(destyler_n_channels, base_n_channels * 2)
|
15 |
+
self.ds_res1 = DestyleResBlock(channels_in=3, channels_out=base_n_channels, kernel_size=5, stride=1, padding=2)
|
16 |
+
self.ds_fc2 = nn.Linear(destyler_n_channels, base_n_channels * 4)
|
17 |
+
self.ds_res2 = DestyleResBlock(channels_in=base_n_channels, channels_out=base_n_channels * 2, kernel_size=3, stride=2, padding=1)
|
18 |
+
self.ds_fc3 = nn.Linear(destyler_n_channels, base_n_channels * 4)
|
19 |
+
self.ds_res3 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
|
20 |
+
self.ds_fc4 = nn.Linear(destyler_n_channels, base_n_channels * 8)
|
21 |
+
self.ds_res4 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 4, kernel_size=3, stride=2, padding=1)
|
22 |
+
self.ds_fc5 = nn.Linear(destyler_n_channels, base_n_channels * 8)
|
23 |
+
self.ds_res5 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
|
24 |
+
self.ds_fc6 = nn.Linear(destyler_n_channels, base_n_channels * 16)
|
25 |
+
self.ds_res6 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 8, kernel_size=3, stride=2, padding=1)
|
26 |
+
|
27 |
+
self.upsample = nn.UpsamplingNearest2d(scale_factor=2.0)
|
28 |
+
|
29 |
+
self.res1 = ResBlock(channels_in=base_n_channels * 8, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
|
30 |
+
self.res2 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1)
|
31 |
+
self.res3 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
|
32 |
+
self.res4 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1)
|
33 |
+
self.res5 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels, kernel_size=3, stride=1, padding=1)
|
34 |
+
|
35 |
+
self.conv1 = nn.Conv2d(base_n_channels, 3, kernel_size=3, stride=1, padding=1)
|
36 |
+
|
37 |
+
self.init_weights(init_type="normal", gain=0.02)
|
38 |
+
|
39 |
+
def forward(self, x, vgg_feat):
|
40 |
+
b_size, ch, h, w = vgg_feat.size()
|
41 |
+
vgg_feat = vgg_feat.view(b_size, ch * h * w)
|
42 |
+
vgg_feat = self.destyler(vgg_feat)
|
43 |
+
|
44 |
+
out = self.ds_res1(x, self.ds_fc1(vgg_feat))
|
45 |
+
out = self.ds_res2(out, self.ds_fc2(vgg_feat))
|
46 |
+
out = self.ds_res3(out, self.ds_fc3(vgg_feat))
|
47 |
+
out = self.ds_res4(out, self.ds_fc4(vgg_feat))
|
48 |
+
out = self.ds_res5(out, self.ds_fc5(vgg_feat))
|
49 |
+
aux = self.ds_res6(out, self.ds_fc6(vgg_feat))
|
50 |
+
|
51 |
+
out = self.upsample(aux)
|
52 |
+
out = self.res1(out)
|
53 |
+
out = self.res2(out)
|
54 |
+
out = self.upsample(out)
|
55 |
+
out = self.res3(out)
|
56 |
+
out = self.res4(out)
|
57 |
+
out = self.upsample(out)
|
58 |
+
out = self.res5(out)
|
59 |
+
out = self.conv1(out)
|
60 |
+
|
61 |
+
return out, aux
|
62 |
+
|
63 |
+
|
64 |
+
class CIFR_Encoder(IFRNet):
|
65 |
+
def __init__(self, base_n_channels, destyler_n_channels):
|
66 |
+
super(CIFR_Encoder, self).__init__(base_n_channels, destyler_n_channels)
|
67 |
+
|
68 |
+
def forward(self, x, vgg_feat):
|
69 |
+
b_size, ch, h, w = vgg_feat.size()
|
70 |
+
vgg_feat = vgg_feat.view(b_size, ch * h * w)
|
71 |
+
vgg_feat = self.destyler(vgg_feat)
|
72 |
+
|
73 |
+
feat1 = self.ds_res1(x, self.ds_fc1(vgg_feat))
|
74 |
+
feat2 = self.ds_res2(feat1, self.ds_fc2(vgg_feat))
|
75 |
+
feat3 = self.ds_res3(feat2, self.ds_fc3(vgg_feat))
|
76 |
+
feat4 = self.ds_res4(feat3, self.ds_fc4(vgg_feat))
|
77 |
+
feat5 = self.ds_res5(feat4, self.ds_fc5(vgg_feat))
|
78 |
+
feat6 = self.ds_res6(feat5, self.ds_fc6(vgg_feat))
|
79 |
+
|
80 |
+
feats = [feat1, feat2, feat3, feat4, feat5, feat6]
|
81 |
+
|
82 |
+
out = self.upsample(feat6)
|
83 |
+
out = self.res1(out)
|
84 |
+
out = self.res2(out)
|
85 |
+
out = self.upsample(out)
|
86 |
+
out = self.res3(out)
|
87 |
+
out = self.res4(out)
|
88 |
+
out = self.upsample(out)
|
89 |
+
out = self.res5(out)
|
90 |
+
out = self.conv1(out)
|
91 |
+
|
92 |
+
return out, feats
|
93 |
+
|
94 |
+
|
95 |
+
class Normalize(nn.Module):
|
96 |
+
def __init__(self, power=2):
|
97 |
+
super(Normalize, self).__init__()
|
98 |
+
self.power = power
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
|
102 |
+
out = x.div(norm + 1e-7)
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
class PatchSampleF(BaseNetwork):
|
107 |
+
def __init__(self, base_n_channels, style_or_content, use_mlp=False, nc=256):
|
108 |
+
# potential issues: currently, we use the same patch_ids for multiple images in the batch
|
109 |
+
super(PatchSampleF, self).__init__()
|
110 |
+
self.is_content = True if style_or_content == "content" else False
|
111 |
+
self.l2norm = Normalize(2)
|
112 |
+
self.use_mlp = use_mlp
|
113 |
+
self.nc = nc # hard-coded
|
114 |
+
|
115 |
+
self.mlp_0 = nn.Sequential(*[nn.Linear(base_n_channels, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
|
116 |
+
self.mlp_1 = nn.Sequential(*[nn.Linear(base_n_channels * 2, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
|
117 |
+
self.mlp_2 = nn.Sequential(*[nn.Linear(base_n_channels * 2, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
|
118 |
+
self.mlp_3 = nn.Sequential(*[nn.Linear(base_n_channels * 4, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
|
119 |
+
self.mlp_4 = nn.Sequential(*[nn.Linear(base_n_channels * 4, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
|
120 |
+
self.mlp_5 = nn.Sequential(*[nn.Linear(base_n_channels * 8, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda()
|
121 |
+
self.init_weights(init_type="normal", gain=0.02)
|
122 |
+
|
123 |
+
@staticmethod
|
124 |
+
def gram_matrix(x):
|
125 |
+
# a, b, c, d = x.size() # a=batch size(=1)
|
126 |
+
a, b = x.size()
|
127 |
+
# b=number of feature maps
|
128 |
+
# (c,d)=dimensions of a f. map (N=c*d)
|
129 |
+
|
130 |
+
# features = x.view(a * b, c * d) # resise F_XL into \hat F_XL
|
131 |
+
|
132 |
+
G = torch.mm(x, x.t()) # compute the gram product
|
133 |
+
|
134 |
+
# we 'normalize' the values of the gram matrix
|
135 |
+
# by dividing by the number of element in each feature maps.
|
136 |
+
return G.div(a * b)
|
137 |
+
|
138 |
+
def forward(self, feats, num_patches=64, patch_ids=None):
|
139 |
+
return_ids = []
|
140 |
+
return_feats = []
|
141 |
+
|
142 |
+
for feat_id, feat in enumerate(feats):
|
143 |
+
B, C, H, W = feat.shape
|
144 |
+
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
|
145 |
+
if num_patches > 0:
|
146 |
+
if patch_ids is not None:
|
147 |
+
patch_id = patch_ids[feat_id]
|
148 |
+
else:
|
149 |
+
patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
|
150 |
+
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
|
151 |
+
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
|
152 |
+
else:
|
153 |
+
x_sample = feat_reshape
|
154 |
+
patch_id = []
|
155 |
+
if self.use_mlp:
|
156 |
+
mlp = getattr(self, 'mlp_%d' % feat_id)
|
157 |
+
x_sample = mlp(x_sample)
|
158 |
+
if not self.is_content:
|
159 |
+
x_sample = self.gram_matrix(x_sample)
|
160 |
+
return_ids.append(patch_id)
|
161 |
+
x_sample = self.l2norm(x_sample)
|
162 |
+
|
163 |
+
if num_patches == 0:
|
164 |
+
x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
|
165 |
+
return_feats.append(x_sample)
|
166 |
+
return return_feats, return_ids
|
167 |
+
|
168 |
+
|
169 |
+
class MLP(nn.Module):
|
170 |
+
def __init__(self, base_n_channels, out_features=14):
|
171 |
+
super(MLP, self).__init__()
|
172 |
+
self.aux_classifier = nn.Sequential(
|
173 |
+
nn.Conv2d(base_n_channels * 8, base_n_channels * 4, kernel_size=3, stride=1, padding=1),
|
174 |
+
nn.MaxPool2d(2),
|
175 |
+
nn.Conv2d(base_n_channels * 4, base_n_channels * 2, kernel_size=3, stride=1, padding=1),
|
176 |
+
nn.MaxPool2d(2),
|
177 |
+
# nn.Conv2d(base_n_channels * 2, base_n_channels * 1, kernel_size=3, stride=1, padding=1),
|
178 |
+
# nn.MaxPool2d(2),
|
179 |
+
Flatten(),
|
180 |
+
nn.Linear(base_n_channels * 8 * 8 * 2, out_features),
|
181 |
+
# nn.Softmax(dim=-1)
|
182 |
+
)
|
183 |
+
|
184 |
+
def forward(self, x):
|
185 |
+
return self.aux_classifier(x)
|
186 |
+
|
187 |
+
|
188 |
+
class Flatten(nn.Module):
|
189 |
+
def forward(self, input):
|
190 |
+
"""
|
191 |
+
Note that input.size(0) is usually the batch size.
|
192 |
+
So what it does is that given any input with input.size(0) # of batches,
|
193 |
+
will flatten to be 1 * nb_elements.
|
194 |
+
"""
|
195 |
+
batch_size = input.size(0)
|
196 |
+
out = input.view(batch_size, -1)
|
197 |
+
return out # (batch_size, *size)
|
198 |
+
|
199 |
+
|
200 |
+
class Discriminator(BaseNetwork):
|
201 |
+
def __init__(self, base_n_channels):
|
202 |
+
"""
|
203 |
+
img_size : (int, int, int)
|
204 |
+
Height and width must be powers of 2. E.g. (32, 32, 1) or
|
205 |
+
(64, 128, 3). Last number indicates number of channels, e.g. 1 for
|
206 |
+
grayscale or 3 for RGB
|
207 |
+
"""
|
208 |
+
super(Discriminator, self).__init__()
|
209 |
+
|
210 |
+
self.image_to_features = nn.Sequential(
|
211 |
+
spectral_norm(nn.Conv2d(3, base_n_channels, 5, 2, 2)),
|
212 |
+
nn.LeakyReLU(0.2, inplace=True),
|
213 |
+
spectral_norm(nn.Conv2d(base_n_channels, 2 * base_n_channels, 5, 2, 2)),
|
214 |
+
nn.LeakyReLU(0.2, inplace=True),
|
215 |
+
spectral_norm(nn.Conv2d(2 * base_n_channels, 2 * base_n_channels, 5, 2, 2)),
|
216 |
+
nn.LeakyReLU(0.2, inplace=True),
|
217 |
+
spectral_norm(nn.Conv2d(2 * base_n_channels, 4 * base_n_channels, 5, 2, 2)),
|
218 |
+
nn.LeakyReLU(0.2, inplace=True),
|
219 |
+
# spectral_norm(nn.Conv2d(4 * base_n_channels, 4 * base_n_channels, 5, 2, 2)),
|
220 |
+
# nn.LeakyReLU(0.2, inplace=True),
|
221 |
+
spectral_norm(nn.Conv2d(4 * base_n_channels, 8 * base_n_channels, 5, 1, 1)),
|
222 |
+
nn.LeakyReLU(0.2, inplace=True),
|
223 |
+
)
|
224 |
+
|
225 |
+
output_size = 8 * base_n_channels * 3 * 3
|
226 |
+
self.features_to_prob = nn.Sequential(
|
227 |
+
spectral_norm(nn.Conv2d(8 * base_n_channels, 2 * base_n_channels, 5, 2, 1)),
|
228 |
+
Flatten(),
|
229 |
+
nn.Linear(output_size, 1)
|
230 |
+
)
|
231 |
+
|
232 |
+
self.init_weights(init_type="normal", gain=0.02)
|
233 |
+
|
234 |
+
def forward(self, input_data):
|
235 |
+
x = self.image_to_features(input_data)
|
236 |
+
return self.features_to_prob(x)
|
237 |
+
|
238 |
+
|
239 |
+
class PatchDiscriminator(Discriminator):
|
240 |
+
def __init__(self, base_n_channels):
|
241 |
+
super(PatchDiscriminator, self).__init__(base_n_channels)
|
242 |
+
|
243 |
+
self.features_to_prob = nn.Sequential(
|
244 |
+
spectral_norm(nn.Conv2d(8 * base_n_channels, 1, 1)),
|
245 |
+
Flatten()
|
246 |
+
)
|
247 |
+
|
248 |
+
def forward(self, input_data):
|
249 |
+
x = self.image_to_features(input_data)
|
250 |
+
return self.features_to_prob(x)
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == '__main__':
|
254 |
+
import torchvision
|
255 |
+
ifrnet = CIFR_Encoder(32, 128).cuda()
|
256 |
+
x = torch.rand((2, 3, 256, 256)).cuda()
|
257 |
+
vgg16 = torchvision.models.vgg16(pretrained=True).features.eval().cuda()
|
258 |
+
with torch.no_grad():
|
259 |
+
vgg_feat = vgg16(x)
|
260 |
+
output, feats = ifrnet(x, vgg_feat)
|
261 |
+
print(output.size())
|
262 |
+
for i, feat in enumerate(feats):
|
263 |
+
print(i, feat.size())
|
264 |
+
|
265 |
+
disc = Discriminator(32).cuda()
|
266 |
+
d_out = disc(output)
|
267 |
+
print(d_out.size())
|
268 |
+
|
269 |
+
patch_disc = PatchDiscriminator(32).cuda()
|
270 |
+
p_d_out = patch_disc(output)
|
271 |
+
print(p_d_out.size())
|
272 |
+
|
modeling/base.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
class BaseNetwork(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super(BaseNetwork, self).__init__()
|
7 |
+
|
8 |
+
def forward(self, x, y):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def print_network(self):
|
12 |
+
if isinstance(self, list):
|
13 |
+
self = self[0]
|
14 |
+
num_params = 0
|
15 |
+
for param in self.parameters():
|
16 |
+
num_params += param.numel()
|
17 |
+
print('Network [%s] was created. Total number of parameters: %.1f million. '
|
18 |
+
'To see the architecture, do print(network).'
|
19 |
+
% (type(self).__name__, num_params / 1000000))
|
20 |
+
|
21 |
+
def set_requires_grad(self, requires_grad=False):
|
22 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
23 |
+
Parameters:
|
24 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
25 |
+
"""
|
26 |
+
for param in self.parameters():
|
27 |
+
param.requires_grad = requires_grad
|
28 |
+
|
29 |
+
def init_weights(self, init_type='xavier', gain=0.02):
|
30 |
+
def init_func(m):
|
31 |
+
classname = m.__class__.__name__
|
32 |
+
if classname.find('BatchNorm2d') != -1:
|
33 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
34 |
+
nn.init.normal_(m.weight.data, 1.0, gain)
|
35 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
36 |
+
nn.init.constant_(m.bias.data, 0.0)
|
37 |
+
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
38 |
+
if init_type == 'normal':
|
39 |
+
nn.init.normal_(m.weight.data, 0.0, gain)
|
40 |
+
elif init_type == 'xavier':
|
41 |
+
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
42 |
+
elif init_type == 'xavier_uniform':
|
43 |
+
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
|
44 |
+
elif init_type == 'kaiming':
|
45 |
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
46 |
+
elif init_type == 'orthogonal':
|
47 |
+
nn.init.orthogonal_(m.weight.data, gain=gain)
|
48 |
+
elif init_type == 'none': # uses pytorch's default init method
|
49 |
+
m.reset_parameters()
|
50 |
+
else:
|
51 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
52 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
53 |
+
nn.init.constant_(m.bias.data, 0.0)
|
54 |
+
|
55 |
+
self.apply(init_func)
|
56 |
+
|
57 |
+
# propagate to children
|
58 |
+
for m in self.children():
|
59 |
+
if hasattr(m, 'init_weights'):
|
60 |
+
m.init_weights(init_type, gain)
|
modeling/build.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modeling.arch import IFRNet, CIFR_Encoder, Discriminator, PatchDiscriminator, MLP, PatchSampleF
|
2 |
+
|
3 |
+
|
4 |
+
def build_model(args):
|
5 |
+
if args.MODEL.NAME.lower() == "ifrnet":
|
6 |
+
net = IFRNet(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS)
|
7 |
+
mlp = MLP(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, out_features=args.MODEL.NUM_CLASS)
|
8 |
+
elif args.MODEL.NAME.lower() == "cifr":
|
9 |
+
net = CIFR_Encoder(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS)
|
10 |
+
mlp = None
|
11 |
+
elif args.MODEL.NAME.lower() == "ifr-no-aux":
|
12 |
+
net = IFRNet(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, destyler_n_channels=args.MODEL.IFR.DESTYLER_CHANNELS)
|
13 |
+
mlp = None
|
14 |
+
else:
|
15 |
+
raise NotImplementedError
|
16 |
+
return net, mlp
|
17 |
+
|
18 |
+
|
19 |
+
def build_discriminators(args):
|
20 |
+
return Discriminator(base_n_channels=args.MODEL.D.NUM_CHANNELS), PatchDiscriminator(base_n_channels=args.MODEL.D.NUM_CHANNELS)
|
21 |
+
|
22 |
+
|
23 |
+
def build_patch_sampler(args):
|
24 |
+
return PatchSampleF(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, style_or_content="content", use_mlp=args.MODEL.PATCH.USE_MLP, nc=args.MODEL.PATCH.NUM_CHANNELS), \
|
25 |
+
PatchSampleF(base_n_channels=args.MODEL.IFR.NUM_CHANNELS, style_or_content="style", use_mlp=args.MODEL.PATCH.USE_MLP, nc=args.MODEL.PATCH.NUM_CHANNELS)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==2.9.4
|
2 |
+
numpy==1.21.2
|
3 |
+
requests==2.27.1
|
4 |
+
torch==1.10.1
|
5 |
+
torchvision==0.11.2
|
6 |
+
yacs==0.1.8
|
utils/__pycache__/data_utils.cpython-37.pyc
ADDED
Binary file (1.07 kB). View file
|
|
utils/data_utils.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
|
4 |
+
def linear_scaling(x):
|
5 |
+
return (x * 255.) / 127.5 - 1.
|
6 |
+
|
7 |
+
|
8 |
+
def linear_unscaling(x):
|
9 |
+
return (x + 1.) * 127.5 / 255.
|
10 |
+
|
11 |
+
|
12 |
+
def read_json(path):
|
13 |
+
"""
|
14 |
+
:param path (str or os.Path): JSON file path.
|
15 |
+
:return: (Dict): the data in the JSON file.
|
16 |
+
"""
|
17 |
+
with open(path) as f:
|
18 |
+
data = json.load(f)
|
19 |
+
return data
|
20 |
+
|
21 |
+
|
22 |
+
def write_json(path, datagroup):
|
23 |
+
"""
|
24 |
+
:param path (str or os.Path): File path for the output JSON file.
|
25 |
+
:param datagroup (Dict): The data which should be dump to the JSON file.
|
26 |
+
:return: void.
|
27 |
+
"""
|
28 |
+
with open(path, "w+") as f:
|
29 |
+
json.dump(datagroup, f)
|