lllyasviel commited on
Commit
ca685d6
·
1 Parent(s): 0e1497e
Files changed (1) hide show
  1. modules/default_pipeline.py +16 -2
modules/default_pipeline.py CHANGED
@@ -9,6 +9,8 @@ xl_base_filename = os.path.join(modelfile_path, 'sd_xl_base_1.0.safetensors')
9
  xl_refiner_filename = os.path.join(modelfile_path, 'sd_xl_refiner_1.0.safetensors')
10
 
11
  xl_base = core.load_model(xl_base_filename)
 
 
12
 
13
 
14
  @torch.no_grad()
@@ -16,16 +18,28 @@ def process(positive_prompt, negative_prompt, width=1024, height=1024, batch_siz
16
  positive_conditions = core.encode_prompt_condition(clip=xl_base.clip, prompt=positive_prompt)
17
  negative_conditions = core.encode_prompt_condition(clip=xl_base.clip, prompt=negative_prompt)
18
 
 
 
 
19
  empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=batch_size)
20
 
21
  sampled_latent = core.ksample(
22
  unet=xl_base.unet,
23
  positive_condition=positive_conditions,
24
  negative_condition=negative_conditions,
25
- latent_image=empty_latent
 
 
 
 
 
 
 
 
 
26
  )
27
 
28
- decoded_latent = core.decode_vae(vae=xl_base.vae, latent_image=sampled_latent)
29
 
30
  images = core.image_to_numpy(decoded_latent)
31
  return images
 
9
  xl_refiner_filename = os.path.join(modelfile_path, 'sd_xl_refiner_1.0.safetensors')
10
 
11
  xl_base = core.load_model(xl_base_filename)
12
+ xl_refiner = core.load_model(xl_refiner_filename)
13
+ del xl_base.vae
14
 
15
 
16
  @torch.no_grad()
 
18
  positive_conditions = core.encode_prompt_condition(clip=xl_base.clip, prompt=positive_prompt)
19
  negative_conditions = core.encode_prompt_condition(clip=xl_base.clip, prompt=negative_prompt)
20
 
21
+ positive_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=positive_prompt)
22
+ negative_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=negative_prompt)
23
+
24
  empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=batch_size)
25
 
26
  sampled_latent = core.ksample(
27
  unet=xl_base.unet,
28
  positive_condition=positive_conditions,
29
  negative_condition=negative_conditions,
30
+ latent_image=empty_latent,
31
+ steps=30, start_at_step=0, end_at_step=20, return_with_leftover_noise=True, add_noise=True
32
+ )
33
+
34
+ sampled_latent = core.ksample(
35
+ unet=xl_refiner.unet,
36
+ positive_condition=positive_conditions_refiner,
37
+ negative_condition=negative_conditions_refiner,
38
+ latent_image=sampled_latent,
39
+ steps=30, start_at_step=20, end_at_step=30, return_with_leftover_noise=False, add_noise=False
40
  )
41
 
42
+ decoded_latent = core.decode_vae(vae=xl_refiner.vae, latent_image=sampled_latent)
43
 
44
  images = core.image_to_numpy(decoded_latent)
45
  return images