chenyangqi commited on
Commit
4dff355
1 Parent(s): 8214cae

cache the ckpt; fix bugs when input new video

Browse files
.gitignore CHANGED
@@ -1 +1,2 @@
1
- trash/*
 
 
1
+ trash/*
2
+ tmp
FateZero/test_fatezero.py CHANGED
@@ -48,6 +48,10 @@ def test(
48
  config: str,
49
  pretrained_model_path: str,
50
  train_dataset: Dict,
 
 
 
 
51
  logdir: str = None,
52
  validation_sample_logger_config: Optional[Dict] = None,
53
  test_pipeline_config: Optional[Dict] = None,
@@ -79,26 +83,28 @@ def test(
79
  set_seed(seed)
80
 
81
  # Load the tokenizer
82
- tokenizer = AutoTokenizer.from_pretrained(
83
- pretrained_model_path,
84
- subfolder="tokenizer",
85
- use_fast=False,
86
- )
 
87
 
88
  # Load models and create wrapper for stable diffusion
89
- text_encoder = CLIPTextModel.from_pretrained(
90
- pretrained_model_path,
91
- subfolder="text_encoder",
92
- )
93
-
94
- vae = AutoencoderKL.from_pretrained(
95
- pretrained_model_path,
96
- subfolder="vae",
97
- )
98
-
99
- unet = UNetPseudo3DConditionModel.from_2d_model(
100
- os.path.join(pretrained_model_path, "unet"), model_config=model_config
101
- )
 
102
 
103
  if 'target' not in test_pipeline_config:
104
  test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline'
 
48
  config: str,
49
  pretrained_model_path: str,
50
  train_dataset: Dict,
51
+ tokenizer = None,
52
+ text_encoder = None,
53
+ vae = None,
54
+ unet = None,
55
  logdir: str = None,
56
  validation_sample_logger_config: Optional[Dict] = None,
57
  test_pipeline_config: Optional[Dict] = None,
 
83
  set_seed(seed)
84
 
85
  # Load the tokenizer
86
+ if tokenizer is None:
87
+ tokenizer = AutoTokenizer.from_pretrained(
88
+ pretrained_model_path,
89
+ subfolder="tokenizer",
90
+ use_fast=False,
91
+ )
92
 
93
  # Load models and create wrapper for stable diffusion
94
+ if text_encoder is None:
95
+ text_encoder = CLIPTextModel.from_pretrained(
96
+ pretrained_model_path,
97
+ subfolder="text_encoder",
98
+ )
99
+ if vae is None:
100
+ vae = AutoencoderKL.from_pretrained(
101
+ pretrained_model_path,
102
+ subfolder="vae",
103
+ )
104
+ if unet is None:
105
+ unet = UNetPseudo3DConditionModel.from_2d_model(
106
+ os.path.join(pretrained_model_path, "unet"), model_config=model_config
107
+ )
108
 
109
  if 'target' not in test_pipeline_config:
110
  test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline'
FateZero/video_diffusion/common/util.py CHANGED
@@ -4,7 +4,7 @@ import copy
4
  import inspect
5
  import datetime
6
  from typing import List, Tuple, Optional, Dict
7
-
8
 
9
  def glob_files(
10
  root_path: str,
@@ -68,6 +68,12 @@ def get_time_string() -> str:
68
  def get_function_args() -> Dict:
69
  frame = sys._getframe(1)
70
  args, _, _, values = inspect.getargvalues(frame)
71
- args_dict = copy.deepcopy({arg: values[arg] for arg in args})
 
 
 
 
 
 
72
 
73
  return args_dict
 
4
  import inspect
5
  import datetime
6
  from typing import List, Tuple, Optional, Dict
7
+ import torch
8
 
9
  def glob_files(
10
  root_path: str,
 
68
  def get_function_args() -> Dict:
69
  frame = sys._getframe(1)
70
  args, _, _, values = inspect.getargvalues(frame)
71
+ tmp_dict = {}
72
+ for arg in args:
73
+ v = values[arg]
74
+ if not isinstance(v, torch.nn.Module) and arg !='tokenizer' :
75
+ tmp_dict[arg] = v
76
+
77
+ args_dict = copy.deepcopy(tmp_dict)
78
 
79
  return args_dict
FateZero/video_diffusion/data/dataset.py CHANGED
@@ -6,6 +6,7 @@ from einops import rearrange
6
  from pathlib import Path
7
  import imageio
8
  import cv2
 
9
 
10
  import torch
11
  from torch.utils.data import Dataset
@@ -156,7 +157,7 @@ class ImageSequenceDataset(Dataset):
156
  images = []
157
  if path[-4:] == '.mp4':
158
  path = self.mp4_to_png(path)
159
- self.path = path
160
 
161
  for file in sorted(os.listdir(path)):
162
  if file.endswith(IMAGE_EXTENSION):
@@ -164,14 +165,19 @@ class ImageSequenceDataset(Dataset):
164
  return images
165
 
166
  # @staticmethod
 
167
  def mp4_to_png(self, video_source=None):
168
  reader = imageio.get_reader(video_source)
169
- os.makedirs(video_source[:-4], exist_ok=True)
170
-
 
 
 
171
  for i, im in enumerate(reader):
172
  # use :05d to add zero, no space before the 05d
173
  # if (i+1)%10 == 0:
174
- path = os.path.join(video_source[:-4], f"{i:05d}.png")
175
  # print(path)
176
  cv2.imwrite(path, im[:, :, ::-1])
177
- return video_source[:-4]
 
 
6
  from pathlib import Path
7
  import imageio
8
  import cv2
9
+ import shutil
10
 
11
  import torch
12
  from torch.utils.data import Dataset
 
157
  images = []
158
  if path[-4:] == '.mp4':
159
  path = self.mp4_to_png(path)
160
+
161
 
162
  for file in sorted(os.listdir(path)):
163
  if file.endswith(IMAGE_EXTENSION):
 
165
  return images
166
 
167
  # @staticmethod
168
+
169
  def mp4_to_png(self, video_source=None):
170
  reader = imageio.get_reader(video_source)
171
+ dir_path = './tmp/fatezero_user_video'
172
+ if os.path.exists(dir_path):
173
+ shutil.rmtree(dir_path)
174
+ os.makedirs(dir_path, exist_ok=True)
175
+
176
  for i, im in enumerate(reader):
177
  # use :05d to add zero, no space before the 05d
178
  # if (i+1)%10 == 0:
179
+ path = os.path.join(dir_path, f"{i:05d}.png")
180
  # print(path)
181
  cv2.imwrite(path, im[:, :, ::-1])
182
+ self.path = dir_path
183
+ return self.path
app_fatezero.py CHANGED
@@ -28,7 +28,7 @@ from inference_fatezero import merge_config_then_run
28
  # TITLE = '# [FateZero](http://fate-zero-edit.github.io/)'
29
  HF_TOKEN = os.getenv('HF_TOKEN')
30
  # pipe = InferencePipeline(HF_TOKEN)
31
- # pipe = merge_config_then_run
32
  # app = InferenceUtil(HF_TOKEN)
33
 
34
  with gr.Blocks(css='style.css') as demo:
@@ -288,7 +288,7 @@ with gr.Blocks(css='style.css') as demo:
288
  *ImageSequenceDataset_list
289
  ],
290
  outputs=result,
291
- fn=merge_config_then_run,
292
  cache_examples=os.getenv('SYSTEM') == 'spaces')
293
 
294
  # model_id.change(fn=app.load_model_info,
@@ -312,8 +312,8 @@ with gr.Blocks(css='style.css') as demo:
312
  *ImageSequenceDataset_list
313
  ]
314
  # prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
315
- target_prompt.submit(fn=merge_config_then_run, inputs=inputs, outputs=result)
316
  # run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
317
- run_button.click(fn=merge_config_then_run, inputs=inputs, outputs=result)
318
 
319
  demo.queue().launch()
 
28
  # TITLE = '# [FateZero](http://fate-zero-edit.github.io/)'
29
  HF_TOKEN = os.getenv('HF_TOKEN')
30
  # pipe = InferencePipeline(HF_TOKEN)
31
+ pipe = merge_config_then_run()
32
  # app = InferenceUtil(HF_TOKEN)
33
 
34
  with gr.Blocks(css='style.css') as demo:
 
288
  *ImageSequenceDataset_list
289
  ],
290
  outputs=result,
291
+ fn=pipe.run,
292
  cache_examples=os.getenv('SYSTEM') == 'spaces')
293
 
294
  # model_id.change(fn=app.load_model_info,
 
312
  *ImageSequenceDataset_list
313
  ]
314
  # prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
315
+ target_prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
316
  # run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
317
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
318
 
319
  demo.queue().launch()
inference_fatezero.py CHANGED
@@ -4,8 +4,40 @@ from FateZero.test_fatezero import *
4
  import copy
5
  import gradio as gr
6
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- def merge_config_then_run(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  model_id,
10
  data_path,
11
  source_prompt,
@@ -27,58 +59,59 @@ def merge_config_then_run(
27
  top_crop=0,
28
  bottom_crop=0,
29
  ):
30
- # , ] = inputs
31
- default_edit_config='FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml'
32
- Omegadict_default_edit_config = OmegaConf.load(default_edit_config)
33
-
34
- dataset_time_string = get_time_string()
35
- config_now = copy.deepcopy(Omegadict_default_edit_config)
36
- print(f"config_now['pretrained_model_path'] = model_id {model_id}")
37
- # config_now['pretrained_model_path'] = model_id
38
- config_now['train_dataset']['prompt'] = source_prompt
39
- config_now['train_dataset']['path'] = data_path
40
- # ImageSequenceDataset_dict = { }
41
- offset_dict = {
42
- "left": left_crop,
43
- "right": right_crop,
44
- "top": top_crop,
45
- "bottom": bottom_crop,
46
- }
47
- ImageSequenceDataset_dict = {
48
- "start_sample_frame" : start_sample_frame,
49
- "n_sample_frame" : n_sample_frame,
50
- "stride" : stride,
51
- "offset": offset_dict,
52
- }
53
- config_now['train_dataset'].update(ImageSequenceDataset_dict)
54
- if user_input_video and data_path is None:
55
- raise gr.Error('You need to upload a video or choose a provided video')
56
- if user_input_video is not None and user_input_video.name is not None:
57
- config_now['train_dataset']['path'] = user_input_video.name
58
- config_now['validation_sample_logger_config']['prompts'] = [target_prompt]
59
-
60
-
61
- # fatezero config
62
- p2p_config_now = copy.deepcopy(config_now['validation_sample_logger_config']['p2p_config'][0])
63
- p2p_config_now['cross_replace_steps']['default_'] = cross_replace_steps
64
- p2p_config_now['self_replace_steps'] = self_replace_steps
65
- p2p_config_now['eq_params']['words'] = enhance_words.split(" ")
66
- p2p_config_now['eq_params']['values'] = [enhance_words_value,]*len(p2p_config_now['eq_params']['words'])
67
- config_now['validation_sample_logger_config']['p2p_config'][0] = copy.deepcopy(p2p_config_now)
68
 
 
 
 
 
 
 
 
69
 
70
- # ddim config
71
- config_now['validation_sample_logger_config']['guidance_scale'] = guidance_scale
72
- config_now['validation_sample_logger_config']['num_inference_steps'] = num_steps
73
-
74
 
75
- logdir = default_edit_config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')+f'_{dataset_time_string}'
76
- config_now['logdir'] = logdir
77
- print(f'Saving at {logdir}')
78
- save_path = test(config=default_edit_config, **config_now)
79
- mp4_path = save_path.replace('_0.gif', '_0_0_0.mp4')
80
- return mp4_path
81
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- if __name__ == "__main__":
84
- run()
 
4
  import copy
5
  import gradio as gr
6
 
7
+ class merge_config_then_run():
8
+ def __init__(self) -> None:
9
+ # Load the tokenizer
10
+ pretrained_model_path = 'FateZero/ckpt/stable-diffusion-v1-4'
11
+ self.tokenizer = AutoTokenizer.from_pretrained(
12
+ pretrained_model_path,
13
+ # 'FateZero/ckpt/stable-diffusion-v1-4',
14
+ subfolder="tokenizer",
15
+ use_fast=False,
16
+ )
17
 
18
+ # Load models and create wrapper for stable diffusion
19
+ self.text_encoder = CLIPTextModel.from_pretrained(
20
+ pretrained_model_path,
21
+ subfolder="text_encoder",
22
+ )
23
+
24
+ self.vae = AutoencoderKL.from_pretrained(
25
+ pretrained_model_path,
26
+ subfolder="vae",
27
+ )
28
+ model_config = {
29
+ "lora": 160,
30
+ # temporal_downsample_time: 4
31
+ "SparseCausalAttention_index": ['mid'],
32
+ "least_sc_channel": 640
33
+ }
34
+ self.unet = UNetPseudo3DConditionModel.from_2d_model(
35
+ os.path.join(pretrained_model_path, "unet"), model_config=model_config
36
+ )
37
+
38
+ def run(
39
+ self,
40
+ # def merge_config_then_run(
41
  model_id,
42
  data_path,
43
  source_prompt,
 
59
  top_crop=0,
60
  bottom_crop=0,
61
  ):
62
+ # , ] = inputs
63
+ default_edit_config='FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml'
64
+ Omegadict_default_edit_config = OmegaConf.load(default_edit_config)
65
+
66
+ dataset_time_string = get_time_string()
67
+ config_now = copy.deepcopy(Omegadict_default_edit_config)
68
+ print(f"config_now['pretrained_model_path'] = model_id {model_id}")
69
+ # config_now['pretrained_model_path'] = model_id
70
+ config_now['train_dataset']['prompt'] = source_prompt
71
+ config_now['train_dataset']['path'] = data_path
72
+ # ImageSequenceDataset_dict = { }
73
+ offset_dict = {
74
+ "left": left_crop,
75
+ "right": right_crop,
76
+ "top": top_crop,
77
+ "bottom": bottom_crop,
78
+ }
79
+ ImageSequenceDataset_dict = {
80
+ "start_sample_frame" : start_sample_frame,
81
+ "n_sample_frame" : n_sample_frame,
82
+ "stride" : stride,
83
+ "offset": offset_dict,
84
+ }
85
+ config_now['train_dataset'].update(ImageSequenceDataset_dict)
86
+ if user_input_video and data_path is None:
87
+ raise gr.Error('You need to upload a video or choose a provided video')
88
+ if user_input_video is not None and user_input_video.name is not None:
89
+ config_now['train_dataset']['path'] = user_input_video.name
90
+ config_now['validation_sample_logger_config']['prompts'] = [target_prompt]
91
+
 
 
 
 
 
 
 
 
92
 
93
+ # fatezero config
94
+ p2p_config_now = copy.deepcopy(config_now['validation_sample_logger_config']['p2p_config'][0])
95
+ p2p_config_now['cross_replace_steps']['default_'] = cross_replace_steps
96
+ p2p_config_now['self_replace_steps'] = self_replace_steps
97
+ p2p_config_now['eq_params']['words'] = enhance_words.split(" ")
98
+ p2p_config_now['eq_params']['values'] = [enhance_words_value,]*len(p2p_config_now['eq_params']['words'])
99
+ config_now['validation_sample_logger_config']['p2p_config'][0] = copy.deepcopy(p2p_config_now)
100
 
 
 
 
 
101
 
102
+ # ddim config
103
+ config_now['validation_sample_logger_config']['guidance_scale'] = guidance_scale
104
+ config_now['validation_sample_logger_config']['num_inference_steps'] = num_steps
105
+
 
 
106
 
107
+ logdir = default_edit_config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')+f'_{dataset_time_string}'
108
+ config_now['logdir'] = logdir
109
+ print(f'Saving at {logdir}')
110
+ save_path = test(tokenizer = self.tokenizer,
111
+ text_encoder = self.text_encoder,
112
+ vae = self.vae,
113
+ unet = self.unet,
114
+ config=default_edit_config, **config_now)
115
+ mp4_path = save_path.replace('_0.gif', '_0_0_0.mp4')
116
+ return mp4_path
117