Doubiiu commited on
Commit
66033d4
1 Parent(s): 24fdf34

Upload 9 files

Browse files
scripts/evaluation/__pycache__/funcs.cpython-39.pyc CHANGED
Binary files a/scripts/evaluation/__pycache__/funcs.cpython-39.pyc and b/scripts/evaluation/__pycache__/funcs.cpython-39.pyc differ
 
scripts/evaluation/__pycache__/inference.cpython-39.pyc ADDED
Binary file (12.1 kB). View file
 
scripts/evaluation/ddp_wrapper.py CHANGED
@@ -1,47 +1,47 @@
1
- import datetime
2
- import argparse, importlib
3
- from pytorch_lightning import seed_everything
4
-
5
- import torch
6
- import torch.distributed as dist
7
-
8
- def setup_dist(local_rank):
9
- if dist.is_initialized():
10
- return
11
- torch.cuda.set_device(local_rank)
12
- torch.distributed.init_process_group('nccl', init_method='env://')
13
-
14
-
15
- def get_dist_info():
16
- if dist.is_available():
17
- initialized = dist.is_initialized()
18
- else:
19
- initialized = False
20
- if initialized:
21
- rank = dist.get_rank()
22
- world_size = dist.get_world_size()
23
- else:
24
- rank = 0
25
- world_size = 1
26
- return rank, world_size
27
-
28
-
29
- if __name__ == '__main__':
30
- now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
31
- parser = argparse.ArgumentParser()
32
- parser.add_argument("--module", type=str, help="module name", default="inference")
33
- parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0)
34
- args, unknown = parser.parse_known_args()
35
- inference_api = importlib.import_module(args.module, package=None)
36
-
37
- inference_parser = inference_api.get_parser()
38
- inference_args, unknown = inference_parser.parse_known_args()
39
-
40
- seed_everything(inference_args.seed)
41
- setup_dist(args.local_rank)
42
- torch.backends.cudnn.benchmark = True
43
- rank, gpu_num = get_dist_info()
44
-
45
- inference_args.savedir = inference_args.savedir+str('_seed')+str(inference_args.seed)
46
- print("@CoLVDM Inference [rank%d]: %s"%(rank, now))
47
  inference_api.run_inference(inference_args, gpu_num, rank)
 
1
+ import datetime
2
+ import argparse, importlib
3
+ from pytorch_lightning import seed_everything
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ def setup_dist(local_rank):
9
+ if dist.is_initialized():
10
+ return
11
+ torch.cuda.set_device(local_rank)
12
+ torch.distributed.init_process_group('nccl', init_method='env://')
13
+
14
+
15
+ def get_dist_info():
16
+ if dist.is_available():
17
+ initialized = dist.is_initialized()
18
+ else:
19
+ initialized = False
20
+ if initialized:
21
+ rank = dist.get_rank()
22
+ world_size = dist.get_world_size()
23
+ else:
24
+ rank = 0
25
+ world_size = 1
26
+ return rank, world_size
27
+
28
+
29
+ if __name__ == '__main__':
30
+ now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--module", type=str, help="module name", default="inference")
33
+ parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0)
34
+ args, unknown = parser.parse_known_args()
35
+ inference_api = importlib.import_module(args.module, package=None)
36
+
37
+ inference_parser = inference_api.get_parser()
38
+ inference_args, unknown = inference_parser.parse_known_args()
39
+
40
+ seed_everything(inference_args.seed)
41
+ setup_dist(args.local_rank)
42
+ torch.backends.cudnn.benchmark = True
43
+ rank, gpu_num = get_dist_info()
44
+
45
+ # inference_args.savedir = inference_args.savedir+str('_seed')+str(inference_args.seed)
46
+ print("@DynamiCrafter Inference [rank%d]: %s"%(rank, now))
47
  inference_api.run_inference(inference_args, gpu_num, rank)
scripts/evaluation/funcs.py CHANGED
@@ -1,205 +1,226 @@
1
- import os, sys, glob
2
- import numpy as np
3
- from collections import OrderedDict
4
- from decord import VideoReader, cpu
5
- import cv2
6
-
7
- import torch
8
- import torchvision
9
- sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
10
- from lvdm.models.samplers.ddim import DDIMSampler
11
- from einops import rearrange
12
-
13
-
14
- def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\
15
- cfg_scale=1.0, temporal_cfg_scale=None, **kwargs):
16
- ddim_sampler = DDIMSampler(model)
17
- uncond_type = model.uncond_type
18
- batch_size = noise_shape[0]
19
- fs = cond["fs"]
20
- del cond["fs"]
21
- ## construct unconditional guidance
22
- if cfg_scale != 1.0:
23
- if uncond_type == "empty_seq":
24
- prompts = batch_size * [""]
25
- #prompts = N * T * [""] ## if is_imgbatch=True
26
- uc_emb = model.get_learned_conditioning(prompts)
27
- elif uncond_type == "zero_embed":
28
- c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
29
- uc_emb = torch.zeros_like(c_emb)
30
-
31
- ## process image embedding token
32
- if hasattr(model, 'embedder'):
33
- uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device)
34
- ## img: b c h w >> b l c
35
- uc_img = model.embedder(uc_img)
36
- uc_img = model.image_proj_model(uc_img)
37
- uc_emb = torch.cat([uc_emb, uc_img], dim=1)
38
-
39
- if isinstance(cond, dict):
40
- uc = {key:cond[key] for key in cond.keys()}
41
- uc.update({'c_crossattn': [uc_emb]})
42
- else:
43
- uc = uc_emb
44
- else:
45
- uc = None
46
-
47
- x_T = None
48
- batch_variants = []
49
-
50
- for _ in range(n_samples):
51
- if ddim_sampler is not None:
52
- kwargs.update({"clean_cond": True})
53
- samples, _ = ddim_sampler.sample(S=ddim_steps,
54
- conditioning=cond,
55
- batch_size=noise_shape[0],
56
- shape=noise_shape[1:],
57
- verbose=False,
58
- unconditional_guidance_scale=cfg_scale,
59
- unconditional_conditioning=uc,
60
- eta=ddim_eta,
61
- temporal_length=noise_shape[2],
62
- conditional_guidance_scale_temporal=temporal_cfg_scale,
63
- x_T=x_T,
64
- fs=fs,
65
- **kwargs
66
- )
67
- ## reconstruct from latent to pixel space
68
- batch_images = model.decode_first_stage(samples)
69
- batch_variants.append(batch_images)
70
- ## batch, <samples>, c, t, h, w
71
- batch_variants = torch.stack(batch_variants, dim=1)
72
- return batch_variants
73
-
74
-
75
- def get_filelist(data_dir, ext='*'):
76
- file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext))
77
- file_list.sort()
78
- return file_list
79
-
80
- def get_dirlist(path):
81
- list = []
82
- if (os.path.exists(path)):
83
- files = os.listdir(path)
84
- for file in files:
85
- m = os.path.join(path,file)
86
- if (os.path.isdir(m)):
87
- list.append(m)
88
- list.sort()
89
- return list
90
-
91
-
92
- def load_model_checkpoint(model, ckpt):
93
- def load_checkpoint(model, ckpt, full_strict):
94
- state_dict = torch.load(ckpt, map_location="cpu")
95
- try:
96
- ## deepspeed
97
- new_pl_sd = OrderedDict()
98
- for key in state_dict['module'].keys():
99
- new_pl_sd[key[16:]]=state_dict['module'][key]
100
- model.load_state_dict(new_pl_sd, strict=full_strict)
101
- except:
102
- if "state_dict" in list(state_dict.keys()):
103
- state_dict = state_dict["state_dict"]
104
- model.load_state_dict(state_dict, strict=full_strict)
105
- return model
106
- load_checkpoint(model, ckpt, full_strict=True)
107
- print('>>> model checkpoint loaded.')
108
- return model
109
-
110
-
111
- def load_prompts(prompt_file):
112
- f = open(prompt_file, 'r')
113
- prompt_list = []
114
- for idx, line in enumerate(f.readlines()):
115
- l = line.strip()
116
- if len(l) != 0:
117
- prompt_list.append(l)
118
- f.close()
119
- return prompt_list
120
-
121
-
122
- def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16):
123
- '''
124
- Notice about some special cases:
125
- 1. video_frames=-1 means to take all the frames (with fs=1)
126
- 2. when the total video frames is less than required, padding strategy will be used (repreated last frame)
127
- '''
128
- fps_list = []
129
- batch_tensor = []
130
- assert frame_stride > 0, "valid frame stride should be a positive interge!"
131
- for filepath in filepath_list:
132
- padding_num = 0
133
- vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
134
- fps = vidreader.get_avg_fps()
135
- total_frames = len(vidreader)
136
- max_valid_frames = (total_frames-1) // frame_stride + 1
137
- if video_frames < 0:
138
- ## all frames are collected: fs=1 is a must
139
- required_frames = total_frames
140
- frame_stride = 1
141
- else:
142
- required_frames = video_frames
143
- query_frames = min(required_frames, max_valid_frames)
144
- frame_indices = [frame_stride*i for i in range(query_frames)]
145
-
146
- ## [t,h,w,c] -> [c,t,h,w]
147
- frames = vidreader.get_batch(frame_indices)
148
- frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
149
- frame_tensor = (frame_tensor / 255. - 0.5) * 2
150
- if max_valid_frames < required_frames:
151
- padding_num = required_frames - max_valid_frames
152
- frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1)
153
- print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.')
154
- batch_tensor.append(frame_tensor)
155
- sample_fps = int(fps/frame_stride)
156
- fps_list.append(sample_fps)
157
-
158
- return torch.stack(batch_tensor, dim=0)
159
-
160
- from PIL import Image
161
- def load_image_batch(filepath_list, image_size=(256,256)):
162
- batch_tensor = []
163
- for filepath in filepath_list:
164
- _, filename = os.path.split(filepath)
165
- _, ext = os.path.splitext(filename)
166
- if ext == '.mp4':
167
- vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0])
168
- frame = vidreader.get_batch([0])
169
- img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float()
170
- elif ext == '.png' or ext == '.jpg':
171
- img = Image.open(filepath).convert("RGB")
172
- rgb_img = np.array(img, np.float32)
173
- #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR)
174
- #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
175
- rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR)
176
- img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float()
177
- else:
178
- print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]')
179
- raise NotImplementedError
180
- img_tensor = (img_tensor / 255. - 0.5) * 2
181
- batch_tensor.append(img_tensor)
182
- return torch.stack(batch_tensor, dim=0)
183
-
184
-
185
- def save_videos(batch_tensors, savedir, filenames, fps=10):
186
- # b,samples,c,t,h,w
187
- n_samples = batch_tensors.shape[1]
188
- for idx, vid_tensor in enumerate(batch_tensors):
189
- video = vid_tensor.detach().cpu()
190
- video = torch.clamp(video.float(), -1., 1.)
191
- video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
192
- frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
193
- grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
194
- grid = (grid + 1.0) / 2.0
195
- grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
196
- savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
197
- torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
198
-
199
-
200
- def get_latent_z(model, videos):
201
- b, c, t, h, w = videos.shape
202
- x = rearrange(videos, 'b c t h w -> (b t) c h w')
203
- z = model.encode_first_stage(x)
204
- z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  return z
 
