pwaldron commited on
Commit
8330e27
1 Parent(s): 763964a

Upload handler and requirements

Browse files
Files changed (2) hide show
  1. handler.py +87 -0
  2. requirements.txt +1 -0
handler.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ import base64
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, AutoencoderKL
7
+ from controlnet_aux.pidi import PidiNetDetector
8
+
9
+ # set device
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ if device.type != 'cuda':
13
+ raise ValueError("need to run on GPU")
14
+
15
+ class EndpointHandler():
16
+ def __init__(self, path=""):
17
+ # Preload all the elements you are going to need at inference.
18
+ # pseudo:
19
+ # self.model= load_model(path)
20
+
21
+ adapter = T2IAdapter.from_pretrained(
22
+ "Adapter/t2iadapter",
23
+ subfolder="sketch_sdxl_1.0",
24
+ torch_dtype=torch.float16,
25
+ adapter_type="full_adapter_xl"
26
+ )
27
+
28
+ vae = AutoencoderKL.from_pretrained(
29
+ "madebyollin/sdxl-vae-fp16-fix",
30
+ torch_dtype=torch.float16,
31
+ use_safetensors=True
32
+ )
33
+
34
+ self.pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
35
+ "stabilityai/stable-diffusion-xl-base-1.0",
36
+ adapter=adapter,
37
+ vae=vae,
38
+ torch_dtype=torch.float16,
39
+ variant="fp16"
40
+ ).to("cuda")
41
+ self.pipeline.enable_sequential_cpu_offload()
42
+
43
+ self.pidinet = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
44
+
45
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
46
+ """
47
+ data args:
48
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
49
+ kwargs
50
+ Return:
51
+ A :obj:`list` | `dict`: will be serialized and returned
52
+ """
53
+
54
+ # pseudo
55
+ # self.model(input)
56
+
57
+ # get inputs
58
+ inputs = data.pop("inputs", "")
59
+ encoded_image = data.pop("image", None)
60
+
61
+ # Decode image and convert to black and white sketch
62
+ decoded_image = self.decode_base64_image(encoded_image).convert('RGB')
63
+ sketch_image = self.pidinet(
64
+ decoded_image,
65
+ detect_resolution=1024,
66
+ image_resolution=1024,
67
+ apply_filter=True
68
+ ).convert('L')
69
+
70
+ # sketch_image.save("./output1.png")
71
+
72
+ output_image = self.pipeline(
73
+ prompt=inputs,
74
+ negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
75
+ image=sketch_image,
76
+ guidance_scale=7.5,
77
+ ).images[0]
78
+
79
+ # output_image.save("./output2.png")
80
+ return output_image
81
+
82
+ # helper to decode input image
83
+ def decode_base64_image(self, image_string):
84
+ base64_image = base64.b64decode(image_string)
85
+ buffer = BytesIO(base64_image)
86
+ image = Image.open(buffer)
87
+ return image
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ controlnet-aux