Tony Lian commited on
Commit
1d6e0a9
β€’
1 Parent(s): ec7f11c

generator seed fix for baseline

Browse files
Files changed (1) hide show
  1. baseline.py +1 -2
baseline.py CHANGED
@@ -22,12 +22,11 @@ bg_negative = DEFAULT_OVERALL_NEGATIVE_PROMPT
22
  # Using dpm scheduler by default
23
  def run(prompt, scheduler_key='dpm_scheduler', bg_seed=1, num_inference_steps=20):
24
  print(f"prompt: {prompt}")
25
- generator = torch.Generator(models.torch_device).manual_seed(bg_seed)
26
 
27
  prompts = [prompt]
28
  input_embeddings = models.encode_prompts(prompts=prompts, tokenizer=tokenizer, text_encoder=text_encoder, negative_prompt=bg_negative)
29
 
30
- generator = torch.manual_seed(1) # Seed generator to create the inital latent noise
31
  latents = models.get_unscaled_latents(batch_size, unet.config.in_channels, height, width, generator, dtype)
32
 
33
  latents = latents * scheduler.init_noise_sigma
 
22
  # Using dpm scheduler by default
23
  def run(prompt, scheduler_key='dpm_scheduler', bg_seed=1, num_inference_steps=20):
24
  print(f"prompt: {prompt}")
25
+ generator = torch.manual_seed(bg_seed)
26
 
27
  prompts = [prompt]
28
  input_embeddings = models.encode_prompts(prompts=prompts, tokenizer=tokenizer, text_encoder=text_encoder, negative_prompt=bg_negative)
29
 
 
30
  latents = models.get_unscaled_latents(batch_size, unet.config.in_channels, height, width, generator, dtype)
31
 
32
  latents = latents * scheduler.init_noise_sigma