Hmrishav commited on
Commit
b1350bf
1 Parent(s): a7b9df8

resolve deps

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +11 -9
  2. app.py +271 -0
  3. app_full.py +243 -0
  4. environment.yml +402 -0
  5. gifs_filter.py +68 -0
  6. invert_utils.py +89 -0
  7. read_vids.py +27 -0
  8. requirements.txt +44 -0
  9. static/app_tmp/gif_logs/vid_sketch10-rand0_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif +0 -0
  10. static/app_tmp/gif_logs/vid_sketch10-rand0_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
  11. static/app_tmp/gif_logs/vid_sketch10-rand0_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif +0 -0
  12. static/app_tmp/gif_logs/vid_sketch10-rand0_dfcba486-0d8c-4d68-9689-97f1fb889213.gif +0 -0
  13. static/app_tmp/gif_logs/vid_sketch10-rand1_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif +0 -0
  14. static/app_tmp/gif_logs/vid_sketch10-rand1_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
  15. static/app_tmp/gif_logs/vid_sketch10-rand1_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif +0 -0
  16. static/app_tmp/gif_logs/vid_sketch10-rand1_dfcba486-0d8c-4d68-9689-97f1fb889213.gif +0 -0
  17. static/app_tmp/gif_logs/vid_sketch10-rand2_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif +0 -0
  18. static/app_tmp/gif_logs/vid_sketch10-rand2_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
  19. static/app_tmp/gif_logs/vid_sketch10-rand2_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif +0 -0
  20. static/app_tmp/gif_logs/vid_sketch10-rand2_dfcba486-0d8c-4d68-9689-97f1fb889213.gif +0 -0
  21. static/app_tmp/gif_logs/vid_sketch10-rand3_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif +0 -0
  22. static/app_tmp/gif_logs/vid_sketch10-rand3_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
  23. static/app_tmp/gif_logs/vid_sketch10-rand3_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif +0 -0
  24. static/app_tmp/gif_logs/vid_sketch10-rand3_dfcba486-0d8c-4d68-9689-97f1fb889213.gif +0 -0
  25. static/app_tmp/gif_logs/vid_sketch10-rand4_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif +0 -0
  26. static/app_tmp/gif_logs/vid_sketch10-rand4_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
  27. static/app_tmp/gif_logs/vid_sketch10-rand4_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif +0 -0
  28. static/app_tmp/gif_logs/vid_sketch10-rand4_dfcba486-0d8c-4d68-9689-97f1fb889213.gif +0 -0
  29. static/app_tmp/gif_logs/vid_sketch10-rand5_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
  30. static/app_tmp/gif_logs/vid_sketch10-rand6_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
  31. static/app_tmp/gif_logs/vid_sketch10-rand7_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
  32. static/app_tmp/gif_logs/vid_sketch10-rand8_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
  33. static/app_tmp/gif_logs/vid_sketch10-rand9_508fa599-d685-462e-ad06-11ca4fd15d6f.gif +0 -0
  34. static/app_tmp/gif_logs/vid_sketch3-rand0_875203a1-f830-46e7-a287-4a0bc2c3a648.gif +0 -0
  35. static/app_tmp/gif_logs/vid_sketch3-rand1_875203a1-f830-46e7-a287-4a0bc2c3a648.gif +0 -0
  36. static/app_tmp/gif_logs/vid_sketch3-rand2_875203a1-f830-46e7-a287-4a0bc2c3a648.gif +0 -0
  37. static/app_tmp/gif_logs/vid_sketch3-rand3_875203a1-f830-46e7-a287-4a0bc2c3a648.gif +0 -0
  38. static/app_tmp/gif_logs/vid_sketch3-rand4_875203a1-f830-46e7-a287-4a0bc2c3a648.gif +0 -0
  39. static/app_tmp/gif_logs/vid_sketch8-rand0_47fc0372-4688-4a2a-abb3-817ccfee8816.gif +0 -0
  40. static/app_tmp/gif_logs/vid_sketch8-rand0_77158110-9239-4771-bb44-a83c3aa47567.gif +0 -0
  41. static/app_tmp/gif_logs/vid_sketch8-rand0_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif +0 -0
  42. static/app_tmp/gif_logs/vid_sketch8-rand1_47fc0372-4688-4a2a-abb3-817ccfee8816.gif +0 -0
  43. static/app_tmp/gif_logs/vid_sketch8-rand1_77158110-9239-4771-bb44-a83c3aa47567.gif +0 -0
  44. static/app_tmp/gif_logs/vid_sketch8-rand1_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif +0 -0
  45. static/app_tmp/gif_logs/vid_sketch8-rand2_47fc0372-4688-4a2a-abb3-817ccfee8816.gif +0 -0
  46. static/app_tmp/gif_logs/vid_sketch8-rand2_77158110-9239-4771-bb44-a83c3aa47567.gif +0 -0
  47. static/app_tmp/gif_logs/vid_sketch8-rand2_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif +0 -0
  48. static/app_tmp/gif_logs/vid_sketch8-rand3_47fc0372-4688-4a2a-abb3-817ccfee8816.gif +0 -0
  49. static/app_tmp/gif_logs/vid_sketch8-rand3_77158110-9239-4771-bb44-a83c3aa47567.gif +0 -0
  50. static/app_tmp/gif_logs/vid_sketch8-rand3_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif +0 -0
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: FlipSketch
3
- emoji: 🚀
4
- colorFrom: purple
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- short_description: Sketch Animations
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
  ---
2
+ title: FlipSketch
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
 
9
  ---
10
 