1
+ import os, sys, glob
2
+ import numpy as np
3
+ from collections import OrderedDict
4
+ from decord import VideoReader, cpu
5
+ import cv2
6
+
7
+ import torch
8
+ import torchvision
9
+ sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
10
+ from lvdm.models.samplers.ddim import DDIMSampler
11
+ from einops import rearrange
12
+
13
+
14
+ def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\
15
+ cfg_scale=1.0, temporal_cfg_scale=None, **kwargs):
16
+ ddim_sampler = DDIMSampler(model)
17
+ uncond_type = model.uncond_type
18
+ batch_size = noise_shape[0]
19
+ fs = cond["fs"]
20
+ del cond["fs"]
21
+ if noise_shape[-1] == 32:
22
+ timestep_spacing = "uniform"
23
+ guidance_rescale = 0.0
24
+ else:
25
+ timestep_spacing = "uniform_trailing"
26
+ guidance_rescale = 0.7
27
+ ## construct unconditional guidance
28
+ if cfg_scale != 1.0:
29
+ if uncond_type == "empty_seq":
30
+ prompts = batch_size * [""]
31
+ #prompts = N * T * [""] ## if is_imgbatch=True
32
+ uc_emb = model.get_learned_conditioning(prompts)
33
+ elif uncond_type == "zero_embed":
34
+ c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond
35
+ uc_emb = torch.zeros_like(c_emb)
36
+
37
+ ## process image embedding token
38
+ if hasattr(model, 'embedder'):
39
+ uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device)
40
+ ## img: b c h w >> b l c
41
+ uc_img = model.embedder(uc_img)
42
+ uc_img = model.image_proj_model(uc_img)
43
+ uc_emb = torch.cat([uc_emb, uc_img], dim=1)
44
+
45
+ if isinstance(cond, dict):
46
+ uc = {key:cond[key] for key in cond.keys()}
47
+ uc.update({'c_crossattn': [uc_emb]})
48
+ else:
49
+ uc = uc_emb
50
+ else:
51
+ uc = None
52
+
53
+ x_T = None
54
+ batch_variants = []
55
+
56
+ for _ in range(n_samples):
57
+ if ddim_sampler is not None:
58
+ kwargs.update({"clean_cond": True})
59
+ samples, _ = ddim_sampler.sample(S=ddim_steps,
60
+ conditioning=cond,
61
+ batch_size=noise_shape[0],
62
+ shape=noise_shape[1:],
63
+ verbose=False,
64
+ unconditional_guidance_scale=cfg_scale,
65
+ unconditional_conditioning=uc,
66
+ eta=ddim_eta,
67
+ temporal_length=noise_shape[2],
68
+ conditional_guidance_scale_temporal=temporal_cfg_scale,
69
+ x_T=x_T,
70
+ fs=fs,
71
+ timestep_spacing=timestep_spacing,
72
+ guidance_rescale=guidance_rescale,
73
+ **kwargs
74
+ )
75
+ ## reconstruct from latent to pixel space
76
+ batch_images = model.decode_first_stage(samples)
77
+ batch_variants.append(batch_images)
78
+ ## batch, <samples>, c, t, h, w
79
+ batch_variants = torch.stack(batch_variants, dim=1)
80
+ return batch_variants
81
+
82
+
83
+ def get_filelist(data_dir, ext='*'):
84
+ file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext))
85
+ file_list.sort()
86
+ return file_list
87
+
88
+ def get_dirlist(path):
89
+ list = []
90
+ if (os.path.exists(path)):
91
+ files = os.listdir(path)
92
+ for file in files:
93
+ m = os.path.join(path,file)
94
+ if (os.path.isdir(m)):
95
+ list.append(m)
96
+ list.sort()
97
+ return list
98
+
99
+
100
+ def load_model_checkpoint(model, ckpt):
101
+ def load_checkpoint(model, ckpt, full_strict):
102
+ state_dict = torch.load(ckpt, map_location="cpu")
103
+ if "state_dict" in list(state_dict.keys()):
104
+ state_dict = state_dict["state_dict"]
105
+ try:
106
+ model.load_state_dict(state_dict, strict=full_strict)
107
+ except:
108
+ ## rename the keys for 256x256 model
109
+ new_pl_sd = OrderedDict()
110
+ for k,v in state_dict.items():
111
+ new_pl_sd[k] = v
112
+
113
+ for k in list(new_pl_sd.keys()):
114
+ if "framestride_embed" in k:
115
+ new_key = k.replace("framestride_embed", "fps_embedding")
116
+ new_pl_sd[new_key] = new_pl_sd[k]
117
+ del new_pl_sd[k]
118
+ model.load_state_dict(new_pl_sd, strict=full_strict)
119
+ else:
120
+ ## deepspeed
121
+ new_pl_sd = OrderedDict()
122
+ for key in state_dict['module'].keys():
123
+ new_pl_sd[key[16:]]=state_dict['module'][key]
124
+ model.load_state_dict(new_pl_sd, strict=full_strict)
125
+
126
+ return model
127
+ load_checkpoint(model, ckpt, full_strict=True)
128
+ print('>>> model checkpoint loaded.')
129
+ return model
130
+
131
+
132
+ def load_prompts(prompt_file):
133
+ f = open(prompt_file, 'r')
134
+ prompt_list = []
135
+ for idx, line in enumerate(f.readlines()):
136
+ l = line.strip()
137
+ if len(l) != 0:
138
+ prompt_list.append(l)
139
+ f.close()
140
+ return prompt_list
141
+
142
+
143
+ def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16):
144
+ '''
145
+ Notice about some special cases:
146
+ 1. video_frames=-1 means to take all the frames (with fs=1)
147
+ 2. when the total video frames is less than required, padding strategy will be used (repreated last frame)
148
+ '''
149
+ fps_list = []
150
+ batch_tensor = []
151
+ assert frame_stride > 0, "valid frame stride should be a positive interge!"
152
+ for filepath in filepath_list:
153
+ padding_num = 0
154
+ vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0])
155
+ fps = vidreader.get_avg_fps()
156
+ total_frames = len(vidreader)
157
+ max_valid_frames = (total_frames-1) // frame_stride + 1
158
+ if video_frames < 0:
159
+ ## all frames are collected: fs=1 is a must
160
+ required_frames = total_frames
161
+ frame_stride = 1
162
+ else:
163
+ required_frames = video_frames
164
+ query_frames = min(required_frames, max_valid_frames)
165
+ frame_indices = [frame_stride*i for i in range(query_frames)]
166
+
167
+ ## [t,h,w,c] -> [c,t,h,w]
168
+ frames = vidreader.get_batch(frame_indices)
169
+ frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
170
+ frame_tensor = (frame_tensor / 255. - 0.5) * 2
171
+ if max_valid_frames < required_frames:
172
+ padding_num = required_frames - max_valid_frames
173
+ frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1)
174
+ print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.')
175
+ batch_tensor.append(frame_tensor)
176
+ sample_fps = int(fps/frame_stride)
177
+ fps_list.append(sample_fps)
178
+
179
+ return torch.stack(batch_tensor, dim=0)
180
+
181
+ from PIL import Image
182
+ def load_image_batch(filepath_list, image_size=(256,256)):
183
+ batch_tensor = []
184
+ for filepath in filepath_list:
185
+ _, filename = os.path.split(filepath)
186
+ _, ext = os.path.splitext(filename)
187
+ if ext == '.mp4':
188
+ vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0])
189
+ frame = vidreader.get_batch([0])
190
+ img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float()
191
+ elif ext == '.png' or ext == '.jpg':
192
+ img = Image.open(filepath).convert("RGB")
193
+ rgb_img = np.array(img, np.float32)
194
+ #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR)
195
+ #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
196
+ rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR)
197
+ img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float()
198
+ else:
199
+ print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]')
200
+ raise NotImplementedError
201
+ img_tensor = (img_tensor / 255. - 0.5) * 2
202
+ batch_tensor.append(img_tensor)
203
+ return torch.stack(batch_tensor, dim=0)
204
+
205
+
206
+ def save_videos(batch_tensors, savedir, filenames, fps=10):
207
+ # b,samples,c,t,h,w
208
+ n_samples = batch_tensors.shape[1]
209
+ for idx, vid_tensor in enumerate(batch_tensors):
210
+ video = vid_tensor.detach().cpu()
211
+ video = torch.clamp(video.float(), -1., 1.)
212
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
213
+ frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w]
214
+ grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
215
+ grid = (grid + 1.0) / 2.0
216
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
217
+ savepath = os.path.join(savedir, f"{filenames[idx]}.mp4")
218
+ torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
219
+
220
+
221
+ def get_latent_z(model, videos):
222
+ b, c, t, h, w = videos.shape
223
+ x = rearrange(videos, 'b c t h w -> (b t) c h w')
224
+ z = model.encode_first_stage(x)
225
+ z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
226
  return z
