File size: 4,925 Bytes
b57c333
 
 
 
afd7574
b57c333
4dff355
 
 
 
43ca864
 
 
 
 
 
 
 
 
 
 
 
 
b57c333
43ca864
 
 
 
 
4dff355
43ca864
 
 
 
 
 
 
 
 
 
 
 
 
4dff355
 
 
 
b57c333
 
 
 
 
 
 
 
 
afd7574
8214cae
afd7574
 
8214cae
 
 
 
 
 
 
b57c333
4dff355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976ac7c
4dff355
 
 
 
 
a8b3fe4
 
 
 
 
4dff355
 
b57c333
4dff355
 
 
 
 
 
 
b57c333
 
4dff355
 
 
 
b57c333
4dff355
 
 
 
 
 
 
 
 
 
b57c333
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

from FateZero.test_fatezero import *

import copy
import gradio as gr

class merge_config_then_run():
    def __init__(self) -> None:
            # Load the tokenizer
        pretrained_model_path = 'FateZero/ckpt/stable-diffusion-v1-4'
        self.tokenizer = None
        self.text_encoder = None
        self.vae = None
        self.unet = None
        
        cache_ckpt = True
        if cache_ckpt:
            self.tokenizer = AutoTokenizer.from_pretrained(
                pretrained_model_path,
                # 'FateZero/ckpt/stable-diffusion-v1-4',
                subfolder="tokenizer",
                use_fast=False,
            )

            # Load models and create wrapper for stable diffusion
            self.text_encoder = CLIPTextModel.from_pretrained(
                pretrained_model_path,
                subfolder="text_encoder",
            )

            self.vae = AutoencoderKL.from_pretrained(
                pretrained_model_path,
                subfolder="vae",
            )
            model_config = {
                "lora": 160,
                # temporal_downsample_time: 4
                "SparseCausalAttention_index": ['mid'],
                "least_sc_channel": 640
            }
            self.unet = UNetPseudo3DConditionModel.from_2d_model(
                os.path.join(pretrained_model_path, "unet"), model_config=model_config
            )

    def run(
        self,
        # def merge_config_then_run(
        model_id,
        data_path,
        source_prompt,
        target_prompt,
        cross_replace_steps,
        self_replace_steps,
        enhance_words,
        enhance_words_value,
        num_steps,
        guidance_scale,
        user_input_video=None,
        
        # Temporal and spatial crop of the video
        start_sample_frame=0,
        n_sample_frame=8,
        stride=1,
        left_crop=0,
        right_crop=0,
        top_crop=0,
        bottom_crop=0,
    ):
        # , ] = inputs
        default_edit_config='FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml'
        Omegadict_default_edit_config = OmegaConf.load(default_edit_config)
        
        dataset_time_string = get_time_string()
        config_now = copy.deepcopy(Omegadict_default_edit_config)
        print(f"config_now['pretrained_model_path'] = model_id {model_id}")
        # config_now['pretrained_model_path'] = model_id
        config_now['train_dataset']['prompt'] = source_prompt
        config_now['train_dataset']['path'] = data_path
        # ImageSequenceDataset_dict = { }
        offset_dict = {
            "left": left_crop,
            "right": right_crop,
            "top": top_crop,
            "bottom": bottom_crop,
        }
        ImageSequenceDataset_dict = {
            "start_sample_frame" : start_sample_frame,
            "n_sample_frame" : n_sample_frame,
            "sampling_rate"       : stride,   
            "offset": offset_dict,
        }
        config_now['train_dataset'].update(ImageSequenceDataset_dict)
        if user_input_video and data_path is None:
            raise gr.Error('You need to upload a video or choose a provided video')
        if user_input_video is not None:
            if isinstance(user_input_video, str):
                config_now['train_dataset']['path'] = user_input_video
            elif hasattr(user_input_video, 'name') and user_input_video.name is not None:
                config_now['train_dataset']['path'] = user_input_video.name
        config_now['validation_sample_logger_config']['prompts'] = [target_prompt]
        

        # fatezero config
        p2p_config_now = copy.deepcopy(config_now['validation_sample_logger_config']['p2p_config'][0])
        p2p_config_now['cross_replace_steps']['default_'] = cross_replace_steps
        p2p_config_now['self_replace_steps'] = self_replace_steps
        p2p_config_now['eq_params']['words'] = enhance_words.split(" ")
        p2p_config_now['eq_params']['values'] = [enhance_words_value,]*len(p2p_config_now['eq_params']['words'])
        config_now['validation_sample_logger_config']['p2p_config'][0] = copy.deepcopy(p2p_config_now)


        # ddim config
        config_now['validation_sample_logger_config']['guidance_scale'] = guidance_scale
        config_now['validation_sample_logger_config']['num_inference_steps'] = num_steps
        

        logdir = default_edit_config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')+f'_{dataset_time_string}'      
        config_now['logdir'] = logdir
        print(f'Saving at {logdir}')
        save_path = test(tokenizer = self.tokenizer,
                         text_encoder = self.text_encoder,
                         vae = self.vae,
                         unet = self.unet,
                         config=default_edit_config, **config_now)
        mp4_path = save_path.replace('_0.gif', '_0_0_0.mp4')
        return mp4_path