rynmurdock commited on
Commit
385fb5f
1 Parent(s): 93b9a94

device changes

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -34,18 +34,18 @@ start_time = time.time()
34
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
35
  sdxl_lightening = "ByteDance/SDXL-Lightning"
36
  ckpt = "sdxl_lightning_2step_unet.safetensors"
37
- unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to("cuda", torch.float16)
38
- unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device="cuda"))
39
 
40
- image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16,).to("cuda")
41
- pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder).to("cuda")
42
  pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin')))
43
  pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin")
44
  pipe.register_modules(image_encoder = image_encoder)
45
 
46
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
47
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
48
- pipe.to(device='cuda')
49
 
50
 
51
  output_hidden_state = False
@@ -60,13 +60,13 @@ def predict(
60
  """Run a single prediction on the model"""
61
  with torch.no_grad():
62
  if im_emb == None:
63
- im_emb = torch.zeros(1, 1024, dtype=torch.float16, device='cuda')
64
 
65
- im_emb = [im_emb.to('cuda').unsqueeze(0)]
66
  if prompt == '':
67
  image = pipe(
68
- prompt_embeds=torch.zeros(1, 1, 2048, dtype=torch.float16, device='cuda'),
69
- pooled_prompt_embeds=torch.zeros(1, 1280, dtype=torch.float16, device='cuda'),
70
  ip_adapter_image_embeds=im_emb,
71
  height=1024,
72
  width=1024,
@@ -83,9 +83,9 @@ def predict(
83
  guidance_scale=0,
84
  ).images[0]
85
  im_emb, _ = pipe.encode_image(
86
- image, 'cuda', 1, output_hidden_state
87
  )
88
- return image, im_emb.to(DEVICE)
89
 
90
  # TODO add to state instead of shared across all
91
  glob_idx = 0
@@ -128,7 +128,7 @@ def next_image(embs, ys, calibrate_prompts):
128
  if has_0 and has_1:
129
  break
130
 
131
- feature_embs = np.array(torch.cat([embs[i] for i in indices]).to('cpu'))
132
  scaler = preprocessing.StandardScaler().fit(feature_embs)
133
  feature_embs = scaler.transform(feature_embs)
134
 
@@ -138,7 +138,7 @@ def next_image(embs, ys, calibrate_prompts):
138
 
139
  rng_prompt = random.choice(prompt_list)
140
  w = 1# if len(embs) % 2 == 0 else 0
141
- im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16)
142
  prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
143
  print(prompt, len(ys))
144
  image, im_emb = predict(prompt, im_emb)
 
34
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
35
  sdxl_lightening = "ByteDance/SDXL-Lightning"
36
  ckpt = "sdxl_lightning_2step_unet.safetensors"
37
+ unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to(DEVICE, torch.float16)
38
+ unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device=DEVICE))
39
 
40
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16,).to(DEVICE)
41
+ pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder).to(DEVICE)
42
  pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin')))
43
  pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin")
44
  pipe.register_modules(image_encoder = image_encoder)
45
 
46
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
47
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
48
+ pipe.to(device=DEVICE)
49
 
50
 
51
  output_hidden_state = False
 
60
  """Run a single prediction on the model"""
61
  with torch.no_grad():
62
  if im_emb == None:
63
+ im_emb = torch.zeros(1, 1024, dtype=torch.float16, device=DEVICE)
64
 
65
+ im_emb = [im_emb.to(DEVICE).unsqueeze(0)]
66
  if prompt == '':
67
  image = pipe(
68
+ prompt_embeds=torch.zeros(1, 1, 2048, dtype=torch.float16, device=DEVICE),
69
+ pooled_prompt_embeds=torch.zeros(1, 1280, dtype=torch.float16, device=DEVICE),
70
  ip_adapter_image_embeds=im_emb,
71
  height=1024,
72
  width=1024,
 
83
  guidance_scale=0,
84
  ).images[0]
85
  im_emb, _ = pipe.encode_image(
86
+ image, DEVICE, 1, output_hidden_state
87
  )
88
+ return image, im_emb.to('cpu')
89
 
90
  # TODO add to state instead of shared across all
91
  glob_idx = 0
 
128
  if has_0 and has_1:
129
  break
130
 
131
+ feature_embs = np.array(torch.cat([embs[i].to('cpu') for i in indices]).to('cpu'))
132
  scaler = preprocessing.StandardScaler().fit(feature_embs)
133
  feature_embs = scaler.transform(feature_embs)
134
 
 
138
 
139
  rng_prompt = random.choice(prompt_list)
140
  w = 1# if len(embs) % 2 == 0 else 0
141
+ im_emb = w * lin_class.coef_.to(dtype=torch.float16)
142
  prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
143
  print(prompt, len(ys))
144
  image, im_emb = predict(prompt, im_emb)