karimbenharrak commited on
Commit
0129d6f
1 Parent(s): 4efc065

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +38 -3
handler.py CHANGED
@@ -1,6 +1,6 @@
1
  from typing import Dict, List, Any
2
  import torch
3
- from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline
4
  from PIL import Image
5
  import base64
6
  from io import BytesIO
@@ -15,7 +15,7 @@ if device.type != 'cuda':
15
  class EndpointHandler():
16
  def __init__(self, path=""):
17
  # load StableDiffusionInpaintPipeline pipeline
18
- self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
19
  "runwayml/stable-diffusion-inpainting",
20
  revision="fp16",
21
  torch_dtype=torch.float16,
@@ -25,6 +25,12 @@ class EndpointHandler():
25
  # move to device
26
  self.pipe = self.pipe.to(device)
27
 
 
 
 
 
 
 
28
 
29
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
30
  """
@@ -43,12 +49,41 @@ class EndpointHandler():
43
  else:
44
  image = None
45
  mask_image = None
 
 
46
 
47
  # run inference pipeline
48
  out = self.pipe(prompt=prompt, image=image, mask_image=mask_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # return first generate PIL image
51
- return out.images[0]
52
 
53
  # helper to decode input image
54
  def decode_base64_image(self, image_string):
 
1
  from typing import Dict, List, Any
2
  import torch
3
+ from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, AutoPipelineForInpainting, AutoPipelineForImage2Image
4
  from PIL import Image
5
  import base64
6
  from io import BytesIO
 
15
  class EndpointHandler():
16
  def __init__(self, path=""):
17
  # load StableDiffusionInpaintPipeline pipeline
18
+ self.pipe = AutoPipelineForInpainting.from_pretrained(
19
  "runwayml/stable-diffusion-inpainting",
20
  revision="fp16",
21
  torch_dtype=torch.float16,
 
25
  # move to device
26
  self.pipe = self.pipe.to(device)
27
 
28
+ self.pipe2 = AutoPipelineForInpainting.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
29
+ self.pipe2.to("cuda")
30
+
31
+ self.pipe3 = AutoPipelineForImage2Image.from_pipe(self.pipe2)
32
+
33
+
34
 
35
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
36
  """
 
49
  else:
50
  image = None
51
  mask_image = None
52
+
53
+ self.pipe.enable_xformers_memory_efficient_attention()
54
 
55
  # run inference pipeline
56
  out = self.pipe(prompt=prompt, image=image, mask_image=mask_image)
57
+
58
+ image = out.images[0].resize((1024, 1024))
59
+
60
+ self.pipe2.enable_xformers_memory_efficient_attention()
61
+
62
+ image = pipe(
63
+ prompt=prompt,
64
+ image=image,
65
+ mask_image=mask_image,
66
+ guidance_scale=8.0,
67
+ num_inference_steps=100,
68
+ strength=0.2,
69
+ generator=generator,
70
+ output_type="latent", # let's keep in latent to save some VRAM
71
+ ).images[0]
72
+
73
+ pipe = AutoPipelineForImage2Image.from_pipe(pipe)
74
+ self.pipe3.enable_xformers_memory_efficient_attention()
75
+
76
+ image = pipe(
77
+ prompt=prompt,
78
+ image=image,
79
+ guidance_scale=8.0,
80
+ num_inference_steps=100,
81
+ strength=0.2,
82
+ generator=generator,
83
+ ).images[0]
84
 
85
  # return first generate PIL image
86
+ return image
87
 
88
  # helper to decode input image
89
  def decode_base64_image(self, image_string):