Ahmed Essam commited on
Commit
402cce1
·
verified ·
1 Parent(s): 5a554da

Upload 5 files

Browse files
Files changed (5) hide show
  1. handler.py +36 -0
  2. model.py +115 -0
  3. preprocessor.py +27 -0
  4. requirements.txt +13 -0
  5. settings.py +17 -0
handler.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ import base64
4
+ from io import BytesIO
5
+ from model import Model
6
+ from PIL import Image
7
+ # set device
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+ if device.type != 'cuda':
11
+ raise ValueError("need to run on GPU")
12
+
13
+ class EndpointHandler():
14
+ def __init__(self, path=""):
15
+ # load the optimized model
16
+ self.model = Model()
17
+
18
+
19
+ def __call__(self, data: Any) -> Any:
20
+ """
21
+ Args:
22
+ data (:obj:):
23
+ includes the input data and the parameters for the inference.
24
+ Return:
25
+ A :obj:`dict`:. base64 encoded image
26
+ """
27
+ inputs = data.pop("image", data)
28
+
29
+ image = Image.open(BytesIO(base64.b64decode(inputs)))
30
+
31
+ # run inference pipeline
32
+ _, res = self.model.process_lineart(image = image)
33
+
34
+
35
+ # encoding image as base 64 is done by the default toolkit
36
+ return res
model.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import gc
3
+ import numpy as np
4
+ import PIL.Image
5
+ import torch
6
+ from diffusers import (
7
+ ControlNetModel,
8
+ DiffusionPipeline,
9
+ StableDiffusionControlNetPipeline,
10
+ UniPCMultistepScheduler,
11
+ )
12
+
13
+ from preprocessor import Preprocessor
14
+ from settings import *
15
+
16
+
17
+ class Model:
18
+ def __init__(self, base_model_id: str = "runwayml/stable-diffusion-v1-5", task_name: str = "lineart"):
19
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+ self.base_model_id = ""
21
+ self.task_name = ""
22
+ self.pipe = self.load_pipe(base_model_id, task_name)
23
+ self.preprocessor = Preprocessor()
24
+
25
+ def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline:
26
+ if (
27
+ base_model_id == self.base_model_id
28
+ and task_name == self.task_name
29
+ and hasattr(self, "pipe")
30
+ and self.pipe is not None
31
+ ):
32
+ return self.pipe
33
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
34
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
35
+ base_model_id, safety_checker=None, controlnet=controlnet, torch_dtype=torch.float16
36
+ )
37
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
38
+ if self.device.type == "cuda":
39
+ pipe.enable_xformers_memory_efficient_attention()
40
+ pipe.to(self.device)
41
+ torch.cuda.empty_cache()
42
+ gc.collect()
43
+ self.base_model_id = base_model_id
44
+ self.task_name = task_name
45
+ return pipe
46
+
47
+ def set_base_model(self, base_model_id: str) -> str:
48
+ if not base_model_id or base_model_id == self.base_model_id:
49
+ return self.base_model_id
50
+ del self.pipe
51
+ torch.cuda.empty_cache()
52
+ gc.collect()
53
+ try:
54
+ self.pipe = self.load_pipe(base_model_id, self.task_name)
55
+ except Exception:
56
+ self.pipe = self.load_pipe(self.base_model_id, self.task_name)
57
+ return self.base_model_id
58
+
59
+ def load_controlnet_weight(self, task_name: str) -> None:
60
+ if task_name == self.task_name:
61
+ return
62
+ if self.pipe is not None and hasattr(self.pipe, "controlnet"):
63
+ del self.pipe.controlnet
64
+ torch.cuda.empty_cache()
65
+ gc.collect()
66
+ controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
67
+ controlnet.to(self.device)
68
+ torch.cuda.empty_cache()
69
+ gc.collect()
70
+ self.pipe.controlnet = controlnet
71
+ self.task_name = task_name
72
+
73
+ def get_prompt(self, prompt: str, additional_prompt: str) -> str:
74
+ if not prompt:
75
+ prompt = additional_prompt
76
+ else:
77
+ prompt = f"{prompt}, {additional_prompt}"
78
+ return prompt
79
+
80
+ @torch.autocast("cuda")
81
+ def run_pipe(
82
+ self,
83
+ control_image: PIL.Image.Image,
84
+ ) -> list[PIL.Image.Image]:
85
+ generator = torch.Generator().manual_seed(randomize_seed)
86
+ return self.pipe(
87
+ prompt=prompt + ' ' + a_prompt,
88
+ negative_prompt=n_prompt,
89
+ guidance_scale=guidance_scale,
90
+ num_images_per_prompt=DEFAULT_NUM_IMAGES,
91
+ num_inference_steps=num_steps,
92
+ generator=generator,
93
+ image=control_image,
94
+ ).images
95
+
96
+ def process_lineart(
97
+ self,
98
+ image: np.ndarray,
99
+ ) -> list[PIL.Image.Image]:
100
+ if image is None:
101
+ raise ValueError
102
+
103
+ else:
104
+
105
+ self.preprocessor.load("Lineart")
106
+ control_image = self.preprocessor(
107
+ image=image,
108
+ image_resolution=DEFAULT_IMAGE_RESOLUTION,
109
+ detect_resolution=preprocess_resolution,
110
+ )
111
+ self.load_controlnet_weight("lineart")
112
+ results = self.run_pipe(
113
+ control_image=control_image
114
+ )
115
+ return [control_image] + results
preprocessor.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import PIL.Image
3
+ import torch
4
+ from controlnet_aux import LineartDetector
5
+
6
+ class Preprocessor:
7
+ MODEL_ID = "lllyasviel/Annotators"
8
+
9
+ def __init__(self):
10
+ self.model = None
11
+ self.name = ""
12
+
13
+ def load(self, name: str) -> None:
14
+ if name == self.name:
15
+ return
16
+ if name == "Lineart":
17
+ self.model = LineartDetector.from_pretrained(self.MODEL_ID)
18
+
19
+ else:
20
+ raise ValueError
21
+ torch.cuda.empty_cache()
22
+ gc.collect()
23
+ self.name = name
24
+
25
+ def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
26
+ return self.model(image, **kwargs)
27
+
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.21.0
2
+ controlnet_aux==0.0.6
3
+ diffusers==0.18.2
4
+ einops==0.6.1
5
+ gradio==3.45.2
6
+ huggingface-hub==0.16.4
7
+ mediapipe==0.10.1
8
+ opencv-python-headless==4.8.0.74
9
+ safetensors==0.3.1
10
+ torch==2.0.1
11
+ torchvision==0.15.2
12
+ transformers==4.30.2
13
+ xformers==0.0.20
settings.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT_MODEL_ID = "runwayml/stable-diffusion-v1-5"
2
+ DEFAULT_NUM_IMAGES = 1
3
+ MAX_IMAGE_RESOLUTION = 768
4
+ DEFAULT_IMAGE_RESOLUTION = 768
5
+ preprocess_resolution = 512
6
+ num_steps = 20
7
+ guidance_scale = 9
8
+ randomize_seed = 0
9
+
10
+ task_name = "lineart"
11
+ model_id = "lllyasviel/control_v11p_sd15_lineart"
12
+ prompt = "Architecture, Building, Realistic, 3D Rendering, 2D Elevation, Professional."
13
+
14
+ a_prompt = "best quality, extremely detailed"
15
+ n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
16
+ preprocessor_name = 'lineart'
17
+