scripts/evaluation/inference.py CHANGED
@@ -1,329 +1,347 @@
1
- import argparse, os, sys, glob
2
- import datetime, time
3
- from omegaconf import OmegaConf
4
- from tqdm import tqdm
5
- from einops import rearrange, repeat
6
- from collections import OrderedDict
7
-
8
- import torch
9
- import torchvision
10
- import torchvision.transforms as transforms
11
- from pytorch_lightning import seed_everything
12
- from PIL import Image
13
- sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
14
- from lvdm.models.samplers.ddim import DDIMSampler
15
- from lvdm.models.samplers.ddim_multiplecond import DDIMSampler as DDIMSampler_multicond
16
- from utils.utils import instantiate_from_config
17
-
18
-
19
- def get_filelist(data_dir, postfixes):
20
- patterns = [os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes]
21
- file_list = []
22
- for pattern in patterns:
23
- file_list.extend(glob.glob(pattern))
24
- file_list.sort()
25
- return file_list
26
-
27
- def load_model_checkpoint(model, ckpt):
28
- state_dict = torch.load(ckpt, map_location="cpu")
29
- if "state_dict" in list(state_dict.keys()):
30
- state_dict = state_dict["state_dict"]
31
- model.load_state_dict(state_dict, strict=True)
32
- else:
33
- # deepspeed
34
- new_pl_sd = OrderedDict()
35
- for key in state_dict['module'].keys():
36
- new_pl_sd[key[16:]]=state_dict['module'][key]
37
- model.load_state_dict(new_pl_sd)
38
- print('>>> model checkpoint loaded.')
39
- return model
40
-
41
- def load_prompts(prompt_file):
42
- f = open(prompt_file, 'r')
43
- prompt_list = []
44
- for idx, line in enumerate(f.readlines()):
45
- l = line.strip()
46
- if len(l) != 0:
47
- prompt_list.append(l)
48
- f.close()
49
- return prompt_list
50
-
51
- def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, gfi=False):
52
- transform = transforms.Compose([
53
- transforms.Resize(min(video_size)),
54
- transforms.CenterCrop(video_size),
55
- transforms.ToTensor(),
56
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
57
- ## load prompts
58
- prompt_file = get_filelist(data_dir, ['txt'])
59
- assert len(prompt_file) > 0, "Error: found NO prompt file!"
60
- ###### default prompt
61
- default_idx = 0
62
- default_idx = min(default_idx, len(prompt_file)-1)
63
- if len(prompt_file) > 1:
64
- print(f"Warning: multiple prompt files exist. The one {os.path.split(prompt_file[default_idx])[1]} is used.")
65
- ## only use the first one (sorted by name) if multiple exist
66
-
67
- ## load video
68
- file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG'])
69
- # assert len(file_list) == n_samples, "Error: data and prompts are NOT paired!"
70
- data_list = []
71
- filename_list = []
72
- prompt_list = load_prompts(prompt_file[default_idx])
73
- n_samples = len(prompt_list)
74
- for idx in range(n_samples):
75
- image = Image.open(file_list[idx]).convert('RGB')
76
- image_tensor = transform(image).unsqueeze(1) # [c,1,h,w]
77
- frame_tensor = repeat(image_tensor, 'c t h w -> c (repeat t) h w', repeat=video_frames)
78
-
79
- data_list.append(frame_tensor)
80
- _, filename = os.path.split(file_list[idx])
81
- filename_list.append(filename)
82
-
83
- return filename_list, data_list, prompt_list
84
-
85
-
86
- def save_results(prompt, samples, filename, fakedir, fps=8, loop=False):
87
- filename = filename.split('.')[0]+'.mp4'
88
- prompt = prompt[0] if isinstance(prompt, list) else prompt
89
-
90
- ## save video
91
- videos = [samples]
92
- savedirs = [fakedir]
93
- for idx, video in enumerate(videos):
94
- if video is None:
95
- continue
96
- # b,c,t,h,w
97
- video = video.detach().cpu()
98
- video = torch.clamp(video.float(), -1., 1.)
99
- n = video.shape[0]
100
- video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
101
- if loop:
102
- video = video[:-1,...]
103
-
104
- frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] #[3, 1*h, n*w]
105
- grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, h, n*w]
106
- grid = (grid + 1.0) / 2.0
107
- grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
108
- path = os.path.join(savedirs[idx], filename)
109
- torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) ## crf indicates the quality
110
-
111
-
112
- def save_results_seperate(prompt, samples, filename, fakedir, fps=10, loop=False):
113
- prompt = prompt[0] if isinstance(prompt, list) else prompt
114
-
115
- ## save video
116
- videos = [samples]
117
- savedirs = [fakedir]
118
- for idx, video in enumerate(videos):
119
- if video is None:
120
- continue
121
- # b,c,t,h,w
122
- video = video.detach().cpu()
123
- if loop: # remove the last frame
124
- video = video[:,:,:-1,...]
125
- video = torch.clamp(video.float(), -1., 1.)
126
- n = video.shape[0]
127
- for i in range(n):
128
- grid = video[i,...]
129
- grid = (grid + 1.0) / 2.0
130
- grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0) #thwc
131
- path = os.path.join(savedirs[idx].replace('samples', 'samples_separate'), f'{filename.split(".")[0]}_sample{i}.mp4')
132
- torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'})
133
-
134
- def get_latent_z(model, videos):
135
- b, c, t, h, w = videos.shape
136
- x = rearrange(videos, 'b c t h w -> (b t) c h w')
137
- z = model.encode_first_stage(x)
138
- z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
139
- return z
140
-
141
-
142
- def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \
143
- unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, multiple_cond_cfg=False, loop=False, gfi=False, **kwargs):
144
- ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model)
145
- batch_size = noise_shape[0]
146
- fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
147
-
148
- if not text_input:
149
- prompts = [""]*batch_size
150
-
151
- img = videos[:,:,0] #bchw
152
- img_emb = model.embedder(img) ## blc
153
- img_emb = model.image_proj_model(img_emb)
154
-
155
- cond_emb = model.get_learned_conditioning(prompts)
156
- cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
157
- if model.model.conditioning_key == 'hybrid':
158
- z = get_latent_z(model, videos) # b c t h w
159
- if loop or gfi:
160
- img_cat_cond = torch.zeros_like(z)
161
- img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
162
- img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
163
- else:
164
- img_cat_cond = z[:,:,:1,:,:]
165
- img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
166
- cond["c_concat"] = [img_cat_cond] # b c 1 h w
167
-
168
- if unconditional_guidance_scale != 1.0:
169
- if model.uncond_type == "empty_seq":
170
- prompts = batch_size * [""]
171
- uc_emb = model.get_learned_conditioning(prompts)
172
- elif model.uncond_type == "zero_embed":
173
- uc_emb = torch.zeros_like(cond_emb)
174
- uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c
175
- uc_img_emb = model.image_proj_model(uc_img_emb)
176
- uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]}
177
- if model.model.conditioning_key == 'hybrid':
178
- uc["c_concat"] = [img_cat_cond]
179
- else:
180
- uc = None
181
-
182
- ## we need one more unconditioning image=yes, text=""
183
- if multiple_cond_cfg and cfg_img != 1.0:
184
- uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]}
185
- if model.model.conditioning_key == 'hybrid':
186
- uc_2["c_concat"] = [img_cat_cond]
187
- kwargs.update({"unconditional_conditioning_img_nonetext": uc_2})
188
- else:
189
- kwargs.update({"unconditional_conditioning_img_nonetext": None})
190
-
191
- z0 = None
192
- cond_mask = None
193
-
194
- batch_variants = []
195
- for _ in range(n_samples):
196
-
197
- if z0 is not None:
198
- cond_z0 = z0.clone()
199
- kwargs.update({"clean_cond": True})
200
- else:
201
- cond_z0 = None
202
- if ddim_sampler is not None:
203
-
204
- samples, _ = ddim_sampler.sample(S=ddim_steps,
205
- conditioning=cond,
206
- batch_size=batch_size,
207
- shape=noise_shape[1:],
208
- verbose=False,
209
- unconditional_guidance_scale=unconditional_guidance_scale,
210
- unconditional_conditioning=uc,
211
- eta=ddim_eta,
212
- cfg_img=cfg_img,
213
- mask=cond_mask,
214
- x0=cond_z0,
215
- fs=fs,
216
- **kwargs
217
- )
218
-
219
- ## reconstruct from latent to pixel space
220
- batch_images = model.decode_first_stage(samples)
221
- batch_variants.append(batch_images)
222
- ## variants, batch, c, t, h, w
223
- batch_variants = torch.stack(batch_variants)
224
- return batch_variants.permute(1, 0, 2, 3, 4, 5)
225
-
226
-
227
- def run_inference(args, gpu_num, gpu_no):
228
- ## model config
229
- config = OmegaConf.load(args.config)
230
- model_config = config.pop("model", OmegaConf.create())
231
-
232
- ## set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
233
- model_config['params']['unet_config']['params']['use_checkpoint'] = False
234
- model = instantiate_from_config(model_config)
235
- model = model.cuda(gpu_no)
236
-
237
- assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
238
- model = load_model_checkpoint(model, args.ckpt_path)
239
- model.eval()
240
-
241
- ## run over data
242
- assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!"
243
- assert args.bs == 1, "Current implementation only support [batch size = 1]!"
244
- ## latent noise shape
245
- h, w = args.height // 8, args.width // 8
246
- channels = model.model.diffusion_model.out_channels
247
- n_frames = args.video_length
248
- print(f'Inference with {n_frames} frames')
249
- noise_shape = [args.bs, channels, n_frames, h, w]
250
-
251
- fakedir = os.path.join(args.savedir, "samples")
252
- fakedir_separate = os.path.join(args.savedir, "samples_separate")
253
-
254
- # os.makedirs(fakedir, exist_ok=True)
255
- os.makedirs(fakedir_separate, exist_ok=True)
256
-
257
- ## prompt file setting
258
- assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
259
- filename_list, data_list, prompt_list = load_data_prompts(args.prompt_dir, video_size=(args.height, args.width), video_frames=n_frames, gfi=args.gfi)
260
- num_samples = len(prompt_list)
261
- samples_split = num_samples // gpu_num
262
- print('Prompts testing [rank:%d] %d/%d samples loaded.'%(gpu_no, samples_split, num_samples))
263
- #indices = random.choices(list(range(0, num_samples)), k=samples_per_device)
264
- indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1)))
265
- prompt_list_rank = [prompt_list[i] for i in indices]
266
- data_list_rank = [data_list[i] for i in indices]
267
- filename_list_rank = [filename_list[i] for i in indices]
268
-
269
- start = time.time()
270
- with torch.no_grad(), torch.cuda.amp.autocast():
271
- for idx, indice in tqdm(enumerate(range(0, len(prompt_list_rank), args.bs)), desc='Sample Batch'):
272
- prompts = prompt_list_rank[indice:indice+args.bs]
273
- videos = data_list_rank[indice:indice+args.bs]
274
- filenames = filename_list_rank[indice:indice+args.bs]
275
- if isinstance(videos, list):
276
- videos = torch.stack(videos, dim=0).to("cuda")
277
- else:
278
- videos = videos.unsqueeze(0).to("cuda")
279
-
280
- batch_samples = image_guided_synthesis(model, prompts, videos, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, \
281
- args.unconditional_guidance_scale, args.cfg_img, args.frame_stride, args.text_input, args.multiple_cond_cfg, args.loop, args.gfi)
282
-
283
- ## save each example individually
284
- for nn, samples in enumerate(batch_samples):
285
- ## samples : [n_samples,c,t,h,w]
286
- prompt = prompts[nn]
287
- filename = filenames[nn]
288
- # save_results(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
289
- save_results_seperate(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
290
-
291
- print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds")
292
-
293
-
294
- def get_parser():
295
- parser = argparse.ArgumentParser()
296
- parser.add_argument("--savedir", type=str, default=None, help="results saving path")
297
- parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path")
298
- parser.add_argument("--config", type=str, help="config (yaml) path")
299
- parser.add_argument("--prompt_dir", type=str, default=None, help="a data dir containing videos and prompts")
300
- parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",)
301
- parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",)
302
- parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
303
- parser.add_argument("--bs", type=int, default=1, help="batch size for inference, should be one")
304
- parser.add_argument("--height", type=int, default=512, help="image height, in pixel space")
305
- parser.add_argument("--width", type=int, default=512, help="image width, in pixel space")
306
- parser.add_argument("--frame_stride", type=int, default=3, choices=[1, 2, 3, 4, 5, 6], help="frame stride control for results, smaller value->smaller motion magnitude and more stable, and vice versa")
307
- parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance")
308
- parser.add_argument("--seed", type=int, default=123, help="seed for seed_everything")
309
- parser.add_argument("--video_length", type=int, default=16, help="inference video length")
310
- parser.add_argument("--negative_prompt", action='store_true', default=False, help="negative prompt")
311
- parser.add_argument("--text_input", action='store_true', default=False, help="input text to I2V model or not")
312
- parser.add_argument("--multiple_cond_cfg", action='store_true', default=False, help="use multi-condition cfg or not")
313
- parser.add_argument("--cfg_img", type=float, default=None, help="guidance scale for image conditioning")
314
-
315
- ## currently not support looping video and generative frame interpolation
316
- parser.add_argument("--loop", action='store_true', default=False, help="generate looping videos or not")
317
- parser.add_argument("--gfi", action='store_true', default=False, help="generate generative frame interpolation (gfi) or not")
318
- return parser
319
-
320
-
321
- if __name__ == '__main__':
322
- now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
323
- print("@CoVideoGen cond-Inference: %s"%now)
324
- parser = get_parser()
325
- args = parser.parse_args()
326
-
327
- seed_everything(args.seed)
328
- rank, gpu_num = 0, 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  run_inference(args, gpu_num, rank)
 
1
+ import argparse, os, sys, glob
2
+ import datetime, time
3
+ from omegaconf import OmegaConf
4
+ from tqdm import tqdm
5
+ from einops import rearrange, repeat
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torchvision
10
+ import torchvision.transforms as transforms
11
+ from pytorch_lightning import seed_everything
12
+ from PIL import Image
13
+ sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
14
+ from lvdm.models.samplers.ddim import DDIMSampler
15
+ from lvdm.models.samplers.ddim_multiplecond import DDIMSampler as DDIMSampler_multicond
16
+ from utils.utils import instantiate_from_config
17
+
18
+
19
+ def get_filelist(data_dir, postfixes):
20
+ patterns = [os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes]
21
+ file_list = []
22
+ for pattern in patterns:
23
+ file_list.extend(glob.glob(pattern))
24
+ file_list.sort()
25
+ return file_list
26
+
27
+ def load_model_checkpoint(model, ckpt):
28
+ state_dict = torch.load(ckpt, map_location="cpu")
29
+ if "state_dict" in list(state_dict.keys()):
30
+ state_dict = state_dict["state_dict"]
31
+ try:
32
+ model.load_state_dict(state_dict, strict=True)
33
+ except:
34
+ ## rename the keys for 256x256 model
35
+ new_pl_sd = OrderedDict()
36
+ for k,v in state_dict.items():
37
+ new_pl_sd[k] = v
38
+
39
+ for k in list(new_pl_sd.keys()):
40
+ if "framestride_embed" in k:
41
+ new_key = k.replace("framestride_embed", "fps_embedding")
42
+ new_pl_sd[new_key] = new_pl_sd[k]
43
+ del new_pl_sd[k]
44
+ model.load_state_dict(new_pl_sd, strict=True)
45
+ else:
46
+ # deepspeed
47
+ new_pl_sd = OrderedDict()
48
+ for key in state_dict['module'].keys():
49
+ new_pl_sd[key[16:]]=state_dict['module'][key]
50
+ model.load_state_dict(new_pl_sd)
51
+ print('>>> model checkpoint loaded.')
52
+ return model
53
+
54
+ def load_prompts(prompt_file):
55
+ f = open(prompt_file, 'r')
56
+ prompt_list = []
57
+ for idx, line in enumerate(f.readlines()):
58
+ l = line.strip()
59
+ if len(l) != 0:
60
+ prompt_list.append(l)
61
+ f.close()
62
+ return prompt_list
63
+
64
+ def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, gfi=False):
65
+ transform = transforms.Compose([
66
+ transforms.Resize(min(video_size)),
67
+ transforms.CenterCrop(video_size),
68
+ transforms.ToTensor(),
69
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
70
+ ## load prompts
71
+ prompt_file = get_filelist(data_dir, ['txt'])
72
+ assert len(prompt_file) > 0, "Error: found NO prompt file!"
73
+ ###### default prompt
74
+ default_idx = 0
75
+ default_idx = min(default_idx, len(prompt_file)-1)
76
+ if len(prompt_file) > 1:
77
+ print(f"Warning: multiple prompt files exist. The one {os.path.split(prompt_file[default_idx])[1]} is used.")
78
+ ## only use the first one (sorted by name) if multiple exist
79
+
80
+ ## load video
81
+ file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG'])
82
+ # assert len(file_list) == n_samples, "Error: data and prompts are NOT paired!"
83
+ data_list = []
84
+ filename_list = []
85
+ prompt_list = load_prompts(prompt_file[default_idx])
86
+ n_samples = len(prompt_list)
87
+ for idx in range(n_samples):
88
+ image = Image.open(file_list[idx]).convert('RGB')
89
+ image_tensor = transform(image).unsqueeze(1) # [c,1,h,w]
90
+ frame_tensor = repeat(image_tensor, 'c t h w -> c (repeat t) h w', repeat=video_frames)
91
+
92
+ data_list.append(frame_tensor)
93
+ _, filename = os.path.split(file_list[idx])
94
+ filename_list.append(filename)
95
+
96
+ return filename_list, data_list, prompt_list
97
+
98
+
99
+ def save_results(prompt, samples, filename, fakedir, fps=8, loop=False):
100
+ filename = filename.split('.')[0]+'.mp4'
101
+ prompt = prompt[0] if isinstance(prompt, list) else prompt
102
+
103
+ ## save video
104
+ videos = [samples]
105
+ savedirs = [fakedir]
106
+ for idx, video in enumerate(videos):
107
+ if video is None:
108
+ continue
109
+ # b,c,t,h,w
110
+ video = video.detach().cpu()
111
+ video = torch.clamp(video.float(), -1., 1.)
112
+ n = video.shape[0]
113
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
114
+ if loop:
115
+ video = video[:-1,...]
116
+
117
+ frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] #[3, 1*h, n*w]
118
+ grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, h, n*w]
119
+ grid = (grid + 1.0) / 2.0
120
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
121
+ path = os.path.join(savedirs[idx], filename)
122
+ torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) ## crf indicates the quality
123
+
124
+
125
+ def save_results_seperate(prompt, samples, filename, fakedir, fps=10, loop=False):
126
+ prompt = prompt[0] if isinstance(prompt, list) else prompt
127
+
128
+ ## save video
129
+ videos = [samples]
130
+ savedirs = [fakedir]
131
+ for idx, video in enumerate(videos):
132
+ if video is None:
133
+ continue
134
+ # b,c,t,h,w
135
+ video = video.detach().cpu()
136
+ if loop: # remove the last frame
137
+ video = video[:,:,:-1,...]
138
+ video = torch.clamp(video.float(), -1., 1.)
139
+ n = video.shape[0]
140
+ for i in range(n):
141
+ grid = video[i,...]
142
+ grid = (grid + 1.0) / 2.0
143
+ grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0) #thwc
144
+ path = os.path.join(savedirs[idx].replace('samples', 'samples_separate'), f'{filename.split(".")[0]}_sample{i}.mp4')
145
+ torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'})
146
+
147
+ def get_latent_z(model, videos):
148
+ b, c, t, h, w = videos.shape
149
+ x = rearrange(videos, 'b c t h w -> (b t) c h w')
150
+ z = model.encode_first_stage(x)
151
+ z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
152
+ return z
153
+
154
+
155
+ def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \
156
+ unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, multiple_cond_cfg=False, loop=False, gfi=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs):
157
+ ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model)
158
+ batch_size = noise_shape[0]
159
+ fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
160
+
161
+ if not text_input:
162
+ prompts = [""]*batch_size
163
+
164
+ img = videos[:,:,0] #bchw
165
+ img_emb = model.embedder(img) ## blc
166
+ img_emb = model.image_proj_model(img_emb)
167
+
168
+ cond_emb = model.get_learned_conditioning(prompts)
169
+ cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
170
+ if model.model.conditioning_key == 'hybrid':
171
+ z = get_latent_z(model, videos) # b c t h w
172
+ if loop or gfi:
173
+ img_cat_cond = torch.zeros_like(z)
174
+ img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
175
+ img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
176
+ else:
177
+ img_cat_cond = z[:,:,:1,:,:]
178
+ img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
179
+ cond["c_concat"] = [img_cat_cond] # b c 1 h w
180
+
181
+ if unconditional_guidance_scale != 1.0:
182
+ if model.uncond_type == "empty_seq":
183
+ prompts = batch_size * [""]
184
+ uc_emb = model.get_learned_conditioning(prompts)
185
+ elif model.uncond_type == "zero_embed":
186
+ uc_emb = torch.zeros_like(cond_emb)
187
+ uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c
188
+ uc_img_emb = model.image_proj_model(uc_img_emb)
189
+ uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]}
190
+ if model.model.conditioning_key == 'hybrid':
191
+ uc["c_concat"] = [img_cat_cond]
192
+ else:
193
+ uc = None
194
+
195
+ ## we need one more unconditioning image=yes, text=""
196
+ if multiple_cond_cfg and cfg_img != 1.0:
197
+ uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]}
198
+ if model.model.conditioning_key == 'hybrid':
199
+ uc_2["c_concat"] = [img_cat_cond]
200
+ kwargs.update({"unconditional_conditioning_img_nonetext": uc_2})
201
+ else:
202
+ kwargs.update({"unconditional_conditioning_img_nonetext": None})
203
+
204
+ z0 = None
205
+ cond_mask = None
206
+
207
+ batch_variants = []
208
+ for _ in range(n_samples):
209
+
210
+ if z0 is not None:
211
+ cond_z0 = z0.clone()
212
+ kwargs.update({"clean_cond": True})
213
+ else:
214
+ cond_z0 = None
215
+ if ddim_sampler is not None:
216
+
217
+ samples, _ = ddim_sampler.sample(S=ddim_steps,
218
+ conditioning=cond,
219
+ batch_size=batch_size,
220
+ shape=noise_shape[1:],
221
+ verbose=False,
222
+ unconditional_guidance_scale=unconditional_guidance_scale,
223
+ unconditional_conditioning=uc,
224
+ eta=ddim_eta,
225
+ cfg_img=cfg_img,
226
+ mask=cond_mask,
227
+ x0=cond_z0,
228
+ fs=fs,
229
+ timestep_spacing=timestep_spacing,
230
+ guidance_rescale=guidance_rescale,
231
+ **kwargs
232
+ )
233
+
234
+ ## reconstruct from latent to pixel space
235
+ batch_images = model.decode_first_stage(samples)
236
+ batch_variants.append(batch_images)
237
+ ## variants, batch, c, t, h, w
238
+ batch_variants = torch.stack(batch_variants)
239
+ return batch_variants.permute(1, 0, 2, 3, 4, 5)
240
+
241
+
242
+ def run_inference(args, gpu_num, gpu_no):
243
+ ## model config
244
+ config = OmegaConf.load(args.config)
245
+ model_config = config.pop("model", OmegaConf.create())
246
+
247
+ ## set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
248
+ model_config['params']['unet_config']['params']['use_checkpoint'] = False
249
+ model = instantiate_from_config(model_config)
250
+ model = model.cuda(gpu_no)
251
+ model.perframe_ae = args.perframe_ae
252
+ assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
253
+ model = load_model_checkpoint(model, args.ckpt_path)
254
+ model.eval()
255
+
256
+ ## run over data
257
+ assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!"
258
+ assert args.bs == 1, "Current implementation only support [batch size = 1]!"
259
+ ## latent noise shape
260
+ h, w = args.height // 8, args.width // 8
261
+ channels = model.model.diffusion_model.out_channels
262
+ n_frames = args.video_length
263
+ print(f'Inference with {n_frames} frames')
264
+ noise_shape = [args.bs, channels, n_frames, h, w]
265
+
266
+ fakedir = os.path.join(args.savedir, "samples")
267
+ fakedir_separate = os.path.join(args.savedir, "samples_separate")
268
+
269
+ # os.makedirs(fakedir, exist_ok=True)
270
+ os.makedirs(fakedir_separate, exist_ok=True)
271
+
272
+ ## prompt file setting
273
+ assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
274
+ filename_list, data_list, prompt_list = load_data_prompts(args.prompt_dir, video_size=(args.height, args.width), video_frames=n_frames, gfi=args.gfi)
275
+ num_samples = len(prompt_list)
276
+ samples_split = num_samples // gpu_num
277
+ print('Prompts testing [rank:%d] %d/%d samples loaded.'%(gpu_no, samples_split, num_samples))
278
+ #indices = random.choices(list(range(0, num_samples)), k=samples_per_device)
279
+ indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1)))
280
+ prompt_list_rank = [prompt_list[i] for i in indices]
281
+ data_list_rank = [data_list[i] for i in indices]
282
+ filename_list_rank = [filename_list[i] for i in indices]
283
+
284
+ start = time.time()
285
+ with torch.no_grad(), torch.cuda.amp.autocast():
286
+ for idx, indice in tqdm(enumerate(range(0, len(prompt_list_rank), args.bs)), desc='Sample Batch'):
287
+ prompts = prompt_list_rank[indice:indice+args.bs]
288
+ videos = data_list_rank[indice:indice+args.bs]
289
+ filenames = filename_list_rank[indice:indice+args.bs]
290
+ if isinstance(videos, list):
291
+ videos = torch.stack(videos, dim=0).to("cuda")
292
+ else:
293
+ videos = videos.unsqueeze(0).to("cuda")
294
+
295
+ batch_samples = image_guided_synthesis(model, prompts, videos, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, \
296
+ args.unconditional_guidance_scale, args.cfg_img, args.frame_stride, args.text_input, args.multiple_cond_cfg, args.loop, args.gfi, args.timestep_spacing, args.guidance_rescale)
297
+
298
+ ## save each example individually
299
+ for nn, samples in enumerate(batch_samples):
300
+ ## samples : [n_samples,c,t,h,w]
301
+ prompt = prompts[nn]
302
+ filename = filenames[nn]
303
+ # save_results(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
304
+ save_results_seperate(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
305
+
306
+ print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds")
307
+
308
+
309
+ def get_parser():
310
+ parser = argparse.ArgumentParser()
311
+ parser.add_argument("--savedir", type=str, default=None, help="results saving path")
312
+ parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path")
313
+ parser.add_argument("--config", type=str, help="config (yaml) path")
314
+ parser.add_argument("--prompt_dir", type=str, default=None, help="a data dir containing videos and prompts")
315
+ parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",)
316
+ parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",)
317
+ parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
318
+ parser.add_argument("--bs", type=int, default=1, help="batch size for inference, should be one")
319
+ parser.add_argument("--height", type=int, default=512, help="image height, in pixel space")
320
+ parser.add_argument("--width", type=int, default=512, help="image width, in pixel space")
321
+ parser.add_argument("--frame_stride", type=int, default=3, help="frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)")
322
+ parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance")
323
+ parser.add_argument("--seed", type=int, default=123, help="seed for seed_everything")
324
+ parser.add_argument("--video_length", type=int, default=16, help="inference video length")
325
+ parser.add_argument("--negative_prompt", action='store_true', default=False, help="negative prompt")
326
+ parser.add_argument("--text_input", action='store_true', default=False, help="input text to I2V model or not")
327
+ parser.add_argument("--multiple_cond_cfg", action='store_true', default=False, help="use multi-condition cfg or not")
328
+ parser.add_argument("--cfg_img", type=float, default=None, help="guidance scale for image conditioning")
329
+ parser.add_argument("--timestep_spacing", type=str, default="uniform", help="The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.")
330
+ parser.add_argument("--guidance_rescale", type=float, default=0.0, help="guidance rescale in [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891)")
331
+ parser.add_argument("--perframe_ae", action='store_true', default=False, help="if we use per-frame AE decoding, set it to True to save GPU memory, especially for the model of 576x1024")
332
+
333
+ ## currently not support looping video and generative frame interpolation
334
+ parser.add_argument("--loop", action='store_true', default=False, help="generate looping videos or not")
335
+ parser.add_argument("--gfi", action='store_true', default=False, help="generate generative frame interpolation (gfi) or not")
336
+ return parser
337
+
338
+
339
+ if __name__ == '__main__':
340
+ now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
341
+ print("@DynamiCrafter cond-Inference: %s"%now)
342
+ parser = get_parser()
343
+ args = parser.parse_args()
344
+
345
+ seed_everything(args.seed)
346
+ rank, gpu_num = 0, 1
347
  run_inference(args, gpu_num, rank)
scripts/gradio/__pycache__/i2v_test.cpython-39.pyc CHANGED
Binary files a/scripts/gradio/__pycache__/i2v_test.cpython-39.pyc and b/scripts/gradio/__pycache__/i2v_test.cpython-39.pyc differ
 
scripts/gradio/i2v_test.py CHANGED
@@ -1,102 +1,106 @@
1
- import os
2
- import time
3
- from omegaconf import OmegaConf
4
- import torch
5
- from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
6
- from utils.utils import instantiate_from_config
7
- from huggingface_hub import hf_hub_download
8
- from einops import repeat
9
- import torchvision.transforms as transforms
10
- from pytorch_lightning import seed_everything
11
-
12
-
13
- class Image2Video():
14
- def __init__(self,result_dir='./tmp/',gpu_num=1) -> None:
15
- self.download_model()
16
- self.result_dir = result_dir
17
- if not os.path.exists(self.result_dir):
18
- os.mkdir(self.result_dir)
19
- ckpt_path='checkpoints/dynamicrafter_256_v1/model.ckpt'
20
- config_file='configs/inference_256_v1.0.yaml'
21
- config = OmegaConf.load(config_file)
22
- model_config = config.pop("model", OmegaConf.create())
23
- model_config['params']['unet_config']['params']['use_checkpoint']=False
24
- model_list = []
25
- for gpu_id in range(gpu_num):
26
- model = instantiate_from_config(model_config)
27
- # model = model.cuda(gpu_id)
28
- assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
29
- model = load_model_checkpoint(model, ckpt_path)
30
- model.eval()
31
- model_list.append(model)
32
- self.model_list = model_list
33
- self.save_fps = 8
34
-
35
- def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
36
- seed_everything(seed)
37
- transform = transforms.Compose([
38
- transforms.Resize(256),
39
- transforms.CenterCrop(256),
40
- ])
41
- torch.cuda.empty_cache()
42
- print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
43
- start = time.time()
44
- gpu_id=0
45
- if steps > 60:
46
- steps = 60
47
- model = self.model_list[gpu_id]
48
- model = model.cuda()
49
- batch_size=1
50
- channels = model.model.diffusion_model.out_channels
51
- frames = model.temporal_length
52
- h, w = 256 // 8, 256 // 8
53
- noise_shape = [batch_size, channels, frames, h, w]
54
-
55
- # text cond
56
- text_emb = model.get_learned_conditioning([prompt])
57
-
58
- # img cond
59
- img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
60
- img_tensor = (img_tensor / 255. - 0.5) * 2
61
-
62
- image_tensor_resized = transform(img_tensor) #3,256,256
63
- videos = image_tensor_resized.unsqueeze(0) # bchw
64
-
65
- z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
66
-
67
- img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
68
-
69
- cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
70
- img_emb = model.image_proj_model(cond_images)
71
-
72
- imtext_cond = torch.cat([text_emb, img_emb], dim=1)
73
-
74
- fs = torch.tensor([fs], dtype=torch.long, device=model.device)
75
- cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
76
-
77
- ## inference
78
- batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
79
- ## b,samples,c,t,h,w
80
- prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
81
- prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
82
- prompt_str=prompt_str[:40]
83
-
84
- save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
85
- print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
86
- model = model.cpu()
87
- return os.path.join(self.result_dir, f"{prompt_str}.mp4")
88
-
89
- def download_model(self):
90
- REPO_ID = 'Doubiiu/DynamiCrafter'
91
- filename_list = ['model.ckpt']
92
- if not os.path.exists('./checkpoints/dynamicrafter_256_v1/'):
93
- os.makedirs('./dynamicrafter_256_v1/')
94
- for filename in filename_list:
95
- local_file = os.path.join('./checkpoints/dynamicrafter_256_v1/', filename)
96
- if not os.path.exists(local_file):
97
- hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_256_v1/', local_dir_use_symlinks=False)
98
-
99
- if __name__ == '__main__':
100
- i2v = Image2Video()
101
- video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset')
 
 
 
 
102
  print('done', video_path)
 
1
+ import os
2
+ import time
3
+ from omegaconf import OmegaConf
4
+ import torch
5
+ from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z
6
+ from utils.utils import instantiate_from_config
7
+ from huggingface_hub import hf_hub_download
8
+ from einops import repeat
9
+ import torchvision.transforms as transforms
10
+ from pytorch_lightning import seed_everything
11
+
12
+
13
+ class Image2Video():
14
+ def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') -> None:
15
+ self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw
16
+ self.download_model()
17
+
18
+ self.result_dir = result_dir
19
+ if not os.path.exists(self.result_dir):
20
+ os.mkdir(self.result_dir)
21
+ ckpt_path='checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_v1/model.ckpt'
22
+ config_file='configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml'
23
+ config = OmegaConf.load(config_file)
24
+ model_config = config.pop("model", OmegaConf.create())
25
+ model_config['params']['unet_config']['params']['use_checkpoint']=False
26
+ model_list = []
27
+ for gpu_id in range(gpu_num):
28
+ model = instantiate_from_config(model_config)
29
+ # model = model.cuda(gpu_id)
30
+ assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
31
+ model = load_model_checkpoint(model, ckpt_path)
32
+ model.eval()
33
+ model_list.append(model)
34
+ self.model_list = model_list
35
+ self.save_fps = 8
36
+
37
+ def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
38
+ seed_everything(seed)
39
+ transform = transforms.Compose([
40
+ transforms.Resize(min(self.resolution)),
41
+ transforms.CenterCrop(self.resolution),
42
+ ])
43
+ torch.cuda.empty_cache()
44
+ print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
45
+ start = time.time()
46
+ gpu_id=0
47
+ if steps > 60:
48
+ steps = 60
49
+ model = self.model_list[gpu_id]
50
+ model = model.cuda()
51
+ batch_size=1
52
+ channels = model.model.diffusion_model.out_channels
53
+ frames = model.temporal_length
54
+ h, w = self.resolution[0] // 8, self.resolution[1] // 8
55
+ noise_shape = [batch_size, channels, frames, h, w]
56
+
57
+ # text cond
58
+ text_emb = model.get_learned_conditioning([prompt])
59
+
60
+ # img cond
61
+ img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
62
+ img_tensor = (img_tensor / 255. - 0.5) * 2
63
+
64
+ image_tensor_resized = transform(img_tensor) #3,h,w
65
+ videos = image_tensor_resized.unsqueeze(0) # bchw
66
+
67
+ z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
68
+
69
+ img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
70
+
71
+ cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
72
+ img_emb = model.image_proj_model(cond_images)
73
+
74
+ imtext_cond = torch.cat([text_emb, img_emb], dim=1)
75
+
76
+ fs = torch.tensor([fs], dtype=torch.long, device=model.device)
77
+ cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
78
+
79
+ ## inference
80
+ batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
81
+ ## b,samples,c,t,h,w
82
+ prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt
83
+ prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str
84
+ prompt_str=prompt_str[:40]
85
+ if len(prompt_str) == 0:
86
+ prompt_str = 'empty_prompt'
87
+
88
+ save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps)
89
+ print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds")
90
+ model = model.cpu()
91
+ return os.path.join(self.result_dir, f"{prompt_str}.mp4")
92
+
93
+ def download_model(self):
94
+ REPO_ID = 'Doubiiu/DynamiCrafter'
95
+ filename_list = ['model.ckpt']
96
+ if not os.path.exists('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/'):
97
+ os.makedirs('./dynamicrafter_'+str(self.resolution[1])+'_v1/')
98
+ for filename in filename_list:
99
+ local_file = os.path.join('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', filename)
100
+ if not os.path.exists(local_file):
101
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', local_dir_use_symlinks=False)
102
+
103
+ if __name__ == '__main__':
104
+ i2v = Image2Video()
105
+ video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset')
106
  print('done', video_path)
