nsfwalex commited on
Commit
5314072
1 Parent(s): 635cd94

Update inference_manager.py

Browse files
Files changed (1) hide show
  1. inference_manager.py +3 -4
inference_manager.py CHANGED
@@ -514,7 +514,6 @@ class ModelManager:
514
  faceid_all_embeds.append(faceid_embed)
515
 
516
  average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
517
- average_embedding = average_embedding.to("cuda")
518
 
519
  print("start inference...")
520
  style_selection = ""
@@ -523,9 +522,9 @@ class ModelManager:
523
  seed = seed or int(randomize_seed_fn(seed, randomize_seed))
524
  p = remove_child_related_content(p)
525
  prompt_str = cfg.get("prompt", "{prompt}").replace("{prompt}", p)
526
- generator = torch.Generator(model.base_model_pipeline.device).manual_seed(seed)
527
  print(f"generate: p={p}, np={negative_prompt}, steps={steps}, guidance_scale={guidance_scale}, size={width},{height}, seed={seed}")
528
- print(f"device: embedding={average_embedding.device}, generator={generator.device}, ip_model={ip_model.device}, pipe={model.base_model_pipeline.device}")
529
  images = ip_model.generate(
530
  prompt=prompt_str,
531
  negative_prompt=negative_prompt,
@@ -535,7 +534,7 @@ class ModelManager:
535
  height=height,
536
  guidance_scale=face_strength,
537
  num_inference_steps=steps,
538
- generator=generator,
539
  num_images_per_prompt=1,
540
  #output_type="pil",
541
  #callback_on_step_end=callback_dynamic_cfg,
 
514
  faceid_all_embeds.append(faceid_embed)
515
 
516
  average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
 
517
 
518
  print("start inference...")
519
  style_selection = ""
 
522
  seed = seed or int(randomize_seed_fn(seed, randomize_seed))
523
  p = remove_child_related_content(p)
524
  prompt_str = cfg.get("prompt", "{prompt}").replace("{prompt}", p)
525
+ #generator = torch.Generator(model.base_model_pipeline.device).manual_seed(seed)
526
  print(f"generate: p={p}, np={negative_prompt}, steps={steps}, guidance_scale={guidance_scale}, size={width},{height}, seed={seed}")
527
+ #print(f"device: embedding={average_embedding.device}, generator={generator.device}, ip_model={ip_model.device}, pipe={model.base_model_pipeline.device}")
528
  images = ip_model.generate(
529
  prompt=prompt_str,
530
  negative_prompt=negative_prompt,
 
534
  height=height,
535
  guidance_scale=face_strength,
536
  num_inference_steps=steps,
537
+ #generator=generator,
538
  num_images_per_prompt=1,
539
  #output_type="pil",
540
  #callback_on_step_end=callback_dynamic_cfg,