PKUWilliamYang commited on
Commit
82faf11
1 Parent(s): 5427b01

Update vtoonify_model.py

Browse files
Files changed (1) hide show
  1. vtoonify_model.py +12 -6
vtoonify_model.py CHANGED
@@ -92,7 +92,7 @@ class Model():
92
  else:
93
  self.color_transfer = False
94
  if style_type not in self.style_types.keys():
95
- return torch.zeros(1,18,512).to(self.device), 'Oops, wrong Style Type. Please select a valid model.'
96
  model_path, ind = self.style_types[style_type]
97
  style_path = os.path.join('models',os.path.dirname(model_path),'exstyle_code.npy')
98
  self.vtoonify.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO,'models/'+model_path),
@@ -106,7 +106,7 @@ class Model():
106
  def detect_and_align(self, frame, top, bottom, left, right, return_para=False):
107
  message = 'Error: no face detected! Please retry or change the photo.'
108
  paras = get_video_crop_parameter(frame, self.landmarkpredictor, [left, right, top, bottom])
109
- instyle = torch.zeros(1,18,512).to(self.device)
110
  h, w, scale = 0, 0, 0
111
  if paras is not None:
112
  h,w,top,bottom,left,right,scale = paras
@@ -136,16 +136,18 @@ class Model():
136
  #@torch.inference_mode()
137
  def detect_and_align_image(self, image: str, top: int, bottom: int, left: int, right: int
138
  ) -> tuple[np.ndarray, torch.Tensor, str]:
139
-
 
140
  frame = cv2.imread(image)
141
  if frame is None:
142
- return np.zeros((256,256,3), np.uint8), torch.zeros(1,18,512).to(self.device), 'Error: fail to load the image.'
143
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
144
  return self.detect_and_align(frame, top, bottom, left, right)
145
 
146
  def detect_and_align_video(self, video: str, top: int, bottom: int, left: int, right: int
147
  ) -> tuple[np.ndarray, torch.Tensor, str]:
148
-
 
149
  video_cap = cv2.VideoCapture(video)
150
  if video_cap.get(7) == 0:
151
  video_cap.release()
@@ -157,7 +159,9 @@ class Model():
157
 
158
  def detect_and_align_full_video(self, video: str, top: int, bottom: int, left: int, right: int) -> tuple[str, torch.Tensor, str]:
159
  message = 'Error: no face detected! Please retry or change the video.'
160
- instyle = torch.zeros(1,18,512).to(self.device)
 
 
161
  video_cap = cv2.VideoCapture(video)
162
  if video_cap.get(7) == 0:
163
  video_cap.release()
@@ -212,6 +216,8 @@ class Model():
212
  return ((y_tilde[0].cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8), 'Successfully toonify the image'
213
 
214
  def video_tooniy(self, aligned_video: str, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float) -> tuple[str, str]:
 
 
215
  video_cap = cv2.VideoCapture(aligned_video)
216
  if instyle is None or aligned_face is None or video_cap.get(7) == 0:
217
  video_cap.release()
 
92
  else:
93
  self.color_transfer = False
94
  if style_type not in self.style_types.keys():
95
+ return None, 'Oops, wrong Style Type. Please select a valid model.'
96
  model_path, ind = self.style_types[style_type]
97
  style_path = os.path.join('models',os.path.dirname(model_path),'exstyle_code.npy')
98
  self.vtoonify.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO,'models/'+model_path),
 
106
  def detect_and_align(self, frame, top, bottom, left, right, return_para=False):
107
  message = 'Error: no face detected! Please retry or change the photo.'
108
  paras = get_video_crop_parameter(frame, self.landmarkpredictor, [left, right, top, bottom])
109
+ instyle = None
110
  h, w, scale = 0, 0, 0
111
  if paras is not None:
112
  h,w,top,bottom,left,right,scale = paras
 
136
  #@torch.inference_mode()
137
  def detect_and_align_image(self, image: str, top: int, bottom: int, left: int, right: int
138
  ) -> tuple[np.ndarray, torch.Tensor, str]:
139
+ if image is None:
140
+ return np.zeros((256,256,3), np.uint8), None, 'Error: fail to load empty file.'
141
  frame = cv2.imread(image)
142
  if frame is None:
143
+ return np.zeros((256,256,3), np.uint8), None, 'Error: fail to load the image.'
144
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
145
  return self.detect_and_align(frame, top, bottom, left, right)
146
 
147
  def detect_and_align_video(self, video: str, top: int, bottom: int, left: int, right: int
148
  ) -> tuple[np.ndarray, torch.Tensor, str]:
149
+ if video is None:
150
+ return np.zeros((256,256,3), np.uint8), None, 'Error: fail to load empty file.'
151
  video_cap = cv2.VideoCapture(video)
152
  if video_cap.get(7) == 0:
153
  video_cap.release()
 
159
 
160
  def detect_and_align_full_video(self, video: str, top: int, bottom: int, left: int, right: int) -> tuple[str, torch.Tensor, str]:
161
  message = 'Error: no face detected! Please retry or change the video.'
162
+ instyle = None
163
+ if video is None:
164
+ return 'default.mp4', instyle, 'Error: fail to load empty file.'
165
  video_cap = cv2.VideoCapture(video)
166
  if video_cap.get(7) == 0:
167
  video_cap.release()
 
216
  return ((y_tilde[0].cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8), 'Successfully toonify the image'
217
 
218
  def video_tooniy(self, aligned_video: str, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float) -> tuple[str, str]:
219
+ if aligned_video is None:
220
+ return 'output.mp4', 'Opps, something wrong with the input. Please go to Step 2 and Rescale Video again.'
221
  video_cap = cv2.VideoCapture(aligned_video)
222
  if instyle is None or aligned_face is None or video_cap.get(7) == 0:
223
  video_cap.release()