scripts/run.sh CHANGED
@@ -1,25 +1,61 @@
1
- name="dynamicrafter_256"
 
 
2
 
3
- ckpt='checkpoints/dynamicrafter_256_v1/model.ckpt'
4
- config='configs/inference_256_v1.0.yaml'
5
 
6
- prompt_dir="prompts/"
7
  res_dir="results"
8
 
9
- CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \
10
- --seed 123 \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  --ckpt_path $ckpt \
12
  --config $config \
13
  --savedir $res_dir/$name \
14
  --n_samples 1 \
15
- --bs 1 --height 256 --width 256 \
16
  --unconditional_guidance_scale 7.5 \
17
  --ddim_steps 50 \
18
  --ddim_eta 1.0 \
19
  --prompt_dir $prompt_dir \
20
  --text_input \
21
  --video_length 16 \
22
- --frame_stride 3
 
 
 
23
 
24
  ## multi-cond CFG: the <unconditional_guidance_scale> is s_txt, <cfg_img> is s_img
25
- #--multiple_cond_cfg --cfg_img 7.5
 
 
1
+ version=$1 ##1024, 512, 256
2
+ seed=123
3
+ name=dynamicrafter_$1_seed${seed}
4
 
5
+ ckpt=checkpoints/dynamicrafter_$1_v1/model.ckpt
6
+ config=configs/inference_$1_v1.0.yaml
7
 
