Ravi21 commited on
Commit
6e858ba
1 Parent(s): 70755df

Upload 4 files

Browse files
Files changed (4) hide show
  1. demo.txt +6 -0
  2. extract_clothes_edges.py +29 -0
  3. inference.py +75 -0
  4. inference_cpu.py +78 -0
demo.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ input.png 006026_1.jpg
2
+ input.png 017575_1.jpg
3
+ input.png 014396_1.jpg
4
+ input.png 003434_1.jpg
5
+ input.png 019119_1.jpg
6
+ input.png 010567_1.jpg
extract_clothes_edges.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from tqdm.auto import tqdm
6
+
7
+ clothes_dir = 'dataset/test_clothes'
8
+ clothes_edges_dir = 'dataset/test_edge'
9
+
10
+
11
+ for img_fn in tqdm(os.listdir(clothes_dir)):
12
+ cloth_img_fp = os.path.join(clothes_dir, img_fn)
13
+ img = cv2.imread(cloth_img_fp)
14
+ OLD_IMG = img.copy()
15
+ mask = np.zeros(img.shape[:2], np.uint8)
16
+ SIZE = (1, 65)
17
+ bgdModle = np.zeros(SIZE, np.float64)
18
+
19
+ fgdModle = np.zeros(SIZE, np.float64)
20
+ rect = (1, 1, img.shape[1], img.shape[0])
21
+ cv2.grabCut(img, mask, rect, bgdModle, fgdModle, 10, cv2.GC_INIT_WITH_RECT)
22
+
23
+ mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')
24
+ img *= mask2[:, :, np.newaxis]
25
+
26
+ mask2 *= 255
27
+
28
+ cloth_edges_img_fp = os.path.join(clothes_edges_dir, img_fn)
29
+ cv2.imwrite(cloth_edges_img_fp, mask2)
inference.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from options.test_options import TestOptions
3
+ from data.data_loader_test import CreateDataLoader
4
+ from models.networks import ResUnetGenerator, load_checkpoint
5
+ from models.afwm import AFWM
6
+ import torch.nn as nn
7
+ import os
8
+ import numpy as np
9
+ import torch
10
+ import cv2
11
+ import torch.nn.functional as F
12
+ from tqdm.auto import tqdm
13
+
14
+ opt = TestOptions().parse()
15
+
16
+ # list human-cloth pairs
17
+ with open('demo.txt', 'w') as file:
18
+ lines = [f'input.png {cloth_img_fn}\n' for cloth_img_fn in os.listdir('dataset/test_clothes')]
19
+ file.writelines(lines)
20
+
21
+ data_loader = CreateDataLoader(opt)
22
+ dataset = data_loader.load_data()
23
+ dataset_size = len(data_loader)
24
+ print('[INFO] Data Loaded')
25
+
26
+ warp_model = AFWM(opt, 3)
27
+ warp_model.eval()
28
+ warp_model.cuda()
29
+ load_checkpoint(warp_model, opt.warp_checkpoint)
30
+ print('[INFO] Warp Model Loaded')
31
+
32
+ gen_model = ResUnetGenerator(7, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d)
33
+ gen_model.eval()
34
+ gen_model.cuda()
35
+ load_checkpoint(gen_model, opt.gen_checkpoint)
36
+ print('[INFO] Gen Model Loaded')
37
+
38
+ def get_result_images():
39
+
40
+ result_images = []
41
+ for i, data in tqdm(enumerate(dataset)):
42
+
43
+ real_image = data['image']
44
+ clothes = data['clothes']
45
+ ##edge is extracted from the clothes image with the built-in function in python
46
+ edge = data['edge']
47
+ edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int))
48
+ clothes = clothes * edge
49
+
50
+ flow_out = warp_model(real_image.cuda(), clothes.cuda())
51
+ warped_cloth, last_flow, = flow_out
52
+ warped_edge = F.grid_sample(edge.cuda(), last_flow.permute(0, 2, 3, 1),
53
+ mode='bilinear', padding_mode='zeros')
54
+
55
+
56
+ gen_inputs = torch.cat([real_image.cuda(), warped_cloth, warped_edge], 1)
57
+ gen_outputs = gen_model(gen_inputs)
58
+ p_rendered, m_composite = torch.split(gen_outputs, [3, 1], 1)
59
+ p_rendered = torch.tanh(p_rendered)
60
+ m_composite = torch.sigmoid(m_composite)
61
+ m_composite = m_composite * warped_edge
62
+ p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)
63
+
64
+ a = real_image.float().cuda()
65
+ b= clothes.cuda()
66
+ c = p_tryon
67
+
68
+ combine = torch.cat([b[0], c[0]], 2).squeeze()
69
+ cv_img = (combine.permute(1, 2, 0).detach().cpu().numpy() + 1) / 2
70
+ rgb = (cv_img * 255).astype(np.uint8)
71
+
72
+ result_images.append(rgb)
73
+
74
+
75
+ return result_images
inference_cpu.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from options.test_options import TestOptions
3
+ from data.data_loader_test import CreateDataLoader
4
+ from models.networks import ResUnetGenerator, load_checkpoint
5
+ from models.afwm import AFWM
6
+ import torch.nn as nn
7
+ import os
8
+ import numpy as np
9
+ import torch
10
+ import cv2
11
+ import torch.nn.functional as F
12
+ from tqdm.auto import tqdm
13
+
14
+ opt = TestOptions().parse()
15
+
16
+ # list human-cloth pairs
17
+ with open('demo.txt', 'w') as file:
18
+ lines = [f'input.png {cloth_img_fn}\n' for cloth_img_fn in os.listdir('dataset/test_clothes')]
19
+ file.writelines(lines)
20
+
21
+ data_loader = CreateDataLoader(opt)
22
+ dataset = data_loader.load_data()
23
+ dataset_size = len(data_loader)
24
+ print('[INFO] Data Loaded')
25
+
26
+ warp_model = AFWM(opt, 3)
27
+ warp_model.eval()
28
+
29
+ load_checkpoint(warp_model, opt.warp_checkpoint)
30
+ print('[INFO] Warp Model Loaded')
31
+
32
+ gen_model = ResUnetGenerator(7, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d)
33
+ gen_model.eval()
34
+
35
+ load_checkpoint(gen_model, opt.gen_checkpoint)
36
+ print('[INFO] Gen Model Loaded')
37
+
38
+ def get_result_images():
39
+
40
+ result_images = []
41
+ for i, data in tqdm(enumerate(dataset)):
42
+
43
+ real_image = data['image']
44
+ clothes = data['clothes']
45
+ ##edge is extracted from the clothes image with the built-in function in python
46
+ edge = data['edge']
47
+ edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int))
48
+ clothes = clothes * edge
49
+ print(clothes.device, edge.device)
50
+
51
+ flow_out = warp_model(real_image, clothes)
52
+ warped_cloth, last_flow, = flow_out
53
+ warped_edge = F.grid_sample(edge.cuda(), last_flow.permute(0, 2, 3, 1),
54
+ mode='bilinear', padding_mode='zeros')
55
+
56
+
57
+ gen_inputs = torch.cat([real_image.cuda(), warped_cloth, warped_edge], 1)
58
+ gen_outputs = gen_model(gen_inputs)
59
+ p_rendered, m_composite = torch.split(gen_outputs, [3, 1], 1)
60
+ p_rendered = torch.tanh(p_rendered)
61
+ m_composite = torch.sigmoid(m_composite)
62
+ m_composite = m_composite * warped_edge
63
+ p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)
64
+
65
+ a = real_image.float().cuda()
66
+ b= clothes.cuda()
67
+ c = p_tryon
68
+
69
+ combine = torch.cat([b[0], c[0]], 2).squeeze()
70
+ cv_img = (combine.permute(1, 2, 0).detach().cpu().numpy() + 1) / 2
71
+ rgb = (cv_img * 255).astype(np.uint8)
72
+
73
+ result_images.append(rgb)
74
+
75
+
76
+ return result_images
77
+
78
+ get_result_images()