Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import cv2 | |
import argparse | |
import glob | |
import spaces | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from torchvision.transforms.functional import normalize | |
from basicsr.utils import imwrite, img2tensor, tensor2img | |
from basicsr.utils.download_util import load_file_from_url | |
from basicsr.utils.misc import gpu_is_available, get_device | |
from scipy.ndimage import gaussian_filter1d | |
from facelib.utils.face_restoration_helper import FaceRestoreHelper | |
from facelib.utils.misc import is_gray | |
from basicsr.utils.video_util import VideoReader, VideoWriter | |
from basicsr.utils.registry import ARCH_REGISTRY | |
import gradio as gr | |
from torch.hub import download_url_to_file | |
title = r"""<h1 align="center">KEEP: Kalman-Inspired Feature Propagation for Video Face Super-Resolution</h1>""" | |
description = r""" | |
<b>Official Gradio demo</b> for <a href='https://github.com/jnjaby/KEEP' target='_blank'><b>Kalman-Inspired FEaturE Propagation for Video Face Super-Resolution (ECCV 2024)</b></a>.<br> | |
π₯ KEEP is a robust video face super-resolution algorithm.<br> | |
π€ Try to drop your own face video, and get the restored results!<br> | |
""" | |
post_article = r""" | |
If you found KEEP helpful, please consider β the <a href='https://github.com/jnjaby/KEEP' target='_blank'>Github Repo</a>. Thanks! | |
[](https://github.com/jnjaby/KEEP) | |
--- | |
π **Citation** | |
<br> | |
If our work is useful for your research, please consider citing: | |
```bibtex | |
@InProceedings{feng2024keep, | |
title = {Kalman-Inspired FEaturE Propagation for Video Face Super-Resolution}, | |
author = {Feng, Ruicheng and Li, Chongyi and Loy, Chen Change}, | |
booktitle = {European Conference on Computer Vision (ECCV)}, | |
year = {2024} | |
} | |
``` | |
π **License** | |
<br> | |
This project is licensed under <a rel="license" href="https://github.com/jnjaby/KEEP/blob/main/LICENSE">S-Lab License 1.0</a>. | |
Redistribution and use for non-commercial purposes should follow this license. | |
<br><br> | |
π§ **Contact** | |
<br> | |
If you have any questions, please feel free to reach out via <b>ruicheng002@ntu.edu.sg</b>. | |
""" | |
def interpolate_sequence(sequence): | |
interpolated_sequence = np.copy(sequence) | |
missing_indices = np.isnan(sequence) | |
if np.any(missing_indices): | |
valid_indices = ~missing_indices | |
x = np.arange(len(sequence)) | |
interpolated_sequence[missing_indices] = np.interp(x[missing_indices], x[valid_indices], sequence[valid_indices]) | |
return interpolated_sequence | |
def set_realesrgan(): | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
from basicsr.utils.realesrgan_utils import RealESRGANer | |
use_half = False | |
if torch.cuda.is_available(): | |
no_half_gpu_list = ['1650', '1660'] | |
if not any(gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list): | |
use_half = True | |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
upsampler = RealESRGANer(scale=2, model_path="https://github.com/jnjaby/KEEP/releases/download/v1.0.0/RealESRGAN_x2plus.pth", model=model, tile=400, tile_pad=40, pre_pad=0, half=use_half) | |
if not gpu_is_available(): | |
import warnings | |
warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA. The unoptimized RealESRGAN is slow on CPU.', category=RuntimeWarning) | |
return upsampler | |
def process_video(input_video, draw_box, bg_enhancement): | |
device = get_device() | |
args = argparse.Namespace( | |
input_path=input_video, | |
upscale=1, | |
max_length=20, | |
has_aligned=False, | |
only_center_face=True, | |
draw_box=draw_box, | |
detection_model='retinaface_resnet50', | |
bg_enhancement=bg_enhancement, | |
face_upsample=False, | |
bg_tile=400, | |
suffix=None, | |
save_video_fps=None, | |
model_type='KEEP', | |
progress=gr.Progress(track_tqdm=True) | |
) | |
output_dir = './results/' | |
os.makedirs(output_dir, exist_ok=True) | |
model_configs = { | |
'KEEP': { | |
'architecture': { | |
'img_size': 512, 'emb_dim': 256, 'dim_embd': 512, 'n_head': 8, 'n_layers': 9, | |
'codebook_size': 1024, 'cft_list': ['16', '32', '64'], 'kalman_attn_head_dim': 48, | |
'num_uncertainty_layers': 3, 'cfa_list': ['16', '32'], 'cfa_nhead': 4, 'cfa_dim': 256, 'cond': 1 | |
}, | |
'checkpoint_dir': '/home/user/app/weights/KEEP', | |
'checkpoint_url': 'https://github.com/jnjaby/KEEP/releases/download/v1.0.0/KEEP-b76feb75.pth' | |
}, | |
} | |
if args.bg_enhancement: | |
bg_upsampler = set_realesrgan() | |
else: | |
bg_upsampler = None | |
if args.face_upsample: | |
face_upsampler = bg_upsampler if bg_upsampler is not None else set_realesrgan() | |
else: | |
face_upsampler = None | |
if args.model_type not in model_configs: | |
raise ValueError(f"Unknown model type: {args.model_type}. Available options: {list(model_configs.keys())}") | |
config = model_configs[args.model_type] | |
net = ARCH_REGISTRY.get('KEEP')(**config['architecture']).to(device) | |
ckpt_path = load_file_from_url(url=config['checkpoint_url'], model_dir=config['checkpoint_dir'], progress=True, file_name=None) | |
checkpoint = torch.load(ckpt_path, weights_only=True) | |
net.load_state_dict(checkpoint['params_ema']) | |
net.eval() | |
if not args.has_aligned: | |
print(f'Face detection model: {args.detection_model}') | |
if bg_upsampler is not None: | |
print(f'Background upsampling: True, Face upsampling: {args.face_upsample}') | |
else: | |
print(f'Background upsampling: False, Face upsampling: {args.face_upsample}') | |
face_helper = FaceRestoreHelper(args.upscale, face_size=512, crop_ratio=(1, 1), det_model=args.detection_model, save_ext='png', use_parse=True, device=device) | |
# Reading the input video. | |
input_img_list = [] | |
if args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): | |
vidreader = VideoReader(args.input_path) | |
image = vidreader.get_frame() | |
while image is not None: | |
input_img_list.append(image) | |
image = vidreader.get_frame() | |
fps = vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps | |
vidreader.close() | |
clip_name = os.path.basename(args.input_path)[:-4] | |
else: | |
raise TypeError(f'Unrecognized type of input video {args.input_path}.') | |
if len(input_img_list) == 0: | |
raise FileNotFoundError('No input image/video is found...') | |
print('Detecting keypoints and smooth alignment ...') | |
if not args.has_aligned: | |
raw_landmarks = [] | |
for i, img in enumerate(input_img_list): | |
face_helper.clean_all() | |
face_helper.read_image(img) | |
num_det_faces = face_helper.get_face_landmarks_5(only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5, only_keep_largest=True) | |
if num_det_faces == 1: | |
raw_landmarks.append(face_helper.all_landmarks_5[0].reshape((10,))) | |
elif num_det_faces == 0: | |
raw_landmarks.append(np.array([np.nan]*10)) | |
raw_landmarks = np.array(raw_landmarks) | |
for i in range(10): | |
raw_landmarks[:, i] = interpolate_sequence(raw_landmarks[:, i]) | |
video_length = len(input_img_list) | |
avg_landmarks = gaussian_filter1d(raw_landmarks, 5, axis=0).reshape(video_length, 5, 2) | |
cropped_faces = [] | |
for i, img in enumerate(input_img_list): | |
face_helper.clean_all() | |
face_helper.read_image(img) | |
face_helper.all_landmarks_5 = [avg_landmarks[i]] | |
face_helper.align_warp_face() | |
cropped_face_t = img2tensor(face_helper.cropped_faces[0] / 255., bgr2rgb=True, float32=True) | |
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) | |
cropped_faces.append(cropped_face_t) | |
cropped_faces = torch.stack(cropped_faces, dim=0).unsqueeze(0).to(device) | |
print('Restoring faces ...') | |
with torch.no_grad(): | |
video_length = cropped_faces.shape[1] | |
output = [] | |
for start_idx in range(0, video_length, args.max_length): | |
end_idx = min(start_idx + args.max_length, video_length) | |
if end_idx - start_idx == 1: | |
output.append(net(cropped_faces[:, [start_idx, start_idx], ...], need_upscale=False)[:, 0:1, ...]) | |
else: | |
output.append(net(cropped_faces[:, start_idx:end_idx, ...], need_upscale=False)) | |
output = torch.cat(output, dim=1).squeeze(0) | |
assert output.shape[0] == video_length, "Different number of frames" | |
restored_faces = [tensor2img(x, rgb2bgr=True, min_max=(-1, 1)) for x in output] | |
del output | |
torch.cuda.empty_cache() | |
print('Pasting faces back ...') | |
restored_frames = [] | |
for i, img in enumerate(input_img_list): | |
face_helper.clean_all() | |
if args.has_aligned: | |
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) | |
face_helper.is_gray = is_gray(img, threshold=10) | |
if face_helper.is_gray: | |
print('Grayscale input: True') | |
face_helper.cropped_faces = [img] | |
else: | |
face_helper.read_image(img) | |
face_helper.all_landmarks_5 = [avg_landmarks[i]] | |
face_helper.align_warp_face() | |
face_helper.add_restored_face(restored_faces[i].astype('uint8')) | |
if not args.has_aligned: | |
if bg_upsampler is not None: | |
bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0] | |
else: | |
bg_img = None | |
face_helper.get_inverse_affine(None) | |
if args.face_upsample and face_upsampler is not None: | |
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=face_upsampler) | |
else: | |
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box) | |
restored_frames.append(restored_img) | |
# Saving the output video. | |
print('Saving video ...') | |
height, width = restored_frames[0].shape[:2] | |
save_restore_path = os.path.join(output_dir, f'{clip_name}.mp4') | |
vidwriter = VideoWriter(save_restore_path, height, width, fps) | |
for f in restored_frames: | |
vidwriter.write_frame(f) | |
vidwriter.close() | |
print(f'All results are saved in {save_restore_path}.') | |
return save_restore_path | |
# Downloading necessary models and sample videos. | |
sample_videos_dir = os.path.join("/home/user/app/hugging_face/", "test_sample/") | |
os.makedirs(sample_videos_dir, exist_ok=True) | |
download_url_to_file("https://github.com/jnjaby/KEEP/releases/download/media/real_1.mp4", os.path.join(sample_videos_dir, "real_1.mp4")) | |
download_url_to_file("https://github.com/jnjaby/KEEP/releases/download/media/real_2.mp4", os.path.join(sample_videos_dir, "real_2.mp4")) | |
download_url_to_file("https://github.com/jnjaby/KEEP/releases/download/media/real_3.mp4", os.path.join(sample_videos_dir, "real_3.mp4")) | |
download_url_to_file("https://github.com/jnjaby/KEEP/releases/download/media/real_4.mp4", os.path.join(sample_videos_dir, "real_4.mp4")) | |
model_dir = "/home/user/app/weights/" | |
model_url = "https://github.com/jnjaby/KEEP/releases/download/v1.0.0/" | |
_ = load_file_from_url(url=os.path.join(model_url, 'KEEP-b76feb75.pth'), | |
model_dir=os.path.join(model_dir, "KEEP"), progress=True, file_name=None) | |
_ = load_file_from_url(url=os.path.join(model_url, 'detection_Resnet50_Final.pth'), | |
model_dir=os.path.join(model_dir, "facelib"), progress=True, file_name=None) | |
_ = load_file_from_url(url=os.path.join(model_url, 'detection_mobilenet0.25_Final.pth'), | |
model_dir=os.path.join(model_dir, "facelib"), progress=True, file_name=None) | |
_ = load_file_from_url(url=os.path.join(model_url, 'yolov5n-face.pth'), | |
model_dir=os.path.join(model_dir, "facelib"), progress=True, file_name=None) | |
_ = load_file_from_url(url=os.path.join(model_url, 'yolov5l-face.pth'), | |
model_dir=os.path.join(model_dir, "facelib"), progress=True, file_name=None) | |
_ = load_file_from_url(url=os.path.join(model_url, 'parsing_parsenet.pth'), | |
model_dir=os.path.join(model_dir, "facelib"), progress=True, file_name=None) | |
_ = load_file_from_url(url=os.path.join(model_url, 'RealESRGAN_x2plus.pth'), | |
model_dir=os.path.join(model_dir, "realesrgan"), progress=True, file_name=None) | |
# Launching the Gradio interface. | |
demo = gr.Interface( | |
fn=process_video, | |
title=title, | |
description=description, | |
inputs=[ | |
gr.Video(label="Input Video"), | |
gr.Checkbox(label="Draw Box", value=False), | |
gr.Checkbox(label="Background Enhancement", value=False), | |
], | |
outputs=gr.Video(label="Processed Video"), | |
examples=[ | |
[os.path.join(os.path.dirname(__file__), sample_videos_dir, "real_1.mp4"), True, False], | |
[os.path.join(os.path.dirname(__file__), sample_videos_dir, "real_2.mp4"), True, False], | |
[os.path.join(os.path.dirname(__file__), sample_videos_dir, "real_3.mp4"), True, False], | |
[os.path.join(os.path.dirname(__file__), sample_videos_dir, "real_4.mp4"), True, False], | |
], | |
cache_examples=False, | |
article=post_article | |
) | |
demo.launch(share=True) | |