YANGYYYY commited on
Commit
623c086
1 Parent(s): f877ad1

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +21 -19
inference.py CHANGED
@@ -233,19 +233,19 @@ class Predictor:
233
  video_writer.close()
234
  #output_file.close()
235
 
236
- def transform_video(self, input_path, output_path, batch_size=4, start=0, end=0):
237
  end = end or None
238
 
239
- if not os.path.isfile(input_path):
240
- raise FileNotFoundError(f'{input_path} does not exist')
241
 
242
- output_dir = "/".join(output_path.split("/")[:-1])
243
- os.makedirs(output_dir, exist_ok=True)
244
- is_gg_drive = '/drive/' in output_path
245
- temp_file = ''
246
 
247
- if is_gg_drive:
248
- temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
249
 
250
  def transform_and_write(frames, count, writer):
251
  anime_images = self.transform(frames)
@@ -253,7 +253,7 @@ class Predictor:
253
  img = np.clip(anime_images[i], 0, 255).astype(np.uint8)
254
  writer.write(img)
255
 
256
- video_capture = cv2.VideoCapture(input_path)
257
  frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
258
  frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
259
  fps = int(video_capture.get(cv2.CAP_PROP_FPS))
@@ -265,10 +265,11 @@ class Predictor:
265
  video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
266
  frame_count = end_frame - start_frame
267
 
268
- video_writer = cv2.VideoWriter(
269
- output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
 
270
 
271
- print(f'Transforming video {input_path}, {frame_count} frames, size: ({frame_width}, {frame_height})')
272
 
273
  batch_shape = (batch_size, frame_height, frame_width, 3)
274
  frames = np.zeros(batch_shape, dtype=np.uint8)
@@ -282,18 +283,19 @@ class Predictor:
282
  frames[frame_idx] = frame
283
  frame_idx += 1
284
  if frame_idx == batch_size:
285
- transform_and_write(frames, frame_idx, video_writer)
286
  frame_idx = 0
287
  except Exception as e:
288
  print(e)
289
  finally:
290
  video_capture.release()
291
- video_writer.release()
292
-
293
- if temp_file:
294
- shutil.move(temp_file, output_path)
 
295
 
296
- print(f'Animation video saved to {output_path}')
297
  def preprocess_images(self, images):
298
  '''
299
  Preprocess image for inference
 
233
  video_writer.close()
234
  #output_file.close()
235
 
236
+ def transform_video(self, input, batch_size=4, start=0, end=0):
237
  end = end or None
238
 
239
+ # if not os.path.isfile(input_path):
240
+ # raise FileNotFoundError(f'{input_path} does not exist')
241
 
242
+ # output_dir = "/".join(output_path.split("/")[:-1])
243
+ # os.makedirs(output_dir, exist_ok=True)
244
+ # is_gg_drive = '/drive/' in output_path
245
+ # temp_file = ''
246
 
247
+ # if is_gg_drive:
248
+ # temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
249
 
250
  def transform_and_write(frames, count, writer):
251
  anime_images = self.transform(frames)
 
253
  img = np.clip(anime_images[i], 0, 255).astype(np.uint8)
254
  writer.write(img)
255
 
256
+ video_capture = input
257
  frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
258
  frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
259
  fps = int(video_capture.get(cv2.CAP_PROP_FPS))
 
265
  video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
266
  frame_count = end_frame - start_frame
267
 
268
+ # video_writer = cv2.VideoWriter(
269
+ # output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
270
+ video_buffer = []
271
 
272
+ #print(f'Transforming video {input_path}, {frame_count} frames, size: ({frame_width}, {frame_height})')
273
 
274
  batch_shape = (batch_size, frame_height, frame_width, 3)
275
  frames = np.zeros(batch_shape, dtype=np.uint8)
 
283
  frames[frame_idx] = frame
284
  frame_idx += 1
285
  if frame_idx == batch_size:
286
+ transform_and_write(frames, frame_idx, video_buffer)
287
  frame_idx = 0
288
  except Exception as e:
289
  print(e)
290
  finally:
291
  video_capture.release()
292
+ #video_writer.release()
293
+
294
+ return video_buffer
295
+ # if temp_file:
296
+ # shutil.move(temp_file, output_path)
297
 
298
+ # print(f'Animation video saved to {output_path}')
299
  def preprocess_images(self, images):
300
  '''
301
  Preprocess image for inference