8
+ prompt_dir=prompts/$1/
9
  res_dir="results"
10
 
11
+ if [ "$1" == "256" ]; then
12
+ H=256
13
+ FS=3 ## This model adopts frame stride=3, range recommended: 1-6 (larger value -> larger motion)
14
+ elif [ "$1" == "512" ]; then
15
+ H=320
16
+ FS=24 ## This model adopts FPS=24, range recommended: 15-30 (smaller value -> larger motion)
17
+ elif [ "$1" == "1024" ]; then
18
+ H=576
19
+ FS=10 ## This model adopts FPS=10, range recommended: 15-5 (smaller value -> larger motion)
20
+ else
21
+ echo "Invalid input. Please enter 256, 512, or 1024."
22
+ exit 1
23
+ fi
24
+
25
+ if [ "$1" == "256" ]; then
26
+ CUDA_VISIBLE_DEVICES=2 python3 scripts/evaluation/inference.py \
27
+ --seed ${seed} \
28
+ --ckpt_path $ckpt \
29
+ --config $config \
30
+ --savedir $res_dir/$name \
31
+ --n_samples 1 \
32
+ --bs 1 --height ${H} --width $1 \
33
+ --unconditional_guidance_scale 7.5 \
34
+ --ddim_steps 50 \
35
+ --ddim_eta 1.0 \
36
+ --prompt_dir $prompt_dir \
37
+ --text_input \
38
+ --video_length 16 \
39
+ --frame_stride ${FS}
40
+ else
41
+ CUDA_VISIBLE_DEVICES=2 python3 scripts/evaluation/inference.py \
42
+ --seed ${seed} \
43
  --ckpt_path $ckpt \
