Upload 5 files
Browse files
gan.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
import torch
|
6 |
+
import io
|
7 |
+
import os
|
8 |
+
import functools
|
9 |
+
|
10 |
+
class DataLoader():
|
11 |
+
|
12 |
+
def __init__(self, opt, cv_img):
|
13 |
+
super(DataLoader, self).__init__()
|
14 |
+
|
15 |
+
self.dataset = Dataset()
|
16 |
+
self.dataset.initialize(opt, cv_img)
|
17 |
+
|
18 |
+
self.dataloader = torch.utils.data.DataLoader(
|
19 |
+
self.dataset,
|
20 |
+
batch_size=opt.batchSize,
|
21 |
+
shuffle=not opt.serial_batches,
|
22 |
+
num_workers=int(opt.nThreads))
|
23 |
+
|
24 |
+
def load_data(self):
|
25 |
+
return self.dataloader
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return 1
|
29 |
+
|
30 |
+
class Dataset(torch.utils.data.Dataset):
|
31 |
+
def __init__(self):
|
32 |
+
super(Dataset, self).__init__()
|
33 |
+
|
34 |
+
def initialize(self, opt, cv_img):
|
35 |
+
self.opt = opt
|
36 |
+
self.root = opt.dataroot
|
37 |
+
|
38 |
+
self.A = Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
|
39 |
+
self.dataset_size = 1
|
40 |
+
|
41 |
+
def __getitem__(self, index):
|
42 |
+
|
43 |
+
transform_A = get_transform(self.opt)
|
44 |
+
A_tensor = transform_A(self.A.convert('RGB'))
|
45 |
+
|
46 |
+
B_tensor = inst_tensor = feat_tensor = 0
|
47 |
+
|
48 |
+
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
|
49 |
+
'feat': feat_tensor, 'path': ""}
|
50 |
+
|
51 |
+
return input_dict
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return 1
|
55 |
+
|
56 |
+
class DeepModel(torch.nn.Module):
|
57 |
+
|
58 |
+
def initialize(self, opt):
|
59 |
+
|
60 |
+
torch.cuda.empty_cache()
|
61 |
+
|
62 |
+
self.opt = opt
|
63 |
+
|
64 |
+
self.gpu_ids = [] #FIX CPU
|
65 |
+
|
66 |
+
self.netG = self.__define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
|
67 |
+
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
|
68 |
+
opt.n_blocks_local, opt.norm, self.gpu_ids)
|
69 |
+
|
70 |
+
# load networks
|
71 |
+
self.__load_network(self.netG)
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
def inference(self, label, inst):
|
76 |
+
|
77 |
+
# Encode Inputs
|
78 |
+
input_label, inst_map, _, _ = self.__encode_input(label, inst, infer=True)
|
79 |
+
|
80 |
+
# Fake Generation
|
81 |
+
input_concat = input_label
|
82 |
+
|
83 |
+
with torch.no_grad():
|
84 |
+
fake_image = self.netG.forward(input_concat)
|
85 |
+
|
86 |
+
return fake_image
|
87 |
+
|
88 |
+
# helper loading function that can be used by subclasses
|
89 |
+
def __load_network(self, network):
|
90 |
+
|
91 |
+
save_path = os.path.join(self.opt.checkpoints_dir)
|
92 |
+
|
93 |
+
network.load_state_dict(torch.load(save_path))
|
94 |
+
|
95 |
+
def __encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
|
96 |
+
if (len(self.gpu_ids) > 0):
|
97 |
+
input_label = label_map.data.cuda() #GPU
|
98 |
+
else:
|
99 |
+
input_label = label_map.data #CPU
|
100 |
+
|
101 |
+
return input_label, inst_map, real_image, feat_map
|
102 |
+
|
103 |
+
def __weights_init(self, m):
|
104 |
+
classname = m.__class__.__name__
|
105 |
+
if classname.find('Conv') != -1:
|
106 |
+
m.weight.data.normal_(0.0, 0.02)
|
107 |
+
elif classname.find('BatchNorm2d') != -1:
|
108 |
+
m.weight.data.normal_(1.0, 0.02)
|
109 |
+
m.bias.data.fill_(0)
|
110 |
+
|
111 |
+
def __define_G(self, input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
|
112 |
+
n_blocks_local=3, norm='instance', gpu_ids=[]):
|
113 |
+
norm_layer = self.__get_norm_layer(norm_type=norm)
|
114 |
+
netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
|
115 |
+
|
116 |
+
if len(gpu_ids) > 0:
|
117 |
+
netG.cuda(gpu_ids[0])
|
118 |
+
netG.apply(self.__weights_init)
|
119 |
+
return netG
|
120 |
+
|
121 |
+
def __get_norm_layer(self, norm_type='instance'):
|
122 |
+
norm_layer = functools.partial(torch.nn.InstanceNorm2d, affine=False)
|
123 |
+
return norm_layer
|
124 |
+
|
125 |
+
##############################################################################
|
126 |
+
# Generator
|
127 |
+
##############################################################################
|
128 |
+
class GlobalGenerator(torch.nn.Module):
|
129 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=torch.nn.BatchNorm2d,
|
130 |
+
padding_type='reflect'):
|
131 |
+
assert(n_blocks >= 0)
|
132 |
+
super(GlobalGenerator, self).__init__()
|
133 |
+
activation = torch.nn.ReLU(True)
|
134 |
+
|
135 |
+
model = [torch.nn.ReflectionPad2d(3), torch.nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
|
136 |
+
### downsample
|
137 |
+
for i in range(n_downsampling):
|
138 |
+
mult = 2**i
|
139 |
+
model += [torch.nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
140 |
+
norm_layer(ngf * mult * 2), activation]
|
141 |
+
|
142 |
+
### resnet blocks
|
143 |
+
mult = 2**n_downsampling
|
144 |
+
for i in range(n_blocks):
|
145 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
|
146 |
+
|
147 |
+
### upsample
|
148 |
+
for i in range(n_downsampling):
|
149 |
+
mult = 2**(n_downsampling - i)
|
150 |
+
model += [torch.nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
|
151 |
+
norm_layer(int(ngf * mult / 2)), activation]
|
152 |
+
model += [torch.nn.ReflectionPad2d(3), torch.nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), torch.nn.Tanh()]
|
153 |
+
self.model = torch.nn.Sequential(*model)
|
154 |
+
|
155 |
+
def forward(self, input):
|
156 |
+
return self.model(input)
|
157 |
+
|
158 |
+
# Define a resnet block
|
159 |
+
class ResnetBlock(torch.nn.Module):
|
160 |
+
def __init__(self, dim, padding_type, norm_layer, activation=torch.nn.ReLU(True), use_dropout=False):
|
161 |
+
super(ResnetBlock, self).__init__()
|
162 |
+
self.conv_block = self.__build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
|
163 |
+
|
164 |
+
def __build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
|
165 |
+
conv_block = []
|
166 |
+
p = 0
|
167 |
+
if padding_type == 'reflect':
|
168 |
+
conv_block += [torch.nn.ReflectionPad2d(1)]
|
169 |
+
elif padding_type == 'replicate':
|
170 |
+
conv_block += [torch.nn.ReplicationPad2d(1)]
|
171 |
+
elif padding_type == 'zero':
|
172 |
+
p = 1
|
173 |
+
else:
|
174 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
175 |
+
|
176 |
+
conv_block += [torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
177 |
+
norm_layer(dim),
|
178 |
+
activation]
|
179 |
+
if use_dropout:
|
180 |
+
conv_block += [torch.nn.Dropout(0.5)]
|
181 |
+
|
182 |
+
p = 0
|
183 |
+
if padding_type == 'reflect':
|
184 |
+
conv_block += [torch.nn.ReflectionPad2d(1)]
|
185 |
+
elif padding_type == 'replicate':
|
186 |
+
conv_block += [torch.nn.ReplicationPad2d(1)]
|
187 |
+
elif padding_type == 'zero':
|
188 |
+
p = 1
|
189 |
+
else:
|
190 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
191 |
+
conv_block += [torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
192 |
+
norm_layer(dim)]
|
193 |
+
|
194 |
+
return torch.nn.Sequential(*conv_block)
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
out = x + self.conv_block(x)
|
198 |
+
return out
|
199 |
+
|
200 |
+
# Data utils:
|
201 |
+
def get_transform(opt, method=Image.BICUBIC, normalize=True):
|
202 |
+
transform_list = []
|
203 |
+
|
204 |
+
base = float(2 ** opt.n_downsample_global)
|
205 |
+
if opt.netG == 'local':
|
206 |
+
base *= (2 ** opt.n_local_enhancers)
|
207 |
+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
|
208 |
+
|
209 |
+
transform_list += [transforms.ToTensor()]
|
210 |
+
|
211 |
+
if normalize:
|
212 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
|
213 |
+
(0.5, 0.5, 0.5))]
|
214 |
+
return transforms.Compose(transform_list)
|
215 |
+
|
216 |
+
def __make_power_2(img, base, method=Image.BICUBIC):
|
217 |
+
ow, oh = img.size
|
218 |
+
h = int(round(oh / base) * base)
|
219 |
+
w = int(round(ow / base) * base)
|
220 |
+
if (h == oh) and (w == ow):
|
221 |
+
return img
|
222 |
+
return img.resize((w, h), method)
|
223 |
+
|
224 |
+
# Converts a Tensor into a Numpy array
|
225 |
+
# |imtype|: the desired type of the converted numpy array
|
226 |
+
def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
|
227 |
+
if isinstance(image_tensor, list):
|
228 |
+
image_numpy = []
|
229 |
+
for i in range(len(image_tensor)):
|
230 |
+
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
|
231 |
+
return image_numpy
|
232 |
+
image_numpy = image_tensor.cpu().float().numpy()
|
233 |
+
if normalize:
|
234 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
235 |
+
else:
|
236 |
+
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
|
237 |
+
image_numpy = np.clip(image_numpy, 0, 255)
|
238 |
+
if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
|
239 |
+
image_numpy = image_numpy[:,:,0]
|
240 |
+
return image_numpy.astype(imtype)
|
gggg.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from run import process
|
2 |
+
import time
|
3 |
+
import subprocess
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import cv2
|
7 |
+
import sys
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
|
13 |
+
TESTdevice = "cpu"
|
14 |
+
|
15 |
+
index = 1
|
16 |
+
|
17 |
+
|
18 |
+
"""
|
19 |
+
main.py
|
20 |
+
|
21 |
+
How to run:
|
22 |
+
python main.py
|
23 |
+
|
24 |
+
"""
|
25 |
+
|
26 |
+
|
27 |
+
def mainTest(inputpath, outpath):
|
28 |
+
watermark = deep_nude_process(inputpath)
|
29 |
+
cv2.imwrite(outpath, watermark)
|
30 |
+
return watermark
|
31 |
+
#
|
32 |
+
|
33 |
+
|
34 |
+
def deep_nude_process(item):
|
35 |
+
# print('Processing {}'.format(item))
|
36 |
+
# dress = cv2.imread(item)
|
37 |
+
dress = (item)
|
38 |
+
h = dress.shape[0]
|
39 |
+
w = dress.shape[1]
|
40 |
+
dress = cv2.resize(dress, (512, 512), interpolation=cv2.INTER_CUBIC)
|
41 |
+
watermark = process(dress)
|
42 |
+
watermark = cv2.resize(watermark, (w, h), interpolation=cv2.INTER_CUBIC)
|
43 |
+
return watermark
|
44 |
+
|
45 |
+
|
46 |
+
def inference(img):
|
47 |
+
global index
|
48 |
+
# inputpath = "input/" + str(index) + ".jpg"
|
49 |
+
outputpath = "out_" + str(index) + ".jpg"
|
50 |
+
# cv2.imwrite(inputpath, img)
|
51 |
+
index += 1
|
52 |
+
print(time.strftime("START!!!!!!!!! %Y-%m-%d %H:%M:%S", time.localtime()))
|
53 |
+
output = mainTest(img, outputpath)
|
54 |
+
print(time.strftime("FINISH!!!!!!!!! %Y-%m-%d %H:%M:%S", time.localtime()))
|
55 |
+
return output
|
56 |
+
|
57 |
+
|
58 |
+
title = "AI脱衣"
|
59 |
+
description = "传入人物照片,类似最下方测试图的那种,将制作脱衣图,一张图至少等40秒,别传私人照片,禁止传真人照片"
|
60 |
+
|
61 |
+
examples = [
|
62 |
+
['input.png', '测试图'],
|
63 |
+
]
|
64 |
+
|
65 |
+
|
66 |
+
web = gr.Interface(inference,
|
67 |
+
inputs="image",
|
68 |
+
outputs="image",
|
69 |
+
title=title,
|
70 |
+
description=description,
|
71 |
+
examples=examples,
|
72 |
+
)
|
73 |
+
|
74 |
+
if __name__ == '__main__':
|
75 |
+
web.launch(
|
76 |
+
share=True,
|
77 |
+
enable_queue=True
|
78 |
+
)
|
input.png
ADDED
main.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
import sys
|
3 |
+
import cv2
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import subprocess
|
8 |
+
import time
|
9 |
+
|
10 |
+
from run import process
|
11 |
+
|
12 |
+
"""
|
13 |
+
main.py
|
14 |
+
|
15 |
+
How to run:
|
16 |
+
python main.py
|
17 |
+
|
18 |
+
"""
|
19 |
+
|
20 |
+
|
21 |
+
def main(inputpath, outpath, show):
|
22 |
+
if isinstance(inputpath, list):
|
23 |
+
for item in inputpath:
|
24 |
+
watermark = deep_nude_process(item)
|
25 |
+
cv2.imwrite("output_"+item, watermark)
|
26 |
+
else:
|
27 |
+
watermark = deep_nude_process(inputpath)
|
28 |
+
cv2.imwrite(outpath, watermark)
|
29 |
+
|
30 |
+
def deep_nude_process(item):
|
31 |
+
print('Processing {}'.format(item))
|
32 |
+
dress = cv2.imread(item)
|
33 |
+
h = dress.shape[0]
|
34 |
+
w = dress.shape[1]
|
35 |
+
dress = cv2.resize(dress, (512,512), interpolation=cv2.INTER_CUBIC)
|
36 |
+
watermark = process(dress)
|
37 |
+
watermark = cv2.resize(watermark, (w,h), interpolation=cv2.INTER_CUBIC)
|
38 |
+
return watermark
|
39 |
+
|
40 |
+
if __name__ == '__main__':
|
41 |
+
parser = argparse.ArgumentParser(description="simple deep nude script tool")
|
42 |
+
parser.add_argument("-i", "--input", action="store", nargs = "*", default="input.png", help = "Use to enter input one or more files's name")
|
43 |
+
parser.add_argument("-o", "--output", action="store", default="output.png", help = "Use to enter output file name")
|
44 |
+
parser.add_argument("-s", "--show", action="store", default="false", help = "Use to automatically display or not display generated images")
|
45 |
+
inputpath, outputpath, show = parser.parse_args().input, parser.parse_args().output, parser.parse_args().show
|
46 |
+
|
47 |
+
print (time.strftime("START!!!!!!!!! %Y-%m-%d %H:%M:%S", time.localtime()))
|
48 |
+
main(inputpath, outputpath, show)
|
49 |
+
print (time.strftime("FINISH!!!!!!!!! %Y-%m-%d %H:%M:%S", time.localtime()))
|
run.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
#Import Neural Network Model
|
4 |
+
from gan import DataLoader, DeepModel, tensor2im
|
5 |
+
|
6 |
+
#OpenCv Transform:
|
7 |
+
from opencv_transform.mask_to_maskref import create_maskref
|
8 |
+
from opencv_transform.maskdet_to_maskfin import create_maskfin
|
9 |
+
from opencv_transform.dress_to_correct import create_correct
|
10 |
+
from opencv_transform.nude_to_watermark import create_watermark
|
11 |
+
|
12 |
+
"""
|
13 |
+
run.py
|
14 |
+
|
15 |
+
This script manage the entire transormation.
|
16 |
+
|
17 |
+
Transformation happens in 6 phases:
|
18 |
+
0: dress -> correct [opencv] dress_to_correct
|
19 |
+
1: correct -> mask: [GAN] correct_to_mask
|
20 |
+
2: mask -> maskref [opencv] mask_to_maskref
|
21 |
+
3: maskref -> maskdet [GAN] maskref_to_maskdet
|
22 |
+
4: maskdet -> maskfin [opencv] maskdet_to_maskfin
|
23 |
+
5: maskfin -> nude [GAN] maskfin_to_nude
|
24 |
+
6: nude -> watermark [opencv] nude_to_watermark
|
25 |
+
|
26 |
+
"""
|
27 |
+
|
28 |
+
phases = ["dress_to_correct", "correct_to_mask", "mask_to_maskref", "maskref_to_maskdet", "maskdet_to_maskfin", "maskfin_to_nude", "nude_to_watermark"]
|
29 |
+
|
30 |
+
class Options():
|
31 |
+
|
32 |
+
#Init options with default values
|
33 |
+
def __init__(self):
|
34 |
+
|
35 |
+
# experiment specifics
|
36 |
+
self.norm = 'batch' #instance normalization or batch normalization
|
37 |
+
self.use_dropout = False #use dropout for the generator
|
38 |
+
self.data_type = 32 #Supported data type i.e. 8, 16, 32 bit
|
39 |
+
|
40 |
+
# input/output sizes
|
41 |
+
self.batchSize = 1 #input batch size
|
42 |
+
self.input_nc = 3 # of input image channels
|
43 |
+
self.output_nc = 3 # of output image channels
|
44 |
+
|
45 |
+
# for setting inputs
|
46 |
+
self.serial_batches = True #if true, takes images in order to make batches, otherwise takes them randomly
|
47 |
+
self.nThreads = 1 ## threads for loading data (???)
|
48 |
+
self.max_dataset_size = 1 #Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.
|
49 |
+
|
50 |
+
# for generator
|
51 |
+
self.netG = 'global' #selects model to use for netG
|
52 |
+
self.ngf = 64 ## of gen filters in first conv layer
|
53 |
+
self.n_downsample_global = 4 #number of downsampling layers in netG
|
54 |
+
self.n_blocks_global = 9 #number of residual blocks in the global generator network
|
55 |
+
self.n_blocks_local = 0 #number of residual blocks in the local enhancer network
|
56 |
+
self.n_local_enhancers = 0 #number of local enhancers to use
|
57 |
+
self.niter_fix_global = 0 #number of epochs that we only train the outmost local enhancer
|
58 |
+
|
59 |
+
#Phase specific options
|
60 |
+
self.checkpoints_dir = ""
|
61 |
+
self.dataroot = ""
|
62 |
+
|
63 |
+
#Changes options accordlying to actual phase
|
64 |
+
def updateOptions(self, phase):
|
65 |
+
|
66 |
+
if phase == "correct_to_mask":
|
67 |
+
self.checkpoints_dir = "checkpoints/cm.lib"
|
68 |
+
|
69 |
+
elif phase == "maskref_to_maskdet":
|
70 |
+
self.checkpoints_dir = "checkpoints/mm.lib"
|
71 |
+
|
72 |
+
elif phase == "maskfin_to_nude":
|
73 |
+
self.checkpoints_dir = "checkpoints/mn.lib"
|
74 |
+
|
75 |
+
# process(cv_img, mode)
|
76 |
+
# return:
|
77 |
+
# watermark image
|
78 |
+
def process(cv_img):
|
79 |
+
|
80 |
+
#InMemory cv2 images:
|
81 |
+
dress = cv_img
|
82 |
+
correct = None
|
83 |
+
mask = None
|
84 |
+
maskref = None
|
85 |
+
maskfin = None
|
86 |
+
maskdet = None
|
87 |
+
nude = None
|
88 |
+
watermark = None
|
89 |
+
|
90 |
+
for index, phase in enumerate(phases):
|
91 |
+
|
92 |
+
print("Executing phase: " + phase)
|
93 |
+
|
94 |
+
#GAN phases:
|
95 |
+
if (phase == "correct_to_mask") or (phase == "maskref_to_maskdet") or (phase == "maskfin_to_nude"):
|
96 |
+
|
97 |
+
#Load global option
|
98 |
+
opt = Options()
|
99 |
+
|
100 |
+
#Load custom phase options:
|
101 |
+
opt.updateOptions(phase)
|
102 |
+
|
103 |
+
#Load Data
|
104 |
+
if (phase == "correct_to_mask"):
|
105 |
+
data_loader = DataLoader(opt, correct)
|
106 |
+
elif (phase == "maskref_to_maskdet"):
|
107 |
+
data_loader = DataLoader(opt, maskref)
|
108 |
+
elif (phase == "maskfin_to_nude"):
|
109 |
+
data_loader = DataLoader(opt, maskfin)
|
110 |
+
|
111 |
+
dataset = data_loader.load_data()
|
112 |
+
|
113 |
+
#Create Model
|
114 |
+
model = DeepModel()
|
115 |
+
model.initialize(opt)
|
116 |
+
|
117 |
+
#Run for every image:
|
118 |
+
for i, data in enumerate(dataset):
|
119 |
+
|
120 |
+
generated = model.inference(data['label'], data['inst'])
|
121 |
+
|
122 |
+
im = tensor2im(generated.data[0])
|
123 |
+
|
124 |
+
#Save Data
|
125 |
+
if (phase == "correct_to_mask"):
|
126 |
+
mask = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
127 |
+
|
128 |
+
elif (phase == "maskref_to_maskdet"):
|
129 |
+
maskdet = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
130 |
+
|
131 |
+
elif (phase == "maskfin_to_nude"):
|
132 |
+
nude = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
133 |
+
|
134 |
+
#Correcting:
|
135 |
+
elif (phase == 'dress_to_correct'):
|
136 |
+
correct = create_correct(dress)
|
137 |
+
|
138 |
+
#mask_ref phase (opencv)
|
139 |
+
elif (phase == "mask_to_maskref"):
|
140 |
+
maskref = create_maskref(mask, correct)
|
141 |
+
|
142 |
+
#mask_fin phase (opencv)
|
143 |
+
elif (phase == "maskdet_to_maskfin"):
|
144 |
+
maskfin = create_maskfin(maskref, maskdet)
|
145 |
+
|
146 |
+
#nude_to_watermark phase (opencv)
|
147 |
+
elif (phase == "nude_to_watermark"):
|
148 |
+
watermark = create_watermark(nude)
|
149 |
+
|
150 |
+
return watermark
|