fantaxy's picture
Update app.py
02a9a58 verified
import spaces
import gradio as gr
import cv2
import numpy as np
import time
import random
from PIL import Image
import torch
torch.jit.script = lambda f: f
from transparent_background import Remover
@spaces.GPU()
def doo(video, mode, progress=gr.Progress()):
if mode == 'Fast':
remover = Remover(mode='fast')
else:
remover = Remover()
cap = cv2.VideoCapture(video)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Get total frames
writer = None
tmpname = random.randint(111111111, 999999999)
processed_frames = 0
start_time = time.time()
while cap.isOpened():
ret, frame = cap.read()
if ret is False:
break
if time.time() - start_time >= 20 * 60 - 5:
print("GPU Timeout is coming")
cap.release()
writer.release()
return str(tmpname) + '.mp4'
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = Image.fromarray(frame).convert('RGB')
if writer is None:
writer = cv2.VideoWriter(str(tmpname) + '.mp4', cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), img.size)
processed_frames += 1
print(f"Processing frame {processed_frames}")
progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")
out = remover.process(img, type='green')
writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB))
cap.release()
writer.release()
return str(tmpname) + '.mp4'
examples = [['./mp4.mp4']]
css = """
footer {
visibility: hidden;
}
"""
iface = gr.Interface(theme="Nymbo/Nymbo_Theme", css=css,
fn=doo,
inputs=["video", gr.components.Radio(['Normal', 'Fast'], label='Select mode', value='Normal', info='Normal is more accurate, but takes longer. | Fast has lower accuracy so the process will be faster.')],
outputs="video",
examples=examples
)
iface.launch()