cocktailpeanut commited on
Commit
ad2d8cc
·
1 Parent(s): 7a86a0a
Files changed (3) hide show
  1. app.py +9 -3
  2. preprocess_utils.py +4 -48
  3. tokenflow_pnp.py +2 -2
app.py CHANGED
@@ -7,7 +7,13 @@ from tokenflow_pnp import TokenFlow
7
  from preprocess_utils import *
8
  from tokenflow_utils import *
9
  # load sd model
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
11
  model_id = "stabilityai/stable-diffusion-2-1-base"
12
 
13
  # components for the Preprocessor
@@ -21,7 +27,7 @@ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", revision
21
  torch_dtype=torch.float16).to(device)
22
 
23
  # pipe for TokenFlow
24
- tokenflow_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
25
  tokenflow_pipe.enable_xformers_memory_efficient_attention()
26
 
27
  def randomize_seed_fn():
@@ -371,4 +377,4 @@ with gr.Blocks(css="style.css") as demo:
371
  )
372
 
373
  demo.queue()
374
- demo.launch()
 
7
  from preprocess_utils import *
8
  from tokenflow_utils import *
9
  # load sd model
10
+ #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ if torch.cuda.is_available():
12
+ device = "cuda"
13
+ elif torch.backends.mps.is_available():
14
+ device = "mps"
15
+ else:
16
+ device = "cpu"
17
  model_id = "stabilityai/stable-diffusion-2-1-base"
18
 
19
  # components for the Preprocessor
 
27
  torch_dtype=torch.float16).to(device)
28
 
29
  # pipe for TokenFlow
30
+ tokenflow_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
31
  tokenflow_pipe.enable_xformers_memory_efficient_attention()
32
 
33
  def randomize_seed_fn():
 
377
  )
378
 
379
  demo.queue()
380
+ demo.launch()
preprocess_utils.py CHANGED
@@ -92,7 +92,7 @@ class Preprocess(nn.Module):
92
  def prepare_depth_maps(self, model_type='DPT_Large', device='cuda'):
93
  depth_maps = []
94
  midas = torch.hub.load("intel-isl/MiDaS", model_type)
95
- midas.to(device)
96
  midas.eval()
97
 
98
  midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
@@ -109,7 +109,7 @@ class Preprocess(nn.Module):
109
  latent_h = img.shape[0] // 8
110
  latent_w = img.shape[1] // 8
111
 
112
- input_batch = transform(img).to(device)
113
  prediction = midas(input_batch)
114
 
115
  depth_map = torch.nn.functional.interpolate(
@@ -167,10 +167,10 @@ class Preprocess(nn.Module):
167
  def get_text_embeds(self, prompt, negative_prompt, device="cuda"):
168
  text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
169
  truncation=True, return_tensors='pt')
170
- text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
171
  uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
172
  return_tensors='pt')
173
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
174
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
175
  return text_embeddings
176
 
@@ -329,47 +329,3 @@ class Preprocess(nn.Module):
329
  return self.frames, self.latents, self.total_inverted_latents, None
330
 
331
 
332
- def prep(opt):
333
- # timesteps to save
334
- if opt["sd_version"] == '2.1':
335
- model_key = "stabilityai/stable-diffusion-2-1-base"
336
- elif opt["sd_version"] == '2.0':
337
- model_key = "stabilityai/stable-diffusion-2-base"
338
- elif opt["sd_version"] == '1.5' or opt["sd_version"] == 'ControlNet':
339
- model_key = "runwayml/stable-diffusion-v1-5"
340
- elif opt["sd_version"] == 'depth':
341
- model_key = "stabilityai/stable-diffusion-2-depth"
342
- toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
343
- toy_scheduler.set_timesteps(opt["save_steps"])
344
- timesteps_to_save, num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=opt["save_steps"],
345
- strength=1.0,
346
- device=device)
347
-
348
- seed_everything(opt["seed"])
349
- if not opt["frames"]: # original non demo setting
350
- save_path = os.path.join(opt["save_dir"],
351
- f'sd_{opt["sd_version"]}',
352
- Path(opt["data_path"]).stem,
353
- f'steps_{opt["steps"]}',
354
- f'nframes_{opt["n_frames"]}')
355
- os.makedirs(os.path.join(save_path, f'latents'), exist_ok=True)
356
- add_dict_to_yaml_file(os.path.join(opt["save_dir"], 'inversion_prompts.yaml'), Path(opt["data_path"]).stem, opt["inversion_prompt"])
357
- # save inversion prompt in a txt file
358
- with open(os.path.join(save_path, 'inversion_prompt.txt'), 'w') as f:
359
- f.write(opt["inversion_prompt"])
360
- else:
361
- save_path = None
362
-
363
- model = Preprocess(device, opt)
364
-
365
- frames, latents, total_inverted_latents, rgb_reconstruction = model.extract_latents(
366
- num_steps=model.config["steps"],
367
- save_path=save_path,
368
- batch_size=model.config["batch_size"],
369
- timesteps_to_save=timesteps_to_save,
370
- inversion_prompt=model.config["inversion_prompt"],
371
- )
372
-
373
-
374
- return frames, latents, total_inverted_latents, rgb_reconstruction
375
-
 
92
  def prepare_depth_maps(self, model_type='DPT_Large', device='cuda'):
93
  depth_maps = []
94
  midas = torch.hub.load("intel-isl/MiDaS", model_type)
95
+ midas.to(self.device)
96
  midas.eval()
97
 
98
  midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
 
109
  latent_h = img.shape[0] // 8
110
  latent_w = img.shape[1] // 8
111
 
112
+ input_batch = transform(img).to(self.device)
113
  prediction = midas(input_batch)
114
 
115
  depth_map = torch.nn.functional.interpolate(
 
167
  def get_text_embeds(self, prompt, negative_prompt, device="cuda"):
168
  text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
169
  truncation=True, return_tensors='pt')
170
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
171
  uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
172
  return_tensors='pt')
173
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
174
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
175
  return text_embeddings
176
 
 
329
  return self.frames, self.latents, self.total_inverted_latents, None
330
 
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenflow_pnp.py CHANGED
@@ -78,7 +78,7 @@ class TokenFlow(nn.Module):
78
  def prepare_depth_maps(self, model_type='DPT_Large', device='cuda'):
79
  depth_maps = []
80
  midas = torch.hub.load("intel-isl/MiDaS", model_type)
81
- midas.to(device)
82
  midas.eval()
83
 
84
  midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
@@ -95,7 +95,7 @@ class TokenFlow(nn.Module):
95
  latent_h = img.shape[0] // 8
96
  latent_w = img.shape[1] // 8
97
 
98
- input_batch = transform(img).to(device)
99
  prediction = midas(input_batch)
100
 
101
  depth_map = torch.nn.functional.interpolate(
 
78
  def prepare_depth_maps(self, model_type='DPT_Large', device='cuda'):
79
  depth_maps = []
80
  midas = torch.hub.load("intel-isl/MiDaS", model_type)
81
+ midas.to(self.device)
82
  midas.eval()
83
 
84
  midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
 
95
  latent_h = img.shape[0] // 8
96
  latent_w = img.shape[1] // 8
97
 
98
+ input_batch = transform(img).to(self.device)
99
  prediction = midas(input_batch)
100
 
101
  depth_map = torch.nn.functional.interpolate(