YANGYYYY commited on
Commit
4c4c8ef
1 Parent(s): 195428b

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +67 -2
inference.py CHANGED
@@ -233,7 +233,70 @@ class Predictor:
233
  video_writer.close()
234
  #output_file.close()
235
 
236
- def transform_video(self, video_frames, batch_size,start,end):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  #end = end or None
238
 
239
  # if not os.path.isfile(input_path):
@@ -262,11 +325,13 @@ class Predictor:
262
  if success:
263
  video_buffer.append(encoded_image.tobytes())
264
 
265
- video_capture = cv2.VideoCapture(video_frames)
266
  frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
267
  frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
268
  fps = int(video_capture.get(cv2.CAP_PROP_FPS))
269
  frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
 
 
270
 
271
  if start or end:
272
  start_frame = int(start * fps)
 
233
  video_writer.close()
234
  #output_file.close()
235
 
236
+
237
+ def transform_video(self, batch_size=4, start=0, end=0):
238
+ end = end or None
239
+
240
+ # if not os.path.isfile(input_path):
241
+ # raise FileNotFoundError(f'{input_path} does not exist')
242
+
243
+ # output_dir = "/".join(output_path.split("/")[:-1])
244
+ # os.makedirs(output_dir, exist_ok=True)
245
+ # is_gg_drive = '/drive/' in output_path
246
+ # temp_file = ''
247
+ #
248
+ # if is_gg_drive:
249
+ # temp_file = f'tmp_anime.{output_path.split(".")[-1]}'
250
+
251
+ def transform_and_write(frames, count, writer):
252
+ anime_images = self.transform(frames)
253
+ for i in range(count):
254
+ img = np.clip(anime_images[i], 0, 255).astype(np.uint8)
255
+ writer.write(img)
256
+
257
+ video_capture = cv2.VideoCapture(input_path)
258
+ frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
259
+ frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
260
+ fps = int(video_capture.get(cv2.CAP_PROP_FPS))
261
+ frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
262
+
263
+ if start or end:
264
+ start_frame = int(start * fps)
265
+ end_frame = int(end * fps) if end else frame_count
266
+ video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
267
+ frame_count = end_frame - start_frame
268
+
269
+ video_writer = cv2.VideoWriter(
270
+ output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
271
+
272
+ print(f'Transforming video {input_path}, {frame_count} frames, size: ({frame_width}, {frame_height})')
273
+
274
+ batch_shape = (batch_size, frame_height, frame_width, 3)
275
+ frames = np.zeros(batch_shape, dtype=np.uint8)
276
+ frame_idx = 0
277
+
278
+ try:
279
+ for _ in tqdm(range(frame_count)):
280
+ ret, frame = video_capture.read()
281
+ if not ret:
282
+ break
283
+ frames[frame_idx] = frame
284
+ frame_idx += 1
285
+ if frame_idx == batch_size:
286
+ transform_and_write(frames, frame_idx, video_writer)
287
+ frame_idx = 0
288
+ except Exception as e:
289
+ print(e)
290
+ finally:
291
+ video_capture.release()
292
+ video_writer.release()
293
+
294
+ if temp_file:
295
+ shutil.move(temp_file, output_path)
296
+
297
+ print(f'Animation video saved to {output_path}')
298
+
299
+ def transform_video1(self, video, batch_size, start, end):
300
  #end = end or None
301
 
302
  # if not os.path.isfile(input_path):
 
325
  if success:
326
  video_buffer.append(encoded_image.tobytes())
327
 
328
+ video_capture = cv2.VideoCapture(video)
329
  frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
330
  frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
331
  fps = int(video_capture.get(cv2.CAP_PROP_FPS))
332
  frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
333
+
334
+ print(f'Transforming video {frame_count} frames, size: ({frame_width}, {frame_height})')
335
 
336
  if start or end:
337
  start_frame = int(start * fps)