VfiTest / app.py
SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
raw
history blame
21.3 kB
import gradio as gr
import numpy as np
import cv2
import os
import glob
import torch
import shutil
from PIL import Image
from tqdm import tqdm
from torch.nn import functional as F
from torchvision.transforms import functional as TF
from matplotlib import pyplot as plt
from modules.components.upr_net_freq import upr_freq as upr_freq002
from modules.components.upr_basic import upr as upr_basic
import datetime
import zipfile
os.system('python -m pip install --upgrade pip')
#from scipy.interpolate import make_interp_spline
# python3 -m vfi_inference_triplet --cuda_index 0 \
# --root ../VFI_Inference/thistriplet_notarget --pretrain_path ./pretrained/upr_freq002.pth \
# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1
# ์•„์ด๋””, ๋น„๋ฐ€๋ฒˆํ˜ธ ํŠœํ”Œ, ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€ํ•˜๋ฉด ์—ฌ๋Ÿฌ ์‚ฌ์šฉ์ž๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.
# ๋‹ค๋ฅธ ํŒŒ์ผ๋กœ ๋งŒ๋“ค์–ด ์‚ฌ์šฉํ•˜๊ธฐ๋ฅผ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค.
KEY = [("test", "test"),
]
# ๋กœ๊ทธ์ธ ์‹œ ํ˜ธ์ถœ๋˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
#ํ˜น์‹œ ๋กœ๊ทธ์ธ์— ๋Œ€ํ•œ ์ •๋ณด๋‚˜, ๋‹คip ๋“ฑ์„ ์–ป๊ณ  ์‹ถ์œผ๋ฉด ์ด ๋ถ€๋ถ„ ์ˆ˜์ •๋ฐ”๋ž๋‹ˆ๋‹ค.
def check_valid_login(user_name, password):
#client_ip = request.client.host
#print(client_ip)
flag = (user_name, password) in KEY
return flag
# ๋น„๋””์˜ค์—์„œ ์ฒ˜์Œ ๋ช‡ ํ”„๋ ˆ์ž„์„ ์ž๋ฅผ์ง€ ๋ณ€์ˆ˜์ž…๋‹ˆ๋‹ค.
MAX_FRAME = 24
#VFI inference ์ฝ”๋“œ๋ฅผ ๊ทธ๋Œ€๋กœ ๊ฐ€์ ธ์™”์Šต๋‹ˆ๋‹ค.
DEVICE = 0#"cuda"
torch.cuda.set_device(DEVICE)
#ROOT = args.root
#SAVE_ROOT = f'output'
SCALE = 1
pyr_level = 7
nr_lvl_skipped = 0
splat_mode = "average"
pretrain_path = "./pretrained/upr_freq002.pth"
model = upr_freq002.Model(pyr_level=pyr_level,
nr_lvl_skipped=nr_lvl_skipped,
splat_mode=splat_mode)
sd = torch.load(pretrain_path, map_location='cpu')
sd = sd['model'] if 'model' in sd.keys() else sd
print(model.load_state_dict(sd))
model = model.to(DEVICE)
def get_sorted_img(file_path):
return sorted(glob.glob(os.path.join(file_path, f"*.png")), key = lambda x : float(x.split("_")[-1][:-4]))
def multiple_pad(image, multiple):
_,_,H,W = image.size()
pad1 = multiple-(H%multiple) if H%multiple!=0 else 0
pad2 = multiple-(W%multiple) if W%multiple!=0 else 0
return TF.pad(image, (0,0,pad2,pad1))
#์ด๋ฏธ์ง€ 1(path1), 2๋ฅผ VFIํ•˜์—ฌ ๊ฐ€์šด๋ฐ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
def multiple_VFIx2(path1, path2, output_name):
file_list = [path1, path2]
img_list = [(torch.from_numpy(cv2.imread(file)[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE) for file in file_list]
img_list = [multiple_pad(img, SCALE) for k, img in enumerate(img_list)]
img_list = [F.interpolate(img, scale_factor=1/SCALE, mode='bicubic') for k, img in enumerate(img_list)]
img0,img1 = img_list
_,_,Hori,Wori = img0.size()
with torch.no_grad():
result_dict, extra_dict = model(img0, img1, pyr_level=pyr_level, nr_lvl_skipped=nr_lvl_skipped, time_step=0.5)
out = F.interpolate(result_dict['imgt_pred'], scale_factor=SCALE, mode='bicubic')[:,:,:Hori,:Wori].clamp(0,1)
cv2.imwrite(output_name, (out[0].cpu().permute(1,2,0)*255).numpy().astype(np.uint8)[:,:,[2,1,0]])
torch.cuda.empty_cache()
#1, 2๋ฅผ 3๋ฒˆ VFIํ•˜์—ฌ 3์žฅ์„ ๋งŒ๋“œ๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
"""
def multiple_VFIx4(path1, path2, name1, name2, name3):
multiple_VFIx2(path1, path2, name2)
multiple_VFIx2(path1, name2, name1)
multiple_VFIx2(name2, path2, name3)
"""
def multiple_VFIx4(path1, path2):
frac = [".25", ".5", ".75"]
name1 , name2, name3 = [f"{path1[:-4]}{f}.png" for f in frac]
multiple_VFIx2(path1, path2, name2)
multiple_VFIx2(path1, name2, name1)
multiple_VFIx2(name2, path2, name3)
#0, 0.125 , 0.25, 0.5, 0.75, 0.875, 1๋กœ 5์žฅ ์ƒ์„ฑ
def multiple_VFIx6(path1, path2):
frac = [".125", ".25", ".75", ".875"]
name_inf1 , name1, name2, name_inf2 = [f"{path1[:-4]}{f}.png" for f in frac]
multiple_VFIx4(path1, path2)
multiple_VFIx2(path1, name1, name_inf1)
multiple_VFIx2(name2, path2, name_inf2)
#๋น„๋””์˜ค์—์„œ fix๋ฅผ ํ•˜์—ฌ ์ด๋ฏธ์ง€๋ฅผ ๋Œ€์ฒดํ•˜์—ฌ ์ถœ๋ ฅํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
def fix_img(idx, fixed_list, input_dir = "input", output_dir = "output"):
idx = int(idx)
#์˜ฌ๋ฐ”๋ฅด์ง€ ์•Š๊ฑฐ๋‚˜, ์ด๋ฏธ fix ํ–ˆ๋‹ค๋ฉด ๋ณ€ํ™” x
if idx < 1 or idx > MAX_FRAME - 2 or fixed_list[idx] == 1:
return {
fix_result_gallery : gr.Gallery(),
fix_result_group : gr.Group(),
fixed_frame : gr.Text()
}
now_time = os.path.basename(input_dir)
output_dir = os.path.join(output_dir, f"fix_{now_time}")
os.makedirs(output_dir, exist_ok = True)
output_name = os.path.join(output_dir, f"img_{idx:03d}.png")
multiple_VFIx2(os.path.join(input_dir, f"img_{idx - 1:03d}.png"),
os.path.join(input_dir, f"img_{idx + 1:03d}.png"),
output_name)
fixed_list[idx] = 1
fixed_frame_string = ""
result_list = []
name_list = []
#์ˆœ์ฐจ์ ์œผ๋กœ ๊ฒฐ๊ณผ ๊ฐค๋Ÿฌ๋ฆฌ ๊ฐฑ์‹ 
for i in range(MAX_FRAME):
if fixed_list[i] == 1:
name_list.append(f"(fixed) frame {i}")
result_list.append(os.path.join(output_dir, f"img_{i:03d}.png"))
fixed_frame_string += f"{i}, "
else:
name_list.append(f"frame {i}")
result_list.append(os.path.join(input_dir, f"img_{i:03d}.png"))
return {
fix_result_gallery : gr.Gallery(value = [(img, name) for img, name in zip(result_list, name_list)], selected_index = idx),
fix_result_group : gr.Group(visible=True),
fixed_frame : gr.Text(visible=True, value = fixed_frame_string[:-2]),
}
#์ฃผ์–ด์ง„ ease_val ๋ฆฌ์ŠคํŠธ์˜ ๊ฐ’ ๋ฐ”ํƒ•์œผ๋กœ ease๋ฅผ ์‹คํ–‰์‹œํ‚ค๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
def ease_frames(ease_val, input_dir = "input", output_dir = "output", progress=gr.Progress(track_tqdm=False)):
#now = os.path.basename(input_dir)
now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
output_dir = os.path.join(output_dir, f"ease_{now}")
os.makedirs(output_dir, exist_ok = True)
out_frame_list = [os.path.join(output_dir, f"img_{i:03d}.png") for i in range(MAX_FRAME)]
for i, f in enumerate([os.path.join(input_dir, f"img_{i:03d}.png") for i in range(MAX_FRAME)]):
shutil.copyfile(f, out_frame_list[i])
img_name = []
for i in progress.tqdm(range(MAX_FRAME - 1), desc = "VFI frames..."):
img_name.append(f"frame {i}")
if ease_val[i] == 1: pass
#x1๋Š” ์•„๋ฌด๊ฒƒ๋„, x2๋Š” ํ•œ ์žฅ, x4๋Š” 3์žฅ
# ์•„๋ž˜ ๊ธ€์ž ์ถ”๊ฐ€ ๋ถ€๋ถ„์€ ์ƒˆ๋กœ์šด ์ด๋ฏธ์ง€์˜ ์ œ๋ชฉ ๋ฐ”๊พธ๋Š” ๋ถ€๋ถ„์ž…๋‹ˆ๋‹ค.
elif ease_val[i] == 2:
multiple_VFIx2(out_frame_list[i], out_frame_list[i + 1]
, os.path.join(output_dir, f"img_{i:03d}.5.png"))
img_name.append(f"(new) frame {i + 0.5}")
elif ease_val[i] == 3:
multiple_VFIx4(out_frame_list[i], out_frame_list[i + 1])
img_name.append(f"(new) frame {i + 0.25}")
img_name.append(f"(new) frame {i + 0.5}")
img_name.append(f"(new) frame {i + 0.75}")
img_name.append(f"frame {MAX_FRAME - 1}")
files = get_sorted_img(output_dir)
#๋‹ค์šด๋กœ๋“œ์šฉ zip ํŒŒ์ผ
zip_name = os.path.join(output_dir,"frame_list.zip")
with zipfile.ZipFile(zip_name, 'w', compression=zipfile.ZIP_DEFLATED) as new_zip:
for x in progress.tqdm(files, desc ="compress file..."):
new_zip.write(x, os.path.basename(x))
return {
ease_result_gallery : [(file, name) for file, name in zip(files, img_name)],
ease_make_video : gr.Accordion(visible = True),
last_ease_dir : output_dir,
ease_zip : gr.File(value = zip_name)
}
# ์ด๋ฏธ์ง€ ๋‘ ์žฅ์„ ๋ฐ›์•„ VFI๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
def VFI_two(l, r, flag ,output_dir = "output"):
now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
output_dir = os.path.join(output_dir, f"fix_img_{now}")
os.makedirs(output_dir, exist_ok = True)
l = Image.fromarray(l)
r = Image.fromarray(r)
#๋ฉ”๋ชจ๋ฆฌ ์ดˆ๊ณผ๋ฅผ ๋ง‰๊ธฐ ์œ„ํ•ด ์ ๋‹นํ•œ ํฌ๊ธฐ ํ”ฝ์…€ ์ดํ•˜๊ฐ€ ๋˜๋„๋ก ๊ด€๋ฆฌ
W, H = l.size
#1920 * 1080 * 1.2 * 1.2 ๊ฐ€ ๋Œ€์ถฉ 3e6๋ผ ๊ทธ๊ฑธ ๊ธฐ์ค€์œผ๋กœ ์žก์•˜์Šต๋‹ˆ๋‹ค.
mul = ((3e6) / (W * H)) ** (1/2)
H, W = int(H * mul), int(W * mul)
#์ด๋ฏธ์ง€๊ฐ€ ์ปค์„œ ์ค„์—ฌ์•ผ ํ•œ๋‹ค๋ฉด ๊ฐ์†Œ, ์•„๋‹˜ ๊ทธ๋ƒฅ ์ž…๋ ฅ
if mul < 1:
l = l.resize((W, H))
r = r.resize((W, H))
l_name, r_name = f"{output_dir}/img_000.png", f"{output_dir}/img_001.png"
l.save(l_name)
r.save(r_name)
if flag == "x4":
multiple_VFIx4(l_name, r_name)
elif flag == "x2":
output_name = f"{output_dir}/img_000.5.png"
multiple_VFIx2(l_name, r_name, output_name)
else:
multiple_VFIx6(l_name, r_name)
return {
frame_gen_result_gallery : gr.Gallery(visible=True, value=get_sorted_img(output_dir))
}
#๋‹ค๋ฅธ ์ด๋ฏธ์ง€ ์ž…๋ ฅ์„ ์œ„ํ•ด ์ž…๋ ฅ๋œ ์ด๋ฏธ์ง€๋ฅผ ๋‚ ๋ฆฌ๋Š” ํ•ฉ์ˆ˜์ž…๋‹ˆ๋‹ค.
def clear_fix():
return{
img_0 : gr.Image(label="start image", sources =["upload"], value = None),
img_1 : gr.Image(label="end image", sources =["upload"], value = None),
frame_gen_result_gallery : gr.Gallery(visible=True, value=None)
}
with gr.Blocks(theme=gr.themes.Default(), title = "Inshorts Animator V. 0.5") as demo:
def info(request: gr.Request):
#ip๋ฅผ ์–ป๋Š” ๋ถ€๋ถ„์ž…๋‹ˆ๋‹ค.
#์ถ”ํ›„ ํŠน์ • ip ํ—ˆ์šฉ, ์ฐจ๋‹จ ๋“ฑ์ด ํ•„์š”ํ•˜๋ฉด ์ด์ชฝ ์ฐธ๊ณ ํ•ด ์ฃผ์„ธ์š”
headers = request.headers
print(headers["x-forwarded-for"].split(","))
demo.load(info, None)
gr.Markdown(f"""# Inshorts Animator V. 0.5 WebUI (Permitted User Only)""")
with gr.Tab("Mid Frame Generator"):
with gr.Column():
with gr.Row():
img_0 = gr.Image(label="start image", sources =["upload"])
img_1 = gr.Image(label="end image", sources =["upload"])
with gr.Row():
VFI_flag = gr.Radio(["x2", "x4", "x6(side ease)"], label="VFI ratio", value = "x2", interactive = True)
image_button = gr.Button("Run model")
frame_gen_result_gallery = gr.Gallery(visible=True,
label="result", columns=[5], rows=[1], object_fit="contain", height="auto", preview = True,
interactive = False)
image_button.click(VFI_two, inputs=[img_0, img_1, VFI_flag],
outputs=[frame_gen_result_gallery])
clear_button = gr.Button("Clear images")
clear_button.click(clear_fix, inputs=[],
outputs=[img_0, img_1, frame_gen_result_gallery])
with gr.Tab("Video"):
with gr.Group(visible=True) as video_input_group:
gr.Markdown(f"""#### only can handle {MAX_FRAME} frames""")
with gr.Column():
input_dir = gr.State("")
fps = gr.Number(visible=False)
video_input = gr.Video(label="Input Video", interactive=True, sources=['upload'])
gr.Markdown(f"""If video frame size is big, it will be resized""")
upload_button = gr.Button("upload video")
with gr.Group(visible=False) as image_edit_group:
with gr.Row():
with gr.Column():
with gr.Tab("Original Frame (for Monitoring)"):
fixed_list = gr.State([0] * (MAX_FRAME))
selected = gr.Number(visible=False, label = "selected frame", interactive = False)
image_gallery = gr.Gallery(
label="inputs", columns=[MAX_FRAME], rows=[1], object_fit="contain", height="auto", preview = True,
show_download_button=False)
clear_video = gr.Button("clear video")
with gr.Tab("Frame Fixer"):
with gr.Column():
#with gr.Row():
#with gr.Row():
with gr.Group(visible=True) as fix_result_group:
fix_result_gallery = gr.Gallery(
label="result", columns=[MAX_FRAME], rows=[1], object_fit="contain", height="auto", preview = True,
interactive = False)
fix_button = gr.Button(visible = True)
with gr.Row():
fixed_frame = gr.Text(visible=False, label = "fixed frame", interactive = False)
fix_button.click(fix_img, inputs=[selected, fixed_list, input_dir],
outputs=[fix_result_gallery, fix_result_group, fixed_frame])
def update_fix_button_visible(evt: gr.SelectData):
flag = 0 < evt.index < MAX_FRAME - 1
msg = f"fix frame {evt.index}" if flag else f"can only fix 1 ~ {MAX_FRAME - 2}"
return {
fix_button:gr.Button(msg, visible=True),
selected : evt.index,
fix_result_gallery : gr.Gallery(selected_index = evt.index)
}
image_gallery.select(update_fix_button_visible, None, [fix_button, selected, fix_result_gallery])
with gr.Tab("Motion easer"):
with gr.Column():
with gr.Column():
#with gr.Row():
with gr.Group(visible=True) as ease_result_group:
last_ease_dir = gr.State("")
ease_result_gallery = gr.Gallery(
label="result", columns=[MAX_FRAME], rows=[4], object_fit="contain", height="auto", preview = True,
interactive = False)
ease_button = gr.Button("ease", visible = True)
plt_data = gr.State([1] * (MAX_FRAME - 1))
VFI_x = gr.Radio([("x1", 1), ("x2", 2), ("x4", 3)], value = 1, label="Slow ratio", info="adjust Slow ratio", interactive = True)
with gr.Row():
edit_one_button = gr.Button("edit one scale", visible = True)
edit_all_button = gr.Button("edit all scale", visible = True)
now_frame = gr.Slider(0, MAX_FRAME - 1 - 1, step=1, label="Start frame", info="Choose Start frame to make slow. Interpolation will apply to (frame ~ frame + 1)")
def plt_edit(data):
fig = plt.figure()
x = np.arange(0, MAX_FRAME - 1) + 0.5
y = np.array(data)
plt.plot(x , y, color = 'black', marker = "o", linewidth = "2.5")
plt.xticks(np.arange(0, MAX_FRAME))
plt.yticks([1, 2, 3], ["x1", "x2\nslow", "x4\nslow"])
plt.gca().invert_yaxis()
plt.grid(True)
plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
return fig
ease_plot = gr.Plot(value = plt_edit(plt_data.value), show_label=False)
with gr.Accordion("get result", visible = False) as ease_make_video:
ease_zip = gr.File(label = "Download all image frames in Zip", interactive = False)
make_video_button = gr.Button("make video")
result_video = gr.Video(interactive = False)
def make_video(frame_dir, fps):
t = os.path.basename(frame_dir)
output_name = f"{frame_dir}/{t}.mp4"
if os.path.exists(output_name):
os.remove(output_name)
frame_list = get_sorted_img(frame_dir)
with open(f"{frame_dir}/input.txt", "w") as f:
for line in frame_list:
f.write(f"file '{os.path.basename(line)}'\n")
cmd = f'ffmpeg -r {fps} -f concat -safe 0 -i {frame_dir}/input.txt -c:v libx264 -preset veryslow -crf 10 {output_name}'
os.system(cmd)
return output_name
make_video_button.click(make_video, inputs = [last_ease_dir, fps], outputs = [result_video])
ease_button.click(ease_frames, inputs=[plt_data, input_dir], outputs=[ease_result_gallery, ease_make_video, last_ease_dir, ease_zip])
def edit_one_scale(data, idx, x):
if idx < MAX_FRAME - 1:
data[idx] = x if x else 1
return plt_edit(data)
edit_one_button.click(edit_one_scale, inputs=[plt_data, now_frame, VFI_x] , outputs=[ease_plot])
def edit_all_scale(data, x):
for i in range(len(data)): data[i] = x if x else 1
return plt_edit(data)
edit_all_button.click(edit_all_scale, inputs=[plt_data, VFI_x], outputs=[ease_plot])
def clear_vd(plt_data, fixed_list):
for i in range(len(plt_data)): plt_data[i] = 1
for i in range(len(fixed_list)): fixed_list[i] = 0
return {video_input:gr.Video(label="Input Video", interactive=True, sources=['upload'], value = None),
ease_result_gallery : gr.Gallery(
label="result", columns=[MAX_FRAME], rows=[4], object_fit="contain", height="auto", preview = True,
interactive = False, value = None),
fix_result_gallery : gr.Gallery(
label="result", columns=[MAX_FRAME], rows=[1], object_fit="contain", height="auto", preview = True,
interactive = False, value = None),
fixed_frame : gr.Text(visible=False, label = "fixed frame", interactive = False, value = None),
ease_make_video : gr.Accordion(visible = True),
video_input_group:gr.Group(visible=True),
image_edit_group:gr.Group(visible=False),
ease_plot : gr.Plot(value = plt_edit(plt_data))}
clear_video.click(clear_vd, inputs=[plt_data, fixed_list],outputs=[video_input, ease_result_gallery, fix_result_gallery, fixed_frame, ease_make_video, video_input_group, image_edit_group, ease_plot])
def update_video_visible(video):
if not video:
return {video_input_group:gr.Group(visible=True),
image_edit_group:gr.Group(visible=False),
image_gallery:[],
input_dir : "",
fps : 0
}
now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
input_now = os.path.join("input", now)
os.makedirs(input_now, exist_ok = True)
cap = cv2.VideoCapture(video)
frame_count = 0
video_fps = cap.get(cv2.CAP_PROP_FPS)
#print('video fps:', video_fps)
H = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
W = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
mul = ((3e6) / (W * H)) ** (1/2)
H, W = int(H * mul), int(W * mul)
frame_name_list = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
img_name = os.path.join(input_now, f"img_{frame_count:03d}.png")
if mul < 1:
frame = cv2.resize(frame, (W, H), interpolation=cv2.INTER_CUBIC)
cv2.imwrite(img_name, frame)
frame_name_list.append((img_name, f"frame {frame_count}"))
frame_count += 1
if frame_count >= MAX_FRAME: break
cap.release()
return {video_input_group:gr.Group(visible=False),
image_edit_group:gr.Group(visible=True),
image_gallery:frame_name_list,
input_dir : input_now,
fps : video_fps
}
upload_button.click(update_video_visible,
[video_input],
[video_input_group, image_edit_group, image_gallery, input_dir, fps])
if __name__ == '__main__':
demo.launch(allowed_paths=["./input", "./output"], auth = check_valid_login, auth_message = "Inshorts Animator V. 0.5 WebUI (Permitted User Only)", share = True)