File size: 5,448 Bytes
4a8a7a3
 
 
 
 
 
 
 
e950119
 
4a8a7a3
 
 
 
 
 
f8a10df
e950119
 
4a8a7a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e950119
4a8a7a3
 
 
 
 
 
 
 
 
e950119
4a8a7a3
308d965
4b66926
4a8a7a3
 
 
 
e950119
4a8a7a3
e950119
4a8a7a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e950119
4a8a7a3
 
 
e950119
4a8a7a3
e950119
4a8a7a3
 
e950119
4a8a7a3
 
e950119
4a8a7a3
e950119
 
 
4a8a7a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e950119
 
4a8a7a3
 
 
 
 
e950119
 
4a8a7a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
try:
    import spaces
except ImportError:
    # Create a dummy decorator if spaces is not available
    def spaces_gpu(func):
        return func
    spaces = type('spaces', (), {'GPU': spaces_gpu})()

import gradio as gr
import torch
from torchvision.transforms import functional as F
from PIL import Image
import os
import cv2
import numpy as np
from super_image import EdsrModel, ImageLoader
from safetensors.torch import load_file


@spaces.GPU
def upscale_video(video_path, scale_factor, progress=gr.Progress()):
    """
    Upscales a video using EDSR model.
    This function is decorated with @spaces.GPU to run on ZeroGPU.
    """
    # Load models inside the function for ZeroGPU compatibility
    if scale_factor == 2:
        model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=2)
    elif scale_factor == 4:
        model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=4)
    else:
        raise gr.Error("Invalid scale factor. Choose 2 or 4.")

    if not os.path.exists(video_path):
        raise gr.Error(f"Input file not found at {video_path}")

    video_capture = cv2.VideoCapture(video_path)
    if not video_capture.isOpened():
        raise gr.Error(f"Could not open video file {video_path}")

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = video_capture.get(cv2.CAP_PROP_FPS)
    width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))

    output_width = width * scale_factor
    output_height = height * scale_factor
    
    output_path = f"upscaled_{scale_factor}x_{os.path.basename(video_path)}"
    video_writer = cv2.VideoWriter(output_path, fourcc, fps, (output_width, output_height))

    for i in progress.tqdm(range(frame_count), desc=f"Upscaling {scale_factor}x"):
        ret, frame = video_capture.read()
        if not ret:
            break

        pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        
        inputs = ImageLoader.load_image(pil_frame)
        preds = model(inputs)        
        output_frame = ImageLoader.to_pil(preds)
        video_writer.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR))

    video_capture.release()
    video_writer.release()
    
    return output_path


@spaces.GPU
def rife_interpolate_video(video_path, progress=gr.Progress()):
    """
    Interpolates a video using the RIFE model.
    This function is decorated with @spaces.GPU to run on ZeroGPU.
    """
    if not os.path.exists(video_path):
        raise gr.Error(f"Input file not found at {video_path}")

    # Load the RIFE model
    model = RIFEModel()
    model.load_state_dict(load_file("/Users/craigellenwood/Workspace/video_upscaler_rife_interpolator/rife_model_new/rife-flownet-4.13.2.safetensors"))
    model.eval()
    model.cuda()

    video_capture = cv2.VideoCapture(video_path)
    if not video_capture.isOpened():
        raise gr.Error(f"Could not open video file {video_path}")

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = video_capture.get(cv2.CAP_PROP_FPS)
    width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))

    output_path = f"interpolated_{os.path.basename(video_path)}"
    video_writer = cv2.VideoWriter(output_path, fourcc, fps * 2, (width, height))

    prev_frame = None
    for i in progress.tqdm(range(frame_count), desc="Interpolating"):
        ret, frame = video_capture.read()
        if not ret:
            break

        if prev_frame is not None:
            # Preprocess frames
            img0 = torch.from_numpy(prev_frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
            img1 = torch.from_numpy(frame.transpose(2, 0, 1)).float().unsqueeze(0).cuda() / 255.
            
            # Run inference
            with torch.no_grad():
                interpolated_frame = model.inference(img0, img1)[0].cpu().numpy().transpose(1, 2, 0) * 255
            
            video_writer.write(interpolated_frame.astype(np.uint8))

        video_writer.write(frame)
        prev_frame = frame

    video_capture.release()
    video_writer.release()

    return output_path




with gr.Blocks() as demo:
    gr.Markdown("# Video Upscaler and Frame Interpolator")
    with gr.Tab("Upscale"):
        with gr.Row():
            with gr.Column():
                video_input_upscale = gr.Video(label="Input Video")
                scale_factor = gr.Radio([2, 4], label="Scale Factor", value=2)
                upscale_button = gr.Button("Upscale Video")
            with gr.Column():
                video_output_upscale = gr.Video(label="Upscaled Video")
    with gr.Tab("Interpolate"):
        with gr.Row():
            with gr.Column():
                video_input_rife = gr.Video(label="Input Video")
                rife_button = gr.Button("Interpolate Frames")
            with gr.Column():
                video_output_rife = gr.Video(label="Interpolated Video")

    upscale_button.click(
        fn=upscale_video,
        inputs=[video_input_upscale, scale_factor],
        outputs=video_output_upscale
    )
    
    rife_button.click(
        fn=rife_interpolate_video,
        inputs=[video_input_rife],
        outputs=video_output_rife
    )

if __name__ == "__main__":
    demo.launch(share=True)