Spaces:
Runtime error
Runtime error
Update for spaces.
Browse files- gradio_app.py +18 -35
gradio_app.py
CHANGED
@@ -12,7 +12,6 @@ import gradio as gr
|
|
12 |
import numpy as np
|
13 |
import torch
|
14 |
import wd14tagger
|
15 |
-
import memory_management
|
16 |
import uuid
|
17 |
|
18 |
from PIL import Image
|
@@ -24,7 +23,10 @@ from diffusers.models.attention_processor import AttnProcessor2_0
|
|
24 |
from transformers import CLIPTextModel, CLIPTokenizer
|
25 |
from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
|
26 |
from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
|
|
|
27 |
|
|
|
|
|
28 |
|
29 |
class ModifiedUNet(UNet2DConditionModel):
|
30 |
@classmethod
|
@@ -37,9 +39,9 @@ class ModifiedUNet(UNet2DConditionModel):
|
|
37 |
|
38 |
model_name = 'lllyasviel/paints_undo_single_frame'
|
39 |
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
|
40 |
-
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16)
|
41 |
-
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16) # bfloat16 vae
|
42 |
-
unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16)
|
43 |
|
44 |
unet.set_attn_processor(AttnProcessor2_0())
|
45 |
vae.set_attn_processor(AttnProcessor2_0())
|
@@ -47,12 +49,7 @@ vae.set_attn_processor(AttnProcessor2_0())
|
|
47 |
video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
|
48 |
'lllyasviel/paints_undo_multi_frame',
|
49 |
fp16=True
|
50 |
-
)
|
51 |
-
|
52 |
-
memory_management.unload_all_models([
|
53 |
-
video_pipe.unet, video_pipe.vae, video_pipe.text_encoder, video_pipe.image_projection, video_pipe.image_encoder,
|
54 |
-
unet, vae, text_encoder
|
55 |
-
])
|
56 |
|
57 |
k_sampler = KDiffusionSampler(
|
58 |
unet=unet,
|
@@ -74,19 +71,16 @@ def find_best_bucket(h, w, options):
|
|
74 |
return best_bucket
|
75 |
|
76 |
|
77 |
-
@torch.inference_mode()
|
78 |
def encode_cropped_prompt_77tokens(txt: str):
|
79 |
-
memory_management.load_models_to_gpu(text_encoder)
|
80 |
cond_ids = tokenizer(txt,
|
81 |
padding="max_length",
|
82 |
max_length=tokenizer.model_max_length,
|
83 |
truncation=True,
|
84 |
-
return_tensors="pt").input_ids.to(device=
|
85 |
text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
|
86 |
return text_cond
|
87 |
|
88 |
|
89 |
-
@torch.inference_mode()
|
90 |
def pytorch2numpy(imgs):
|
91 |
results = []
|
92 |
for x in imgs:
|
@@ -97,7 +91,6 @@ def pytorch2numpy(imgs):
|
|
97 |
return results
|
98 |
|
99 |
|
100 |
-
@torch.inference_mode()
|
101 |
def numpy2pytorch(imgs):
|
102 |
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
103 |
h = h.movedim(-1, 1)
|
@@ -110,29 +103,26 @@ def resize_without_crop(image, target_width, target_height):
|
|
110 |
return np.array(resized_image)
|
111 |
|
112 |
|
113 |
-
@torch.inference_mode()
|
114 |
def interrogator_process(x):
|
115 |
-
|
|
|
116 |
|
117 |
|
118 |
-
@
|
119 |
def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
|
120 |
progress=gr.Progress()):
|
121 |
-
rng = torch.Generator(device=
|
122 |
|
123 |
-
memory_management.load_models_to_gpu(vae)
|
124 |
fg = resize_and_center_crop(input_fg, image_width, image_height)
|
125 |
-
concat_conds = numpy2pytorch([fg]).to(device=
|
126 |
concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
|
127 |
|
128 |
-
memory_management.load_models_to_gpu(text_encoder)
|
129 |
conds = encode_cropped_prompt_77tokens(prompt)
|
130 |
unconds = encode_cropped_prompt_77tokens(n_prompt)
|
131 |
|
132 |
-
|
133 |
-
fs = torch.tensor(input_undo_steps).to(device=unet.device, dtype=torch.long)
|
134 |
initial_latents = torch.zeros_like(concat_conds)
|
135 |
-
concat_conds = concat_conds.to(device=
|
136 |
latents = k_sampler(
|
137 |
initial_latent=initial_latents,
|
138 |
strength=1.0,
|
@@ -147,7 +137,6 @@ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed,
|
|
147 |
progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
|
148 |
).to(vae.dtype) / vae.config.scaling_factor
|
149 |
|
150 |
-
memory_management.load_models_to_gpu(vae)
|
151 |
pixels = vae.decode(latents).sample
|
152 |
pixels = pytorch2numpy(pixels)
|
153 |
pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
|
@@ -155,7 +144,6 @@ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed,
|
|
155 |
return pixels
|
156 |
|
157 |
|
158 |
-
@torch.inference_mode()
|
159 |
def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
|
160 |
random.seed(seed)
|
161 |
np.random.seed(seed)
|
@@ -174,25 +162,21 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
|
|
174 |
input_frames = numpy2pytorch([image_1, image_2])
|
175 |
input_frames = input_frames.unsqueeze(0).movedim(1, 2)
|
176 |
|
177 |
-
memory_management.load_models_to_gpu(video_pipe.text_encoder)
|
178 |
positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
|
179 |
negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
|
180 |
|
181 |
-
|
182 |
-
input_frames = input_frames.to(device=video_pipe.image_encoder.device, dtype=video_pipe.image_encoder.dtype)
|
183 |
positive_image_cond = video_pipe.encode_clip_vision(input_frames)
|
184 |
positive_image_cond = video_pipe.image_projection(positive_image_cond)
|
185 |
negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
|
186 |
negative_image_cond = video_pipe.image_projection(negative_image_cond)
|
187 |
|
188 |
-
|
189 |
-
input_frames = input_frames.to(device=video_pipe.vae.device, dtype=video_pipe.vae.dtype)
|
190 |
input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
|
191 |
first_frame = input_frame_latents[:, :, 0]
|
192 |
last_frame = input_frame_latents[:, :, 1]
|
193 |
concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
|
194 |
|
195 |
-
memory_management.load_models_to_gpu([video_pipe.unet])
|
196 |
latents = video_pipe(
|
197 |
batch_size=1,
|
198 |
steps=int(steps),
|
@@ -206,12 +190,11 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
|
|
206 |
progress_tqdm=progress_tqdm
|
207 |
)
|
208 |
|
209 |
-
memory_management.load_models_to_gpu([video_pipe.vae])
|
210 |
video = video_pipe.decode_latents(latents, vae_hidden_states)
|
211 |
return video, image_1, image_2
|
212 |
|
213 |
|
214 |
-
@
|
215 |
def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
|
216 |
result_frames = []
|
217 |
cropped_images = []
|
|
|
12 |
import numpy as np
|
13 |
import torch
|
14 |
import wd14tagger
|
|
|
15 |
import uuid
|
16 |
|
17 |
from PIL import Image
|
|
|
23 |
from transformers import CLIPTextModel, CLIPTokenizer
|
24 |
from diffusers_vdm.pipeline import LatentVideoDiffusionPipeline
|
25 |
from diffusers_vdm.utils import resize_and_center_crop, save_bcthw_as_mp4
|
26 |
+
import spaces
|
27 |
|
28 |
+
# Disable gradients globally
|
29 |
+
torch.set_grad_enabled(False)
|
30 |
|
31 |
class ModifiedUNet(UNet2DConditionModel):
|
32 |
@classmethod
|
|
|
39 |
|
40 |
model_name = 'lllyasviel/paints_undo_single_frame'
|
41 |
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
|
42 |
+
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder").to(torch.float16).to("cuda")
|
43 |
+
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae").to(torch.bfloat16).to("cuda") # bfloat16 vae
|
44 |
+
unet = ModifiedUNet.from_pretrained(model_name, subfolder="unet").to(torch.float16).to("cuda")
|
45 |
|
46 |
unet.set_attn_processor(AttnProcessor2_0())
|
47 |
vae.set_attn_processor(AttnProcessor2_0())
|
|
|
49 |
video_pipe = LatentVideoDiffusionPipeline.from_pretrained(
|
50 |
'lllyasviel/paints_undo_multi_frame',
|
51 |
fp16=True
|
52 |
+
).to("cuda")
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
k_sampler = KDiffusionSampler(
|
55 |
unet=unet,
|
|
|
71 |
return best_bucket
|
72 |
|
73 |
|
|
|
74 |
def encode_cropped_prompt_77tokens(txt: str):
|
|
|
75 |
cond_ids = tokenizer(txt,
|
76 |
padding="max_length",
|
77 |
max_length=tokenizer.model_max_length,
|
78 |
truncation=True,
|
79 |
+
return_tensors="pt").input_ids.to(device="cuda")
|
80 |
text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
|
81 |
return text_cond
|
82 |
|
83 |
|
|
|
84 |
def pytorch2numpy(imgs):
|
85 |
results = []
|
86 |
for x in imgs:
|
|
|
91 |
return results
|
92 |
|
93 |
|
|
|
94 |
def numpy2pytorch(imgs):
|
95 |
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
96 |
h = h.movedim(-1, 1)
|
|
|
103 |
return np.array(resized_image)
|
104 |
|
105 |
|
|
|
106 |
def interrogator_process(x):
|
107 |
+
image_description = wd14tagger.default_interrogator(x)
|
108 |
+
return image_description
|
109 |
|
110 |
|
111 |
+
@spaces.GPU()
|
112 |
def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, steps, n_prompt, cfg,
|
113 |
progress=gr.Progress()):
|
114 |
+
rng = torch.Generator(device="cuda").manual_seed(int(seed))
|
115 |
|
|
|
116 |
fg = resize_and_center_crop(input_fg, image_width, image_height)
|
117 |
+
concat_conds = numpy2pytorch([fg]).clone().detach().to(device="cuda", dtype=vae.dtype)
|
118 |
concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
|
119 |
|
|
|
120 |
conds = encode_cropped_prompt_77tokens(prompt)
|
121 |
unconds = encode_cropped_prompt_77tokens(n_prompt)
|
122 |
|
123 |
+
fs = torch.tensor(input_undo_steps).to(device="cuda", dtype=torch.long)
|
|
|
124 |
initial_latents = torch.zeros_like(concat_conds)
|
125 |
+
concat_conds = concat_conds.to(device="cuda", dtype=unet.dtype)
|
126 |
latents = k_sampler(
|
127 |
initial_latent=initial_latents,
|
128 |
strength=1.0,
|
|
|
137 |
progress_tqdm=functools.partial(progress.tqdm, desc='Generating Key Frames')
|
138 |
).to(vae.dtype) / vae.config.scaling_factor
|
139 |
|
|
|
140 |
pixels = vae.decode(latents).sample
|
141 |
pixels = pytorch2numpy(pixels)
|
142 |
pixels = [fg] + pixels + [np.zeros_like(fg) + 255]
|
|
|
144 |
return pixels
|
145 |
|
146 |
|
|
|
147 |
def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=7.5, fs=3, progress_tqdm=None):
|
148 |
random.seed(seed)
|
149 |
np.random.seed(seed)
|
|
|
162 |
input_frames = numpy2pytorch([image_1, image_2])
|
163 |
input_frames = input_frames.unsqueeze(0).movedim(1, 2)
|
164 |
|
|
|
165 |
positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
|
166 |
negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
|
167 |
|
168 |
+
input_frames = input_frames.to(device="cuda", dtype=video_pipe.image_encoder.dtype)
|
|
|
169 |
positive_image_cond = video_pipe.encode_clip_vision(input_frames)
|
170 |
positive_image_cond = video_pipe.image_projection(positive_image_cond)
|
171 |
negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
|
172 |
negative_image_cond = video_pipe.image_projection(negative_image_cond)
|
173 |
|
174 |
+
input_frames = input_frames.to(device="cuda", dtype=video_pipe.vae.dtype)
|
|
|
175 |
input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
|
176 |
first_frame = input_frame_latents[:, :, 0]
|
177 |
last_frame = input_frame_latents[:, :, 1]
|
178 |
concat_cond = torch.stack([first_frame] + [torch.zeros_like(first_frame)] * (frames - 2) + [last_frame], dim=2)
|
179 |
|
|
|
180 |
latents = video_pipe(
|
181 |
batch_size=1,
|
182 |
steps=int(steps),
|
|
|
190 |
progress_tqdm=progress_tqdm
|
191 |
)
|
192 |
|
|
|
193 |
video = video_pipe.decode_latents(latents, vae_hidden_states)
|
194 |
return video, image_1, image_2
|
195 |
|
196 |
|
197 |
+
@spaces.GPU(duration=360)
|
198 |
def process_video(keyframes, prompt, steps, cfg, fps, seed, progress=gr.Progress()):
|
199 |
result_frames = []
|
200 |
cropped_images = []
|