44
  --config $config \
45
  --savedir $res_dir/$name \
46
  --n_samples 1 \
47
+ --bs 1 --height ${H} --width $1 \
48
  --unconditional_guidance_scale 7.5 \
49
  --ddim_steps 50 \
50
  --ddim_eta 1.0 \
51
  --prompt_dir $prompt_dir \
52
  --text_input \
53
  --video_length 16 \
54
+ --frame_stride ${FS} \
55
+ --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae
56
+ fi
57
+
58
 
59
  ## multi-cond CFG: the <unconditional_guidance_scale> is s_txt, <cfg_img> is s_img
60
+ #--multiple_cond_cfg --cfg_img 7.5
61
+ #--loop
scripts/run_mp.sh ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version=$1 ##1024, 512, 256
2
+ seed=123
3
+
4
+ name=dynamicrafter_$1_mp_seed${seed}
5
+
6
+ ckpt=checkpoints/dynamicrafter_$1_v1/model.ckpt
7
+ config=configs/inference_$1_v1.0.yaml
8
+
9
+ prompt_dir=prompts/$1/
10
+ res_dir="results"
11
+
12
+ if [ "$1" == "256" ]; then
13
+ H=256
14
+ FS=3 ## This model adopts frame stride=3
15
+ elif [ "$1" == "512" ]; then
16
+ H=320
17
+ FS=24 ## This model adopts FPS=24
18
+ elif [ "$1" == "1024" ]; then
19
+ H=576
20
+ FS=10 ## This model adopts FPS=10
21
+ else
22
+ echo "Invalid input. Please enter 256, 512, or 1024."
23
+ exit 1
24
+ fi
25
+
26
+ # if [ "$1" == "256" ]; then
27
+ # CUDA_VISIBLE_DEVICES=2 python3 scripts/evaluation/inference.py \
28
+ # --seed 123 \
29
+ # --ckpt_path $ckpt \
30
+ # --config $config \
31
+ # --savedir $res_dir/$name \
32
+ # --n_samples 1 \
33
+ # --bs 1 --height ${H} --width $1 \
34
+ # --unconditional_guidance_scale 7.5 \
35
+ # --ddim_steps 50 \
36
+ # --ddim_eta 1.0 \
37
+ # --prompt_dir $prompt_dir \
38
+ # --text_input \
39
+ # --video_length 16 \
40
+ # --frame_stride ${FS}
41
+ # else
42
+ # CUDA_VISIBLE_DEVICES=2 python3 scripts/evaluation/inference.py \
43
+ # --seed 123 \
44
+ # --ckpt_path $ckpt \
45
+ # --config $config \
46
+ # --savedir $res_dir/$name \
47
+ # --n_samples 1 \
48
+ # --bs 1 --height ${H} --width $1 \
49
+ # --unconditional_guidance_scale 7.5 \
50
+ # --ddim_steps 50 \
51
+ # --ddim_eta 1.0 \
52
+ # --prompt_dir $prompt_dir \
53
+ # --text_input \
54
+ # --video_length 16 \
55
+ # --frame_stride ${FS} \
56
+ # --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7
57
+ # fi
58
+
59
+
60
+ ## multi-cond CFG: the <unconditional_guidance_scale> is s_txt, <cfg_img> is s_img
61
+ #--multiple_cond_cfg --cfg_img 7.5
62
+ #--loop
63
+
64
+ ## inference using single node with multi-GPUs:
65
+ if [ "$1" == "256" ]; then
66
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
67
+ --nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \
68
+ scripts/evaluation/ddp_wrapper.py \
69
+ --module 'inference' \
70
+ --seed ${seed} \
71
+ --ckpt_path $ckpt \
72
+ --config $config \
73
+ --savedir $res_dir/$name \
74
+ --n_samples 1 \
75
+ --bs 1 --height ${H} --width $1 \
76
+ --unconditional_guidance_scale 7.5 \
77
+ --ddim_steps 50 \
78
+ --ddim_eta 1.0 \
79
+ --prompt_dir $prompt_dir \
80
+ --text_input \
81
+ --video_length 16 \
82
+ --frame_stride ${FS}
83
+ else
84
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
85
+ --nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=23456 --node_rank=0 \
86
+ scripts/evaluation/ddp_wrapper.py \
87
+ --module 'inference' \
88
+ --seed ${seed} \
89
+ --ckpt_path $ckpt \
90
+ --config $config \
91
+ --savedir $res_dir/$name \
92
+ --n_samples 1 \
93
+ --bs 1 --height ${H} --width $1 \
94
+ --unconditional_guidance_scale 7.5 \
95
+ --ddim_steps 50 \
96
+ --ddim_eta 1.0 \
97
+ --prompt_dir $prompt_dir \
98
+ --text_input \
99
+ --video_length 16 \
100
+ --frame_stride ${FS} \
101
+ --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae
102
+ fi