Nick088 commited on
Commit
a734e0b
1 Parent(s): 0fa50f3

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +32 -17
infer.py CHANGED
@@ -25,46 +25,61 @@ def infer_image(img: Image.Image, size_modifier: int ) -> Image.Image:
25
  return result
26
 
27
  def infer_video(video_filepath: str, size_modifier: int) -> str:
 
 
 
 
 
28
  model = RealESRGAN(device, scale=size_modifier)
29
  model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
30
 
31
  cap = cv.VideoCapture(video_filepath)
32
 
 
33
  tmpfile = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
34
  vid_output = tmpfile.name
35
  tmpfile.close()
36
 
37
- vid_writer = cv.VideoWriter(
 
38
  vid_output,
39
- fourcc=cv.VideoWriter.fourcc(*'mp4v'),
40
- fps=cap.get(cv.CAP_PROP_FPS),
41
- frameSize=(int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) * size_modifier, int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) * size_modifier)
42
  )
43
 
44
- n_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
45
-
46
- # while cap.isOpened():
47
- for _ in tqdm.tqdm(range(n_frames)):
48
  ret, frame = cap.read()
49
  if not ret:
50
  break
51
 
52
- frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
 
53
  frame = Image.fromarray(frame)
54
-
55
  upscaled_frame = model.predict(frame.convert('RGB'))
56
-
57
- upscaled_frame = np.array(upscaled_frame)
58
- upscaled_frame = cv.cvtColor(upscaled_frame, cv.COLOR_RGB2BGR)
59
 
60
- print(upscaled_frame.shape)
 
 
61
 
 
62
  vid_writer.write(upscaled_frame)
63
 
 
 
64
  vid_writer.release()
65
 
66
- print(f"Video file : {video_filepath}")
 
67
 
68
- return vid_output
69
-
70
 
 
 
 
 
 
 
25
  return result
26
 
27
  def infer_video(video_filepath: str, size_modifier: int) -> str:
28
+ # Extract audio from the original video file
29
+ audio = cv2.AudioCapture(video_filepath)
30
+ audio_data = np.frombuffer(audio.readAll(), dtype=np.int16)
31
+ audio_array = np.array(audio_data, dtype=np.int16)
32
+
33
  model = RealESRGAN(device, scale=size_modifier)
34
  model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
35
 
36
  cap = cv.VideoCapture(video_filepath)
37
 
38
+ # Create a temporary file for the output video
39
  tmpfile = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
40
  vid_output = tmpfile.name
41
  tmpfile.close()
42
 
43
+ # Create a VideoWriter object for the output video
44
+ vid_writer = cv2.VideoWriter(
45
  vid_output,
46
+ fourcc=cv2.VideoWriter.fourcc(*'mp4v'),
47
+ fps=cap.get(cv2.CAP_PROP_FPS),
48
+ frameSize=(int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) * size_modifier, int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) * size_modifier)
49
  )
50
 
51
+ # Process each frame of the video and write it to the output video
52
+ n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
53
+ for i in tqdm(range(n_frames)):
54
+ # Read the next frame
55
  ret, frame = cap.read()
56
  if not ret:
57
  break
58
 
59
+ # Convert the frame to RGB and feed it to the model
60
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
61
  frame = Image.fromarray(frame)
 
62
  upscaled_frame = model.predict(frame.convert('RGB'))
 
 
 
63
 
64
+ # Convert the upscaled frame back to BGR and write it to the output video
65
+ upscaled_frame = np.array(upscaled_frame)
66
+ upscaled_frame = cv2.cvtColor(upscaled_frame, cv2.COLOR_RGB2BGR)
67
 
68
+ # Write the upscaled frame to the output video
69
  vid_writer.write(upscaled_frame)
70
 
71
+ # Release the VideoCapture and VideoWriter objects
72
+ cap.release()
73
  vid_writer.release()
74
 
75
+ # Create a new VideoFileClip object from the output video
76
+ output_clip = mpy.VideoFileClip(vid_output)
77
 
78
+ # Add the audio back to the output video
79
+ output_clip = output_clip.set_audio(mpy.AudioFileClip(video_filepath, fps=output_clip.fps))
80
 
81
+ # Save the output video to a new file
82
+ output_clip.write_videofile(f'output_{video_filepath}')
83
+
84
+ return f'output_{video_filepath}'
85
+