multimodalart HF staff commited on
Commit
3b16b97
1 Parent(s): 2e6eea6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -165
app.py CHANGED
@@ -19,6 +19,8 @@ from sgm.inference.helpers import embed_watermark
19
  from sgm.util import default, instantiate_from_config
20
  from huggingface_hub import hf_hub_download
21
 
 
 
22
  num_frames = 25
23
  num_steps = 30
24
  model_config = "scripts/sampling/configs/svd_xt.yaml"
@@ -30,38 +32,6 @@ css = '''
30
  .gradio-container{max-width:850px !important}
31
  '''
32
 
33
- def load_model(
34
- config: str,
35
- device: str,
36
- num_frames: int,
37
- num_steps: int,
38
- ):
39
- config = OmegaConf.load(config)
40
- if device == "cuda":
41
- config.model.params.conditioner_config.params.emb_models[
42
- 0
43
- ].params.open_clip_embedding_config.params.init_device = device
44
-
45
- config.model.params.sampler_config.params.num_steps = num_steps
46
- config.model.params.sampler_config.params.guider_config.params.num_frames = (
47
- num_frames
48
- )
49
- if device == "cuda":
50
- with torch.device(device):
51
- model = instantiate_from_config(config.model).to(device).eval()
52
- else:
53
- model = instantiate_from_config(config.model).to(device).eval()
54
-
55
- filter = DeepFloydDataFiltering(verbose=False, device=device)
56
- return model, filter
57
-
58
- model, filter = load_model(
59
- model_config,
60
- device,
61
- num_frames,
62
- num_steps,
63
- )
64
-
65
  def sample(
66
  input_path: str,
67
  num_frames: Optional[int] = 25,
@@ -74,139 +44,8 @@ def sample(
74
  decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
75
  ):
76
  output_folder = str(uuid.uuid4())
77
- torch.manual_seed(seed)
78
- path = Path(input_path)
79
- all_img_paths = []
80
- if path.is_file():
81
- if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
82
- all_img_paths = [input_path]
83
- else:
84
- raise ValueError("Path is not valid image file.")
85
- elif path.is_dir():
86
- all_img_paths = sorted(
87
- [
88
- f
89
- for f in path.iterdir()
90
- if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
91
- ]
92
- )
93
- if len(all_img_paths) == 0:
94
- raise ValueError("Folder does not contain any images.")
95
- else:
96
- raise ValueError
97
-
98
- for input_img_path in all_img_paths:
99
- with Image.open(input_img_path) as image:
100
- if image.mode == "RGBA":
101
- image = image.convert("RGB")
102
- w, h = image.size
103
-
104
- if h % 64 != 0 or w % 64 != 0:
105
- width, height = map(lambda x: x - x % 64, (w, h))
106
- image = image.resize((width, height))
107
- print(
108
- f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
109
- )
110
-
111
- image = ToTensor()(image)
112
- image = image * 2.0 - 1.0
113
-
114
- image = image.unsqueeze(0).to(device)
115
- H, W = image.shape[2:]
116
- assert image.shape[1] == 3
117
- F = 8
118
- C = 4
119
- shape = (num_frames, C, H // F, W // F)
120
- if (H, W) != (576, 1024):
121
- print(
122
- "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
123
- )
124
- if motion_bucket_id > 255:
125
- print(
126
- "WARNING: High motion bucket! This may lead to suboptimal performance."
127
- )
128
-
129
- if fps_id < 5:
130
- print("WARNING: Small fps value! This may lead to suboptimal performance.")
131
-
132
- if fps_id > 30:
133
- print("WARNING: Large fps value! This may lead to suboptimal performance.")
134
-
135
- value_dict = {}
136
- value_dict["motion_bucket_id"] = motion_bucket_id
137
- value_dict["fps_id"] = fps_id
138
- value_dict["cond_aug"] = cond_aug
139
- value_dict["cond_frames_without_noise"] = image
140
- value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
141
- value_dict["cond_aug"] = cond_aug
142
-
143
-
144
-
145
- with torch.no_grad():
146
- with torch.autocast(device):
147
- batch, batch_uc = get_batch(
148
- get_unique_embedder_keys_from_conditioner(model.conditioner),
149
- value_dict,
150
- [1, num_frames],
151
- T=num_frames,
152
- device=device,
153
- )
154
- c, uc = model.conditioner.get_unconditional_conditioning(
155
- batch,
156
- batch_uc=batch_uc,
157
- force_uc_zero_embeddings=[
158
- "cond_frames",
159
- "cond_frames_without_noise",
160
- ],
161
- )
162
-
163
- for k in ["crossattn", "concat"]:
164
- uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
165
- uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
166
- c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
167
- c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
168
-
169
- randn = torch.randn(shape, device=device)
170
-
171
- additional_model_inputs = {}
172
- additional_model_inputs["image_only_indicator"] = torch.zeros(
173
- 2, num_frames
174
- ).to(device)
175
- additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
176
-
177
- def denoiser(input, sigma, c):
178
- return model.denoiser(
179
- model.model, input, sigma, c, **additional_model_inputs
180
- )
181
-
182
- samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
183
- model.en_and_decode_n_samples_a_time = decoding_t
184
- samples_x = model.decode_first_stage(samples_z)
185
- samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
186
-
187
- os.makedirs(output_folder, exist_ok=True)
188
- base_count = len(glob(os.path.join(output_folder, "*.mp4")))
189
- video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
190
- writer = cv2.VideoWriter(
191
- video_path,
192
- cv2.VideoWriter_fourcc(*"MP4V"),
193
- fps_id + 1,
194
- (samples.shape[-1], samples.shape[-2]),
195
- )
196
-
197
- samples = embed_watermark(samples)
198
- samples = filter(samples)
199
- vid = (
200
- (rearrange(samples, "t c h w -> t h w c") * 255)
201
- .cpu()
202
- .numpy()
203
- .astype(np.uint8)
204
- )
205
- for frame in vid:
206
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
207
- writer.write(frame)
208
- writer.release()
209
- return video_path
210
 
211
  def get_unique_embedder_keys_from_conditioner(conditioner):
212
  return list(set([x.input_key for x in conditioner.embedders]))
 
19
  from sgm.util import default, instantiate_from_config
20
  from huggingface_hub import hf_hub_download
21
 
22
+ from simple_video_sample import sample
23
+
24
  num_frames = 25
25
  num_steps = 30
26
  model_config = "scripts/sampling/configs/svd_xt.yaml"
 
32
  .gradio-container{max-width:850px !important}
33
  '''
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def sample(
36
  input_path: str,
37
  num_frames: Optional[int] = 25,
 
44
  decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
45
  ):
46
  output_folder = str(uuid.uuid4())
47
+ sample(input_path, version, output_folder, decoding_t)
48
+ return f"{output_folder}/000000.mp4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def get_unique_embedder_keys_from_conditioner(conditioner):
51
  return list(set([x.input_key for x in conditioner.embedders]))