Spaces:
Running
on
Zero
Running
on
Zero
resolve deps
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +11 -9
- app.py +271 -0
- app_full.py +243 -0
- environment.yml +402 -0
- gifs_filter.py +68 -0
- invert_utils.py +89 -0
- read_vids.py +27 -0
- requirements.txt +44 -0
- static/app_tmp/gif_logs/vid_sketch10-rand0_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand0_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand0_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand0_dfcba486-0d8c-4d68-9689-97f1fb889213.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand1_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand1_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand1_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand1_dfcba486-0d8c-4d68-9689-97f1fb889213.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand2_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand2_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand2_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand2_dfcba486-0d8c-4d68-9689-97f1fb889213.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand3_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand3_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand3_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand3_dfcba486-0d8c-4d68-9689-97f1fb889213.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand4_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand4_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand4_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand4_dfcba486-0d8c-4d68-9689-97f1fb889213.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand5_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand6_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand7_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand8_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch10-rand9_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch3-rand0_875203a1-f830-46e7-a287-4a0bc2c3a648.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch3-rand1_875203a1-f830-46e7-a287-4a0bc2c3a648.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch3-rand2_875203a1-f830-46e7-a287-4a0bc2c3a648.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch3-rand3_875203a1-f830-46e7-a287-4a0bc2c3a648.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch3-rand4_875203a1-f830-46e7-a287-4a0bc2c3a648.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch8-rand0_47fc0372-4688-4a2a-abb3-817ccfee8816.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch8-rand0_77158110-9239-4771-bb44-a83c3aa47567.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch8-rand0_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch8-rand1_47fc0372-4688-4a2a-abb3-817ccfee8816.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch8-rand1_77158110-9239-4771-bb44-a83c3aa47567.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch8-rand1_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch8-rand2_47fc0372-4688-4a2a-abb3-817ccfee8816.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch8-rand2_77158110-9239-4771-bb44-a83c3aa47567.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch8-rand2_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch8-rand3_47fc0372-4688-4a2a-abb3-817ccfee8816.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch8-rand3_77158110-9239-4771-bb44-a83c3aa47567.gif +0 -0
- static/app_tmp/gif_logs/vid_sketch8-rand3_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif +0 -0
README.md
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
---
|
2 |
-
title: FlipSketch
|
3 |
-
emoji: 🚀
|
4 |
-
colorFrom:
|
5 |
-
colorTo: green
|
6 |
-
sdk:
|
7 |
-
|
8 |
-
|
9 |
-
short_description: Sketch Animations
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: FlipSketch
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
|
|
9 |
---
|
10 |
|
11 |
+
|
12 |
+
# FlipSketch
|
13 |
+
|
14 |
+
FlipSketch: Flipping assets Drawings to Text-Guided Sketch Animations
|
app.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, render_template, request, jsonify
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import subprocess
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
import warnings
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image, ImageSequence
|
10 |
+
from moviepy.editor import VideoFileClip
|
11 |
+
import imageio
|
12 |
+
import uuid
|
13 |
+
|
14 |
+
from diffusers import (
|
15 |
+
TextToVideoSDPipeline,
|
16 |
+
AutoencoderKL,
|
17 |
+
DDPMScheduler,
|
18 |
+
DDIMScheduler,
|
19 |
+
UNet3DConditionModel,
|
20 |
+
)
|
21 |
+
import time
|
22 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
23 |
+
|
24 |
+
from diffusers.utils import export_to_video
|
25 |
+
from gifs_filter import filter
|
26 |
+
from invert_utils import ddim_inversion as dd_inversion
|
27 |
+
from text2vid_modded import TextToVideoSDPipelineModded
|
28 |
+
|
29 |
+
|
30 |
+
def run_setup():
|
31 |
+
try:
|
32 |
+
# Step 1: Install Git LFS
|
33 |
+
subprocess.run(["git", "lfs", "install"], check=True)
|
34 |
+
|
35 |
+
# Step 2: Clone the repository
|
36 |
+
repo_url = "https://huggingface.co/Hmrishav/t2v_sketch-lora"
|
37 |
+
subprocess.run(["git", "clone", repo_url], check=True)
|
38 |
+
|
39 |
+
# Step 3: Move the checkpoint file
|
40 |
+
source = "t2v_sketch-lora/checkpoint-2500"
|
41 |
+
destination = "./checkpoint-2500/"
|
42 |
+
os.rename(source, destination)
|
43 |
+
|
44 |
+
print("Setup completed successfully!")
|
45 |
+
except subprocess.CalledProcessError as e:
|
46 |
+
print(f"Error during setup: {e}")
|
47 |
+
except FileNotFoundError as e:
|
48 |
+
print(f"File operation error: {e}")
|
49 |
+
except Exception as e:
|
50 |
+
print(f"Unexpected error: {e}")
|
51 |
+
|
52 |
+
# Automatically run setup during app initialization
|
53 |
+
run_setup()
|
54 |
+
|
55 |
+
|
56 |
+
# Flask app setup
|
57 |
+
app = Flask(__name__)
|
58 |
+
app.config['UPLOAD_FOLDER'] = 'static/uploads'
|
59 |
+
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
|
60 |
+
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
61 |
+
|
62 |
+
# Environment setup
|
63 |
+
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1"
|
64 |
+
LORA_CHECKPOINT = "checkpoint-2500"
|
65 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
66 |
+
dtype = torch.bfloat16
|
67 |
+
|
68 |
+
# Helper functions
|
69 |
+
|
70 |
+
def cleanup_old_files(directory, age_in_seconds = 600):
|
71 |
+
"""
|
72 |
+
Deletes files older than a certain age in the specified directory.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
directory (str): The directory to clean up.
|
76 |
+
age_in_seconds (int): The age in seconds; files older than this will be deleted.
|
77 |
+
"""
|
78 |
+
now = time.time()
|
79 |
+
for filename in os.listdir(directory):
|
80 |
+
file_path = os.path.join(directory, filename)
|
81 |
+
# Only delete files (not directories)
|
82 |
+
if os.path.isfile(file_path):
|
83 |
+
file_age = now - os.path.getmtime(file_path)
|
84 |
+
if file_age > age_in_seconds:
|
85 |
+
try:
|
86 |
+
os.remove(file_path)
|
87 |
+
print(f"Deleted old file: {file_path}")
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error deleting file {file_path}: {e}")
|
90 |
+
|
91 |
+
def load_frames(image: Image, mode='RGBA'):
|
92 |
+
return np.array([np.array(frame.convert(mode)) for frame in ImageSequence.Iterator(image)])
|
93 |
+
|
94 |
+
def save_gif(frames, path):
|
95 |
+
imageio.mimsave(path, [frame.astype(np.uint8) for frame in frames], format='GIF', duration=1/10)
|
96 |
+
|
97 |
+
def load_image(imgname, target_size=None):
|
98 |
+
pil_img = Image.open(imgname).convert('RGB')
|
99 |
+
if target_size:
|
100 |
+
if isinstance(target_size, int):
|
101 |
+
target_size = (target_size, target_size)
|
102 |
+
pil_img = pil_img.resize(target_size, Image.Resampling.LANCZOS)
|
103 |
+
return torchvision.transforms.ToTensor()(pil_img).unsqueeze(0) # Add batch dimension
|
104 |
+
|
105 |
+
def prepare_latents(pipe, x_aug):
|
106 |
+
with torch.cuda.amp.autocast():
|
107 |
+
batch_size, num_frames, channels, height, width = x_aug.shape
|
108 |
+
x_aug = x_aug.reshape(batch_size * num_frames, channels, height, width)
|
109 |
+
latents = pipe.vae.encode(x_aug).latent_dist.sample()
|
110 |
+
latents = latents.view(batch_size, num_frames, -1, latents.shape[2], latents.shape[3])
|
111 |
+
latents = latents.permute(0, 2, 1, 3, 4)
|
112 |
+
return pipe.vae.config.scaling_factor * latents
|
113 |
+
|
114 |
+
@torch.no_grad()
|
115 |
+
def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
|
116 |
+
input_img = [load_image(load_name, 256).to(device, dtype=dtype).unsqueeze(1)] * 5
|
117 |
+
input_img = torch.cat(input_img, dim=1)
|
118 |
+
latents = prepare_latents(pipe, input_img).to(torch.bfloat16)
|
119 |
+
inv.set_timesteps(25)
|
120 |
+
id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1].to(dtype)
|
121 |
+
return torch.mean(id_latents, dim=2, keepdim=True)
|
122 |
+
|
123 |
+
def load_primary_models(pretrained_model_path):
|
124 |
+
return (
|
125 |
+
DDPMScheduler.from_config(pretrained_model_path, subfolder="scheduler"),
|
126 |
+
CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer"),
|
127 |
+
CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder"),
|
128 |
+
AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae"),
|
129 |
+
UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet"),
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
def initialize_pipeline(model: str, device: str = "cuda"):
|
134 |
+
with warnings.catch_warnings():
|
135 |
+
warnings.simplefilter("ignore")
|
136 |
+
scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
|
137 |
+
pipe = TextToVideoSDPipeline.from_pretrained(
|
138 |
+
pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
|
139 |
+
scheduler=scheduler,
|
140 |
+
tokenizer=tokenizer,
|
141 |
+
text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16),
|
142 |
+
vae=vae.to(device=device, dtype=torch.bfloat16),
|
143 |
+
unet=unet.to(device=device, dtype=torch.bfloat16),
|
144 |
+
)
|
145 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
146 |
+
return pipe, pipe.scheduler
|
147 |
+
|
148 |
+
pipe_inversion, inv = initialize_pipeline(LORA_CHECKPOINT, device)
|
149 |
+
pipe = TextToVideoSDPipelineModded.from_pretrained(
|
150 |
+
pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
|
151 |
+
scheduler=pipe_inversion.scheduler,
|
152 |
+
tokenizer=pipe_inversion.tokenizer,
|
153 |
+
text_encoder=pipe_inversion.text_encoder,
|
154 |
+
vae=pipe_inversion.vae,
|
155 |
+
unet=pipe_inversion.unet,
|
156 |
+
).to(device)
|
157 |
+
|
158 |
+
@torch.no_grad()
|
159 |
+
def process(num_frames, num_seeds, generator, exp_dir, load_name, caption, lambda_):
|
160 |
+
pipe_inversion.to(device)
|
161 |
+
id_latents = invert(pipe_inversion, inv, load_name).to(device, dtype=dtype)
|
162 |
+
latents = id_latents.repeat(num_seeds, 1, 1, 1, 1)
|
163 |
+
generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(num_seeds)]
|
164 |
+
video_frames = pipe(
|
165 |
+
prompt=caption,
|
166 |
+
negative_prompt="",
|
167 |
+
num_frames=num_frames,
|
168 |
+
num_inference_steps=25,
|
169 |
+
inv_latents=latents,
|
170 |
+
guidance_scale=9,
|
171 |
+
generator=generator,
|
172 |
+
lambda_=lambda_,
|
173 |
+
).frames
|
174 |
+
try:
|
175 |
+
load_name = load_name.split("/")[-1]
|
176 |
+
except:
|
177 |
+
pass
|
178 |
+
gifs = []
|
179 |
+
for seed in range(num_seeds):
|
180 |
+
vid_name = f"{exp_dir}/mp4_logs/vid_{load_name[:-4]}-rand{seed}.mp4"
|
181 |
+
gif_name = f"{exp_dir}/gif_logs/vid_{load_name[:-4]}-rand{seed}.gif"
|
182 |
+
video_path = export_to_video(video_frames[seed], output_video_path=vid_name)
|
183 |
+
VideoFileClip(vid_name).write_gif(gif_name)
|
184 |
+
with Image.open(gif_name) as im:
|
185 |
+
frames = load_frames(im)
|
186 |
+
|
187 |
+
frames_collect = np.empty((0, 1024, 1024), int)
|
188 |
+
for frame in frames:
|
189 |
+
frame = cv2.resize(frame, (1024, 1024))[:, :, :3]
|
190 |
+
frame = cv2.cvtColor(255 - frame, cv2.COLOR_RGB2GRAY)
|
191 |
+
|
192 |
+
_, frame = cv2.threshold(255 - frame, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
193 |
+
|
194 |
+
frames_collect = np.append(frames_collect, [frame], axis=0)
|
195 |
+
|
196 |
+
save_gif(frames_collect, gif_name)
|
197 |
+
gifs.append(gif_name)
|
198 |
+
|
199 |
+
return gifs
|
200 |
+
|
201 |
+
|
202 |
+
def generate_gifs(filepath, prompt, num_seeds=5, lambda_=0):
|
203 |
+
exp_dir = "static/app_tmp"
|
204 |
+
os.makedirs(exp_dir, exist_ok=True)
|
205 |
+
gifs = process(
|
206 |
+
num_frames=10,
|
207 |
+
num_seeds=num_seeds,
|
208 |
+
generator=None,
|
209 |
+
exp_dir=exp_dir,
|
210 |
+
load_name=filepath,
|
211 |
+
caption=prompt,
|
212 |
+
lambda_=lambda_
|
213 |
+
)
|
214 |
+
return gifs
|
215 |
+
|
216 |
+
@app.route('/')
|
217 |
+
def index():
|
218 |
+
return render_template('index.html')
|
219 |
+
|
220 |
+
@app.route('/generate', methods=['POST'])
|
221 |
+
def generate():
|
222 |
+
|
223 |
+
directories_to_clean = [
|
224 |
+
app.config['UPLOAD_FOLDER'],
|
225 |
+
'static/app_tmp/mp4_logs',
|
226 |
+
'static/app_tmp/gif_logs',
|
227 |
+
'static/app_tmp/png_logs'
|
228 |
+
]
|
229 |
+
|
230 |
+
# Perform cleanup
|
231 |
+
os.makedirs('static/app_tmp', exist_ok=True)
|
232 |
+
for directory in directories_to_clean:
|
233 |
+
os.makedirs(directory, exist_ok=True) # Ensure the directory exists
|
234 |
+
cleanup_old_files(directory)
|
235 |
+
|
236 |
+
prompt = request.form.get('prompt', '')
|
237 |
+
num_gifs = int(request.form.get('seeds', 3))
|
238 |
+
lambda_value = 1 - float(request.form.get('lambda', 0.5))
|
239 |
+
selected_example = request.form.get('selected_example', None)
|
240 |
+
file = request.files.get('image')
|
241 |
+
|
242 |
+
if not file and not selected_example:
|
243 |
+
return jsonify({'error': 'No image file provided or example selected'}), 400
|
244 |
+
|
245 |
+
if selected_example:
|
246 |
+
# Use the selected example image
|
247 |
+
filepath = os.path.join('static', 'examples', selected_example)
|
248 |
+
unique_id = None # No need for unique ID
|
249 |
+
else:
|
250 |
+
# Save the uploaded image
|
251 |
+
unique_id = str(uuid.uuid4())
|
252 |
+
filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{unique_id}_uploaded_image.png")
|
253 |
+
file.save(filepath)
|
254 |
+
|
255 |
+
generated_gifs = generate_gifs(filepath, prompt, num_seeds=num_gifs, lambda_=lambda_value)
|
256 |
+
|
257 |
+
unique_id = str(uuid.uuid4())
|
258 |
+
# Append unique id to each gif path
|
259 |
+
for i in range(len(generated_gifs)):
|
260 |
+
os.rename(generated_gifs[i], f"{generated_gifs[i].split('.')[0]}_{unique_id}.gif")
|
261 |
+
generated_gifs[i] = f"{generated_gifs[i].split('.')[0]}_{unique_id}.gif"
|
262 |
+
# Move the generated gifs to the static folder
|
263 |
+
|
264 |
+
|
265 |
+
filtered_gifs = filter(generated_gifs, filepath)
|
266 |
+
return jsonify({'gifs': filtered_gifs, 'prompt': prompt})
|
267 |
+
|
268 |
+
if __name__ == '__main__':
|
269 |
+
|
270 |
+
|
271 |
+
app.run(debug=True)
|
app_full.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, render_template, request, jsonify
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
import warnings
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image, ImageSequence
|
9 |
+
from moviepy.editor import VideoFileClip
|
10 |
+
import imageio
|
11 |
+
import uuid
|
12 |
+
|
13 |
+
from diffusers import (
|
14 |
+
TextToVideoSDPipeline,
|
15 |
+
AutoencoderKL,
|
16 |
+
DDPMScheduler,
|
17 |
+
DDIMScheduler,
|
18 |
+
UNet3DConditionModel,
|
19 |
+
)
|
20 |
+
import time
|
21 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
22 |
+
|
23 |
+
from diffusers.utils import export_to_video
|
24 |
+
from gifs_filter import filter
|
25 |
+
from invert_utils import ddim_inversion as dd_inversion
|
26 |
+
from text2vid_modded_full import TextToVideoSDPipelineModded
|
27 |
+
|
28 |
+
# Flask app setup
|
29 |
+
app = Flask(__name__)
|
30 |
+
app.config['UPLOAD_FOLDER'] = 'static/uploads'
|
31 |
+
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
|
32 |
+
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
33 |
+
|
34 |
+
# Environment setup
|
35 |
+
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1"
|
36 |
+
LORA_CHECKPOINT = "checkpoint-2500"
|
37 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
38 |
+
dtype = torch.bfloat16
|
39 |
+
|
40 |
+
# Helper functions
|
41 |
+
|
42 |
+
def cleanup_old_files(directory, age_in_seconds = 600):
|
43 |
+
"""
|
44 |
+
Deletes files older than a certain age in the specified directory.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
directory (str): The directory to clean up.
|
48 |
+
age_in_seconds (int): The age in seconds; files older than this will be deleted.
|
49 |
+
"""
|
50 |
+
now = time.time()
|
51 |
+
for filename in os.listdir(directory):
|
52 |
+
file_path = os.path.join(directory, filename)
|
53 |
+
# Only delete files (not directories)
|
54 |
+
if os.path.isfile(file_path):
|
55 |
+
file_age = now - os.path.getmtime(file_path)
|
56 |
+
if file_age > age_in_seconds:
|
57 |
+
try:
|
58 |
+
os.remove(file_path)
|
59 |
+
print(f"Deleted old file: {file_path}")
|
60 |
+
except Exception as e:
|
61 |
+
print(f"Error deleting file {file_path}: {e}")
|
62 |
+
|
63 |
+
def load_frames(image: Image, mode='RGBA'):
|
64 |
+
return np.array([np.array(frame.convert(mode)) for frame in ImageSequence.Iterator(image)])
|
65 |
+
|
66 |
+
def save_gif(frames, path):
|
67 |
+
imageio.mimsave(path, [frame.astype(np.uint8) for frame in frames], format='GIF', duration=1/10)
|
68 |
+
|
69 |
+
def load_image(imgname, target_size=None):
|
70 |
+
pil_img = Image.open(imgname).convert('RGB')
|
71 |
+
if target_size:
|
72 |
+
if isinstance(target_size, int):
|
73 |
+
target_size = (target_size, target_size)
|
74 |
+
pil_img = pil_img.resize(target_size, Image.Resampling.LANCZOS)
|
75 |
+
return torchvision.transforms.ToTensor()(pil_img).unsqueeze(0) # Add batch dimension
|
76 |
+
|
77 |
+
def prepare_latents(pipe, x_aug):
|
78 |
+
with torch.cuda.amp.autocast():
|
79 |
+
batch_size, num_frames, channels, height, width = x_aug.shape
|
80 |
+
x_aug = x_aug.reshape(batch_size * num_frames, channels, height, width)
|
81 |
+
latents = pipe.vae.encode(x_aug).latent_dist.sample()
|
82 |
+
latents = latents.view(batch_size, num_frames, -1, latents.shape[2], latents.shape[3])
|
83 |
+
latents = latents.permute(0, 2, 1, 3, 4)
|
84 |
+
return pipe.vae.config.scaling_factor * latents
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
|
88 |
+
input_img = [load_image(load_name, 256).to(device, dtype=dtype).unsqueeze(1)] * 5
|
89 |
+
input_img = torch.cat(input_img, dim=1)
|
90 |
+
latents = prepare_latents(pipe, input_img).to(torch.bfloat16)
|
91 |
+
inv.set_timesteps(25)
|
92 |
+
id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1].to(dtype)
|
93 |
+
return torch.mean(id_latents, dim=2, keepdim=True)
|
94 |
+
|
95 |
+
def load_primary_models(pretrained_model_path):
|
96 |
+
return (
|
97 |
+
DDPMScheduler.from_config(pretrained_model_path, subfolder="scheduler"),
|
98 |
+
CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer"),
|
99 |
+
CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder"),
|
100 |
+
AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae"),
|
101 |
+
UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet"),
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
def initialize_pipeline(model: str, device: str = "cuda"):
|
106 |
+
with warnings.catch_warnings():
|
107 |
+
warnings.simplefilter("ignore")
|
108 |
+
scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
|
109 |
+
pipe = TextToVideoSDPipeline.from_pretrained(
|
110 |
+
pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
|
111 |
+
scheduler=scheduler,
|
112 |
+
tokenizer=tokenizer,
|
113 |
+
text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16),
|
114 |
+
vae=vae.to(device=device, dtype=torch.bfloat16),
|
115 |
+
unet=unet.to(device=device, dtype=torch.bfloat16),
|
116 |
+
)
|
117 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
118 |
+
return pipe, pipe.scheduler
|
119 |
+
|
120 |
+
pipe_inversion, inv = initialize_pipeline(LORA_CHECKPOINT, device)
|
121 |
+
pipe = TextToVideoSDPipelineModded.from_pretrained(
|
122 |
+
pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
|
123 |
+
scheduler=pipe_inversion.scheduler,
|
124 |
+
tokenizer=pipe_inversion.tokenizer,
|
125 |
+
text_encoder=pipe_inversion.text_encoder,
|
126 |
+
vae=pipe_inversion.vae,
|
127 |
+
unet=pipe_inversion.unet,
|
128 |
+
).to(device)
|
129 |
+
|
130 |
+
@torch.no_grad()
|
131 |
+
def process(num_frames, num_seeds, generator, exp_dir, load_name, caption, lambda_):
|
132 |
+
pipe_inversion.to(device)
|
133 |
+
id_latents = invert(pipe_inversion, inv, load_name).to(device, dtype=dtype)
|
134 |
+
latents = id_latents.repeat(num_seeds, 1, 1, 1, 1)
|
135 |
+
generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(num_seeds)]
|
136 |
+
video_frames = pipe(
|
137 |
+
prompt=caption,
|
138 |
+
negative_prompt="",
|
139 |
+
num_frames=num_frames,
|
140 |
+
num_inference_steps=25,
|
141 |
+
inv_latents=latents,
|
142 |
+
guidance_scale=9,
|
143 |
+
generator=generator,
|
144 |
+
lambda_=lambda_,
|
145 |
+
).frames
|
146 |
+
try:
|
147 |
+
load_name = load_name.split("/")[-1]
|
148 |
+
except:
|
149 |
+
pass
|
150 |
+
gifs = []
|
151 |
+
for seed in range(num_seeds):
|
152 |
+
vid_name = f"{exp_dir}/mp4_logs/vid_{load_name[:-4]}-rand{seed}.mp4"
|
153 |
+
gif_name = f"{exp_dir}/gif_logs/vid_{load_name[:-4]}-rand{seed}.gif"
|
154 |
+
video_path = export_to_video(video_frames[seed], output_video_path=vid_name)
|
155 |
+
VideoFileClip(vid_name).write_gif(gif_name)
|
156 |
+
with Image.open(gif_name) as im:
|
157 |
+
frames = load_frames(im)
|
158 |
+
|
159 |
+
frames_collect = np.empty((0, 1024, 1024), int)
|
160 |
+
for frame in frames:
|
161 |
+
frame = cv2.resize(frame, (1024, 1024))[:, :, :3]
|
162 |
+
frame = cv2.cvtColor(255 - frame, cv2.COLOR_RGB2GRAY)
|
163 |
+
|
164 |
+
_, frame = cv2.threshold(255 - frame, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
165 |
+
|
166 |
+
frames_collect = np.append(frames_collect, [frame], axis=0)
|
167 |
+
|
168 |
+
save_gif(frames_collect, gif_name)
|
169 |
+
gifs.append(gif_name)
|
170 |
+
|
171 |
+
return gifs
|
172 |
+
|
173 |
+
|
174 |
+
def generate_gifs(filepath, prompt, num_seeds=5, lambda_=0):
|
175 |
+
exp_dir = "static/app_tmp"
|
176 |
+
os.makedirs(exp_dir, exist_ok=True)
|
177 |
+
gifs = process(
|
178 |
+
num_frames=10,
|
179 |
+
num_seeds=num_seeds,
|
180 |
+
generator=None,
|
181 |
+
exp_dir=exp_dir,
|
182 |
+
load_name=filepath,
|
183 |
+
caption=prompt,
|
184 |
+
lambda_=lambda_
|
185 |
+
)
|
186 |
+
return gifs
|
187 |
+
|
188 |
+
@app.route('/')
|
189 |
+
def index():
|
190 |
+
return render_template('index.html')
|
191 |
+
|
192 |
+
@app.route('/generate', methods=['POST'])
|
193 |
+
def generate():
|
194 |
+
|
195 |
+
directories_to_clean = [
|
196 |
+
app.config['UPLOAD_FOLDER'],
|
197 |
+
'static/app_tmp/mp4_logs',
|
198 |
+
'static/app_tmp/gif_logs',
|
199 |
+
'static/app_tmp/png_logs'
|
200 |
+
]
|
201 |
+
|
202 |
+
# Perform cleanup
|
203 |
+
os.makedirs('static/app_tmp', exist_ok=True)
|
204 |
+
for directory in directories_to_clean:
|
205 |
+
os.makedirs(directory, exist_ok=True) # Ensure the directory exists
|
206 |
+
cleanup_old_files(directory)
|
207 |
+
|
208 |
+
prompt = request.form.get('prompt', '')
|
209 |
+
num_gifs = int(request.form.get('seeds', 3))
|
210 |
+
lambda_value = 1 - float(request.form.get('lambda', 0.5))
|
211 |
+
selected_example = request.form.get('selected_example', None)
|
212 |
+
file = request.files.get('image')
|
213 |
+
|
214 |
+
if not file and not selected_example:
|
215 |
+
return jsonify({'error': 'No image file provided or example selected'}), 400
|
216 |
+
|
217 |
+
if selected_example:
|
218 |
+
# Use the selected example image
|
219 |
+
filepath = os.path.join('static', 'examples', selected_example)
|
220 |
+
unique_id = None # No need for unique ID
|
221 |
+
else:
|
222 |
+
# Save the uploaded image
|
223 |
+
unique_id = str(uuid.uuid4())
|
224 |
+
filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{unique_id}_uploaded_image.png")
|
225 |
+
file.save(filepath)
|
226 |
+
|
227 |
+
generated_gifs = generate_gifs(filepath, prompt, num_seeds=num_gifs, lambda_=lambda_value)
|
228 |
+
|
229 |
+
unique_id = str(uuid.uuid4())
|
230 |
+
# Append unique id to each gif path
|
231 |
+
for i in range(len(generated_gifs)):
|
232 |
+
os.rename(generated_gifs[i], f"{generated_gifs[i].split('.')[0]}_{unique_id}.gif")
|
233 |
+
generated_gifs[i] = f"{generated_gifs[i].split('.')[0]}_{unique_id}.gif"
|
234 |
+
# Move the generated gifs to the static folder
|
235 |
+
|
236 |
+
|
237 |
+
filtered_gifs = filter(generated_gifs, filepath)
|
238 |
+
return jsonify({'gifs': filtered_gifs, 'prompt': prompt})
|
239 |
+
|
240 |
+
if __name__ == '__main__':
|
241 |
+
|
242 |
+
|
243 |
+
app.run(debug=True)
|
environment.yml
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: flipsketch
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
- defaults
|
7 |
+
dependencies:
|
8 |
+
- _libgcc_mutex=0.1=main
|
9 |
+
- _openmp_mutex=5.1=1_gnu
|
10 |
+
- asttokens=2.4.1=pyhd8ed1ab_0
|
11 |
+
- blas=1.0=mkl
|
12 |
+
- brotli-python=1.0.9=py310hd8f1fbe_7
|
13 |
+
- bzip2=1.0.8=h7f98852_4
|
14 |
+
- ca-certificates=2024.2.2=hbcca054_0
|
15 |
+
- certifi=2024.2.2=pyhd8ed1ab_0
|
16 |
+
- charset-normalizer=2.0.4=pyhd8ed1ab_0
|
17 |
+
- comm=0.2.2=pyhd8ed1ab_0
|
18 |
+
- cuda=11.6.1=0
|
19 |
+
- cuda-cccl=11.6.55=hf6102b2_0
|
20 |
+
- cuda-command-line-tools=11.6.2=0
|
21 |
+
- cuda-compiler=11.6.2=0
|
22 |
+
- cuda-cudart=11.6.55=he381448_0
|
23 |
+
- cuda-cudart-dev=11.6.55=h42ad0f4_0
|
24 |
+
- cuda-cuobjdump=11.6.124=h2eeebcb_0
|
25 |
+
- cuda-cupti=11.6.124=h86345e5_0
|
26 |
+
- cuda-cuxxfilt=11.6.124=hecbf4f6_0
|
27 |
+
- cuda-driver-dev=11.6.55=0
|
28 |
+
- cuda-gdb=12.4.127=0
|
29 |
+
- cuda-libraries=11.6.1=0
|
30 |
+
- cuda-libraries-dev=11.6.1=0
|
31 |
+
- cuda-memcheck=11.8.86=0
|
32 |
+
- cuda-nsight=12.4.127=0
|
33 |
+
- cuda-nsight-compute=12.4.1=0
|
34 |
+
- cuda-nvcc=11.6.124=hbba6d2d_0
|
35 |
+
- cuda-nvdisasm=12.4.127=0
|
36 |
+
- cuda-nvml-dev=11.6.55=haa9ef22_0
|
37 |
+
- cuda-nvprof=12.4.127=0
|
38 |
+
- cuda-nvprune=11.6.124=he22ec0a_0
|
39 |
+
- cuda-nvrtc=11.6.124=h020bade_0
|
40 |
+
- cuda-nvrtc-dev=11.6.124=h249d397_0
|
41 |
+
- cuda-nvtx=11.6.124=h0630a44_0
|
42 |
+
- cuda-nvvp=12.4.127=0
|
43 |
+
- cuda-runtime=11.6.1=0
|
44 |
+
- cuda-samples=11.6.101=h8efea70_0
|
45 |
+
- cuda-sanitizer-api=12.4.127=0
|
46 |
+
- cuda-toolkit=11.6.1=0
|
47 |
+
- cuda-tools=11.6.1=0
|
48 |
+
- cuda-visual-tools=11.6.1=0
|
49 |
+
- debugpy=1.6.7=py310h6a678d5_0
|
50 |
+
- entrypoints=0.4=pyhd8ed1ab_0
|
51 |
+
- exceptiongroup=1.2.0=pyhd8ed1ab_2
|
52 |
+
- executing=2.0.1=pyhd8ed1ab_0
|
53 |
+
- ffmpeg=4.3=hf484d3e_0
|
54 |
+
- freetype=2.12.1=h4a9f257_0
|
55 |
+
- gds-tools=1.9.1.3=0
|
56 |
+
- gmp=6.2.1=h58526e2_0
|
57 |
+
- gnutls=3.6.15=he1e5248_0
|
58 |
+
- idna=3.4=pyhd8ed1ab_0
|
59 |
+
- intel-openmp=2023.1.0=hdb19cb5_46306
|
60 |
+
- ipykernel=6.29.3=pyhd33586a_0
|
61 |
+
- jedi=0.19.1=pyhd8ed1ab_0
|
62 |
+
- jpeg=9e=h166bdaf_1
|
63 |
+
- jupyter_client=7.3.4=pyhd8ed1ab_0
|
64 |
+
- jupyter_core=5.7.2=pyh31011fe_1
|
65 |
+
- lame=3.100=h7f98852_1001
|
66 |
+
- lcms2=2.12=h3be6417_0
|
67 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
68 |
+
- lerc=3.0=h9c3ff4c_0
|
69 |
+
- libcublas=11.9.2.110=h5e84587_0
|
70 |
+
- libcublas-dev=11.9.2.110=h5c901ab_0
|
71 |
+
- libcufft=10.7.1.112=hf425ae0_0
|
72 |
+
- libcufft-dev=10.7.1.112=ha5ce4c0_0
|
73 |
+
- libcufile=1.9.1.3=0
|
74 |
+
- libcufile-dev=1.9.1.3=0
|
75 |
+
- libcurand=10.3.5.147=0
|
76 |
+
- libcurand-dev=10.3.5.147=0
|
77 |
+
- libcusolver=11.3.4.124=h33c3c4e_0
|
78 |
+
- libcusparse=11.7.2.124=h7538f96_0
|
79 |
+
- libcusparse-dev=11.7.2.124=hbbe9722_0
|
80 |
+
- libdeflate=1.17=h5eee18b_1
|
81 |
+
- libffi=3.4.4=h6a678d5_1
|
82 |
+
- libgcc-ng=11.2.0=h1234567_1
|
83 |
+
- libgomp=11.2.0=h1234567_1
|
84 |
+
- libiconv=1.16=h516909a_0
|
85 |
+
- libidn2=2.3.4=h5eee18b_0
|
86 |
+
- libnpp=11.6.3.124=hd2722f0_0
|
87 |
+
- libnpp-dev=11.6.3.124=h3c42840_0
|
88 |
+
- libnvjpeg=11.6.2.124=hd473ad6_0
|
89 |
+
- libnvjpeg-dev=11.6.2.124=hb5906b9_0
|
90 |
+
- libpng=1.6.39=h5eee18b_0
|
91 |
+
- libsodium=1.0.18=h36c2ea0_1
|
92 |
+
- libstdcxx-ng=11.2.0=he4da1e4_16
|
93 |
+
- libtasn1=4.19.0=h5eee18b_0
|
94 |
+
- libtiff=4.5.1=h6a678d5_0
|
95 |
+
- libunistring=0.9.10=h7f98852_0
|
96 |
+
- libuuid=1.41.5=h5eee18b_0
|
97 |
+
- libwebp-base=1.3.2=h5eee18b_0
|
98 |
+
- lz4-c=1.9.4=h6a678d5_1
|
99 |
+
- mkl=2023.1.0=h213fc3f_46344
|
100 |
+
- mkl-service=2.4.0=py310h5eee18b_1
|
101 |
+
- mkl_fft=1.3.8=py310h5eee18b_0
|
102 |
+
- mkl_random=1.2.4=py310hdb19cb5_0
|
103 |
+
- ncurses=6.4=h6a678d5_0
|
104 |
+
- nest-asyncio=1.6.0=pyhd8ed1ab_0
|
105 |
+
- nettle=3.7.3=hbbd107a_1
|
106 |
+
- nsight-compute=2024.1.1.4=0
|
107 |
+
- numpy-base=1.26.4=py310hb5e798b_0
|
108 |
+
- openh264=2.1.1=h780b84a_0
|
109 |
+
- openjpeg=2.4.0=h9ca470c_2
|
110 |
+
- openssl=3.0.13=h7f8727e_2
|
111 |
+
- packaging=24.0=pyhd8ed1ab_0
|
112 |
+
- parso=0.8.4=pyhd8ed1ab_0
|
113 |
+
- pexpect=4.9.0=pyhd8ed1ab_0
|
114 |
+
- pickleshare=0.7.5=py_1003
|
115 |
+
- pip=23.3.1=pyhd8ed1ab_0
|
116 |
+
- ptyprocess=0.7.0=pyhd3deb0d_0
|
117 |
+
- pure_eval=0.2.2=pyhd8ed1ab_0
|
118 |
+
- pygments=2.17.2=pyhd8ed1ab_0
|
119 |
+
- pysocks=1.7.1=pyha2e5f31_6
|
120 |
+
- python=3.10.14=h955ad1f_0
|
121 |
+
- python_abi=3.10=2_cp310
|
122 |
+
- pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0
|
123 |
+
- pytorch-cuda=11.6=h867d48c_1
|
124 |
+
- pytorch-mutex=1.0=cuda
|
125 |
+
- pyzmq=25.1.2=py310h6a678d5_0
|
126 |
+
- readline=8.2=h5eee18b_0
|
127 |
+
- requests=2.31.0=pyhd8ed1ab_0
|
128 |
+
- setuptools=68.2.2=pyhd8ed1ab_0
|
129 |
+
- six=1.16.0=pyh6c4a22f_0
|
130 |
+
- sqlite=3.41.2=h5eee18b_0
|
131 |
+
- tbb=2021.8.0=hdb19cb5_0
|
132 |
+
- tk=8.6.12=h1ccaba5_0
|
133 |
+
- torchaudio=0.13.1=py310_cu116
|
134 |
+
- tornado=6.1=py310h5764c6d_3
|
135 |
+
- typing_extensions=4.9.0=pyha770c72_0
|
136 |
+
- tzdata=2024a=h8827d51_1
|
137 |
+
- urllib3=2.1.0=pyhd8ed1ab_0
|
138 |
+
- wcwidth=0.2.13=pyhd8ed1ab_0
|
139 |
+
- wheel=0.41.2=pyhd8ed1ab_0
|
140 |
+
- xz=5.4.6=h5eee18b_1
|
141 |
+
- zeromq=4.3.5=h6a678d5_0
|
142 |
+
- zlib=1.2.13=h5eee18b_1
|
143 |
+
- zstd=1.5.5=hc292b87_2
|
144 |
+
- pip:
|
145 |
+
- absl-py==2.1.0
|
146 |
+
- accelerate==0.29.2
|
147 |
+
- addict==2.4.0
|
148 |
+
- aiofiles==23.2.1
|
149 |
+
- aiohttp==3.9.3
|
150 |
+
- aiosignal==1.3.1
|
151 |
+
- albumentations==1.3.0
|
152 |
+
- aliyun-python-sdk-core==2.15.1
|
153 |
+
- aliyun-python-sdk-kms==2.16.2
|
154 |
+
- annotated-types==0.7.0
|
155 |
+
- antlr4-python3-runtime==4.8
|
156 |
+
- anyio==4.6.2.post1
|
157 |
+
- appdirs==1.4.4
|
158 |
+
- async-timeout==4.0.3
|
159 |
+
- attrs==23.2.0
|
160 |
+
- basicsr==1.4.2
|
161 |
+
- beautifulsoup4==4.12.3
|
162 |
+
- bitsandbytes==0.35.4
|
163 |
+
- black==21.4b2
|
164 |
+
- blinker==1.8.2
|
165 |
+
- blis==0.7.11
|
166 |
+
- boto3==1.34.97
|
167 |
+
- botocore==1.34.97
|
168 |
+
- bresenham==0.2.1
|
169 |
+
- cachetools==5.3.3
|
170 |
+
- captum==0.7.0
|
171 |
+
- catalogue==2.0.10
|
172 |
+
- cffi==1.16.0
|
173 |
+
- chardet==5.2.0
|
174 |
+
- click==8.1.7
|
175 |
+
- clip==0.1.0
|
176 |
+
- cloudpickle==3.0.0
|
177 |
+
- cmake==3.25.2
|
178 |
+
- compel==2.0.3
|
179 |
+
- confection==0.1.4
|
180 |
+
- contourpy==1.2.1
|
181 |
+
- controlnet-aux==0.0.6
|
182 |
+
- crcmod==1.7
|
183 |
+
- cryptography==42.0.7
|
184 |
+
- cssselect2==0.7.0
|
185 |
+
- cycler==0.12.1
|
186 |
+
- cymem==2.0.8
|
187 |
+
- cython==3.0.10
|
188 |
+
- datasets==2.18.0
|
189 |
+
- decorator==4.4.2
|
190 |
+
- decord==0.6.0
|
191 |
+
- deepspeed==0.8.0
|
192 |
+
- diffdist==0.1
|
193 |
+
- diffusers==0.27.2
|
194 |
+
- dill==0.3.8
|
195 |
+
- docker-pycreds==0.4.0
|
196 |
+
- easydict==1.10
|
197 |
+
- einops==0.3.0
|
198 |
+
- fairscale==0.4.13
|
199 |
+
- faiss-cpu==1.8.0
|
200 |
+
- fastapi==0.115.4
|
201 |
+
- ffmpy==0.3.0
|
202 |
+
- filelock==3.13.4
|
203 |
+
- flask==3.0.3
|
204 |
+
- flatbuffers==24.3.25
|
205 |
+
- fonttools==4.51.0
|
206 |
+
- frozenlist==1.4.1
|
207 |
+
- fsspec==2024.2.0
|
208 |
+
- ftfy==6.1.1
|
209 |
+
- future==1.0.0
|
210 |
+
- fvcore==0.1.5.post20221221
|
211 |
+
- gast==0.5.4
|
212 |
+
- gdown==5.1.0
|
213 |
+
- gitdb==4.0.11
|
214 |
+
- gitpython==3.1.43
|
215 |
+
- google-auth==2.29.0
|
216 |
+
- google-auth-oauthlib==0.4.6
|
217 |
+
- gradio==5.5.0
|
218 |
+
- gradio-client==1.4.2
|
219 |
+
- grpcio==1.62.1
|
220 |
+
- h11==0.14.0
|
221 |
+
- hjson==3.1.0
|
222 |
+
- httpcore==1.0.6
|
223 |
+
- httpx==0.27.2
|
224 |
+
- huggingface-hub==0.25.2
|
225 |
+
- hydra-core==1.1.1
|
226 |
+
- imageio==2.25.1
|
227 |
+
- imageio-ffmpeg==0.4.8
|
228 |
+
- importlib-metadata==7.1.0
|
229 |
+
- inquirerpy==0.3.4
|
230 |
+
- iopath==0.1.9
|
231 |
+
- ipdb==0.13.13
|
232 |
+
- ipympl==0.9.4
|
233 |
+
- ipython==8.23.0
|
234 |
+
- ipython-genutils==0.2.0
|
235 |
+
- ipywidgets==8.1.2
|
236 |
+
- itsdangerous==2.2.0
|
237 |
+
- jax==0.4.26
|
238 |
+
- jaxlib==0.4.26
|
239 |
+
- jinja2==3.1.3
|
240 |
+
- jmespath==0.10.0
|
241 |
+
- joblib==1.4.2
|
242 |
+
- jupyterlab-widgets==3.0.10
|
243 |
+
- kiwisolver==1.4.5
|
244 |
+
- kornia==0.6.0
|
245 |
+
- lightning-utilities==0.11.2
|
246 |
+
- lmdb==1.4.1
|
247 |
+
- loguru==0.7.2
|
248 |
+
- loralib==0.1.2
|
249 |
+
- lvis==0.5.3
|
250 |
+
- lxml==5.2.1
|
251 |
+
- markdown==3.6
|
252 |
+
- markdown-it-py==3.0.0
|
253 |
+
- markupsafe==2.1.5
|
254 |
+
- matplotlib==3.8.4
|
255 |
+
- matplotlib-inline==0.1.6
|
256 |
+
- mdurl==0.1.2
|
257 |
+
- mediapipe==0.10.11
|
258 |
+
- ml-dtypes==0.4.0
|
259 |
+
- modelcards==0.1.6
|
260 |
+
- modelscope==1.14.0
|
261 |
+
- motion-vector-extractor==1.0.6
|
262 |
+
- moviepy==1.0.3
|
263 |
+
- mpmath==1.3.0
|
264 |
+
- multidict==6.0.5
|
265 |
+
- multiprocess==0.70.16
|
266 |
+
- murmurhash==1.0.10
|
267 |
+
- mypy-extensions==1.0.0
|
268 |
+
- networkx==3.3
|
269 |
+
- ninja==1.11.1.1
|
270 |
+
- nltk==3.8.1
|
271 |
+
- numpy==1.24.2
|
272 |
+
- nvidia-cublas-cu11==11.10.3.66
|
273 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
274 |
+
- nvidia-cuda-nvrtc-cu11==11.7.99
|
275 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
276 |
+
- nvidia-cuda-runtime-cu11==11.7.99
|
277 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
278 |
+
- nvidia-cudnn-cu11==8.5.0.96
|
279 |
+
- nvidia-cufft-cu12==11.0.2.54
|
280 |
+
- nvidia-curand-cu12==10.3.2.106
|
281 |
+
- nvidia-nccl-cu12==2.20.5
|
282 |
+
- nvidia-nvjitlink-cu12==12.6.77
|
283 |
+
- nvidia-nvtx-cu12==12.1.105
|
284 |
+
- oauthlib==3.2.2
|
285 |
+
- omegaconf==2.1.1
|
286 |
+
- open-clip-torch==2.0.2
|
287 |
+
- opencv-contrib-python==4.9.0.80
|
288 |
+
- opencv-python==4.6.0.66
|
289 |
+
- opencv-python-headless==4.9.0.80
|
290 |
+
- opt-einsum==3.3.0
|
291 |
+
- orjson==3.10.11
|
292 |
+
- oss2==2.18.5
|
293 |
+
- pandas==1.5.3
|
294 |
+
- pathspec==0.12.1
|
295 |
+
- pathtools==0.1.2
|
296 |
+
- peft==0.10.0
|
297 |
+
- pfzy==0.3.4
|
298 |
+
- pillow==9.5.0
|
299 |
+
- pkgconfig==1.5.5
|
300 |
+
- platformdirs==4.2.0
|
301 |
+
- portalocker==2.8.2
|
302 |
+
- preshed==3.0.9
|
303 |
+
- proglog==0.1.10
|
304 |
+
- prompt-toolkit==3.0.43
|
305 |
+
- protobuf==3.20.3
|
306 |
+
- psutil==5.9.8
|
307 |
+
- py-cpuinfo==9.0.0
|
308 |
+
- pyarrow==15.0.2
|
309 |
+
- pyarrow-hotfix==0.6
|
310 |
+
- pyasn1==0.6.0
|
311 |
+
- pyasn1-modules==0.4.0
|
312 |
+
- pyav==12.0.5
|
313 |
+
- pycocotools==2.0.7
|
314 |
+
- pycparser==2.22
|
315 |
+
- pycryptodome==3.20.0
|
316 |
+
- pydantic==2.9.2
|
317 |
+
- pydantic-core==2.23.4
|
318 |
+
- pydeprecate==0.3.1
|
319 |
+
- pydot==2.0.0
|
320 |
+
- pydub==0.25.1
|
321 |
+
- pynvml==11.5.3
|
322 |
+
- pyparsing==3.1.2
|
323 |
+
- pyre-extensions==0.0.23
|
324 |
+
- python-dateutil==2.9.0.post0
|
325 |
+
- python-multipart==0.0.12
|
326 |
+
- pytorch-lightning==1.4.2
|
327 |
+
- pytz==2024.1
|
328 |
+
- pywavelets==1.6.0
|
329 |
+
- pyyaml==6.0.1
|
330 |
+
- qudida==0.0.4
|
331 |
+
- regex==2024.4.16
|
332 |
+
- reportlab==4.1.0
|
333 |
+
- requests-oauthlib==2.0.0
|
334 |
+
- rich==13.9.4
|
335 |
+
- rsa==4.9
|
336 |
+
- ruff==0.7.2
|
337 |
+
- s3transfer==0.10.1
|
338 |
+
- safehttpx==0.1.1
|
339 |
+
- safetensors==0.4.2
|
340 |
+
- scikit-image==0.19.3
|
341 |
+
- scikit-learn==1.4.2
|
342 |
+
- scikit-video==1.1.11
|
343 |
+
- scipy==1.10.1
|
344 |
+
- semantic-version==2.10.0
|
345 |
+
- sentry-sdk==1.44.1
|
346 |
+
- setproctitle==1.3.3
|
347 |
+
- shapely==2.0.3
|
348 |
+
- shellingham==1.5.4
|
349 |
+
- simplejson==3.19.2
|
350 |
+
- smmap==5.0.1
|
351 |
+
- sniffio==1.3.1
|
352 |
+
- sortedcontainers==2.4.0
|
353 |
+
- sounddevice==0.4.6
|
354 |
+
- soupsieve==2.5
|
355 |
+
- srsly==2.4.8
|
356 |
+
- stable-diffusion-sdkit==2.1.3
|
357 |
+
- stack-data==0.6.3
|
358 |
+
- starlette==0.41.2
|
359 |
+
- svg-path==6.3
|
360 |
+
- svglib==1.5.1
|
361 |
+
- svgpathtools==1.6.1
|
362 |
+
- svgwrite==1.4.3
|
363 |
+
- sympy==1.13.3
|
364 |
+
- tabulate==0.9.0
|
365 |
+
- tb-nightly==2.17.0a20240408
|
366 |
+
- tensorboard==2.12.0
|
367 |
+
- tensorboard-data-server==0.7.0
|
368 |
+
- tensorboard-plugin-wit==1.8.1
|
369 |
+
- termcolor==2.2.0
|
370 |
+
- test-tube==0.7.5
|
371 |
+
- thinc==8.1.10
|
372 |
+
- threadpoolctl==3.5.0
|
373 |
+
- tifffile==2024.2.12
|
374 |
+
- timm==0.6.11
|
375 |
+
- tinycss2==1.2.1
|
376 |
+
- tokenizers==0.20.1
|
377 |
+
- toml==0.10.2
|
378 |
+
- tomli==2.0.1
|
379 |
+
- tomlkit==0.12.0
|
380 |
+
- torch==1.13.1
|
381 |
+
- torchmetrics==0.6.0
|
382 |
+
- torchsummary==1.5.1
|
383 |
+
- torchvision==0.14.1
|
384 |
+
- tqdm==4.64.1
|
385 |
+
- traitlets==5.14.2
|
386 |
+
- transformers==4.45.2
|
387 |
+
- triton==2.3.0
|
388 |
+
- typer==0.12.5
|
389 |
+
- typing-inspect==0.9.0
|
390 |
+
- uvicorn==0.32.0
|
391 |
+
- wandb==0.16.6
|
392 |
+
- wasabi==1.1.2
|
393 |
+
- webencodings==0.5.1
|
394 |
+
- websockets==12.0
|
395 |
+
- werkzeug==3.0.2
|
396 |
+
- widgetsnbextension==4.0.10
|
397 |
+
- xformers==0.0.16
|
398 |
+
- xxhash==3.4.1
|
399 |
+
- yacs==0.1.8
|
400 |
+
- yapf==0.40.2
|
401 |
+
- yarl==1.9.4
|
402 |
+
- zipp==3.18.1
|
gifs_filter.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# filter images
|
2 |
+
from PIL import Image, ImageSequence
|
3 |
+
import requests
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from transformers import CLIPProcessor, CLIPModel
|
8 |
+
|
9 |
+
def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
|
10 |
+
converted_len = int(clip_len * frame_sample_rate)
|
11 |
+
end_idx = np.random.randint(converted_len, seg_len)
|
12 |
+
start_idx = end_idx - converted_len
|
13 |
+
indices = np.linspace(start_idx, end_idx, num=clip_len)
|
14 |
+
indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
|
15 |
+
return indices
|
16 |
+
|
17 |
+
def load_frames(image: Image, mode='RGBA'):
|
18 |
+
return np.array([
|
19 |
+
np.array(frame.convert(mode))
|
20 |
+
for frame in ImageSequence.Iterator(image)
|
21 |
+
])
|
22 |
+
|
23 |
+
img_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
24 |
+
img_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def filter(gifs, input_image):
|
29 |
+
max_cosine = 0.9
|
30 |
+
max_gif = []
|
31 |
+
|
32 |
+
for gif in tqdm(gifs, total=len(gifs)):
|
33 |
+
with Image.open(gif) as im:
|
34 |
+
frames = load_frames(im)
|
35 |
+
|
36 |
+
frames = np.array(frames)
|
37 |
+
frames = frames[:, :, :, :3]
|
38 |
+
frames = np.transpose(frames, (0, 3, 1, 2))[1:]
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
image = Image.open(input_image)
|
43 |
+
|
44 |
+
|
45 |
+
inputs = img_processor(images=frames, return_tensors="pt", padding=False)
|
46 |
+
inputs_base = img_processor(images=image, return_tensors="pt", padding=False)
|
47 |
+
|
48 |
+
with torch.no_grad():
|
49 |
+
feat_img_base = img_model.get_image_features(pixel_values=inputs_base["pixel_values"])
|
50 |
+
feat_img_vid = img_model.get_image_features(pixel_values=inputs["pixel_values"])
|
51 |
+
cos_avg = 0
|
52 |
+
avg_score_for_vid = 0
|
53 |
+
for i in range(len(feat_img_vid)):
|
54 |
+
|
55 |
+
cosine_similarity = torch.nn.functional.cosine_similarity(
|
56 |
+
feat_img_base,
|
57 |
+
feat_img_vid[0].unsqueeze(0),
|
58 |
+
dim=1)
|
59 |
+
# print(cosine_similarity)
|
60 |
+
cos_avg += cosine_similarity.item()
|
61 |
+
|
62 |
+
cos_avg /= len(feat_img_vid)
|
63 |
+
print("Current cosine similarity: ", cos_avg)
|
64 |
+
print("Max cosine similarity: ", max_cosine)
|
65 |
+
if cos_avg > max_cosine:
|
66 |
+
# max_cosine = cos_avg
|
67 |
+
max_gif.append(gif)
|
68 |
+
return max_gif
|
invert_utils.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import imageio
|
3 |
+
import numpy as np
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
from einops import rearrange
|
11 |
+
|
12 |
+
|
13 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
|
14 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
15 |
+
outputs = []
|
16 |
+
for x in videos:
|
17 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
18 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
19 |
+
if rescale:
|
20 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
21 |
+
x = (x * 255).numpy().astype(np.uint8)
|
22 |
+
outputs.append(x)
|
23 |
+
|
24 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
25 |
+
imageio.mimsave(path, outputs, fps=fps)
|
26 |
+
|
27 |
+
|
28 |
+
# DDIM Inversion
|
29 |
+
@torch.no_grad()
|
30 |
+
def init_prompt(prompt, pipeline):
|
31 |
+
uncond_input = pipeline.tokenizer(
|
32 |
+
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
|
33 |
+
return_tensors="pt"
|
34 |
+
)
|
35 |
+
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
|
36 |
+
text_input = pipeline.tokenizer(
|
37 |
+
[prompt],
|
38 |
+
padding="max_length",
|
39 |
+
max_length=pipeline.tokenizer.model_max_length,
|
40 |
+
truncation=True,
|
41 |
+
return_tensors="pt",
|
42 |
+
)
|
43 |
+
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
|
44 |
+
context = torch.cat([uncond_embeddings, text_embeddings])
|
45 |
+
|
46 |
+
return context
|
47 |
+
|
48 |
+
|
49 |
+
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
|
50 |
+
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
|
51 |
+
timestep, next_timestep = min(
|
52 |
+
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
|
53 |
+
# try:
|
54 |
+
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
55 |
+
# except:
|
56 |
+
# alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] #if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
57 |
+
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
|
58 |
+
beta_prod_t = 1 - alpha_prod_t
|
59 |
+
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
60 |
+
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
61 |
+
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
62 |
+
return next_sample
|
63 |
+
|
64 |
+
|
65 |
+
def get_noise_pred_single(latents, t, context, unet):
|
66 |
+
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
|
67 |
+
return noise_pred
|
68 |
+
|
69 |
+
|
70 |
+
@torch.no_grad()
|
71 |
+
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
|
72 |
+
context = init_prompt(prompt, pipeline)
|
73 |
+
uncond_embeddings, cond_embeddings = context.chunk(2)
|
74 |
+
all_latent = [latent]
|
75 |
+
latent = latent.clone().detach()
|
76 |
+
for i in tqdm(range(num_inv_steps)):
|
77 |
+
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
|
78 |
+
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
|
79 |
+
noise_pred_unc = get_noise_pred_single(latent, t, uncond_embeddings, pipeline.unet)
|
80 |
+
noise_pred = noise_pred_unc + 9.0 * (noise_pred_unc - noise_pred)
|
81 |
+
latent = next_step(noise_pred, t, latent, ddim_scheduler)
|
82 |
+
all_latent.append(latent)
|
83 |
+
return all_latent
|
84 |
+
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
88 |
+
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
89 |
+
return ddim_latents
|
read_vids.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio.v3 as iio
|
2 |
+
import os
|
3 |
+
from sys import argv
|
4 |
+
video_name = argv[1]
|
5 |
+
|
6 |
+
video = video_name
|
7 |
+
video_id = video.split("/")[-1].replace(".mp4","")
|
8 |
+
|
9 |
+
|
10 |
+
png_base = "png_logs"
|
11 |
+
try:
|
12 |
+
os.mkdir(png_base)
|
13 |
+
except:
|
14 |
+
pass
|
15 |
+
|
16 |
+
video_id = os.path.join(png_base, video_id)
|
17 |
+
all_frames = list(iio.imiter(video))
|
18 |
+
|
19 |
+
ctr = 0
|
20 |
+
try:
|
21 |
+
os.makedirs(video_id)
|
22 |
+
except:
|
23 |
+
pass
|
24 |
+
for idx, frame in enumerate(all_frames):
|
25 |
+
|
26 |
+
iio.imwrite(f"{video_id}/{ctr:03d}.jpg", frame)
|
27 |
+
ctr += 1
|
requirements.txt
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.29.2
|
2 |
+
blinker==1.9.0
|
3 |
+
certifi==2024.8.30
|
4 |
+
charset-normalizer==3.4.0
|
5 |
+
click==8.1.7
|
6 |
+
decorator==4.4.2
|
7 |
+
diffusers==0.27.2
|
8 |
+
einops==0.8.0
|
9 |
+
filelock==3.16.1
|
10 |
+
Flask==3.0.3
|
11 |
+
fsspec==2024.10.0
|
12 |
+
huggingface-hub==0.25.2
|
13 |
+
idna==3.10
|
14 |
+
imageio==2.36.0
|
15 |
+
imageio-ffmpeg==0.5.1
|
16 |
+
importlib_metadata==8.5.0
|
17 |
+
itsdangerous==2.2.0
|
18 |
+
Jinja2==3.1.4
|
19 |
+
MarkupSafe==3.0.2
|
20 |
+
moviepy==1.0.3
|
21 |
+
numpy==1.24.2
|
22 |
+
nvidia-cublas-cu11==11.10.3.66
|
23 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
24 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
25 |
+
nvidia-cudnn-cu11==8.5.0.96
|
26 |
+
opencv-python==4.10.0.84
|
27 |
+
packaging==24.2
|
28 |
+
pillow==10.4.0
|
29 |
+
proglog==0.1.10
|
30 |
+
psutil==6.1.0
|
31 |
+
python-dotenv==1.0.1
|
32 |
+
PyYAML==6.0.2
|
33 |
+
regex==2024.11.6
|
34 |
+
requests==2.32.3
|
35 |
+
safetensors==0.4.5
|
36 |
+
tokenizers==0.20.3
|
37 |
+
torch==1.13.1
|
38 |
+
torchvision==0.14.1
|
39 |
+
tqdm==4.67.0
|
40 |
+
transformers==4.45.2
|
41 |
+
typing_extensions==4.12.2
|
42 |
+
urllib3==2.2.3
|
43 |
+
Werkzeug==3.1.3
|
44 |
+
zipp==3.21.0
|
static/app_tmp/gif_logs/vid_sketch10-rand0_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand0_508fa599-d685-462e-ad06-11ca4fd15d6f.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand0_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand0_dfcba486-0d8c-4d68-9689-97f1fb889213.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand1_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand1_508fa599-d685-462e-ad06-11ca4fd15d6f.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand1_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand1_dfcba486-0d8c-4d68-9689-97f1fb889213.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand2_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand2_508fa599-d685-462e-ad06-11ca4fd15d6f.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand2_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand2_dfcba486-0d8c-4d68-9689-97f1fb889213.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand3_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand3_508fa599-d685-462e-ad06-11ca4fd15d6f.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand3_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand3_dfcba486-0d8c-4d68-9689-97f1fb889213.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand4_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand4_508fa599-d685-462e-ad06-11ca4fd15d6f.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand4_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand4_dfcba486-0d8c-4d68-9689-97f1fb889213.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand5_508fa599-d685-462e-ad06-11ca4fd15d6f.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand6_508fa599-d685-462e-ad06-11ca4fd15d6f.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand7_508fa599-d685-462e-ad06-11ca4fd15d6f.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand8_508fa599-d685-462e-ad06-11ca4fd15d6f.gif
ADDED
static/app_tmp/gif_logs/vid_sketch10-rand9_508fa599-d685-462e-ad06-11ca4fd15d6f.gif
ADDED
static/app_tmp/gif_logs/vid_sketch3-rand0_875203a1-f830-46e7-a287-4a0bc2c3a648.gif
ADDED
static/app_tmp/gif_logs/vid_sketch3-rand1_875203a1-f830-46e7-a287-4a0bc2c3a648.gif
ADDED
static/app_tmp/gif_logs/vid_sketch3-rand2_875203a1-f830-46e7-a287-4a0bc2c3a648.gif
ADDED
static/app_tmp/gif_logs/vid_sketch3-rand3_875203a1-f830-46e7-a287-4a0bc2c3a648.gif
ADDED
static/app_tmp/gif_logs/vid_sketch3-rand4_875203a1-f830-46e7-a287-4a0bc2c3a648.gif
ADDED
static/app_tmp/gif_logs/vid_sketch8-rand0_47fc0372-4688-4a2a-abb3-817ccfee8816.gif
ADDED
static/app_tmp/gif_logs/vid_sketch8-rand0_77158110-9239-4771-bb44-a83c3aa47567.gif
ADDED
static/app_tmp/gif_logs/vid_sketch8-rand0_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif
ADDED
static/app_tmp/gif_logs/vid_sketch8-rand1_47fc0372-4688-4a2a-abb3-817ccfee8816.gif
ADDED
static/app_tmp/gif_logs/vid_sketch8-rand1_77158110-9239-4771-bb44-a83c3aa47567.gif
ADDED
static/app_tmp/gif_logs/vid_sketch8-rand1_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif
ADDED
static/app_tmp/gif_logs/vid_sketch8-rand2_47fc0372-4688-4a2a-abb3-817ccfee8816.gif
ADDED
static/app_tmp/gif_logs/vid_sketch8-rand2_77158110-9239-4771-bb44-a83c3aa47567.gif
ADDED
static/app_tmp/gif_logs/vid_sketch8-rand2_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif
ADDED
static/app_tmp/gif_logs/vid_sketch8-rand3_47fc0372-4688-4a2a-abb3-817ccfee8816.gif
ADDED
static/app_tmp/gif_logs/vid_sketch8-rand3_77158110-9239-4771-bb44-a83c3aa47567.gif
ADDED
static/app_tmp/gif_logs/vid_sketch8-rand3_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif
ADDED