ifire commited on
Commit
a8f08a2
1 Parent(s): adc915a

Update for spaces.

Browse files
Files changed (1) hide show
  1. gradio_app.py +18 -35
gradio_app.py CHANGED
@@ -12,7 +12,6 @@ import gradio as gr
12
  import numpy as np
13
  import torch
14
  import wd14tagger
15
- import memory_management
16
  import uuid
17
 
18
  from PIL import Image
@@ -24,7 +23,10 @@ from diffusers.models.attention_processor import AttnProcessor2_0
24
  from transformers import CLIPTextModel, CLIPTokenizer
25
  from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
26
  from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
 
27
 
 
 
28
 
29
  class ModifiedUNet(UNet2DConditionModel):
30
  @classmethod
@@ -37,9 +39,9 @@ class ModifiedUNet(UNet2DConditionModel):
37
 
38
  model_name = 'lllyasviel/paints_undo_single_frame'
39
  tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
40
- text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16)
41
- vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16) # bfloat16 vae
42
- unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16)
43
 
44
  unet.set_attn_processor(AttnProcessor2_0())
45
  vae.set_attn_processor(AttnProcessor2_0())
@@ -47,12 +49,7 @@ vae.set_attn_processor(AttnProcessor2_0())
47
  video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
48
  'lllyasviel/paints_undo_multi_frame',
49
  fp16=True
50
- )
51
-
52
- memory_management.unload_all_models([
53
- video_pipe.unet, video_pipe.vae, video_pipe.text_encoder, video_pipe.image_projection, video_pipe.image_encoder,
54
- unet, vae, text_encoder
55
- ])
56
 
57
  k_sampler = KDiffusionSampler(
58
  unet=unet,
@@ -74,19 +71,16 @@ def find_best_bucket(h, w, options):
74
  return best_bucket
75
 
76
 
77
- @torch.inference_mode()
78
  def encode_cropped_prompt_77tokens(txt: str):
79
- memory_management.load_models_to_gpu(text_encoder)
80
  cond_ids = tokenizer(txt,
81
  padding="max_length",
82
  max_length=tokenizer.model_max_length,
83
  truncation=True,
84
- return_tensors="pt").input_ids.to(device=text_encoder.device)
85
  text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
86
  return text_cond
87
 
88
 
89
- @torch.inference_mode()
90
  def pytorch2numpy(imgs):
91
  results = []
92
  for x in imgs:
@@ -97,7 +91,6 @@ def pytorch2numpy(imgs):
97
  return results
98
 
99
 
100
- @torch.inference_mode()
101
  def numpy2pytorch(imgs):
102
  h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
103
  h = h.movedim(-1, 1)
@@ -110,29 +103,26 @@ def resize_without_crop(image, target_width, target_height):
110
  return np.array(resized_image)
111
 
112
 
113
- @torch.inference_mode()
114
  def interrogator_process(x):
115
- return wd14tagger.default_interrogator(x)
 
116
 
117
 
118
- @torch.inference_mode()
119
  def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
120
  progress=gr.Progress()):
121
- rng = torch.Generator(device=memory_management.gpu).manual_seed(int(seed))
122
 
123
- memory_management.load_models_to_gpu(vae)
124
  fg = resize_and_center_crop(input_fg, image_width, image_height)
125
- concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
126
  concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
127
 
128
- memory_management.load_models_to_gpu(text_encoder)
129
  conds = encode_cropped_prompt_77tokens(prompt)
130
  unconds = encode_cropped_prompt_77tokens(n_prompt)
131
 
132
- memory_management.load_models_to_gpu(unet)
133
- fs = torch.tensor(input_undo_steps).to(device=unet.device, dtype=torch.long)
134
  initial_latents = torch.zeros_like(concat_conds)
135
- concat_conds = concat_conds.to(device=unet.device, dtype=unet.dtype)
136
  latents = k_sampler(
137
  initial_latent=initial_latents,
138
  strength=1.0,
@@ -147,7 +137,6 @@ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed,
147
  progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
148
  ).to(vae.dtype) / vae.config.scaling_factor
149
 
150
- memory_management.load_models_to_gpu(vae)
151
  pixels = vae.decode(latents).sample
152
  pixels = pytorch2numpy(pixels)
153
  pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
@@ -155,7 +144,6 @@ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed,
155
  return pixels
156
 
157
 
158
- @torch.inference_mode()
159
  def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
160
  random.seed(seed)
161
  np.random.seed(seed)
@@ -174,25 +162,21 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
174
  input_frames = numpy2pytorch([image_1, image_2])
175
  input_frames = input_frames.unsqueeze(0).movedim(1, 2)
176
 
177
- memory_management.load_models_to_gpu(video_pipe.text_encoder)
178
  positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
179
  negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
180
 
181
- memory_management.load_models_to_gpu([video_pipe.image_projection, video_pipe.image_encoder])
182
- input_frames = input_frames.to(device=video_pipe.image_encoder.device, dtype=video_pipe.image_encoder.dtype)
183
  positive_image_cond = video_pipe.encode_clip_vision(input_frames)
184
  positive_image_cond = video_pipe.image_projection(positive_image_cond)
185
  negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
186
  negative_image_cond = video_pipe.image_projection(negative_image_cond)
187
 