11
+
12
+ # FlipSketch
13
+
14
+ FlipSketch: Flipping assets Drawings to Text-Guided Sketch Animations
app.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ import os
3
+ import cv2
4
+ import subprocess
5
+ import torch
6
+ import torchvision
7
+ import warnings
8
+ import numpy as np
9
+ from PIL import Image, ImageSequence
10
+ from moviepy.editor import VideoFileClip
11
+ import imageio
12
+ import uuid
13
+
14
+ from diffusers import (
15
+ TextToVideoSDPipeline,
16
+ AutoencoderKL,
17
+ DDPMScheduler,
18
+ DDIMScheduler,
19
+ UNet3DConditionModel,
20
+ )
21
+ import time
22
+ from transformers import CLIPTokenizer, CLIPTextModel
23
+
24
+ from diffusers.utils import export_to_video
25
+ from gifs_filter import filter
26
+ from invert_utils import ddim_inversion as dd_inversion
27
+ from text2vid_modded import TextToVideoSDPipelineModded
28
+
29
+
30
+ def run_setup():
31
+ try:
32
+ # Step 1: Install Git LFS
33
+ subprocess.run(["git", "lfs", "install"], check=True)
34
+
35
+ # Step 2: Clone the repository
36
+ repo_url = "https://huggingface.co/Hmrishav/t2v_sketch-lora"
37
+ subprocess.run(["git", "clone", repo_url], check=True)
38
+
39
+ # Step 3: Move the checkpoint file
40
+ source = "t2v_sketch-lora/checkpoint-2500"
41
+ destination = "./checkpoint-2500/"
42
+ os.rename(source, destination)
43
+
44
+ print("Setup completed successfully!")
45
+ except subprocess.CalledProcessError as e:
46
+ print(f"Error during setup: {e}")
47
+ except FileNotFoundError as e:
48
+ print(f"File operation error: {e}")
49
+ except Exception as e:
50
+ print(f"Unexpected error: {e}")
51
+
52
+ # Automatically run setup during app initialization
53
+ run_setup()
54
+
55
+
56
+ # Flask app setup
57
+ app = Flask(__name__)
58
+ app.config['UPLOAD_FOLDER'] = 'static/uploads'
59
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
60
+ os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
61
+
62
+ # Environment setup
63
+ os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1"
64
+ LORA_CHECKPOINT = "checkpoint-2500"
65
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
66
+ dtype = torch.bfloat16
67
+
68
+ # Helper functions
69
+
70
+ def cleanup_old_files(directory, age_in_seconds = 600):
71
+ """
72
+ Deletes files older than a certain age in the specified directory.
73
+
74
+ Args:
75
+ directory (str): The directory to clean up.
76
+ age_in_seconds (int): The age in seconds; files older than this will be deleted.
77
+ """
78
+ now = time.time()
79
+ for filename in os.listdir(directory):
80
+ file_path = os.path.join(directory, filename)
81
+ # Only delete files (not directories)
82
+ if os.path.isfile(file_path):
83
+ file_age = now - os.path.getmtime(file_path)
84
+ if file_age > age_in_seconds:
85
+ try:
86
+ os.remove(file_path)
87
+ print(f"Deleted old file: {file_path}")
88
+ except Exception as e:
89
+ print(f"Error deleting file {file_path}: {e}")
90
+
91
+ def load_frames(image: Image, mode='RGBA'):
92
+ return np.array([np.array(frame.convert(mode)) for frame in ImageSequence.Iterator(image)])
93
+
94
+ def save_gif(frames, path):
95
+ imageio.mimsave(path, [frame.astype(np.uint8) for frame in frames], format='GIF', duration=1/10)
96
+
97
+ def load_image(imgname, target_size=None):
98
+ pil_img = Image.open(imgname).convert('RGB')
99
+ if target_size:
100
+ if isinstance(target_size, int):
101
+ target_size = (target_size, target_size)
102
+ pil_img = pil_img.resize(target_size, Image.Resampling.LANCZOS)
103
+ return torchvision.transforms.ToTensor()(pil_img).unsqueeze(0) # Add batch dimension
104
+
105
+ def prepare_latents(pipe, x_aug):
106
+ with torch.cuda.amp.autocast():
107
+ batch_size, num_frames, channels, height, width = x_aug.shape
108
+ x_aug = x_aug.reshape(batch_size * num_frames, channels, height, width)
109
+ latents = pipe.vae.encode(x_aug).latent_dist.sample()
110
+ latents = latents.view(batch_size, num_frames, -1, latents.shape[2], latents.shape[3])
111
+ latents = latents.permute(0, 2, 1, 3, 4)
112
+ return pipe.vae.config.scaling_factor * latents
113
+
114
+ @torch.no_grad()
115
+ def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
116
+ input_img = [load_image(load_name, 256).to(device, dtype=dtype).unsqueeze(1)] * 5
117
+ input_img = torch.cat(input_img, dim=1)
118
+ latents = prepare_latents(pipe, input_img).to(torch.bfloat16)
119
+ inv.set_timesteps(25)
120
+ id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1].to(dtype)
121
+ return torch.mean(id_latents, dim=2, keepdim=True)
122
+
123
+ def load_primary_models(pretrained_model_path):
124
+ return (
125
+ DDPMScheduler.from_config(pretrained_model_path, subfolder="scheduler"),
126
+ CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer"),
127
+ CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder"),
128
+ AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae"),
129
+ UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet"),
130
+ )
131
+
132
+
133
+ def initialize_pipeline(model: str, device: str = "cuda"):
134
+ with warnings.catch_warnings():
135
+ warnings.simplefilter("ignore")
136
+ scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
137
+ pipe = TextToVideoSDPipeline.from_pretrained(
138
+ pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
139
+ scheduler=scheduler,
140
+ tokenizer=tokenizer,
141
+ text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16),
142
+ vae=vae.to(device=device, dtype=torch.bfloat16),
143
+ unet=unet.to(device=device, dtype=torch.bfloat16),
144
+ )
145
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
146
+ return pipe, pipe.scheduler
147
+
148
+ pipe_inversion, inv = initialize_pipeline(LORA_CHECKPOINT, device)
149
+ pipe = TextToVideoSDPipelineModded.from_pretrained(
150
+ pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
151
+ scheduler=pipe_inversion.scheduler,
152
+ tokenizer=pipe_inversion.tokenizer,
153
+ text_encoder=pipe_inversion.text_encoder,
154
+ vae=pipe_inversion.vae,
155
+ unet=pipe_inversion.unet,
156
+ ).to(device)
157
+
158
+ @torch.no_grad()
159
+ def process(num_frames, num_seeds, generator, exp_dir, load_name, caption, lambda_):
160
+ pipe_inversion.to(device)
161
+ id_latents = invert(pipe_inversion, inv, load_name).to(device, dtype=dtype)
162
+ latents = id_latents.repeat(num_seeds, 1, 1, 1, 1)
163
+ generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(num_seeds)]
164
+ video_frames = pipe(
165
+ prompt=caption,
166
+ negative_prompt="",
167
+ num_frames=num_frames,
168
+ num_inference_steps=25,
169
+ inv_latents=latents,
170
+ guidance_scale=9,
171
+ generator=generator,
172
+ lambda_=lambda_,
173
+ ).frames
174
+ try:
175
+ load_name = load_name.split("/")[-1]
176
+ except:
177
+ pass
178
+ gifs = []
179
+ for seed in range(num_seeds):
180
+ vid_name = f"{exp_dir}/mp4_logs/vid_{load_name[:-4]}-rand{seed}.mp4"
181
+ gif_name = f"{exp_dir}/gif_logs/vid_{load_name[:-4]}-rand{seed}.gif"
182
+ video_path = export_to_video(video_frames[seed], output_video_path=vid_name)
183
+ VideoFileClip(vid_name).write_gif(gif_name)
184
+ with Image.open(gif_name) as im:
185
+ frames = load_frames(im)
186
+
187
+ frames_collect = np.empty((0, 1024, 1024), int)
188
+ for frame in frames:
189
+ frame = cv2.resize(frame, (1024, 1024))[:, :, :3]
190
+ frame = cv2.cvtColor(255 - frame, cv2.COLOR_RGB2GRAY)
191
+
192
+ _, frame = cv2.threshold(255 - frame, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
193
+
194
+ frames_collect = np.append(frames_collect, [frame], axis=0)
195
+
196
+ save_gif(frames_collect, gif_name)
197
+ gifs.append(gif_name)
198
+
199
+ return gifs
200
+
201
+
202
+ def generate_gifs(filepath, prompt, num_seeds=5, lambda_=0):
203
+ exp_dir = "static/app_tmp"
204
+ os.makedirs(exp_dir, exist_ok=True)
205
+ gifs = process(
206
+ num_frames=10,
207
+ num_seeds=num_seeds,
208
+ generator=None,
209
+ exp_dir=exp_dir,
210
+ load_name=filepath,
211
+ caption=prompt,
212
+ lambda_=lambda_
213
+ )
214
+ return gifs
215
+
216
+ @app.route('/')
217
+ def index():
218
+ return render_template('index.html')
219
+
220
+ @app.route('/generate', methods=['POST'])
221
+ def generate():
222
+
223
+ directories_to_clean = [
224
+ app.config['UPLOAD_FOLDER'],
225
+ 'static/app_tmp/mp4_logs',
226
+ 'static/app_tmp/gif_logs',
227
+ 'static/app_tmp/png_logs'
228
+ ]
229
+
230
+ # Perform cleanup
231
+ os.makedirs('static/app_tmp', exist_ok=True)
232
+ for directory in directories_to_clean:
233
+ os.makedirs(directory, exist_ok=True) # Ensure the directory exists
234
+ cleanup_old_files(directory)
235
+
236
+ prompt = request.form.get('prompt', '')
237
+ num_gifs = int(request.form.get('seeds', 3))
238
+ lambda_value = 1 - float(request.form.get('lambda', 0.5))
239
+ selected_example = request.form.get('selected_example', None)
240
+ file = request.files.get('image')
241
+
242
+ if not file and not selected_example:
243
+ return jsonify({'error': 'No image file provided or example selected'}), 400
244
+
245
+ if selected_example:
246
+ # Use the selected example image
247
+ filepath = os.path.join('static', 'examples', selected_example)
248
+ unique_id = None # No need for unique ID
249
+ else:
250
+ # Save the uploaded image
251
+ unique_id = str(uuid.uuid4())
252
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{unique_id}_uploaded_image.png")
253
+ file.save(filepath)
254
+
255
+ generated_gifs = generate_gifs(filepath, prompt, num_seeds=num_gifs, lambda_=lambda_value)
256
+
257
+ unique_id = str(uuid.uuid4())
258
+ # Append unique id to each gif path
259
+ for i in range(len(generated_gifs)):
260
+ os.rename(generated_gifs[i], f"{generated_gifs[i].split('.')[0]}_{unique_id}.gif")
261
+ generated_gifs[i] = f"{generated_gifs[i].split('.')[0]}_{unique_id}.gif"
262
+ # Move the generated gifs to the static folder
263
+
264
+
265
+ filtered_gifs = filter(generated_gifs, filepath)
266
+ return jsonify({'gifs': filtered_gifs, 'prompt': prompt})
267
+
268
+ if __name__ == '__main__':
269
+
270
+
271
+ app.run(debug=True)
app_full.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ import os
3
+ import cv2
4
+ import torch
5
+ import torchvision
6
+ import warnings
7
+ import numpy as np
8
+ from PIL import Image, ImageSequence
9
+ from moviepy.editor import VideoFileClip
10
+ import imageio
11
+ import uuid
12
+
13
+ from diffusers import (
14
+ TextToVideoSDPipeline,
15
+ AutoencoderKL,
16
+ DDPMScheduler,
17
+ DDIMScheduler,
18
+ UNet3DConditionModel,
19
+ )
20
+ import time
21
+ from transformers import CLIPTokenizer, CLIPTextModel
22
+
23
+ from diffusers.utils import export_to_video
24
+ from gifs_filter import filter
25
+ from invert_utils import ddim_inversion as dd_inversion
26
+ from text2vid_modded_full import TextToVideoSDPipelineModded
27
+
28
+ # Flask app setup
29
+ app = Flask(__name__)
30
+ app.config['UPLOAD_FOLDER'] = 'static/uploads'
31
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
32
+ os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
33
+
34
+ # Environment setup
35
+ os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1"
36
+ LORA_CHECKPOINT = "checkpoint-2500"
37
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
38
+ dtype = torch.bfloat16
39
+
40
+ # Helper functions
41
+
42
+ def cleanup_old_files(directory, age_in_seconds = 600):
43
+ """
44
+ Deletes files older than a certain age in the specified directory.
45
+
46
+ Args:
47
+ directory (str): The directory to clean up.
48
+ age_in_seconds (int): The age in seconds; files older than this will be deleted.
49
+ """
50
+ now = time.time()
51
+ for filename in os.listdir(directory):
52
+ file_path = os.path.join(directory, filename)
53
+ # Only delete files (not directories)
54
+ if os.path.isfile(file_path):
55
+ file_age = now - os.path.getmtime(file_path)
56
+ if file_age > age_in_seconds:
57
+ try:
58
+ os.remove(file_path)
59
+ print(f"Deleted old file: {file_path}")
60
+ except Exception as e:
61
+ print(f"Error deleting file {file_path}: {e}")
62
+
63
+ def load_frames(image: Image, mode='RGBA'):
64
+ return np.array([np.array(frame.convert(mode)) for frame in ImageSequence.Iterator(image)])
65
+
66
+ def save_gif(frames, path):
67
+ imageio.mimsave(path, [frame.astype(np.uint8) for frame in frames], format='GIF', duration=1/10)
68
+
69
+ def load_image(imgname, target_size=None):
70
+ pil_img = Image.open(imgname).convert('RGB')
71
+ if target_size:
72
+ if isinstance(target_size, int):
73
+ target_size = (target_size, target_size)
74
+ pil_img = pil_img.resize(target_size, Image.Resampling.LANCZOS)
75
+ return torchvision.transforms.ToTensor()(pil_img).unsqueeze(0) # Add batch dimension
76
+
77
+ def prepare_latents(pipe, x_aug):
78
+ with torch.cuda.amp.autocast():
79
+ batch_size, num_frames, channels, height, width = x_aug.shape
80
+ x_aug = x_aug.reshape(batch_size * num_frames, channels, height, width)
81
+ latents = pipe.vae.encode(x_aug).latent_dist.sample()
82
+ latents = latents.view(batch_size, num_frames, -1, latents.shape[2], latents.shape[3])
83
+ latents = latents.permute(0, 2, 1, 3, 4)
84
+ return pipe.vae.config.scaling_factor * latents
85
+
86
+ @torch.no_grad()
87
+ def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
88
+ input_img = [load_image(load_name, 256).to(device, dtype=dtype).unsqueeze(1)] * 5
89
+ input_img = torch.cat(input_img, dim=1)
90
+ latents = prepare_latents(pipe, input_img).to(torch.bfloat16)
91
+ inv.set_timesteps(25)
92
+ id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1].to(dtype)
93
+ return torch.mean(id_latents, dim=2, keepdim=True)
94
+
95
+ def load_primary_models(pretrained_model_path):
96
+ return (
97
+ DDPMScheduler.from_config(pretrained_model_path, subfolder="scheduler"),
98
+ CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer"),
99
+ CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder"),
100
+ AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae"),
101
+ UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet"),
102
+ )
103
+
104
+
105
+ def initialize_pipeline(model: str, device: str = "cuda"):
106
+ with warnings.catch_warnings():
107
+ warnings.simplefilter("ignore")
108
+ scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
109
+ pipe = TextToVideoSDPipeline.from_pretrained(
110
+ pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
111
+ scheduler=scheduler,
112
+ tokenizer=tokenizer,
113
+ text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16),
114
+ vae=vae.to(device=device, dtype=torch.bfloat16),
115
+ unet=unet.to(device=device, dtype=torch.bfloat16),
116
+ )
117
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
118
+ return pipe, pipe.scheduler
119
+
120
+ pipe_inversion, inv = initialize_pipeline(LORA_CHECKPOINT, device)
121
+ pipe = TextToVideoSDPipelineModded.from_pretrained(
122
+ pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
123
+ scheduler=pipe_inversion.scheduler,
124
+ tokenizer=pipe_inversion.tokenizer,
125
+ text_encoder=pipe_inversion.text_encoder,
126
+ vae=pipe_inversion.vae,
127
+ unet=pipe_inversion.unet,
128
+ ).to(device)
129
+
130
+ @torch.no_grad()
131
+ def process(num_frames, num_seeds, generator, exp_dir, load_name, caption, lambda_):
132
+ pipe_inversion.to(device)
133
+ id_latents = invert(pipe_inversion, inv, load_name).to(device, dtype=dtype)
134
+ latents = id_latents.repeat(num_seeds, 1, 1, 1, 1)
135
+ generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(num_seeds)]
136
+ video_frames = pipe(
137
+ prompt=caption,
138
+ negative_prompt="",
139
+ num_frames=num_frames,
140
+ num_inference_steps=25,
141
+ inv_latents=latents,
142
+ guidance_scale=9,
143
+ generator=generator,
144
+ lambda_=lambda_,
145
+ ).frames
146
+ try:
147
+ load_name = load_name.split("/")[-1]
148
+ except:
149
+ pass
150
+ gifs = []
151
+ for seed in range(num_seeds):
152
+ vid_name = f"{exp_dir}/mp4_logs/vid_{load_name[:-4]}-rand{seed}.mp4"
153
+ gif_name = f"{exp_dir}/gif_logs/vid_{load_name[:-4]}-rand{seed}.gif"
154
+ video_path = export_to_video(video_frames[seed], output_video_path=vid_name)
155
+ VideoFileClip(vid_name).write_gif(gif_name)
156
+ with Image.open(gif_name) as im:
157
+ frames = load_frames(im)
158
+
159
+ frames_collect = np.empty((0, 1024, 1024), int)
160
+ for frame in frames:
161
+ frame = cv2.resize(frame, (1024, 1024))[:, :, :3]
162
+ frame = cv2.cvtColor(255 - frame, cv2.COLOR_RGB2GRAY)
163
+
164
+ _, frame = cv2.threshold(255 - frame, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
165
+
166
+ frames_collect = np.append(frames_collect, [frame], axis=0)
167
+
168
+ save_gif(frames_collect, gif_name)
169
+ gifs.append(gif_name)
170
+
171
+ return gifs
172
+
173
+
174
+ def generate_gifs(filepath, prompt, num_seeds=5, lambda_=0):
175
+ exp_dir = "static/app_tmp"
176
+ os.makedirs(exp_dir, exist_ok=True)
177
+ gifs = process(
178
+ num_frames=10,
179
+ num_seeds=num_seeds,
180
+ generator=None,
181
+ exp_dir=exp_dir,
182
+ load_name=filepath,
183
+ caption=prompt,
184
+ lambda_=lambda_
185
+ )
186
+ return gifs
187
+
188
+ @app.route('/')
189
+ def index():
190
+ return render_template('index.html')
191
+
192
+ @app.route('/generate', methods=['POST'])
193
+ def generate():
194
+
195
+ directories_to_clean = [
196
+ app.config['UPLOAD_FOLDER'],
197
+ 'static/app_tmp/mp4_logs',
198
+ 'static/app_tmp/gif_logs',
199
+ 'static/app_tmp/png_logs'
200
+ ]
201
+
202
+ # Perform cleanup
203
+ os.makedirs('static/app_tmp', exist_ok=True)
204
+ for directory in directories_to_clean:
205
+ os.makedirs(directory, exist_ok=True) # Ensure the directory exists
206
+ cleanup_old_files(directory)
207
+
208
+ prompt = request.form.get('prompt', '')
209
+ num_gifs = int(request.form.get('seeds', 3))
210
+ lambda_value = 1 - float(request.form.get('lambda', 0.5))
211
+ selected_example = request.form.get('selected_example', None)
212
+ file = request.files.get('image')
213
+
214
+ if not file and not selected_example:
215
+ return jsonify({'error': 'No image file provided or example selected'}), 400
216
+
217
+ if selected_example:
218
+ # Use the selected example image
219
+ filepath = os.path.join('static', 'examples', selected_example)
220
+ unique_id = None # No need for unique ID
221
+ else:
222
+ # Save the uploaded image
223
+ unique_id = str(uuid.uuid4())
224
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{unique_id}_uploaded_image.png")
225
+ file.save(filepath)
226
+
227
+ generated_gifs = generate_gifs(filepath, prompt, num_seeds=num_gifs, lambda_=lambda_value)
228
+
229
+ unique_id = str(uuid.uuid4())
230
+ # Append unique id to each gif path
231
+ for i in range(len(generated_gifs)):
232
+ os.rename(generated_gifs[i], f"{generated_gifs[i].split('.')[0]}_{unique_id}.gif")
233
+ generated_gifs[i] = f"{generated_gifs[i].split('.')[0]}_{unique_id}.gif"
234
+ # Move the generated gifs to the static folder
235
+
236
+
237
+ filtered_gifs = filter(generated_gifs, filepath)
238
+ return jsonify({'gifs': filtered_gifs, 'prompt': prompt})
239
+
240
+ if __name__ == '__main__':
241
+
242
+
243
+ app.run(debug=True)
environment.yml ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: flipsketch
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=main
9
+ - _openmp_mutex=5.1=1_gnu
10
+ - asttokens=2.4.1=pyhd8ed1ab_0
11
+ - blas=1.0=mkl
12
+ - brotli-python=1.0.9=py310hd8f1fbe_7
13
+ - bzip2=1.0.8=h7f98852_4
14
+ - ca-certificates=2024.2.2=hbcca054_0
15
+ - certifi=2024.2.2=pyhd8ed1ab_0
16
+ - charset-normalizer=2.0.4=pyhd8ed1ab_0
17
+ - comm=0.2.2=pyhd8ed1ab_0
18
+ - cuda=11.6.1=0
19
+ - cuda-cccl=11.6.55=hf6102b2_0
20
+ - cuda-command-line-tools=11.6.2=0
21
+ - cuda-compiler=11.6.2=0
22
+ - cuda-cudart=11.6.55=he381448_0
23
+ - cuda-cudart-dev=11.6.55=h42ad0f4_0
24
+ - cuda-cuobjdump=11.6.124=h2eeebcb_0
25
+ - cuda-cupti=11.6.124=h86345e5_0
26
+ - cuda-cuxxfilt=11.6.124=hecbf4f6_0
27
+ - cuda-driver-dev=11.6.55=0
28
+ - cuda-gdb=12.4.127=0
29
+ - cuda-libraries=11.6.1=0
30
+ - cuda-libraries-dev=11.6.1=0
31
+ - cuda-memcheck=11.8.86=0
32
+ - cuda-nsight=12.4.127=0
33
+ - cuda-nsight-compute=12.4.1=0
34
+ - cuda-nvcc=11.6.124=hbba6d2d_0
35
+ - cuda-nvdisasm=12.4.127=0
36
+ - cuda-nvml-dev=11.6.55=haa9ef22_0
37
+ - cuda-nvprof=12.4.127=0
38
+ - cuda-nvprune=11.6.124=he22ec0a_0
39
+ - cuda-nvrtc=11.6.124=h020bade_0
40
+ - cuda-nvrtc-dev=11.6.124=h249d397_0
41
+ - cuda-nvtx=11.6.124=h0630a44_0
42
+ - cuda-nvvp=12.4.127=0
43
+ - cuda-runtime=11.6.1=0
44
+ - cuda-samples=11.6.101=h8efea70_0
45
+ - cuda-sanitizer-api=12.4.127=0
46
+ - cuda-toolkit=11.6.1=0
47
+ - cuda-tools=11.6.1=0
48
+ - cuda-visual-tools=11.6.1=0
49
+ - debugpy=1.6.7=py310h6a678d5_0
50
+ - entrypoints=0.4=pyhd8ed1ab_0
51
+ - exceptiongroup=1.2.0=pyhd8ed1ab_2
52
+ - executing=2.0.1=pyhd8ed1ab_0
53
+ - ffmpeg=4.3=hf484d3e_0
54
+ - freetype=2.12.1=h4a9f257_0
55
+ - gds-tools=1.9.1.3=0
56
+ - gmp=6.2.1=h58526e2_0
57
+ - gnutls=3.6.15=he1e5248_0
58
+ - idna=3.4=pyhd8ed1ab_0
59
+ - intel-openmp=2023.1.0=hdb19cb5_46306
60
+ - ipykernel=6.29.3=pyhd33586a_0
61
+ - jedi=0.19.1=pyhd8ed1ab_0
62
+ - jpeg=9e=h166bdaf_1
63
+ - jupyter_client=7.3.4=pyhd8ed1ab_0
64
+ - jupyter_core=5.7.2=pyh31011fe_1
65
+ - lame=3.100=h7f98852_1001
66
+ - lcms2=2.12=h3be6417_0
67
+ - ld_impl_linux-64=2.38=h1181459_1
68
+ - lerc=3.0=h9c3ff4c_0
69
+ - libcublas=11.9.2.110=h5e84587_0
70
+ - libcublas-dev=11.9.2.110=h5c901ab_0
71
+ - libcufft=10.7.1.112=hf425ae0_0
72
+ - libcufft-dev=10.7.1.112=ha5ce4c0_0
73
+ - libcufile=1.9.1.3=0
74
+ - libcufile-dev=1.9.1.3=0
75
+ - libcurand=10.3.5.147=0
76
+ - libcurand-dev=10.3.5.147=0
77
+ - libcusolver=11.3.4.124=h33c3c4e_0
78
+ - libcusparse=11.7.2.124=h7538f96_0
79
+ - libcusparse-dev=11.7.2.124=hbbe9722_0
80
+ - libdeflate=1.17=h5eee18b_1
81
+ - libffi=3.4.4=h6a678d5_1
82
+ - libgcc-ng=11.2.0=h1234567_1
83
+ - libgomp=11.2.0=h1234567_1
84
+ - libiconv=1.16=h516909a_0
85
+ - libidn2=2.3.4=h5eee18b_0
86
+ - libnpp=11.6.3.124=hd2722f0_0
87
+ - libnpp-dev=11.6.3.124=h3c42840_0
88
+ - libnvjpeg=11.6.2.124=hd473ad6_0
89
+ - libnvjpeg-dev=11.6.2.124=hb5906b9_0
90
+ - libpng=1.6.39=h5eee18b_0
91
+ - libsodium=1.0.18=h36c2ea0_1
92
+ - libstdcxx-ng=11.2.0=he4da1e4_16
93
+ - libtasn1=4.19.0=h5eee18b_0
94
+ - libtiff=4.5.1=h6a678d5_0
95
+ - libunistring=0.9.10=h7f98852_0
96
+ - libuuid=1.41.5=h5eee18b_0
97
+ - libwebp-base=1.3.2=h5eee18b_0
98
+ - lz4-c=1.9.4=h6a678d5_1
99
+ - mkl=2023.1.0=h213fc3f_46344
100
+ - mkl-service=2.4.0=py310h5eee18b_1
101
+ - mkl_fft=1.3.8=py310h5eee18b_0
102
+ - mkl_random=1.2.4=py310hdb19cb5_0
103
+ - ncurses=6.4=h6a678d5_0
104
+ - nest-asyncio=1.6.0=pyhd8ed1ab_0
105
+ - nettle=3.7.3=hbbd107a_1
106
+ - nsight-compute=2024.1.1.4=0
107
+ - numpy-base=1.26.4=py310hb5e798b_0
108
+ - openh264=2.1.1=h780b84a_0
109
+ - openjpeg=2.4.0=h9ca470c_2
110
+ - openssl=3.0.13=h7f8727e_2
111
+ - packaging=24.0=pyhd8ed1ab_0
112
+ - parso=0.8.4=pyhd8ed1ab_0
113
+ - pexpect=4.9.0=pyhd8ed1ab_0
114
+ - pickleshare=0.7.5=py_1003
115
+ - pip=23.3.1=pyhd8ed1ab_0
116
+ - ptyprocess=0.7.0=pyhd3deb0d_0
117
+ - pure_eval=0.2.2=pyhd8ed1ab_0
118
+ - pygments=2.17.2=pyhd8ed1ab_0
119
+ - pysocks=1.7.1=pyha2e5f31_6
120
+ - python=3.10.14=h955ad1f_0
121
+ - python_abi=3.10=2_cp310
122
+ - pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0
123
+ - pytorch-cuda=11.6=h867d48c_1
124
+ - pytorch-mutex=1.0=cuda
125
+ - pyzmq=25.1.2=py310h6a678d5_0
126
+ - readline=8.2=h5eee18b_0
127
+ - requests=2.31.0=pyhd8ed1ab_0
128
+ - setuptools=68.2.2=pyhd8ed1ab_0
129
+ - six=1.16.0=pyh6c4a22f_0
130
+ - sqlite=3.41.2=h5eee18b_0
131
+ - tbb=2021.8.0=hdb19cb5_0
132
+ - tk=8.6.12=h1ccaba5_0
133
+ - torchaudio=0.13.1=py310_cu116
134
+ - tornado=6.1=py310h5764c6d_3
135
+ - typing_extensions=4.9.0=pyha770c72_0
136
+ - tzdata=2024a=h8827d51_1
137
+ - urllib3=2.1.0=pyhd8ed1ab_0
138
+ - wcwidth=0.2.13=pyhd8ed1ab_0
139
+ - wheel=0.41.2=pyhd8ed1ab_0
140
+ - xz=5.4.6=h5eee18b_1
141
+ - zeromq=4.3.5=h6a678d5_0
142
+ - zlib=1.2.13=h5eee18b_1
143
+ - zstd=1.5.5=hc292b87_2
144
+ - pip:
145
+ - absl-py==2.1.0
146
+ - accelerate==0.29.2
147
+ - addict==2.4.0
148
+ - aiofiles==23.2.1
149
+ - aiohttp==3.9.3
150
+ - aiosignal==1.3.1
151
+ - albumentations==1.3.0
152
+ - aliyun-python-sdk-core==2.15.1
153
+ - aliyun-python-sdk-kms==2.16.2
154
+ - annotated-types==0.7.0
155
+ - antlr4-python3-runtime==4.8
156
+ - anyio==4.6.2.post1
157
+ - appdirs==1.4.4
158
+ - async-timeout==4.0.3
159
+ - attrs==23.2.0
160
+ - basicsr==1.4.2
161
+ - beautifulsoup4==4.12.3
162
+ - bitsandbytes==0.35.4
163
+ - black==21.4b2
164
+ - blinker==1.8.2
165
+ - blis==0.7.11
166
+ - boto3==1.34.97
167
+ - botocore==1.34.97
168
+ - bresenham==0.2.1
169
+ - cachetools==5.3.3
170
+ - captum==0.7.0
171
+ - catalogue==2.0.10
172
+ - cffi==1.16.0
173
+ - chardet==5.2.0
174
+ - click==8.1.7
175
+ - clip==0.1.0
176
+ - cloudpickle==3.0.0
177
+ - cmake==3.25.2
178
+ - compel==2.0.3
179
+ - confection==0.1.4
180
+ - contourpy==1.2.1
181
+ - controlnet-aux==0.0.6
182
+ - crcmod==1.7
183
+ - cryptography==42.0.7
184
+ - cssselect2==0.7.0
185
+ - cycler==0.12.1
186
+ - cymem==2.0.8
187
+ - cython==3.0.10
188
+ - datasets==2.18.0
189
+ - decorator==4.4.2
190
+ - decord==0.6.0
191
+ - deepspeed==0.8.0
192
+ - diffdist==0.1
193
+ - diffusers==0.27.2
194
+ - dill==0.3.8
195
+ - docker-pycreds==0.4.0
196
+ - easydict==1.10
197
+ - einops==0.3.0
198
+ - fairscale==0.4.13
199
+ - faiss-cpu==1.8.0
200
+ - fastapi==0.115.4
201
+ - ffmpy==0.3.0
202
+ - filelock==3.13.4
203
+ - flask==3.0.3
204
+ - flatbuffers==24.3.25
205
+ - fonttools==4.51.0
206
+ - frozenlist==1.4.1
207
+ - fsspec==2024.2.0
208
+ - ftfy==6.1.1
209
+ - future==1.0.0
210
+ - fvcore==0.1.5.post20221221
211
+ - gast==0.5.4
212
+ - gdown==5.1.0
213
+ - gitdb==4.0.11
214
+ - gitpython==3.1.43
215
+ - google-auth==2.29.0
216
+ - google-auth-oauthlib==0.4.6
217
+ - gradio==5.5.0
218
+ - gradio-client==1.4.2
219
+ - grpcio==1.62.1
220
+ - h11==0.14.0
221
+ - hjson==3.1.0
222
+ - httpcore==1.0.6
223
+ - httpx==0.27.2
224
+ - huggingface-hub==0.25.2
225
+ - hydra-core==1.1.1
226
+ - imageio==2.25.1
227
+ - imageio-ffmpeg==0.4.8
228
+ - importlib-metadata==7.1.0
229
+ - inquirerpy==0.3.4
230
+ - iopath==0.1.9
231
+ - ipdb==0.13.13
232
+ - ipympl==0.9.4
233
+ - ipython==8.23.0
234
+ - ipython-genutils==0.2.0
235
+ - ipywidgets==8.1.2
236
+ - itsdangerous==2.2.0
237
+ - jax==0.4.26
238
+ - jaxlib==0.4.26
239
+ - jinja2==3.1.3
240
+ - jmespath==0.10.0
241
+ - joblib==1.4.2
242
+ - jupyterlab-widgets==3.0.10
243
+ - kiwisolver==1.4.5
244
+ - kornia==0.6.0
245
+ - lightning-utilities==0.11.2
246
+ - lmdb==1.4.1
247
+ - loguru==0.7.2
248
+ - loralib==0.1.2
249
+ - lvis==0.5.3
250
+ - lxml==5.2.1
251
+ - markdown==3.6
252
+ - markdown-it-py==3.0.0
253
+ - markupsafe==2.1.5
254
+ - matplotlib==3.8.4
255
+ - matplotlib-inline==0.1.6
256
+ - mdurl==0.1.2
257
+ - mediapipe==0.10.11
258
+ - ml-dtypes==0.4.0
259
+ - modelcards==0.1.6
260
+ - modelscope==1.14.0
261
+ - motion-vector-extractor==1.0.6
262
+ - moviepy==1.0.3
263
+ - mpmath==1.3.0
264
+ - multidict==6.0.5
265
+ - multiprocess==0.70.16
266
+ - murmurhash==1.0.10
267
+ - mypy-extensions==1.0.0
268
+ - networkx==3.3
269
+ - ninja==1.11.1.1
270
+ - nltk==3.8.1
271
+ - numpy==1.24.2
272
+ - nvidia-cublas-cu11==11.10.3.66
273
+ - nvidia-cuda-cupti-cu12==12.1.105
274
+ - nvidia-cuda-nvrtc-cu11==11.7.99
275
+ - nvidia-cuda-nvrtc-cu12==12.1.105
276
+ - nvidia-cuda-runtime-cu11==11.7.99
277
+ - nvidia-cuda-runtime-cu12==12.1.105
278
+ - nvidia-cudnn-cu11==8.5.0.96
279
+ - nvidia-cufft-cu12==11.0.2.54
280
+ - nvidia-curand-cu12==10.3.2.106
281
+ - nvidia-nccl-cu12==2.20.5
282
+ - nvidia-nvjitlink-cu12==12.6.77
283
+ - nvidia-nvtx-cu12==12.1.105
284
+ - oauthlib==3.2.2
285
+ - omegaconf==2.1.1
286
+ - open-clip-torch==2.0.2
287
+ - opencv-contrib-python==4.9.0.80
288
+ - opencv-python==4.6.0.66
289
+ - opencv-python-headless==4.9.0.80
290
+ - opt-einsum==3.3.0
291
+ - orjson==3.10.11
292
+ - oss2==2.18.5
293
+ - pandas==1.5.3
294
+ - pathspec==0.12.1
295
+ - pathtools==0.1.2
296
+ - peft==0.10.0
297
+ - pfzy==0.3.4
298
+ - pillow==9.5.0
299
+ - pkgconfig==1.5.5
300
+ - platformdirs==4.2.0
301
+ - portalocker==2.8.2
302
+ - preshed==3.0.9
303
+ - proglog==0.1.10
304
+ - prompt-toolkit==3.0.43
305
+ - protobuf==3.20.3
306
+ - psutil==5.9.8
307
+ - py-cpuinfo==9.0.0
308
+ - pyarrow==15.0.2
309
+ - pyarrow-hotfix==0.6
310
+ - pyasn1==0.6.0
311
+ - pyasn1-modules==0.4.0
312
+ - pyav==12.0.5
313
+ - pycocotools==2.0.7
314
+ - pycparser==2.22
315
+ - pycryptodome==3.20.0
316
+ - pydantic==2.9.2
317
+ - pydantic-core==2.23.4
318
+ - pydeprecate==0.3.1
319
+ - pydot==2.0.0
320
+ - pydub==0.25.1
321
+ - pynvml==11.5.3
322
+ - pyparsing==3.1.2
323
+ - pyre-extensions==0.0.23
324
+ - python-dateutil==2.9.0.post0
325
+ - python-multipart==0.0.12
326
+ - pytorch-lightning==1.4.2
327
+ - pytz==2024.1
328
+ - pywavelets==1.6.0
329
+ - pyyaml==6.0.1
330
+ - qudida==0.0.4
331
+ - regex==2024.4.16
332
+ - reportlab==4.1.0
333
+ - requests-oauthlib==2.0.0
334
+ - rich==13.9.4
335
+ - rsa==4.9
336
+ - ruff==0.7.2
337
+ - s3transfer==0.10.1
338
+ - safehttpx==0.1.1
339
+ - safetensors==0.4.2
340
+ - scikit-image==0.19.3
341
+ - scikit-learn==1.4.2
342
+ - scikit-video==1.1.11
343
+ - scipy==1.10.1
344
+ - semantic-version==2.10.0
345
+ - sentry-sdk==1.44.1
346
+ - setproctitle==1.3.3
347
+ - shapely==2.0.3
348
+ - shellingham==1.5.4
349
+ - simplejson==3.19.2
350
+ - smmap==5.0.1
351
+ - sniffio==1.3.1
352
+ - sortedcontainers==2.4.0
353
+ - sounddevice==0.4.6
354
+ - soupsieve==2.5
355
+ - srsly==2.4.8
356
+ - stable-diffusion-sdkit==2.1.3
357
+ - stack-data==0.6.3
358
+ - starlette==0.41.2
359
+ - svg-path==6.3
360
+ - svglib==1.5.1
361
+ - svgpathtools==1.6.1
362
+ - svgwrite==1.4.3
363
+ - sympy==1.13.3
364
+ - tabulate==0.9.0
365
+ - tb-nightly==2.17.0a20240408
366
+ - tensorboard==2.12.0
367
+ - tensorboard-data-server==0.7.0
368
+ - tensorboard-plugin-wit==1.8.1
369
+ - termcolor==2.2.0
370
+ - test-tube==0.7.5
371
+ - thinc==8.1.10
372
+ - threadpoolctl==3.5.0
373
+ - tifffile==2024.2.12
374
+ - timm==0.6.11
375
+ - tinycss2==1.2.1
376
+ - tokenizers==0.20.1
377
+ - toml==0.10.2
378
+ - tomli==2.0.1
379
+ - tomlkit==0.12.0
380
+ - torch==1.13.1
381
+ - torchmetrics==0.6.0
382
+ - torchsummary==1.5.1
383
+ - torchvision==0.14.1
384
+ - tqdm==4.64.1
385
+ - traitlets==5.14.2
386
+ - transformers==4.45.2
387
+ - triton==2.3.0
388
+ - typer==0.12.5
389
+ - typing-inspect==0.9.0
390
+ - uvicorn==0.32.0
391
+ - wandb==0.16.6
392
+ - wasabi==1.1.2
393
+ - webencodings==0.5.1
394
+ - websockets==12.0
395
+ - werkzeug==3.0.2
396
+ - widgetsnbextension==4.0.10
397
+ - xformers==0.0.16
398
+ - xxhash==3.4.1
399
+ - yacs==0.1.8
400
+ - yapf==0.40.2
401
+ - yarl==1.9.4
402
+ - zipp==3.18.1
gifs_filter.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filter images
2
+ from PIL import Image, ImageSequence
3
+ import requests
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import torch
7
+ from transformers import CLIPProcessor, CLIPModel
8
+
9
+ def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
10
+ converted_len = int(clip_len * frame_sample_rate)
11
+ end_idx = np.random.randint(converted_len, seg_len)
12
+ start_idx = end_idx - converted_len
13
+ indices = np.linspace(start_idx, end_idx, num=clip_len)
14
+ indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
15
+ return indices
16
+
17
+ def load_frames(image: Image, mode='RGBA'):
18
+ return np.array([
19
+ np.array(frame.convert(mode))
20
+ for frame in ImageSequence.Iterator(image)
21
+ ])
22
+
23
+ img_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
24
+ img_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
25
+
26
+
27
+
28
+ def filter(gifs, input_image):
29
+ max_cosine = 0.9
30
+ max_gif = []
31
+
32
+ for gif in tqdm(gifs, total=len(gifs)):
33
+ with Image.open(gif) as im:
34
+ frames = load_frames(im)
35
+
36
+ frames = np.array(frames)
37
+ frames = frames[:, :, :, :3]
38
+ frames = np.transpose(frames, (0, 3, 1, 2))[1:]
39
+
40
+
41
+
42
+ image = Image.open(input_image)
43
+
44
+
45
+ inputs = img_processor(images=frames, return_tensors="pt", padding=False)
46
+ inputs_base = img_processor(images=image, return_tensors="pt", padding=False)
47
+
48
+ with torch.no_grad():
49
+ feat_img_base = img_model.get_image_features(pixel_values=inputs_base["pixel_values"])
50
+ feat_img_vid = img_model.get_image_features(pixel_values=inputs["pixel_values"])
51
+ cos_avg = 0
52
+ avg_score_for_vid = 0
53
+ for i in range(len(feat_img_vid)):
54
+
55
+ cosine_similarity = torch.nn.functional.cosine_similarity(
56
+ feat_img_base,
57
+ feat_img_vid[0].unsqueeze(0),
58
+ dim=1)
59
+ # print(cosine_similarity)
60
+ cos_avg += cosine_similarity.item()
61
+
62
+ cos_avg /= len(feat_img_vid)
63
+ print("Current cosine similarity: ", cos_avg)
64
+ print("Max cosine similarity: ", max_cosine)
65
+ if cos_avg > max_cosine:
66
+ # max_cosine = cos_avg
67
+ max_gif.append(gif)
68
+ return max_gif
invert_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union
5
+
6
+ import torch
7
+ import torchvision
8
+
9
+ from tqdm import tqdm
10
+ from einops import rearrange
11
+
12
+
13
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
14
+ videos = rearrange(videos, "b c t h w -> t b c h w")
15
+ outputs = []
16
+ for x in videos:
17
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
18
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
19
+ if rescale:
20
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
21
+ x = (x * 255).numpy().astype(np.uint8)
22
+ outputs.append(x)
23
+
24
+ os.makedirs(os.path.dirname(path), exist_ok=True)
25
+ imageio.mimsave(path, outputs, fps=fps)
26
+
27
+
28
+ # DDIM Inversion
29
+ @torch.no_grad()
30
+ def init_prompt(prompt, pipeline):
31
+ uncond_input = pipeline.tokenizer(
32
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
33
+ return_tensors="pt"
34
+ )
35
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
36
+ text_input = pipeline.tokenizer(
37
+ [prompt],
38
+ padding="max_length",
39
+ max_length=pipeline.tokenizer.model_max_length,
40
+ truncation=True,
41
+ return_tensors="pt",
42
+ )
43
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
44
+ context = torch.cat([uncond_embeddings, text_embeddings])
45
+
46
+ return context
47
+
48
+
49
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
50
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
51
+ timestep, next_timestep = min(
52
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
53
+ # try:
54
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
55
+ # except:
56
+ # alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] #if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
57
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
58
+ beta_prod_t = 1 - alpha_prod_t
59
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
60
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
61
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
62
+ return next_sample
63
+
64
+
65
+ def get_noise_pred_single(latents, t, context, unet):
66
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
67
+ return noise_pred
68
+
69
+
70
+ @torch.no_grad()
71
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
72
+ context = init_prompt(prompt, pipeline)
73
+ uncond_embeddings, cond_embeddings = context.chunk(2)
74
+ all_latent = [latent]
75
+ latent = latent.clone().detach()
76
+ for i in tqdm(range(num_inv_steps)):
77
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
78
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
79
+ noise_pred_unc = get_noise_pred_single(latent, t, uncond_embeddings, pipeline.unet)
80
+ noise_pred = noise_pred_unc + 9.0 * (noise_pred_unc - noise_pred)
81
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
82
+ all_latent.append(latent)
83
+ return all_latent
84
+
85
+
86
+ @torch.no_grad()
87
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
88
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
89
+ return ddim_latents
read_vids.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio.v3 as iio
2
+ import os
3
+ from sys import argv
4
+ video_name = argv[1]
5
+
6
+ video = video_name
7
+ video_id = video.split("/")[-1].replace(".mp4","")
8
+
9
+
10
+ png_base = "png_logs"
11
+ try:
12
+ os.mkdir(png_base)
13
+ except:
14
+ pass
15
+
16
+ video_id = os.path.join(png_base, video_id)
17
+ all_frames = list(iio.imiter(video))
18
+
19
+ ctr = 0
20
+ try:
21
+ os.makedirs(video_id)
22
+ except:
23
+ pass
24
+ for idx, frame in enumerate(all_frames):
25
+
26
+ iio.imwrite(f"{video_id}/{ctr:03d}.jpg", frame)
27
+ ctr += 1
requirements.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.29.2
2
+ blinker==1.9.0
3
+ certifi==2024.8.30
4
+ charset-normalizer==3.4.0
5
+ click==8.1.7
6
+ decorator==4.4.2
7
+ diffusers==0.27.2
8
+ einops==0.8.0
9
+ filelock==3.16.1
10
+ Flask==3.0.3
11
+ fsspec==2024.10.0
12
+ huggingface-hub==0.25.2
13
+ idna==3.10
14
+ imageio==2.36.0
15
+ imageio-ffmpeg==0.5.1
16
+ importlib_metadata==8.5.0
17
+ itsdangerous==2.2.0
18
+ Jinja2==3.1.4
19
+ MarkupSafe==3.0.2
20
+ moviepy==1.0.3
21
+ numpy==1.24.2
22
+ nvidia-cublas-cu11==11.10.3.66
23
+ nvidia-cuda-nvrtc-cu11==11.7.99
24
+ nvidia-cuda-runtime-cu11==11.7.99
25
+ nvidia-cudnn-cu11==8.5.0.96
26
+ opencv-python==4.10.0.84
27
+ packaging==24.2
28
+ pillow==10.4.0
29
+ proglog==0.1.10
30
+ psutil==6.1.0
31
+ python-dotenv==1.0.1
32
+ PyYAML==6.0.2
33
+ regex==2024.11.6
34
+ requests==2.32.3
35
+ safetensors==0.4.5
36
+ tokenizers==0.20.3
37
+ torch==1.13.1
38
+ torchvision==0.14.1
39
+ tqdm==4.67.0
40
+ transformers==4.45.2
41
+ typing_extensions==4.12.2
42
+ urllib3==2.2.3
43
+ Werkzeug==3.1.3
44
+ zipp==3.21.0
static/app_tmp/gif_logs/vid_sketch10-rand0_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand0_508fa599-d685-462e-ad06-11ca4fd15d6f.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand0_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand0_dfcba486-0d8c-4d68-9689-97f1fb889213.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand1_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand1_508fa599-d685-462e-ad06-11ca4fd15d6f.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand1_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand1_dfcba486-0d8c-4d68-9689-97f1fb889213.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand2_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand2_508fa599-d685-462e-ad06-11ca4fd15d6f.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand2_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand2_dfcba486-0d8c-4d68-9689-97f1fb889213.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand3_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand3_508fa599-d685-462e-ad06-11ca4fd15d6f.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand3_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand3_dfcba486-0d8c-4d68-9689-97f1fb889213.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand4_4e766a8e-9d22-4818-8991-e884ce17e5e5.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand4_508fa599-d685-462e-ad06-11ca4fd15d6f.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand4_9e7e07af-2adc-47b0-8aa4-716a934690e8.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand4_dfcba486-0d8c-4d68-9689-97f1fb889213.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand5_508fa599-d685-462e-ad06-11ca4fd15d6f.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand6_508fa599-d685-462e-ad06-11ca4fd15d6f.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand7_508fa599-d685-462e-ad06-11ca4fd15d6f.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand8_508fa599-d685-462e-ad06-11ca4fd15d6f.gif ADDED
static/app_tmp/gif_logs/vid_sketch10-rand9_508fa599-d685-462e-ad06-11ca4fd15d6f.gif ADDED
static/app_tmp/gif_logs/vid_sketch3-rand0_875203a1-f830-46e7-a287-4a0bc2c3a648.gif ADDED
static/app_tmp/gif_logs/vid_sketch3-rand1_875203a1-f830-46e7-a287-4a0bc2c3a648.gif ADDED
static/app_tmp/gif_logs/vid_sketch3-rand2_875203a1-f830-46e7-a287-4a0bc2c3a648.gif ADDED
static/app_tmp/gif_logs/vid_sketch3-rand3_875203a1-f830-46e7-a287-4a0bc2c3a648.gif ADDED
static/app_tmp/gif_logs/vid_sketch3-rand4_875203a1-f830-46e7-a287-4a0bc2c3a648.gif ADDED
static/app_tmp/gif_logs/vid_sketch8-rand0_47fc0372-4688-4a2a-abb3-817ccfee8816.gif ADDED
static/app_tmp/gif_logs/vid_sketch8-rand0_77158110-9239-4771-bb44-a83c3aa47567.gif ADDED
static/app_tmp/gif_logs/vid_sketch8-rand0_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif ADDED
static/app_tmp/gif_logs/vid_sketch8-rand1_47fc0372-4688-4a2a-abb3-817ccfee8816.gif ADDED
static/app_tmp/gif_logs/vid_sketch8-rand1_77158110-9239-4771-bb44-a83c3aa47567.gif ADDED
static/app_tmp/gif_logs/vid_sketch8-rand1_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif ADDED
static/app_tmp/gif_logs/vid_sketch8-rand2_47fc0372-4688-4a2a-abb3-817ccfee8816.gif ADDED
static/app_tmp/gif_logs/vid_sketch8-rand2_77158110-9239-4771-bb44-a83c3aa47567.gif ADDED
static/app_tmp/gif_logs/vid_sketch8-rand2_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif ADDED
static/app_tmp/gif_logs/vid_sketch8-rand3_47fc0372-4688-4a2a-abb3-817ccfee8816.gif ADDED
static/app_tmp/gif_logs/vid_sketch8-rand3_77158110-9239-4771-bb44-a83c3aa47567.gif ADDED
static/app_tmp/gif_logs/vid_sketch8-rand3_fd1dace5-80a2-4a0f-afb1-c6aa0943c91a.gif ADDED