Suraj Narayanan Sasikumar commited on
Commit
2671954
1 Parent(s): 7aa9aed

add refiner

Browse files
Files changed (1) hide show
  1. handler.py +50 -15
handler.py CHANGED
@@ -16,40 +16,75 @@ if device.type != "cuda":
16
  class EndpointHandler:
17
  def __init__(self, path=""):
18
  # load StableDiffusionInpaintPipeline pipeline
19
- self.pipe = StableDiffusionXLPipeline.from_pretrained(
20
  path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
21
  )
22
  # use DPMSolverMultistepScheduler
23
- self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
24
- self.pipe.scheduler.config
25
  )
26
  # move to device
27
- self.pipe = self.pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
30
  """
31
  :param data: A dictionary contains `inputs` and optional `image` field.
32
  :return: A dictionary with `image` field contains image in base64.
33
  """
34
- prompt = data.pop("inputs", data)
 
 
 
 
35
 
36
  # hyperparamters
 
37
  num_inference_steps = data.pop("num_inference_steps", 30)
38
  guidance_scale = data.pop("guidance_scale", 8)
39
  negative_prompt = data.pop("negative_prompt", None)
 
40
  height = data.pop("height", None)
41
  width = data.pop("width", None)
42
 
43
- # run inference pipeline
44
- out = self.pipe(
45
- prompt,
46
- num_inference_steps=num_inference_steps,
47
- guidance_scale=guidance_scale,
48
- num_images_per_prompt=1,
49
- negative_prompt=negative_prompt,
50
- height=height,
51
- width=width,
52
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # encode image as base 64
55
  buffered = BytesIO()
 
16
  class EndpointHandler:
17
  def __init__(self, path=""):
18
  # load StableDiffusionInpaintPipeline pipeline
19
+ self.base = StableDiffusionXLPipeline.from_pretrained(
20
  path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
21
  )
22
  # use DPMSolverMultistepScheduler
23
+ self.base.scheduler = DPMSolverMultistepScheduler.from_config(
24
+ self.base.scheduler.config
25
  )
26
  # move to device
27
+ self.base = self.base.to(device)
28
+ self.base.unet = torch.compile(self.base.unet, mode="reduce-overhead", fullgraph=True)
29
+
30
+ self.refiner = StableDiffusionXLPipeline.from_pretrained(
31
+ "socialtrait/stable-diffusion-xl-refiner-1.0-infendpoint",
32
+ text_encoder_2=self.base.text_encoder_2,
33
+ vae=self.base.vae,
34
+ torch_dtype=torch.float16,
35
+ use_safetensors=True,
36
+ variant="fp16",
37
+ )
38
+ # use DPMSolverMultistepScheduler
39
+ self.refiner.scheduler = DPMSolverMultistepScheduler.from_config(
40
+ self.refiner.scheduler.config
41
+ )
42
+ self.refiner = self.refiner.to(device)
43
+ self.refiner.unet = torch.compile(self.refiner.unet, mode="reduce-overhead", fullgraph=True)
44
 
45
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
46
  """
47
  :param data: A dictionary contains `inputs` and optional `image` field.
48
  :return: A dictionary with `image` field contains image in base64.
49
  """
50
+ prompt = data.pop("inputs", None)
51
+
52
+ if prompt is None:
53
+ return {"error": "Please provide a prompt"}
54
+
55
 
56
  # hyperparamters
57
+ use_refiner = True if data.pop("use_refiner", False) else False
58
  num_inference_steps = data.pop("num_inference_steps", 30)
59
  guidance_scale = data.pop("guidance_scale", 8)
60
  negative_prompt = data.pop("negative_prompt", None)
61
+ high_noise_frac = data.pop("high_noise_frac", 0.8)
62
  height = data.pop("height", None)
63
  width = data.pop("width", None)
64
 
65
+ if use_refiner:
66
+ image = self.base(
67
+ prompt=prompt,
68
+ num_inference_steps=num_inference_steps,
69
+ denoising_end=high_noise_frac,
70
+ output_type="latent",
71
+ ).images
72
+ out = self.refiner(
73
+ prompt=prompt,
74
+ num_inference_steps=num_inference_steps,
75
+ denoising_start=high_noise_frac,
76
+ image=image,
77
+ )
78
+ else:
79
+ out = self.pipe(
80
+ prompt,
81
+ num_inference_steps=num_inference_steps,
82
+ guidance_scale=guidance_scale,
83
+ num_images_per_prompt=1,
84
+ negative_prompt=negative_prompt,
85
+ height=height,
86
+ width=width,
87
+ )
88
 
89
  # encode image as base 64
90
  buffered = BytesIO()