BertChristiaens commited on
Commit
be0162b
1 Parent(s): e36ef6a
Files changed (1) hide show
  1. models.py +57 -48
models.py CHANGED
@@ -4,6 +4,7 @@ from typing import List, Tuple, Dict
4
 
5
  import streamlit as st
6
  import torch
 
7
  import numpy as np
8
  from PIL import Image
9
  from time import perf_counter
@@ -23,6 +24,45 @@ from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNe
23
  LOGGING = logging.getLogger(__name__)
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  @contextmanager
27
  def catchtime(message: str) -> float:
28
  """Context manager to measure time
@@ -81,22 +121,8 @@ def get_controlnet() -> ControlNetModel:
81
  Returns:
82
  ControlNetModel: controlnet model
83
  """
84
- controlnet = ControlNetModel.from_pretrained(
85
- "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16)
86
-
87
- pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
88
- "runwayml/stable-diffusion-inpainting",
89
- controlnet=controlnet,
90
- safety_checker=None,
91
- torch_dtype=torch.float16
92
- )
93
-
94
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
95
- pipe.enable_xformers_memory_efficient_attention()
96
- pipe = pipe.to("cuda")
97
-
98
- compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
99
- return pipe, compel_proc
100
 
101
 
102
  @st.experimental_singleton(max_entries=5)
@@ -126,9 +152,7 @@ def get_inpainting_pipeline() -> StableDiffusionInpaintPipeline:
126
  pipe.enable_xformers_memory_efficient_attention()
127
  pipe = pipe.to("cuda")
128
 
129
- compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
130
-
131
- return pipe, compel_proc
132
 
133
 
134
  def make_grid_parameters(grid_search: Dict, params: Dict) -> List[Dict]:
@@ -185,27 +209,18 @@ def make_image_controlnet(image: np.ndarray,
185
  """
186
 
187
  with catchtime("get controlnet"):
188
- pipe, proc = get_controlnet()
189
 
190
  torch.cuda.empty_cache()
191
  images = []
192
 
193
- if '+' in positive_prompt or '-' in positive_prompt or '+' in negative_prompt or '-' in negative_prompt:
194
- common_parameters = {'prompt_embeds': proc(positive_prompt),
195
- 'negative_prompt_embeds': proc(negative_prompt),
196
- 'num_inference_steps': 30,
197
- 'controlnet_conditioning_scale': 1.1,
198
- 'controlnet_conditioning_scale_decay': 0.96,
199
- 'controlnet_steps': 28,
200
- }
201
- else:
202
- common_parameters = {'prompt': positive_prompt,
203
- 'negative_prompt': negative_prompt,
204
- 'num_inference_steps': 30,
205
- 'controlnet_conditioning_scale': 1.1,
206
- 'controlnet_conditioning_scale_decay': 0.96,
207
- 'controlnet_steps': 28,
208
- }
209
 
210
  grid_search = {'strength': [1.00, ],
211
  'guidance_scale': [7.0],
@@ -253,18 +268,12 @@ def make_inpainting(positive_prompt: str,
253
  """
254
 
255
  with catchtime("Get inpainting pipeline"):
256
- pipe, proc = get_inpainting_pipeline()
257
-
258
- if '+' in positive_prompt or '-' in positive_prompt or '+' in negative_prompt or '-' in negative_prompt:
259
- common_parameters = {'prompt_embeds': proc(positive_prompt),
260
- 'negative_prompt_embeds': proc(negative_prompt),
261
- 'num_inference_steps': 20,
262
- }
263
- else:
264
- common_parameters = {'prompt': positive_prompt,
265
- 'negative_prompt': negative_prompt,
266
- 'num_inference_steps': 20,
267
- }
268
 
269
  torch.cuda.empty_cache()
270
  images = []
 
4
 
5
  import streamlit as st
6
  import torch
7
+ import time
8
  import numpy as np
9
  from PIL import Image
10
  from time import perf_counter
 
24
  LOGGING = logging.getLogger(__name__)
25
 
26
 
27
+ class ControlNetPipeline:
28
+ def __init__(self):
29
+ self.in_use = False
30
+ self.controlnet = ControlNetModel.from_pretrained(
31
+ "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16)
32
+
33
+ self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
34
+ "runwayml/stable-diffusion-inpainting",
35
+ controlnet=self.controlnet,
36
+ safety_checker=None,
37
+ torch_dtype=torch.float16
38
+ )
39
+
40
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
41
+ self.pipe.enable_xformers_memory_efficient_attention()
42
+ self.pipe = self.pipe.to("cuda")
43
+
44
+ self.waiting_queue = []
45
+ self.count = 0
46
+
47
+ def __call__(self, **kwargs):
48
+ self.count += 1
49
+ number = self.count
50
+
51
+ self.waiting_queue.append(number)
52
+
53
+ # wait until the next number in the queue is the current number
54
+ while self.waiting_queue[0] != number:
55
+ print(f"Wait for your turn {number} in queue {self.waiting_queue}")
56
+ time.sleep(0.5)
57
+ pass
58
+
59
+ # it's your turn, so remove the number from the queue
60
+ # and call the function
61
+ self.waiting_queue.pop(0)
62
+ print("It's the turn of", self.count)
63
+ return self.pipe(**kwargs)
64
+
65
+
66
  @contextmanager
67
  def catchtime(message: str) -> float:
68
  """Context manager to measure time
 
121
  Returns:
122
  ControlNetModel: controlnet model
123
  """
124
+ pipe = ControlNetPipeline()
125
+ return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
 
128
  @st.experimental_singleton(max_entries=5)
 
152
  pipe.enable_xformers_memory_efficient_attention()
153
  pipe = pipe.to("cuda")
154
 
155
+ return pipe
 
 
156
 
157
 
158
  def make_grid_parameters(grid_search: Dict, params: Dict) -> List[Dict]:
 
209
  """
210
 
211
  with catchtime("get controlnet"):
212
+ pipe = get_controlnet()
213
 
214
  torch.cuda.empty_cache()
215
  images = []
216
 
217
+ common_parameters = {'prompt': positive_prompt,
218
+ 'negative_prompt': negative_prompt,
219
+ 'num_inference_steps': 30,
220
+ 'controlnet_conditioning_scale': 1.1,
221
+ 'controlnet_conditioning_scale_decay': 0.96,
222
+ 'controlnet_steps': 28,
223
+ }
 
 
 
 
 
 
 
 
 
224
 
225
  grid_search = {'strength': [1.00, ],
226
  'guidance_scale': [7.0],
 
268
  """
269
 
270
  with catchtime("Get inpainting pipeline"):
271
+ pipe = get_inpainting_pipeline()
272
+
273
+ common_parameters = {'prompt': positive_prompt,
274
+ 'negative_prompt': negative_prompt,
275
+ 'num_inference_steps': 20,
276
+ }
 
 
 
 
 
 
277
 
278
  torch.cuda.empty_cache()
279
  images = []