ava-1 / improve.py
GowthamYarlagadda's picture
Upload 304 files
b36e9ec verified
raw
history blame contribute delete
No virus
5.12 kB
import time
from config import *
import cv2
import glob
import numpy as np
import os
from basicsr.utils import imwrite
from pathos.pools import ParallelPool
import subprocess
import platform
from mutagen.wave import WAVE
import tqdm
from p_tqdm import *
import torch
from PIL import Image
from RealESRGAN import RealESRGAN
def vid2frames(vidPath, framesOutPath):
print(vidPath)
print(framesOutPath)
vidcap = cv2.VideoCapture(vidPath)
success,image = vidcap.read()
frame = 1
while success:
cv2.imwrite(os.path.join(framesOutPath, str(frame).zfill(5) + '.png'), image)
success,image = vidcap.read()
frame += 1
def restore_frames(audiofilePath, videoOutPath, improveOutputPath):
no_of_frames = count_files(improveOutputPath)
audio_duration = get_audio_duration(audiofilePath)
framesPath = improveOutputPath + "/%5d.png"
fps = no_of_frames/audio_duration
command = f"ffmpeg -y -r {fps} -f image2 -i {framesPath} -i {audiofilePath} -vcodec mpeg4 -b:v 20000k {videoOutPath}"
print(command)
subprocess.call(command, shell=platform.system() != 'Windows')
def get_audio_duration(audioPath):
audio = WAVE(audioPath)
duration = audio.info.length
return duration
def count_files(directory):
return len([name for name in os.listdir(directory) if os.path.isfile(os.path.join(directory, name))])
def improve(disassembledPath, improvedPath):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RealESRGAN(device, scale=4)
model.load_weights('weights/RealESRGAN_x4.pth', download=True)
files = glob.glob(os.path.join(disassembledPath,"*.png"))
# pool = ParallelPool(nodes=20)
# results = pool.amap(real_esrgan, files, [model]*len(files), [improvedPath] * len(files))
results = t_map(real_esrgan, files, [model]*len(files), [improvedPath] * len(files))
def real_esrgan(img_path, model, improvedPath):
image = Image.open(img_path).convert('RGB')
sr_image = model.predict(image)
img_name = os.path.basename(img_path)
sr_image.save(os.path.join(improvedPath, img_name))
# def process(img_path, improveOutputPath):
# only_center_face=True
# aligned=True
# ext='auto'
# weight=0.5
# upscale=1
# arch = 'clean'
# channel_multiplier = 2
# model_name = 'GFPGANv1.3'
# url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
# # determine model paths
# model_path = os.path.join('gfpgan_models', model_name + '.pth')
# if not os.path.isfile(model_path):
# model_path = os.path.join('gfpgan/weights', model_name + '.pth')
# if not os.path.isfile(model_path):
# # download pre-trained models from url
# model_path = url
# restorer = GFPGANer(
# model_path=model_path,
# upscale=upscale,
# arch=arch,
# channel_multiplier=channel_multiplier,
# bg_upsampler=None)
# # read image
# img_name = os.path.basename(img_path)
# basename, ext = os.path.splitext(img_name)
# input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
# # restore faces and background if necessary
# cropped_faces, restored_faces, restored_img = restorer.enhance(
# input_img,
# has_aligned=aligned,
# only_center_face=only_center_face,
# paste_back=True,
# weight=weight)
# # save faces
# for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
# # save cropped face
# save_crop_path = os.path.join(improveOutputPath, 'cropped_faces', f'{basename}.png')
# imwrite(cropped_face, save_crop_path)
# # save restored face
# save_face_name = f'{basename}.png'
# save_restore_path = os.path.join(improveOutputPath, 'restored_faces', save_face_name)
# imwrite(restored_face, save_restore_path)
# # save comparison image
# cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
# imwrite(cmp_img, os.path.join(improveOutputPath, 'cmp', f'{basename}.png'))
# # save restored img
# if restored_img is not None:
# if ext == 'auto':
# extension = ext[1:]
# else:
# extension = ext
# save_restore_path = os.path.join(improveOutputPath, 'restored_imgs', f'{basename}.{extension}')
# imwrite(restored_img, save_restore_path)
# print(f'Processed {img_name} ...')
# def improve_faces(improveInputPath, improveOutputPath):
# if improveInputPath.endswith('/'):
# improveInputPath = improveInputPath[:-1]
# if os.path.isfile(improveInputPath):
# img_list = [improveInputPath]
# else:
# img_list = sorted(glob.glob(os.path.join(improveInputPath, '*')))
# os.makedirs(improveInputPath, exist_ok=True)
# os.makedirs(improveOutputPath, exist_ok=True)
# pool = ParallelPool(nodes=10)
# results = pool.amap(process, img_list, [improveOutputPath] * len(img_list))
# while not results.ready():
# time.sleep(5); print(".", end=' ')