188
- memory_management.load_models_to_gpu([video_pipe.vae])
189
- input_frames = input_frames.to(device=video_pipe.vae.device, dtype=video_pipe.vae.dtype)
190
  input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
191
  first_frame = input_frame_latents[:, :, 0]
192
  last_frame = input_frame_latents[:, :, 1]
193
  concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
194
 
195
- memory_management.load_models_to_gpu([video_pipe.unet])
196
  latents = video_pipe(
197
  batch_size=1,
198
  steps=int(steps),
@@ -206,12 +190,11 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
206
  progress_tqdm=progress_tqdm
207
  )
208
 
209
- memory_management.load_models_to_gpu([video_pipe.vae])
210
  video = video_pipe.decode_latents(latents, vae_hidden_states)
211
  return video, image_1, image_2
212
 
213
 
214
- @torch.inference_mode()
215
  def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
216
  result_frames = []
217
  cropped_images = []
 
12
  import numpy as np
13
  import torch
14
  import wd14tagger
 
15
  import uuid
16
 
17
  from PIL import Image
 
23
  from transformers import CLIPTextModel, CLIPTokenizer
24
  from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
25
  from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
26
+ import spaces
27
 
28
+ # Disable gradients globally
29
+ torch.set_grad_enabled(False)
30
 
31
  class ModifiedUNet(UNet2DConditionModel):
32
  @classmethod
 
39
 
40
  model_name = 'lllyasviel/paints_undo_single_frame'
41
  tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
42
+ text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16).to("cuda")
43
+ vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16).to("cuda") # bfloat16 vae
44
+ unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16).to("cuda")
45
 
46
  unet.set_attn_processor(AttnProcessor2_0())
47
  vae.set_attn_processor(AttnProcessor2_0())
 
49
  video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
50
  'lllyasviel/paints_undo_multi_frame',
51
  fp16=True
52
+ ).to("cuda")
 
 
 
 
 
53
 
54
  k_sampler = KDiffusionSampler(
55
  unet=unet,
 
71
  return best_bucket
72
 
73
 
 
74
  def encode_cropped_prompt_77tokens(txt: str):
 
75
  cond_ids = tokenizer(txt,
76
  padding="max_length",
77
  max_length=tokenizer.model_max_length,
78
  truncation=True,
79
+ return_tensors="pt").input_ids.to(device="cuda")
80
  text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
81
  return text_cond
82
 
83
 
 
84
  def pytorch2numpy(imgs):
85
  results = []
86
  for x in imgs:
 
91
  return results
92
 
93
 
 
94
  def numpy2pytorch(imgs):
95
  h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
96
  h = h.movedim(-1, 1)
 
103
  return np.array(resized_image)
104
 
105
 
 
106
  def interrogator_process(x):
107
+ image_description = wd14tagger.default_interrogator(x)
108
+ return image_description
109
 
110
 
111
+ @spaces.GPU()
112
  def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
113
  progress=gr.Progress()):
114
+ rng = torch.Generator(device="cuda").manual_seed(int(seed))
115
 
 
116
  fg = resize_and_center_crop(input_fg, image_width, image_height)
117
+ concat_conds = numpy2pytorch([fg]).clone().detach().to(device="cuda", dtype=vae.dtype)
118
  concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
119
 
 
120
  conds = encode_cropped_prompt_77tokens(prompt)
121
  unconds = encode_cropped_prompt_77tokens(n_prompt)
122
 
123
+ fs = torch.tensor(input_undo_steps).to(device="cuda", dtype=torch.long)
 
124
  initial_latents = torch.zeros_like(concat_conds)
125
+ concat_conds = concat_conds.to(device="cuda", dtype=unet.dtype)
126
  latents = k_sampler(
127
  initial_latent=initial_latents,
128
  strength=1.0,
 
137
  progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
138
  ).to(vae.dtype) / vae.config.scaling_factor
139
 
 
140
  pixels = vae.decode(latents).sample
141
  pixels = pytorch2numpy(pixels)
142
  pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
 
144
  return pixels
145
 
146
 
 
147
  def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
148
  random.seed(seed)
149
  np.random.seed(seed)
 
162
  input_frames = numpy2pytorch([image_1, image_2])
163
  input_frames = input_frames.unsqueeze(0).movedim(1, 2)
164
 
 
165
  positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
166
  negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
167
 
168
+ input_frames = input_frames.to(device="cuda", dtype=video_pipe.image_encoder.dtype)
 
169
  positive_image_cond = video_pipe.encode_clip_vision(input_frames)
170
  positive_image_cond = video_pipe.image_projection(positive_image_cond)
171
  negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
172
  negative_image_cond = video_pipe.image_projection(negative_image_cond)
173
 
174
+ input_frames = input_frames.to(device="cuda", dtype=video_pipe.vae.dtype)
 
175
  input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
176
  first_frame = input_frame_latents[:, :, 0]
177
  last_frame = input_frame_latents[:, :, 1]
178
  concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
179
 
 
180
  latents = video_pipe(
181
  batch_size=1,
182
  steps=int(steps),
 
190
  progress_tqdm=progress_tqdm
191
  )
192
 
 
193
  video = video_pipe.decode_latents(latents, vae_hidden_states)
194
  return video, image_1, image_2
195
 
196
 
197
+ @spaces.GPU(duration=360)
198
  def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
199
  result_frames = []
200
  cropped_images = []