Ahsen Khaliq commited on
Commit
34ad95a
1 Parent(s): 1d11e4b

Update style_transfer_folder.py

Browse files
Files changed (1) hide show
  1. style_transfer_folder.py +21 -27
style_transfer_folder.py CHANGED
@@ -58,35 +58,29 @@ if __name__ == '__main__':
58
  psp_encoder = PSPEncoder(args.psp_encoder_ckpt, output_size=args.size).to(device)
59
  psp_encoder.eval()
60
 
61
- input_img_paths = sorted(glob.glob(os.path.join(args.input_img_path, '*.*')))
62
- style_img_paths = sorted(glob.glob(os.path.join(args.style_img_path, '*.*')))[:]
63
-
64
  num = 0
65
 
66
- for input_img_path in input_img_paths:
67
- print(num)
68
- num += 1
69
-
70
- name_in = os.path.splitext(os.path.basename(input_img_path))[0]
71
- img_in = cv2.imread(input_img_path, 1)
72
- img_in_ten = cv2ten(img_in, device)
73
- img_in = cv2.resize(img_in, (args.size, args.size))
74
-
75
- for style_img_path in style_img_paths:
76
- name_style = os.path.splitext(os.path.basename(style_img_path))[0]
77
- img_style = cv2.imread(style_img_path, 1)
78
- img_style_ten = cv2ten(img_style, device)
79
- img_style = cv2.resize(img_style, (args.size, args.size))
80
-
81
- with torch.no_grad():
82
- sample_style = g_ema.get_z_embed(img_style_ten)
83
- sample_in = psp_encoder(img_in_ten)
84
- img_out_ten, _ = g_ema([sample_in], z_embed=sample_style, add_weight_index=args.add_weight_index,
85
- input_is_latent=True, return_latents=False, randomize_noise=False)
86
- img_out = ten2cv(img_out_ten)
87
- out = np.concatenate([img_in, img_style, img_out], axis=1)
88
- # out = img_out
89
- cv2.imwrite(f'{args.outdir}/{name_in}_v_{name_style}.jpg', out)
90
 
91
  print('Done!')
92
 
 
58
  psp_encoder = PSPEncoder(args.psp_encoder_ckpt, output_size=args.size).to(device)
59
  psp_encoder.eval()
60
 
 
 
 
61
  num = 0
62
 
63
+
64
+ print(num)
65
+ num += 1
66
+
67
+ img_in = cv2.imread(args.input_img_path)
68
+ img_in_ten = cv2ten(img_in, device)
69
+ img_in = cv2.resize(img_in, (args.size, args.size))
70
+
71
+
72
+ img_style = cv2.imread(args.style_img_path)
73
+ img_style_ten = cv2ten(img_style, device)
74
+ img_style = cv2.resize(img_style, (args.size, args.size))
75
+
76
+ with torch.no_grad():
77
+ sample_style = g_ema.get_z_embed(img_style_ten)
78
+ sample_in = psp_encoder(img_in_ten)
79
+ img_out_ten, _ = g_ema([sample_in], z_embed=sample_style, add_weight_index=args.add_weight_index,
80
+ input_is_latent=True, return_latents=False, randomize_noise=False)
81
+ img_out = ten2cv(img_out_ten)
82
+ out = np.concatenate([img_in, img_style, img_out], axis=1)
83
+ cv2.imwrite('out.jpg', out)
 
 
 
84
 
85
  print('Done!')
86