Ravi21 commited on
Commit
b2fcba7
1 Parent(s): 46a2636

Upload 5 files

Browse files
Files changed (5) hide show
  1. test.py +89 -0
  2. test.sh +1 -0
  3. util/__init__.py +1 -0
  4. util/image_pool.py +31 -0
  5. util/util.py +94 -0
test.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from options.test_options import TestOptions
3
+ from data.data_loader_test import CreateDataLoader
4
+ from models.networks import ResUnetGenerator, load_checkpoint
5
+ from models.afwm import AFWM
6
+ import torch.nn as nn
7
+ import os
8
+ import numpy as np
9
+ import torch
10
+ import cv2
11
+ import torch.nn.functional as F
12
+
13
+ opt = TestOptions().parse()
14
+
15
+ start_epoch, epoch_iter = 1, 0
16
+
17
+ # list human-cloth pairs
18
+ with open('demo.txt', 'w') as file:
19
+ lines = [f'3.png {cloth_img_fn}\n' for cloth_img_fn in os.listdir('dataset/test_clothes')]
20
+ file.writelines(lines)
21
+
22
+ data_loader = CreateDataLoader(opt)
23
+ dataset = data_loader.load_data()
24
+ dataset_size = len(data_loader)
25
+ print(dataset_size)
26
+
27
+ warp_model = AFWM(opt, 3)
28
+ print(warp_model)
29
+ warp_model.eval()
30
+ warp_model.cuda()
31
+ load_checkpoint(warp_model, opt.warp_checkpoint)
32
+
33
+ gen_model = ResUnetGenerator(7, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d)
34
+ print(gen_model)
35
+ gen_model.eval()
36
+ gen_model.cuda()
37
+ load_checkpoint(gen_model, opt.gen_checkpoint)
38
+
39
+ total_steps = (start_epoch-1) * dataset_size + epoch_iter
40
+ step = 0
41
+ step_per_batch = dataset_size / opt.batchSize
42
+
43
+ for epoch in range(1,2):
44
+
45
+ for i, data in enumerate(dataset, start=epoch_iter):
46
+ iter_start_time = time.time()
47
+ total_steps += opt.batchSize
48
+ epoch_iter += opt.batchSize
49
+
50
+ real_image = data['image']
51
+ clothes = data['clothes']
52
+ ##edge is extracted from the clothes image with the built-in function in python
53
+ edge = data['edge']
54
+ edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int))
55
+ clothes = clothes * edge
56
+
57
+ flow_out = warp_model(real_image.cuda(), clothes.cuda())
58
+ warped_cloth, last_flow, = flow_out
59
+ warped_edge = F.grid_sample(edge.cuda(), last_flow.permute(0, 2, 3, 1),
60
+ mode='bilinear', padding_mode='zeros')
61
+
62
+ gen_inputs = torch.cat([real_image.cuda(), warped_cloth, warped_edge], 1)
63
+ gen_outputs = gen_model(gen_inputs)
64
+ p_rendered, m_composite = torch.split(gen_outputs, [3, 1], 1)
65
+ p_rendered = torch.tanh(p_rendered)
66
+ m_composite = torch.sigmoid(m_composite)
67
+ m_composite = m_composite * warped_edge
68
+ p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)
69
+
70
+ path = 'results/' + opt.name
71
+ os.makedirs(path, exist_ok=True)
72
+ sub_path = path + '/PFAFN'
73
+ os.makedirs(sub_path,exist_ok=True)
74
+
75
+ if step % 1 == 0:
76
+ a = real_image.float().cuda()
77
+ b= clothes.cuda()
78
+ c = p_tryon
79
+ combine = torch.cat([a[0],b[0],c[0]], 2).squeeze()
80
+ cv_img=(combine.permute(1,2,0).detach().cpu().numpy()+1)/2
81
+ rgb=(cv_img*255).astype(np.uint8)
82
+ bgr=cv2.cvtColor(rgb,cv2.COLOR_RGB2BGR)
83
+ cv2.imwrite(sub_path+'/'+str(step)+'.jpg',bgr)
84
+
85
+ step += 1
86
+ if epoch_iter >= dataset_size:
87
+ break
88
+
89
+
test.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python test.py --name demo --resize_or_crop None --batchSize 1 --gpu_ids 0
util/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # util_init
util/image_pool.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from torch.autograd import Variable
4
+ class ImagePool():
5
+ def __init__(self, pool_size):
6
+ self.pool_size = pool_size
7
+ if self.pool_size > 0:
8
+ self.num_imgs = 0
9
+ self.images = []
10
+
11
+ def query(self, images):
12
+ if self.pool_size == 0:
13
+ return images
14
+ return_images = []
15
+ for image in images.data:
16
+ image = torch.unsqueeze(image, 0)
17
+ if self.num_imgs < self.pool_size:
18
+ self.num_imgs = self.num_imgs + 1
19
+ self.images.append(image)
20
+ return_images.append(image)
21
+ else:
22
+ p = random.uniform(0, 1)
23
+ if p > 0.5:
24
+ random_id = random.randint(0, self.pool_size-1)
25
+ tmp = self.images[random_id].clone()
26
+ self.images[random_id] = image
27
+ return_images.append(tmp)
28
+ else:
29
+ return_images.append(image)
30
+ return_images = Variable(torch.cat(return_images, 0))
31
+ return return_images
util/util.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os
7
+
8
+ def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
9
+ if isinstance(image_tensor, list):
10
+ image_numpy = []
11
+ for i in range(len(image_tensor)):
12
+ image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
13
+ return image_numpy
14
+ image_numpy = image_tensor.cpu().float().numpy()
15
+
16
+ image_numpy = (image_numpy + 1) / 2.0
17
+ image_numpy = np.clip(image_numpy, 0, 1)
18
+ if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
19
+ image_numpy = image_numpy[:,:,0]
20
+
21
+ return image_numpy
22
+
23
+ def tensor2label(label_tensor, n_label, imtype=np.uint8):
24
+ if n_label == 0:
25
+ return tensor2im(label_tensor, imtype)
26
+ label_tensor = label_tensor.cpu().float()
27
+ if label_tensor.size()[0] > 1:
28
+ label_tensor = label_tensor.max(0, keepdim=True)[1]
29
+ label_tensor = Colorize(n_label)(label_tensor)
30
+ label_numpy = label_tensor.numpy()
31
+ label_numpy = label_numpy / 255.0
32
+
33
+ return label_numpy
34
+
35
+ def save_image(image_numpy, image_path):
36
+ image_pil = Image.fromarray(image_numpy)
37
+ image_pil.save(image_path)
38
+
39
+ def mkdirs(paths):
40
+ if isinstance(paths, list) and not isinstance(paths, str):
41
+ for path in paths:
42
+ mkdir(path)
43
+ else:
44
+ mkdir(paths)
45
+
46
+ def mkdir(path):
47
+ if not os.path.exists(path):
48
+ os.makedirs(path)
49
+
50
+
51
+ def uint82bin(n, count=8):
52
+ """returns the binary of integer n, count refers to amount of bits"""
53
+ return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
54
+
55
+ def labelcolormap(N):
56
+ if N == 35: # cityscape
57
+ cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
58
+ (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
59
+ (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
60
+ (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
61
+ ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
62
+ dtype=np.uint8)
63
+ else:
64
+ cmap = np.zeros((N, 3), dtype=np.uint8)
65
+ for i in range(N):
66
+ r, g, b = 0, 0, 0
67
+ id = i
68
+ for j in range(7):
69
+ str_id = uint82bin(id)
70
+ r = r ^ (np.uint8(str_id[-1]) << (7-j))
71
+ g = g ^ (np.uint8(str_id[-2]) << (7-j))
72
+ b = b ^ (np.uint8(str_id[-3]) << (7-j))
73
+ id = id >> 3
74
+ cmap[i, 0] = r
75
+ cmap[i, 1] = g
76
+ cmap[i, 2] = b
77
+ return cmap
78
+
79
+ class Colorize(object):
80
+ def __init__(self, n=35):
81
+ self.cmap = labelcolormap(n)
82
+ self.cmap = torch.from_numpy(self.cmap[:n])
83
+
84
+ def __call__(self, gray_image):
85
+ size = gray_image.size()
86
+ color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
87
+
88
+ for label in range(0, len(self.cmap)):
89
+ mask = (label == gray_image[0]).cpu()
90
+ color_image[0][mask] = self.cmap[label][0]
91
+ color_image[1][mask] = self.cmap[label][1]
92
+ color_image[2][mask] = self.cmap[label][2]
93
+
94
+ return color_image