PKUWilliamYang commited on
Commit
5f3f3d5
1 Parent(s): 2432b66

Update vtoonify_model.py

Browse files
Files changed (1) hide show
  1. vtoonify_model.py +21 -8
vtoonify_model.py CHANGED
@@ -91,6 +91,8 @@ class Model():
91
  self.color_transfer = True
92
  else:
93
  self.color_transfer = False
 
 
94
  model_path, ind = self.style_types[style_type]
95
  style_path = os.path.join('models',os.path.dirname(model_path),'exstyle_code.npy')
96
  self.vtoonify.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO,'models/'+model_path),
@@ -102,7 +104,7 @@ class Model():
102
  return exstyle, 'Model of %s loaded.'%(style_type)
103
 
104
  def detect_and_align(self, frame, top, bottom, left, right, return_para=False):
105
- message = 'Error: no face detected!'
106
  paras = get_video_crop_parameter(frame, self.landmarkpredictor, [left, right, top, bottom])
107
  instyle = torch.zeros(1,18,512).to(self.device)
108
  h, w, scale = 0, 0, 0
@@ -136,6 +138,8 @@ class Model():
136
  ) -> tuple[np.ndarray, torch.Tensor, str]:
137
 
138
  frame = cv2.imread(image)
 
 
139
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
140
  return self.detect_and_align(frame, top, bottom, left, right)
141
 
@@ -143,13 +147,15 @@ class Model():
143
  ) -> tuple[np.ndarray, torch.Tensor, str]:
144
 
145
  video_cap = cv2.VideoCapture(video)
 
 
146
  success, frame = video_cap.read()
147
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
148
  video_cap.release()
149
  return self.detect_and_align(frame, top, bottom, left, right)
150
 
151
  def detect_and_align_full_video(self, video: str, top: int, bottom: int, left: int, right: int) -> tuple[str, torch.Tensor, str]:
152
- message = 'Error: no face detected!'
153
  instyle = torch.zeros(1,18,512).to(self.device)
154
  video_cap = cv2.VideoCapture(video)
155
  num = min(300, int(video_cap.get(7)))
@@ -178,7 +184,11 @@ class Model():
178
 
179
  return 'input.mp4', instyle, 'Successfully rescale the video to (%d, %d)'%(bottom-top, right-left)
180
 
181
- def image_toonify(self, aligned_face: np.ndarray, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float) -> np.ndarray:
 
 
 
 
182
  if exstyle is None:
183
  exstyle = self.exstyle
184
  with torch.no_grad():
@@ -195,12 +205,15 @@ class Model():
195
  y_tilde = self.vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s = style_degree)
196
  y_tilde = torch.clamp(y_tilde, -1, 1)
197
 
198
- return ((y_tilde[0].cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)
199
 
200
- def video_tooniy(self, aligned_video: str, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float) -> str:
201
- if exstyle is None:
202
- exstyle = self.exstyle
203
  video_cap = cv2.VideoCapture(aligned_video)
 
 
 
 
 
204
  num = min(300, int(video_cap.get(7)))
205
  if self.device == 'cpu':
206
  num = min(100, num)
@@ -243,6 +256,6 @@ class Model():
243
 
244
  videoWriter.release()
245
  video_cap.release()
246
- return 'output.mp4'
247
 
248
 
 
91
  self.color_transfer = True
92
  else:
93
  self.color_transfer = False
94
+ if style_type not in self.style_types.keys():
95
+ return torch.zeros(1,18,512).to(self.device), 'Oops, wrong Style Type. Please select a valid model.'
96
  model_path, ind = self.style_types[style_type]
97
  style_path = os.path.join('models',os.path.dirname(model_path),'exstyle_code.npy')
98
  self.vtoonify.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO,'models/'+model_path),
 
104
  return exstyle, 'Model of %s loaded.'%(style_type)
105
 
106
  def detect_and_align(self, frame, top, bottom, left, right, return_para=False):
107
+ message = 'Error: no face detected! Please retry or change the photo.'
108
  paras = get_video_crop_parameter(frame, self.landmarkpredictor, [left, right, top, bottom])
109
  instyle = torch.zeros(1,18,512).to(self.device)
110
  h, w, scale = 0, 0, 0
 
138
  ) -> tuple[np.ndarray, torch.Tensor, str]:
139
 
140
  frame = cv2.imread(image)
141
+ if frame is None:
142
+ return np.zeros((256,256,3), np.uint8), torch.zeros(1,18,512).to(self.device), 'Error: fail to load the image.'
143
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
144
  return self.detect_and_align(frame, top, bottom, left, right)
145
 
 
147
  ) -> tuple[np.ndarray, torch.Tensor, str]:
148
 
149
  video_cap = cv2.VideoCapture(video)
150
+ if video_cap.get(7) == 0:
151
+ return np.zeros((256,256,3), np.uint8), torch.zeros(1,18,512).to(self.device), 'Error: fail to load the video.'
152
  success, frame = video_cap.read()
153
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
154
  video_cap.release()
155
  return self.detect_and_align(frame, top, bottom, left, right)
156
 
157
  def detect_and_align_full_video(self, video: str, top: int, bottom: int, left: int, right: int) -> tuple[str, torch.Tensor, str]:
158
+ message = 'Error: no face detected! Please retry or change the video.'
159
  instyle = torch.zeros(1,18,512).to(self.device)
160
  video_cap = cv2.VideoCapture(video)
161
  num = min(300, int(video_cap.get(7)))
 
184
 
185
  return 'input.mp4', instyle, 'Successfully rescale the video to (%d, %d)'%(bottom-top, right-left)
186
 
187
+ def image_toonify(self, aligned_face: np.ndarray, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float) -> tuple[np.ndarray, str]:
188
+ if instyle is None or aligned_face is None:
189
+ return np.zeros((256,256,3), np.uint8), 'Opps, something wrong with the input. Please go to Step 2 and Rescale Image/First Frame again.'
190
+ if exstyle is None:
191
+ return np.zeros((256,256,3), np.uint8), 'Opps, something wrong with the style type. Please go to Step 1 and load model again.'
192
  if exstyle is None:
193
  exstyle = self.exstyle
194
  with torch.no_grad():
 
205
  y_tilde = self.vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s = style_degree)
206
  y_tilde = torch.clamp(y_tilde, -1, 1)
207
 
208
+ return ((y_tilde[0].cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8), 'Successfully toonify the image'
209
 
210
+ def video_tooniy(self, aligned_video: str, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float) -> tuple[str, str]:
 
 
211
  video_cap = cv2.VideoCapture(aligned_video)
212
+ if instyle is None or aligned_face is None or video_cap.get(7) == 0:
213
+ video_cap.release()
214
+ return 'output.mp4', 'Opps, something wrong with the input. Please go to Step 2 and Rescale Video again.'
215
+ if exstyle is None:
216
+ return 'output.mp4', 'Opps, something wrong with the style type. Please go to Step 1 and load model again.'
217
  num = min(300, int(video_cap.get(7)))
218
  if self.device == 'cpu':
219
  num = min(100, num)
 
256
 
257
  videoWriter.release()
258
  video_cap.release()
259
+ return 'output.mp4', 'Successfully toonify video of %d frames'%(num)
260
 
261