Egrt commited on
Commit
253a98e
1 Parent(s): e50762b

更强的模型

Browse files
cyclegan.py CHANGED
@@ -19,6 +19,10 @@ class CYCLEGAN(object):
19
  #-----------------------------------------------#
20
  "input_shape" : [112, 112],
21
  #-------------------------------#
 
 
 
 
22
  # 是否使用Cuda
23
  # 没有GPU可以设置成False
24
  #-------------------------------#
@@ -64,9 +68,14 @@ class CYCLEGAN(object):
64
  #---------------------------------------------------------#
65
  image = cvtColor(image)
66
  #---------------------------------------------------------#
 
 
 
 
 
67
  # 添加上batch_size维度
68
  #---------------------------------------------------------#
69
- image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image, dtype='float32')), (2, 0, 1)), 0)
70
 
71
  with torch.no_grad():
72
  images = torch.from_numpy(image_data)
@@ -80,10 +89,17 @@ class CYCLEGAN(object):
80
  #---------------------------------------------------#
81
  # 转为numpy
82
  #---------------------------------------------------#
83
- pr = pr.permute(1, 2, 0).cpu().numpy()
 
 
 
 
 
 
 
 
84
 
85
  image = postprocess_output(pr)
86
- image = np.clip(image, 0, 255)
87
  image = Image.fromarray(np.uint8(image))
88
 
89
  return image
 
19
  #-----------------------------------------------#
20
  "input_shape" : [112, 112],
21
  #-------------------------------#
22
+ # 是否进行不失真的resize
23
+ #-------------------------------#
24
+ "letterbox_image" : True,
25
+ #-------------------------------#
26
  # 是否使用Cuda
27
  # 没有GPU可以设置成False
28
  #-------------------------------#
 
68
  #---------------------------------------------------------#
69
  image = cvtColor(image)
70
  #---------------------------------------------------------#
71
+ # 给图像增加灰条,实现不失真的resize
72
+ # 也可以直接resize进行识别
73
+ #---------------------------------------------------------#
74
+ image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
75
+ #---------------------------------------------------------#
76
  # 添加上batch_size维度
77
  #---------------------------------------------------------#
78
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
79
 
80
  with torch.no_grad():
81
  images = torch.from_numpy(image_data)
 
89
  #---------------------------------------------------#
90
  # 转为numpy
91
  #---------------------------------------------------#
92
+ pr = pr.permute(1, 2, 0).cpu().numpy()
93
+
94
+ #--------------------------------------#
95
+ # 将灰条部分截取掉
96
+ #--------------------------------------#
97
+ if nw is not None:
98
+ pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
99
+ int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
100
+
101
 
102
  image = postprocess_output(pr)
 
103
  image = Image.fromarray(np.uint8(image))
104
 
105
  return image
model_data/G_model_B2A_last_epoch_weights.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:03aa46c6ebd8da9196749e02e214687ea3dd7143976d030f92b91b0cb2496583
3
  size 11888773
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1815cd8f77471a8712b9a80b20da4cd7afe7aad2b32ad48cd205d1c370a65dc2
3
  size 11888773