YANGYYYY commited on
Commit
f3bbb27
1 Parent(s): 7232d95

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +17 -35
inference.py CHANGED
@@ -233,7 +233,7 @@ class Predictor:
233
  video_writer.close()
234
  #output_file.close()
235
 
236
- def transform_video(self, input, batch_size=4, start=0, end=0):
237
  end = end or None
238
 
239
  # if not os.path.isfile(input_path):
@@ -247,55 +247,37 @@ class Predictor:
247
  # if is_gg_drive:
248
  # temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
249
 
250
- def transform_and_write(frames, count, writer):
 
251
  anime_images = self.transform(frames)
252
  for i in range(count):
253
  img = np.clip(anime_images[i], 0, 255).astype(np.uint8)
254
- writer.write(img)
255
-
256
- video_capture = input
257
- frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
258
- frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
259
- fps = int(video_capture.get(cv2.CAP_PROP_FPS))
260
- frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
261
-
262
- if start or end:
263
- start_frame = int(start * fps)
264
- end_frame = int(end * fps) if end else frame_count
265
- video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
266
- frame_count = end_frame - start_frame
267
-
268
- # video_writer = cv2.VideoWriter(
269
- # output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
270
- video_buffer = []
271
 
272
- #print(f'Transforming video {input_path}, {frame_count} frames, size: ({frame_width}, {frame_height})')
 
273
 
274
- batch_shape = (batch_size, frame_height, frame_width, 3)
275
  frames = np.zeros(batch_shape, dtype=np.uint8)
276
  frame_idx = 0
277
 
278
  try:
279
- for _ in tqdm(range(frame_count)):
280
- ret, frame = video_capture.read()
281
- if not ret:
282
- break
283
  frames[frame_idx] = frame
284
  frame_idx += 1
285
  if frame_idx == batch_size:
286
- transform_and_write(frames, frame_idx, video_buffer)
 
287
  frame_idx = 0
288
  except Exception as e:
289
  print(e)
290
- finally:
291
- video_capture.release()
292
- #video_writer.release()
293
-
294
- return video_buffer
295
- # if temp_file:
296
- # shutil.move(temp_file, output_path)
297
-
298
- # print(f'Animation video saved to {output_path}')
299
  def preprocess_images(self, images):
300
  '''
301
  Preprocess image for inference
 
233
  video_writer.close()
234
  #output_file.close()
235
 
236
+ def transform_video(self, video_frames, batch_size=4):
237
  end = end or None
238
 
239
  # if not os.path.isfile(input_path):
 
247
  # if is_gg_drive:
248
  # temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
249
 
250
+ def transform_and_save(self, frames, count):
251
+ transformed_frames = []
252
  anime_images = self.transform(frames)
253
  for i in range(count):
254
  img = np.clip(anime_images[i], 0, 255).astype(np.uint8)
255
+ transformed_frames.append(img)
256
+ return transformed_frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
+ frame_count = len(video_frames)
259
+ transformed_video_frames = []
260
 
261
+ batch_shape = (batch_size,) + video_frames[0].shape
262
  frames = np.zeros(batch_shape, dtype=np.uint8)
263
  frame_idx = 0
264
 
265
  try:
266
+ for frame in video_frames:
 
 
 
267
  frames[frame_idx] = frame
268
  frame_idx += 1
269
  if frame_idx == batch_size:
270
+ transformed_frames = self.transform_and_save(frames, frame_idx)
271
+ transformed_video_frames.extend(transformed_frames)
272
  frame_idx = 0
273
  except Exception as e:
274
  print(e)
275
+
276
+ return transformed_video_frames
277
+
278
+
279
+
280
+
 
 
 
281
  def preprocess_images(self, images):
282
  '''
283
  Preprocess image for inference