Upload 5 files
Browse files- test.py +89 -0
- test.sh +1 -0
- util/__init__.py +1 -0
- util/image_pool.py +31 -0
- util/util.py +94 -0
test.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from options.test_options import TestOptions
|
3 |
+
from data.data_loader_test import CreateDataLoader
|
4 |
+
from models.networks import ResUnetGenerator, load_checkpoint
|
5 |
+
from models.afwm import AFWM
|
6 |
+
import torch.nn as nn
|
7 |
+
import os
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import cv2
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
opt = TestOptions().parse()
|
14 |
+
|
15 |
+
start_epoch, epoch_iter = 1, 0
|
16 |
+
|
17 |
+
# list human-cloth pairs
|
18 |
+
with open('demo.txt', 'w') as file:
|
19 |
+
lines = [f'3.png {cloth_img_fn}\n' for cloth_img_fn in os.listdir('dataset/test_clothes')]
|
20 |
+
file.writelines(lines)
|
21 |
+
|
22 |
+
data_loader = CreateDataLoader(opt)
|
23 |
+
dataset = data_loader.load_data()
|
24 |
+
dataset_size = len(data_loader)
|
25 |
+
print(dataset_size)
|
26 |
+
|
27 |
+
warp_model = AFWM(opt, 3)
|
28 |
+
print(warp_model)
|
29 |
+
warp_model.eval()
|
30 |
+
warp_model.cuda()
|
31 |
+
load_checkpoint(warp_model, opt.warp_checkpoint)
|
32 |
+
|
33 |
+
gen_model = ResUnetGenerator(7, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d)
|
34 |
+
print(gen_model)
|
35 |
+
gen_model.eval()
|
36 |
+
gen_model.cuda()
|
37 |
+
load_checkpoint(gen_model, opt.gen_checkpoint)
|
38 |
+
|
39 |
+
total_steps = (start_epoch-1) * dataset_size + epoch_iter
|
40 |
+
step = 0
|
41 |
+
step_per_batch = dataset_size / opt.batchSize
|
42 |
+
|
43 |
+
for epoch in range(1,2):
|
44 |
+
|
45 |
+
for i, data in enumerate(dataset, start=epoch_iter):
|
46 |
+
iter_start_time = time.time()
|
47 |
+
total_steps += opt.batchSize
|
48 |
+
epoch_iter += opt.batchSize
|
49 |
+
|
50 |
+
real_image = data['image']
|
51 |
+
clothes = data['clothes']
|
52 |
+
##edge is extracted from the clothes image with the built-in function in python
|
53 |
+
edge = data['edge']
|
54 |
+
edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int))
|
55 |
+
clothes = clothes * edge
|
56 |
+
|
57 |
+
flow_out = warp_model(real_image.cuda(), clothes.cuda())
|
58 |
+
warped_cloth, last_flow, = flow_out
|
59 |
+
warped_edge = F.grid_sample(edge.cuda(), last_flow.permute(0, 2, 3, 1),
|
60 |
+
mode='bilinear', padding_mode='zeros')
|
61 |
+
|
62 |
+
gen_inputs = torch.cat([real_image.cuda(), warped_cloth, warped_edge], 1)
|
63 |
+
gen_outputs = gen_model(gen_inputs)
|
64 |
+
p_rendered, m_composite = torch.split(gen_outputs, [3, 1], 1)
|
65 |
+
p_rendered = torch.tanh(p_rendered)
|
66 |
+
m_composite = torch.sigmoid(m_composite)
|
67 |
+
m_composite = m_composite * warped_edge
|
68 |
+
p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)
|
69 |
+
|
70 |
+
path = 'results/' + opt.name
|
71 |
+
os.makedirs(path, exist_ok=True)
|
72 |
+
sub_path = path + '/PFAFN'
|
73 |
+
os.makedirs(sub_path,exist_ok=True)
|
74 |
+
|
75 |
+
if step % 1 == 0:
|
76 |
+
a = real_image.float().cuda()
|
77 |
+
b= clothes.cuda()
|
78 |
+
c = p_tryon
|
79 |
+
combine = torch.cat([a[0],b[0],c[0]], 2).squeeze()
|
80 |
+
cv_img=(combine.permute(1,2,0).detach().cpu().numpy()+1)/2
|
81 |
+
rgb=(cv_img*255).astype(np.uint8)
|
82 |
+
bgr=cv2.cvtColor(rgb,cv2.COLOR_RGB2BGR)
|
83 |
+
cv2.imwrite(sub_path+'/'+str(step)+'.jpg',bgr)
|
84 |
+
|
85 |
+
step += 1
|
86 |
+
if epoch_iter >= dataset_size:
|
87 |
+
break
|
88 |
+
|
89 |
+
|
test.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python test.py --name demo --resize_or_crop None --batchSize 1 --gpu_ids 0
|
util/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# util_init
|
util/image_pool.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
from torch.autograd import Variable
|
4 |
+
class ImagePool():
|
5 |
+
def __init__(self, pool_size):
|
6 |
+
self.pool_size = pool_size
|
7 |
+
if self.pool_size > 0:
|
8 |
+
self.num_imgs = 0
|
9 |
+
self.images = []
|
10 |
+
|
11 |
+
def query(self, images):
|
12 |
+
if self.pool_size == 0:
|
13 |
+
return images
|
14 |
+
return_images = []
|
15 |
+
for image in images.data:
|
16 |
+
image = torch.unsqueeze(image, 0)
|
17 |
+
if self.num_imgs < self.pool_size:
|
18 |
+
self.num_imgs = self.num_imgs + 1
|
19 |
+
self.images.append(image)
|
20 |
+
return_images.append(image)
|
21 |
+
else:
|
22 |
+
p = random.uniform(0, 1)
|
23 |
+
if p > 0.5:
|
24 |
+
random_id = random.randint(0, self.pool_size-1)
|
25 |
+
tmp = self.images[random_id].clone()
|
26 |
+
self.images[random_id] = image
|
27 |
+
return_images.append(tmp)
|
28 |
+
else:
|
29 |
+
return_images.append(image)
|
30 |
+
return_images = Variable(torch.cat(return_images, 0))
|
31 |
+
return return_images
|
util/util.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
|
8 |
+
def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
|
9 |
+
if isinstance(image_tensor, list):
|
10 |
+
image_numpy = []
|
11 |
+
for i in range(len(image_tensor)):
|
12 |
+
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
|
13 |
+
return image_numpy
|
14 |
+
image_numpy = image_tensor.cpu().float().numpy()
|
15 |
+
|
16 |
+
image_numpy = (image_numpy + 1) / 2.0
|
17 |
+
image_numpy = np.clip(image_numpy, 0, 1)
|
18 |
+
if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
|
19 |
+
image_numpy = image_numpy[:,:,0]
|
20 |
+
|
21 |
+
return image_numpy
|
22 |
+
|
23 |
+
def tensor2label(label_tensor, n_label, imtype=np.uint8):
|
24 |
+
if n_label == 0:
|
25 |
+
return tensor2im(label_tensor, imtype)
|
26 |
+
label_tensor = label_tensor.cpu().float()
|
27 |
+
if label_tensor.size()[0] > 1:
|
28 |
+
label_tensor = label_tensor.max(0, keepdim=True)[1]
|
29 |
+
label_tensor = Colorize(n_label)(label_tensor)
|
30 |
+
label_numpy = label_tensor.numpy()
|
31 |
+
label_numpy = label_numpy / 255.0
|
32 |
+
|
33 |
+
return label_numpy
|
34 |
+
|
35 |
+
def save_image(image_numpy, image_path):
|
36 |
+
image_pil = Image.fromarray(image_numpy)
|
37 |
+
image_pil.save(image_path)
|
38 |
+
|
39 |
+
def mkdirs(paths):
|
40 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
41 |
+
for path in paths:
|
42 |
+
mkdir(path)
|
43 |
+
else:
|
44 |
+
mkdir(paths)
|
45 |
+
|
46 |
+
def mkdir(path):
|
47 |
+
if not os.path.exists(path):
|
48 |
+
os.makedirs(path)
|
49 |
+
|
50 |
+
|
51 |
+
def uint82bin(n, count=8):
|
52 |
+
"""returns the binary of integer n, count refers to amount of bits"""
|
53 |
+
return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
|
54 |
+
|
55 |
+
def labelcolormap(N):
|
56 |
+
if N == 35: # cityscape
|
57 |
+
cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
|
58 |
+
(128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
|
59 |
+
(180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
|
60 |
+
(107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
|
61 |
+
( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
|
62 |
+
dtype=np.uint8)
|
63 |
+
else:
|
64 |
+
cmap = np.zeros((N, 3), dtype=np.uint8)
|
65 |
+
for i in range(N):
|
66 |
+
r, g, b = 0, 0, 0
|
67 |
+
id = i
|
68 |
+
for j in range(7):
|
69 |
+
str_id = uint82bin(id)
|
70 |
+
r = r ^ (np.uint8(str_id[-1]) << (7-j))
|
71 |
+
g = g ^ (np.uint8(str_id[-2]) << (7-j))
|
72 |
+
b = b ^ (np.uint8(str_id[-3]) << (7-j))
|
73 |
+
id = id >> 3
|
74 |
+
cmap[i, 0] = r
|
75 |
+
cmap[i, 1] = g
|
76 |
+
cmap[i, 2] = b
|
77 |
+
return cmap
|
78 |
+
|
79 |
+
class Colorize(object):
|
80 |
+
def __init__(self, n=35):
|
81 |
+
self.cmap = labelcolormap(n)
|
82 |
+
self.cmap = torch.from_numpy(self.cmap[:n])
|
83 |
+
|
84 |
+
def __call__(self, gray_image):
|
85 |
+
size = gray_image.size()
|
86 |
+
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
|
87 |
+
|
88 |
+
for label in range(0, len(self.cmap)):
|
89 |
+
mask = (label == gray_image[0]).cpu()
|
90 |
+
color_image[0][mask] = self.cmap[label][0]
|
91 |
+
color_image[1][mask] = self.cmap[label][1]
|
92 |
+
color_image[2][mask] = self.cmap[label][2]
|
93 |
+
|
94 |
+
return color_image
|