File size: 5,684 Bytes
251e479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import json
import os
from typing import Optional, Sequence, Tuple

from src.video_util import get_frame_count


class RerenderConfig:

    def __init__(self):
        ...

    def create_from_parameters(self,
                               input_path: str,
                               output_path: str,
                               prompt: str,
                               work_dir: Optional[str] = None,
                               key_subdir: str = 'keys',
                               frame_count: Optional[int] = None,
                               interval: int = 10,
                               crop: Sequence[int] = (0, 0, 0, 0),
                               sd_model: Optional[str] = None,
                               a_prompt: str = '',
                               n_prompt: str = '',
                               ddim_steps=20,
                               scale=7.5,
                               control_type: str = 'HED',
                               control_strength=1,
                               seed: int = -1,
                               image_resolution: int = 512,
                               x0_strength: float = -1,
                               style_update_freq: int = 10,
                               cross_period: Tuple[float, float] = (0, 1),
                               warp_period: Tuple[float, float] = (0, 0.1),
                               mask_period: Tuple[float, float] = (0.5, 0.8),
                               ada_period: Tuple[float, float] = (1.0, 1.0),
                               mask_strength: float = 0.5,
                               inner_strength: float = 0.9,
                               smooth_boundary: bool = True,
                               color_preserve: bool = True,
                               **kwargs):
        self.input_path = input_path
        self.output_path = output_path
        self.prompt = prompt
        self.work_dir = work_dir
        if work_dir is None:
            self.work_dir = os.path.dirname(output_path)
        self.key_dir = os.path.join(self.work_dir, key_subdir)
        self.first_dir = os.path.join(self.work_dir, 'first')

        # Split video into frames
        if not os.path.isfile(input_path):
            raise FileNotFoundError(f'Cannot find video file {input_path}')
        self.input_dir = os.path.join(self.work_dir, 'video')

        self.frame_count = frame_count
        if frame_count is None:
            self.frame_count = get_frame_count(self.input_path)
        self.interval = interval
        self.crop = crop
        self.sd_model = sd_model
        self.a_prompt = a_prompt
        self.n_prompt = n_prompt
        self.ddim_steps = ddim_steps
        self.scale = scale
        self.control_type = control_type
        if self.control_type == 'canny':
            self.canny_low = kwargs.get('canny_low', 100)
            self.canny_high = kwargs.get('canny_high', 200)
        else:
            self.canny_low = None
            self.canny_high = None
        self.control_strength = control_strength
        self.seed = seed
        self.image_resolution = image_resolution
        self.x0_strength = x0_strength
        self.style_update_freq = style_update_freq
        self.cross_period = cross_period
        self.mask_period = mask_period
        self.warp_period = warp_period
        self.ada_period = ada_period
        self.mask_strength = mask_strength
        self.inner_strength = inner_strength
        self.smooth_boundary = smooth_boundary
        self.color_preserve = color_preserve

        os.makedirs(self.input_dir, exist_ok=True)
        os.makedirs(self.work_dir, exist_ok=True)
        os.makedirs(self.key_dir, exist_ok=True)
        os.makedirs(self.first_dir, exist_ok=True)

    def create_from_path(self, cfg_path: str):
        with open(cfg_path, 'r') as fp:
            cfg = json.load(fp)
        kwargs = dict()

        def append_if_not_none(key):
            value = cfg.get(key, None)
            if value is not None:
                kwargs[key] = value

        kwargs['input_path'] = cfg['input']
        kwargs['output_path'] = cfg['output']
        kwargs['prompt'] = cfg['prompt']
        append_if_not_none('work_dir')
        append_if_not_none('key_subdir')
        append_if_not_none('frame_count')
        append_if_not_none('interval')
        append_if_not_none('crop')
        append_if_not_none('sd_model')
        append_if_not_none('a_prompt')
        append_if_not_none('n_prompt')
        append_if_not_none('ddim_steps')
        append_if_not_none('scale')
        append_if_not_none('control_type')
        if kwargs.get('control_type', '') == 'canny':
            append_if_not_none('canny_low')
            append_if_not_none('canny_high')
        append_if_not_none('control_strength')
        append_if_not_none('seed')
        append_if_not_none('image_resolution')
        append_if_not_none('x0_strength')
        append_if_not_none('style_update_freq')
        append_if_not_none('cross_period')
        append_if_not_none('warp_period')
        append_if_not_none('mask_period')
        append_if_not_none('ada_period')
        append_if_not_none('mask_strength')
        append_if_not_none('inner_strength')
        append_if_not_none('smooth_boundary')
        append_if_not_none('color_perserve')
        self.create_from_parameters(**kwargs)

    @property
    def use_warp(self):
        return self.warp_period[0] <= self.warp_period[1]

    @property
    def use_mask(self):
        return self.mask_period[0] <= self.mask_period[1]

    @property
    def use_ada(self):
        return self.ada_period[0] <= self.ada_period[1]