Egrt commited on
Commit
7c02a0d
1 Parent(s): 72d244a
Files changed (1) hide show
  1. cyclegan.py +2 -33
cyclegan.py CHANGED
@@ -1,4 +1,3 @@
1
- import cv2
2
  import numpy as np
3
  import torch
4
  from PIL import Image
@@ -15,19 +14,11 @@ class CYCLEGAN(object):
15
  # model_path指向logs文件夹下的权值文件
16
  #-----------------------------------------------#
17
  "model_path" : 'model_data/G_model_B2A_last_epoch_weights.pth',
18
- #-----------------------------------------------#
19
- # 输入图像大小的设置
20
- #-----------------------------------------------#
21
- "input_shape" : [112, 112],
22
- #-------------------------------#
23
- # 是否进行不失真的resize
24
- #-------------------------------#
25
- "letterbox_image" : True,
26
  #-------------------------------#
27
  # 是否使用Cuda
28
  # 没有GPU可以设置成False
29
  #-------------------------------#
30
- "cuda" : False,
31
  }
32
 
33
  #---------------------------------------------------#
@@ -68,16 +59,6 @@ class CYCLEGAN(object):
68
  # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
69
  #---------------------------------------------------------#
70
  image = cvtColor(image)
71
- #---------------------------------------------------#
72
- # 获得高宽
73
- #---------------------------------------------------#
74
- orininal_h = np.array(image).shape[0]
75
- orininal_w = np.array(image).shape[1]
76
- #---------------------------------------------------------#
77
- # 给图像增加灰条,实现不失真的resize
78
- # 也可以直接resize进行识别
79
- #---------------------------------------------------------#
80
- image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
81
  #---------------------------------------------------------#
82
  # 添加上batch_size维度
83
  #---------------------------------------------------------#
@@ -95,19 +76,7 @@ class CYCLEGAN(object):
95
  #---------------------------------------------------#
96
  # 转为numpy
97
  #---------------------------------------------------#
98
- pr = pr.permute(1, 2, 0).cpu().numpy()
99
-
100
- #--------------------------------------#
101
- # 将灰条部分截取掉
102
- #--------------------------------------#
103
- if nw is not None:
104
- pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
105
- int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
106
-
107
- #---------------------------------------------------#
108
- # 进行图片的resize
109
- #---------------------------------------------------#
110
- pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
111
 
112
  image = postprocess_output(pr)
113
  image = np.clip(image, 0, 255)
 
 
1
  import numpy as np
2
  import torch
3
  from PIL import Image
 
14
  # model_path指向logs文件夹下的权值文件
15
  #-----------------------------------------------#
16
  "model_path" : 'model_data/G_model_B2A_last_epoch_weights.pth',
 
 
 
 
 
 
 
 
17
  #-------------------------------#
18
  # 是否使用Cuda
19
  # 没有GPU可以设置成False
20
  #-------------------------------#
21
+ "cuda" : True,
22
  }
23
 
24
  #---------------------------------------------------#
 
59
  # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
60
  #---------------------------------------------------------#
61
  image = cvtColor(image)
 
 
 
 
 
 
 
 
 
 
62
  #---------------------------------------------------------#
63
  # 添加上batch_size维度
64
  #---------------------------------------------------------#
 
76
  #---------------------------------------------------#
77
  # 转为numpy
78
  #---------------------------------------------------#
79
+ pr = pr.permute(1, 2, 0).cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  image = postprocess_output(pr)
82
  image = np.clip(image, 0, 255)