BertChristiaens
commited on
Commit
·
32a644b
1
Parent(s):
ef697d2
flush
Browse files
app.py
CHANGED
@@ -182,11 +182,6 @@ def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, pain
|
|
182 |
|
183 |
|
184 |
elif generation_mode == "Re-generate objects":
|
185 |
-
st.write("This mode allows you to choose which objects you want to re-generate in the image. "
|
186 |
-
"Use the selection dropdown to add or remove objects. If you are ready, press the generate button"
|
187 |
-
" to generate the image, which can take up to 30 seconds. If you want to improve the generated image, click"
|
188 |
-
" the 'move image to input' button."
|
189 |
-
)
|
190 |
canvas = st_canvas(
|
191 |
**canvas_dict,
|
192 |
)
|
@@ -209,6 +204,12 @@ def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, pain
|
|
209 |
default=st.session_state['unique_colors'],
|
210 |
format_func=map_colors_rgb,
|
211 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
if st.button("generate image", key='generate_button'):
|
214 |
image = get_image()
|
|
|
182 |
|
183 |
|
184 |
elif generation_mode == "Re-generate objects":
|
|
|
|
|
|
|
|
|
|
|
185 |
canvas = st_canvas(
|
186 |
**canvas_dict,
|
187 |
)
|
|
|
204 |
default=st.session_state['unique_colors'],
|
205 |
format_func=map_colors_rgb,
|
206 |
)
|
207 |
+
with st.expander("Explanation", expanded=False):
|
208 |
+
st.write("This mode allows you to choose which objects you want to re-generate in the image. "
|
209 |
+
"Use the selection dropdown to add or remove objects. If you are ready, press the generate button"
|
210 |
+
" to generate the image, which can take up to 30 seconds. If you want to improve the generated image, click"
|
211 |
+
" the 'move image to input' button."
|
212 |
+
)
|
213 |
|
214 |
if st.button("generate image", key='generate_button'):
|
215 |
image = get_image()
|
models.py
CHANGED
@@ -4,6 +4,7 @@ from typing import List, Tuple, Dict
|
|
4 |
|
5 |
import streamlit as st
|
6 |
import torch
|
|
|
7 |
import time
|
8 |
import numpy as np
|
9 |
from PIL import Image
|
@@ -23,22 +24,26 @@ from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNe
|
|
23 |
|
24 |
LOGGING = logging.getLogger(__name__)
|
25 |
|
|
|
|
|
|
|
26 |
|
27 |
class ControlNetPipeline:
|
28 |
def __init__(self):
|
29 |
self.in_use = False
|
30 |
self.controlnet = ControlNetModel.from_pretrained(
|
31 |
-
"BertChristiaens/controlnet-seg-room", torch_dtype=torch.
|
32 |
|
33 |
self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
|
34 |
"runwayml/stable-diffusion-inpainting",
|
35 |
controlnet=self.controlnet,
|
36 |
safety_checker=None,
|
37 |
-
torch_dtype=torch.
|
38 |
)
|
39 |
|
40 |
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
41 |
self.pipe.enable_xformers_memory_efficient_attention()
|
|
|
42 |
self.pipe = self.pipe.to("cuda")
|
43 |
|
44 |
self.waiting_queue = []
|
@@ -59,8 +64,10 @@ class ControlNetPipeline:
|
|
59 |
# it's your turn, so remove the number from the queue
|
60 |
# and call the function
|
61 |
print("It's the turn of", self.count)
|
|
|
62 |
self.waiting_queue.pop(0)
|
63 |
-
|
|
|
64 |
|
65 |
|
66 |
@contextmanager
|
|
|
4 |
|
5 |
import streamlit as st
|
6 |
import torch
|
7 |
+
import gc
|
8 |
import time
|
9 |
import numpy as np
|
10 |
from PIL import Image
|
|
|
24 |
|
25 |
LOGGING = logging.getLogger(__name__)
|
26 |
|
27 |
+
def flush():
|
28 |
+
gc.collect()
|
29 |
+
torch.cuda.empty_cache()
|
30 |
|
31 |
class ControlNetPipeline:
|
32 |
def __init__(self):
|
33 |
self.in_use = False
|
34 |
self.controlnet = ControlNetModel.from_pretrained(
|
35 |
+
"BertChristiaens/controlnet-seg-room", torch_dtype=torch.float32)
|
36 |
|
37 |
self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
|
38 |
"runwayml/stable-diffusion-inpainting",
|
39 |
controlnet=self.controlnet,
|
40 |
safety_checker=None,
|
41 |
+
torch_dtype=torch.float32
|
42 |
)
|
43 |
|
44 |
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
45 |
self.pipe.enable_xformers_memory_efficient_attention()
|
46 |
+
self.pipe.enable_attention_slicing("max")
|
47 |
self.pipe = self.pipe.to("cuda")
|
48 |
|
49 |
self.waiting_queue = []
|
|
|
64 |
# it's your turn, so remove the number from the queue
|
65 |
# and call the function
|
66 |
print("It's the turn of", self.count)
|
67 |
+
results = self.pipe(**kwargs)
|
68 |
self.waiting_queue.pop(0)
|
69 |
+
flush()
|
70 |
+
return results
|
71 |
|
72 |
|
73 |
@contextmanager
|