Linoy Tsaban commited on
Commit
1b62550
1 Parent(s): 502ed04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -5
app.py CHANGED
@@ -1,16 +1,25 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionPipeline, DDIMScheduler
 
4
  from utils import video_to_frames, add_dict_to_yaml_file, save_video, seed_everything
5
  # from diffusers.utils import export_to_video
6
  from tokenflow_pnp import TokenFlow
7
  from preprocess_utils import *
8
  from tokenflow_utils import *
 
9
  # load sd model
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- # model_id = "stabilityai/stable-diffusion-2-1-base"
12
- # inv_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
13
- # inv_pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
 
 
 
 
 
 
 
14
 
15
  def randomize_seed_fn():
16
  seed = random.randint(0, np.iinfo(np.int32).max)
@@ -65,7 +74,12 @@ def prep(config):
65
  else:
66
  save_path = None
67
 
68
- model = Preprocess(device, config)
 
 
 
 
 
69
  print(type(model.config["batch_size"]))
70
  frames, latents, total_inverted_latents, rgb_reconstruction = model.extract_latents(
71
  num_steps=model.config["steps"],
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
4
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
5
  from utils import video_to_frames, add_dict_to_yaml_file, save_video, seed_everything
6
  # from diffusers.utils import export_to_video
7
  from tokenflow_pnp import TokenFlow
8
  from preprocess_utils import *
9
  from tokenflow_utils import *
10
+
11
  # load sd model
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model_id = "stabilityai/stable-diffusion-2-1-base"
14
+
15
+ scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
16
+ vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", revision="fp16",
17
+ torch_dtype=torch.float16).to(device)
18
+ tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
19
+ text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision="fp16",
20
+ torch_dtype=torch.float16).to(device)
21
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", revision="fp16",
22
+ torch_dtype=torch.float16).to(device)
23
 
24
  def randomize_seed_fn():
25
  seed = random.randint(0, np.iinfo(np.int32).max)
 
74
  else:
75
  save_path = None
76
 
77
+ model = Preprocess(device, config,
78
+ vae=vae,
79
+ text_encoder=text_encoder,
80
+ scheduler=scheduler,
81
+ tokenizer=tokenizer,
82
+ unet=unet)
83
  print(type(model.config["batch_size"]))
84
  frames, latents, total_inverted_latents, rgb_reconstruction = model.extract_latents(
85
  num_steps=model.config["steps"],