Nick088 commited on
Commit
f84f6c9
·
verified ·
1 Parent(s): 64daa3d

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +27 -54
infer.py CHANGED
@@ -1,89 +1,62 @@
1
  from PIL import Image
2
- import cv2
3
  import torch
4
  from RealESRGAN import RealESRGAN
5
  import tempfile
6
  import numpy as np
7
- from tqdm import tqdm
8
- import pydub
9
- from pydub import AudioSegment
10
- from moviepy.editor import VideoFileClip, AudioFileClip
11
 
12
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
-
14
- def infer_image(img: Image.Image, size_modifier: int ) -> Image.Image:
15
- if img is None:
16
- raise Exception("Image not uploaded")
17
-
18
- width, height = img.size
19
-
20
- if width >= 5000 or height >= 5000:
21
- raise Exception("The image is too large.")
22
 
23
- model = RealESRGAN(device, scale=size_modifier)
24
- model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
25
 
26
- result = model.predict(img.convert('RGB'))
27
- print(f"Image size ({device}): {size_modifier} ... OK")
28
- return result
29
 
30
  def infer_video(video_filepath: str, size_modifier: int) -> str:
31
  model = RealESRGAN(device, scale=size_modifier)
32
  model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
33
-
34
- # Extract audio from the original video file
35
- audio = AudioSegment.from_file(video_filepath, format="mp4")
36
- audio_array = np.array(audio.get_array_of_samples())
37
-
38
- # Create a VideoCapture object for the video file
39
- cap = cv2.VideoCapture(video_filepath)
40
 
41
- # Create a temporary file for the output video
 
42
  tmpfile = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
43
  vid_output = tmpfile.name
44
  tmpfile.close()
45
 
46
- # Create a VideoWriter object for the output video
47
- vid_writer = cv2.VideoWriter(
 
 
 
48
  vid_output,
49
- fourcc=cv2.VideoWriter.fourcc(*'mp4v'),
50
- fps=cap.get(cv2.CAP_PROP_FPS),
51
- frameSize=(int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) * size_modifier, int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) * size_modifier)
52
  )
53
 
54
- # Process each frame of the video and write it to the output video
55
- n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
56
- for i in tqdm.tqdm(range(n_frames)):
57
- # Read the next frame
58
  ret, frame = cap.read()
59
  if not ret:
60
  break
61
 
62
- # Convert the frame to RGB and feed it to the model
63
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
64
  frame = Image.fromarray(frame)
65
- upscaled_frame = model.predict(frame.convert('RGB'))
66
 
67
- # Convert the upscaled frame back to BGR and write it to the output video
 
68
  upscaled_frame = np.array(upscaled_frame)
69
- upscaled_frame = cv2.cvtColor(upscaled_frame, cv2.COLOR_RGB2BGR)
70
 
71
- # Write the upscaled frame to the output video
72
  vid_writer.write(upscaled_frame)
73
 
74
- # Release the VideoCapture and VideoWriter objects
75
- cap.release()
76
  vid_writer.release()
77
 
78
- # Create a new VideoFileClip object from the output video
79
- output_clip = VideoFileClip(vid_output)
80
 
81
- # Add the audio back to the output video
82
- audio_clip = AudioFileClip(f"{video_filepath.split('.')[0]}.wav", fps=output_clip.fps)
83
- output_clip = output_clip.set_audio(audio_clip)
84
 
85
- # Save the output video to a new file
86
- output_clip.write_videofile(f'output_{video_filepath}')
87
 
88
- return f'output_{video_filepath}'
89
-
 
1
  from PIL import Image
2
+ import cv2 as cv
3
  import torch
4
  from RealESRGAN import RealESRGAN
5
  import tempfile
6
  import numpy as np
7
+ import tqdm
8
+ import ffmpeg
 
 
9
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
12
 
 
 
 
13
 
14
  def infer_video(video_filepath: str, size_modifier: int) -> str:
15
  model = RealESRGAN(device, scale=size_modifier)
16
  model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
 
 
 
 
 
 
 
17
 
18
+ cap = cv.VideoCapture(video_filepath)
19
+
20
  tmpfile = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
21
  vid_output = tmpfile.name
22
  tmpfile.close()
23
 
24
+ # Extract audio from the input video
25
+ audio_file = video_filepath.replace(".mp4", ".wav")
26
+ ffmpeg.input(video_filepath).output(audio_file, format='wav', ac=1).run(overwrite_output=True)
27
+
28
+ vid_writer = cv.VideoWriter(
29
  vid_output,
30
+ fourcc=cv.VideoWriter.fourcc(*'mp4v'),
31
+ fps=cap.get(cv.CAP_PROP_FPS),
32
+ frameSize=(int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) * size_modifier, int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) * size_modifier)
33
  )
34
 
35
+ n_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
36
+
37
+ for _ in tqdm.tqdm(range(n_frames)):
 
38
  ret, frame = cap.read()
39
  if not ret:
40
  break
41
 
42
+ frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
 
43
  frame = Image.fromarray(frame)
 
44
 
45
+ upscaled_frame = model.predict(frame.convert('RGB'))
46
+
47
  upscaled_frame = np.array(upscaled_frame)
48
+ upscaled_frame = cv.cvtColor(upscaled_frame, cv.COLOR_RGB2BGR)
49
 
 
50
  vid_writer.write(upscaled_frame)
51
 
 
 
52
  vid_writer.release()
53
 
54
+ # Re-encode the video with the modified audio
55
+ ffmpeg.input(vid_output).output(video_filepath.replace(".mp4", "_upscaled.mp4"), vcodec='libx264', acodec='aac', audio_bitrate='320k').run(overwrite_output=True)
56
 
57
+ # Replace the original audio with the upscaled audio
58
+ ffmpeg.input(audio_file).output(video_filepath.replace(".mp4", "_upscaled.mp4"), acodec='aac', audio_bitrate='320k').run(overwrite_output=True)
 
59
 
60
+ print(f"Video file : {video_filepath}")
 
61
 
62
+ return vid_output.replace(".mp4", "_upscaled.mp4")