Rerender / src /config.py
Anonymous-sub's picture
merge (#1)
251e479
import json
import os
from typing import Optional, Sequence, Tuple
from src.video_util import get_frame_count
class RerenderConfig:
def __init__(self):
...
def create_from_parameters(self,
input_path: str,
output_path: str,
prompt: str,
work_dir: Optional[str] = None,
key_subdir: str = 'keys',
frame_count: Optional[int] = None,
interval: int = 10,
crop: Sequence[int] = (0, 0, 0, 0),
sd_model: Optional[str] = None,
a_prompt: str = '',
n_prompt: str = '',
ddim_steps=20,
scale=7.5,
control_type: str = 'HED',
control_strength=1,
seed: int = -1,
image_resolution: int = 512,
x0_strength: float = -1,
style_update_freq: int = 10,
cross_period: Tuple[float, float] = (0, 1),
warp_period: Tuple[float, float] = (0, 0.1),
mask_period: Tuple[float, float] = (0.5, 0.8),
ada_period: Tuple[float, float] = (1.0, 1.0),
mask_strength: float = 0.5,
inner_strength: float = 0.9,
smooth_boundary: bool = True,
color_preserve: bool = True,
**kwargs):
self.input_path = input_path
self.output_path = output_path
self.prompt = prompt
self.work_dir = work_dir
if work_dir is None:
self.work_dir = os.path.dirname(output_path)
self.key_dir = os.path.join(self.work_dir, key_subdir)
self.first_dir = os.path.join(self.work_dir, 'first')
# Split video into frames
if not os.path.isfile(input_path):
raise FileNotFoundError(f'Cannot find video file {input_path}')
self.input_dir = os.path.join(self.work_dir, 'video')
self.frame_count = frame_count
if frame_count is None:
self.frame_count = get_frame_count(self.input_path)
self.interval = interval
self.crop = crop
self.sd_model = sd_model
self.a_prompt = a_prompt
self.n_prompt = n_prompt
self.ddim_steps = ddim_steps
self.scale = scale
self.control_type = control_type
if self.control_type == 'canny':
self.canny_low = kwargs.get('canny_low', 100)
self.canny_high = kwargs.get('canny_high', 200)
else:
self.canny_low = None
self.canny_high = None
self.control_strength = control_strength
self.seed = seed
self.image_resolution = image_resolution
self.x0_strength = x0_strength
self.style_update_freq = style_update_freq
self.cross_period = cross_period
self.mask_period = mask_period
self.warp_period = warp_period
self.ada_period = ada_period
self.mask_strength = mask_strength
self.inner_strength = inner_strength
self.smooth_boundary = smooth_boundary
self.color_preserve = color_preserve
os.makedirs(self.input_dir, exist_ok=True)
os.makedirs(self.work_dir, exist_ok=True)
os.makedirs(self.key_dir, exist_ok=True)
os.makedirs(self.first_dir, exist_ok=True)
def create_from_path(self, cfg_path: str):
with open(cfg_path, 'r') as fp:
cfg = json.load(fp)
kwargs = dict()
def append_if_not_none(key):
value = cfg.get(key, None)
if value is not None:
kwargs[key] = value
kwargs['input_path'] = cfg['input']
kwargs['output_path'] = cfg['output']
kwargs['prompt'] = cfg['prompt']
append_if_not_none('work_dir')
append_if_not_none('key_subdir')
append_if_not_none('frame_count')
append_if_not_none('interval')
append_if_not_none('crop')
append_if_not_none('sd_model')
append_if_not_none('a_prompt')
append_if_not_none('n_prompt')
append_if_not_none('ddim_steps')
append_if_not_none('scale')
append_if_not_none('control_type')
if kwargs.get('control_type', '') == 'canny':
append_if_not_none('canny_low')
append_if_not_none('canny_high')
append_if_not_none('control_strength')
append_if_not_none('seed')
append_if_not_none('image_resolution')
append_if_not_none('x0_strength')
append_if_not_none('style_update_freq')
append_if_not_none('cross_period')
append_if_not_none('warp_period')
append_if_not_none('mask_period')
append_if_not_none('ada_period')
append_if_not_none('mask_strength')
append_if_not_none('inner_strength')
append_if_not_none('smooth_boundary')
append_if_not_none('color_perserve')
self.create_from_parameters(**kwargs)
@property
def use_warp(self):
return self.warp_period[0] <= self.warp_period[1]
@property
def use_mask(self):
return self.mask_period[0] <= self.mask_period[1]
@property
def use_ada(self):
return self.ada_period[0] <= self.ada_period[1]