Spaces:
Runtime error
Runtime error
LivePortrait
/
stf
/stf-api-alternative
/src
/stf_alternative
/.ipynb_checkpoints
/util-checkpoint.py
| import json | |
| import random | |
| import string | |
| from datetime import datetime | |
| from pathlib import Path | |
| import ffmpeg | |
| import imageio_ffmpeg | |
| import numpy as np | |
| import torch | |
| from addict import Dict | |
| def icycle(iterable): | |
| while True: | |
| for it in iterable: | |
| yield it | |
| async def acycle(aiterable): | |
| while True: | |
| async for it in aiterable: | |
| yield it | |
| def read_config(config_path): | |
| try: | |
| with open(config_path) as fd: | |
| conf = json.load(fd) | |
| conf = Dict(conf) | |
| except Exception as e: | |
| print("read config exception in ", config_path) | |
| raise e | |
| return conf | |
| def get_preprocess_dir(work_root_path, name): | |
| return str(Path(work_root_path) / "preprocess" / name) | |
| def get_crop_mp4_dir(preprocess_dir, video_path): | |
| return f"{preprocess_dir}/crop_video_{Path(video_path).stem}" | |
| def get_frame_dir(preprocess_dir, video_path, ratio): | |
| ratio_s = "" if ratio == 1.0 else f"_{ratio}" | |
| return f"{preprocess_dir}/{Path(video_path).stem}/frames{ratio_s}" | |
| def get_template_ratio_file_path(preprocess_dir, video_path, ratio): | |
| if ratio == 1.0: | |
| return video_path | |
| root_path = f"{preprocess_dir}/{Path(video_path).name}" | |
| return f"{root_path}/{Path(video_path).name}_ratio_{ratio}{Path(video_path).suffix}" | |
| class _CallBack(object): | |
| def __init__(self, callback, min_per, max_per, desc, verbose=False): | |
| assert max_per > min_per | |
| self.callback = callback | |
| self.min_per = min_per | |
| self.max_per = max_per | |
| if isinstance(callback, _CallBack): | |
| self.desc = callback.desc + "/" + desc | |
| else: | |
| self.desc = desc | |
| self.last_per = -1 | |
| self.verbose = verbose | |
| self.callback_interval = 1 | |
| def __call__(self, per): | |
| if self.callback is None: | |
| return | |
| my_per = self.min_per + (per + 1) / 100.0 * (self.max_per - self.min_per) | |
| my_per = int(my_per) | |
| if my_per - self.last_per >= self.callback_interval: | |
| # if self.verbose: | |
| # print(self.desc, ' : ', my_per) | |
| self.callback(my_per) | |
| self.last_per = my_per | |
| def callback_inter(callback, min_per=0, max_per=100, desc="", verbose=False): | |
| assert min_per >= 0 and max_per >= 0 and max_per > min_per | |
| return _CallBack(callback, min_per, max_per, desc, verbose=verbose) | |
| def callback_test(): | |
| def callback(per): | |
| print("real callback", per) | |
| callback1 = callback_inter(callback, min_per=0, max_per=50, desc="1") | |
| callback2 = callback_inter(callback, min_per=50, max_per=90, desc="2") | |
| callback3 = callback_inter(callback, min_per=90, max_per=100, desc="3") | |
| # for i in range(0,101,10): | |
| # callback1(i) | |
| callback11 = callback_inter(callback1, min_per=0, max_per=20, desc="a") | |
| callback12 = callback_inter(callback1, min_per=20, max_per=80, desc="b") | |
| callback13 = callback_inter(callback1, min_per=80, max_per=100, desc="c") | |
| for i in range(0, 101, 1): | |
| callback11(i) | |
| for i in range(0, 101, 1): | |
| callback12(i) | |
| for i in range(0, 101, 1): | |
| callback13(i) | |
| for i in range(0, 101, 1): | |
| callback2(i) | |
| for i in range(0, 101, 1): | |
| callback3(i) | |
| def fix_seed(random_seed): | |
| """ | |
| fix seed to control any randomness from a code | |
| (enable stability of the experiments' results.) | |
| """ | |
| torch.manual_seed(random_seed) | |
| torch.cuda.manual_seed(random_seed) | |
| torch.cuda.manual_seed_all(random_seed) # if use multi-GPU | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| np.random.seed(random_seed) | |
| random.seed(random_seed) | |
| def seed_worker(worker_id): | |
| worker_seed = torch.initial_seed() % 2**32 | |
| np.random.seed(worker_seed) | |
| random.seed(worker_seed) | |
| def get_three_channel_ffmpeg_reader(path): | |
| reader = imageio_ffmpeg.read_frames(path) | |
| meta = reader.__next__() # meta data, e.g. meta["size"] -> (width, height) | |
| return reader, meta | |
| def get_four_channel_ffmpeg_reader(path): | |
| if path.endswith(".mov"): | |
| reader = imageio_ffmpeg.read_frames( | |
| str(path), pix_fmt="rgba", bits_per_pixel=32 | |
| ) | |
| elif path.endswith(".webm"): | |
| stream_meta = [ | |
| it | |
| for it in ffmpeg.probe(str(path))["streams"] | |
| if it["codec_type"] == "video" | |
| ][0] | |
| reader = imageio_ffmpeg.read_frames( | |
| path=str(path), | |
| pix_fmt="rgba", | |
| input_params=["-c:v", "libvpx-vp9"] | |
| if stream_meta["codec_name"] == "vp9" | |
| else ["-c:v", "libvpx"], | |
| bits_per_pixel=32, | |
| ) | |
| meta = reader.__next__() # meta data, e.g. meta["size"] -> (width, height) | |
| return reader, meta | |
| def get_three_channel_ffmpeg_writer(out_path, size, fps, ffmpeg_params, wav_path): | |
| writer = imageio_ffmpeg.write_frames( | |
| out_path, | |
| size=size, | |
| fps=fps, | |
| ffmpeg_log_level="error", | |
| quality=10, # 0~10 | |
| output_params=ffmpeg_params, | |
| audio_path=wav_path, | |
| macro_block_size=1, | |
| ) | |
| return writer | |
| def get_webm_ffmpeg_writer(out_path, size, fps, wav_path, low_quality=False): | |
| writer = imageio_ffmpeg.write_frames( | |
| out_path, | |
| size=size, | |
| fps=fps / 2 if low_quality else fps, | |
| ffmpeg_log_level="error", | |
| quality=10, # 0~10 | |
| # hojin | |
| pix_fmt_in="rgba", | |
| pix_fmt_out="yuva420p", | |
| codec="libvpx", | |
| bitrate="10M", | |
| output_params=["-crf", "4", "-auto-alt-ref", "0"] | |
| + (["-deadline", "realtime"] if low_quality else []), | |
| # output_params=['-b','37800k', '-vf', 'hflip'], # ์ข์ฐ ๋ฐ์ ํ ์คํธ (์๋ฃ) | |
| # hojin end | |
| audio_path=wav_path, | |
| macro_block_size=1, | |
| ) | |
| return writer | |
| def get_mov_ffmpeg_writer(out_path, size, fps, wav_path): | |
| writer = imageio_ffmpeg.write_frames( | |
| out_path, | |
| size=size, | |
| fps=fps, | |
| ffmpeg_log_level="error", | |
| quality=10, # 0~10 | |
| pix_fmt_in="rgba", | |
| pix_fmt_out="yuva444p10le", | |
| # codec="prores_ks", | |
| output_params=[ | |
| "-c:v", | |
| "prores_ks", | |
| "-profile:v", | |
| "4", | |
| "-vendor", | |
| "apl0", | |
| "-bits_per_mb", | |
| "8000", | |
| ], | |
| audio_path=wav_path, | |
| macro_block_size=1, | |
| ) | |
| return writer | |
| def get_reader(template_video_path): | |
| # document : https://github.com/imageio/imageio-ffmpeg | |
| if template_video_path.endswith(".mp4"): | |
| reader, meta = get_three_channel_ffmpeg_reader(template_video_path) | |
| elif template_video_path.endswith(".mov") or template_video_path.endswith(".webm"): | |
| reader, meta = get_four_channel_ffmpeg_reader(template_video_path) | |
| else: | |
| assert False | |
| return reader, meta | |
| def get_writer(out_path, size, fps, wav_path, slow_write): | |
| if out_path.endswith(".mp4"): | |
| # ํฉ์ฑํ๋ฉด์ ๋น๋์ค ์์ฑ | |
| ffmpeg_params = None | |
| if slow_write: | |
| # ffmpeg_params=['-acodec', 'aac', '-preset', 'veryslow', '-crf', '17'] | |
| ffmpeg_params = ["-acodec", "aac", "-crf", "17"] | |
| writer = get_three_channel_ffmpeg_writer( | |
| out_path, size, fps, ffmpeg_params, wav_path | |
| ) | |
| elif out_path.endswith(".mov"): | |
| writer = get_mov_ffmpeg_writer(out_path, size, fps, wav_path) | |
| elif out_path.endswith(".webm"): | |
| writer = get_webm_ffmpeg_writer( | |
| out_path, size, fps, wav_path | |
| ) # webm fps ๋ณ๊ฒฝํ๋ค.(์๋๋ฅผ ์ํด) | |
| else: | |
| print('out_path should one of ["mp4", "webm"]') | |
| assert False | |
| return writer | |
| def pretty_string_dict(d, tab=4): | |
| s = ["{\n"] | |
| for k, v in d.items(): | |
| if isinstance(v, dict): | |
| v = pretty_string_dict(v, tab + 1) | |
| else: | |
| v = repr(v) | |
| s.append("%s%r: %s,\n" % (" " * tab, k, v)) | |
| s.append("%s}" % (" " * tab)) | |
| return "".join(s) | |
| def get_random_string_with_len(size: int): | |
| time_str = datetime.now().strftime("%y%m%d_%H%M%S_") | |
| return "".join([time_str] + random.choices(string.ascii_letters, k=size)) | |