cashtaan commited on
Commit
9f71ed7
1 Parent(s): 5e4890a

minimal version test

Browse files
Files changed (4) hide show
  1. .gitmodules +3 -0
  2. app.py +176 -0
  3. requirements.txt +7 -0
  4. src/ControlNetInpaint +1 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "src/ControlNetInpaint"]
2
+ path = src/ControlNetInpaint
3
+ url = https://github.com/mikonvergence/ControlNetInpaint.git
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this code is largely inspired by https://huggingface.co/spaces/hysts/ControlNet-with-Anything-v4/blob/main/app_scribble_interactive.py
2
+ # Thank you, hysts!
3
+
4
+ import sys
5
+ sys.path.append('./src/ControlNetInpaint/')
6
+ # functionality based on https://github.com/mikonvergence/ControlNetInpaint
7
+
8
+ import gradio as gr
9
+ #import torch
10
+ #from torch import autocast // only for GPU
11
+
12
+ from PIL import Image
13
+ import numpy as np
14
+ from io import BytesIO
15
+ import os
16
+
17
+ # Usage
18
+ # 1. Upload image or fill with white
19
+ # 2. Sketch the mask (image->[image,mask]
20
+ # 3. Sketch the content of the mask
21
+
22
+ # Global Storage
23
+ CURRENT_IMAGE={'image' : None,
24
+ 'mask' : None,
25
+ 'guide' : None
26
+ }
27
+
28
+ HEIGHT,WIDTH=512,512
29
+
30
+ ## SETUP PIPE
31
+
32
+ from diffusers import StableDiffusionInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
33
+ from src.pipeline_stable_diffusion_controlnet_inpaint import *
34
+ from diffusers.utils import load_image
35
+ from controlnet_aux import HEDdetector
36
+
37
+ hed = HEDdetector.from_pretrained('lllyasviel/ControlNet')
38
+
39
+ controlnet = ControlNetModel.from_pretrained(
40
+ "fusing/stable-diffusion-v1-5-controlnet-scribble", torch_dtype=torch.float16
41
+ )
42
+ pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
43
+ "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
44
+ )
45
+
46
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
47
+
48
+ # Remove if you do not have xformers installed
49
+ # see https://huggingface.co/docs/diffusers/v0.13.0/en/optimization/xformers#installing-xformers
50
+ # for installation instructions
51
+ pipe.enable_xformers_memory_efficient_attention()
52
+
53
+ pipe.to('cuda')
54
+
55
+ # Functions
56
+
57
+ def get_guide(image):
58
+ return hed(image,scribble=True)
59
+
60
+ def set_mask(image):
61
+ img=image['image'][...,:3]
62
+ mask=1*(image['mask'][...,:3]>0)
63
+ # save vars
64
+ CURRENT_IMAGE['image']=img
65
+ CURRENT_IMAGE['mask']=mask
66
+
67
+ guide=get_guide(img)
68
+ CURRENT_IMAGE['guide']=np.array(guide)
69
+ guide=255-np.asarray(guide)
70
+
71
+ seg_img = guide*(1-mask) + mask*192
72
+ preview = img * (seg_img==255)
73
+
74
+ vis_image=(preview/2).astype(seg_img.dtype) + seg_img * (seg_img!=255)
75
+
76
+ return vis_image
77
+
78
+ def generate(image,
79
+ prompt,
80
+ num_steps,
81
+ text_scale,
82
+ sketch_scale,
83
+ seed):
84
+
85
+ sketch=(255*(image['mask'][...,:3]>0)).astype(CURRENT_IMAGE['image'].dtype)
86
+ mask=CURRENT_IMAGE['mask']
87
+
88
+ CURRENT_IMAGE['guide']=(CURRENT_IMAGE['guide']*(mask==0) + sketch*(mask!=0)).astype(CURRENT_IMAGE['image'].dtype)
89
+
90
+ mask_img=255*CURRENT_IMAGE['mask'].astype(CURRENT_IMAGE['image'].dtype)
91
+
92
+ new_image = pipe(
93
+ prompt,
94
+ num_inference_steps=num_steps,
95
+ guidance_scale=text_scale,
96
+ generator=torch.manual_seed(seed),
97
+ image=Image.fromarray(CURRENT_IMAGE['image']),
98
+ control_image=Image.fromarray(CURRENT_IMAGE['guide']),
99
+ controlnet_conditioning_scale=sketch_scale,
100
+ mask_image=Image.fromarray(mask_img)
101
+ ).images
102
+
103
+ return new_image
104
+
105
+ def create_demo(max_images=12, default_num_images=3):
106
+
107
+ with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace","monospace"]),
108
+ css=".gradio-container {background-color: #f2faf2}"
109
+ ) as demo:
110
+ gr.Markdown('## Cut and Sketch ✂️▶️✏️')
111
+
112
+ prompt = gr.Textbox(label='Prompt')
113
+
114
+ with gr.Row():
115
+ with gr.Column():
116
+ with gr.Row():
117
+ input_image = gr.Image(source='upload',
118
+ shape=[HEIGHT,WIDTH],
119
+ type='numpy',
120
+ label='Mask Draw',
121
+ tool='sketch',
122
+ brush_radius=70)
123
+ sketch_image = gr.Image(source='upload',
124
+ shape=[HEIGHT,WIDTH],
125
+ type='numpy',
126
+ label='Fill Draw',
127
+ tool='sketch',
128
+ brush_radius=15)
129
+ with gr.Row():
130
+ mask_button = gr.Button(label='Set Mask', value='Set Mask')
131
+ run_button = gr.Button(label='Generate', value='Generate')
132
+ output_image = gr.Gallery(
133
+ label="Generated images",
134
+ show_label=False,
135
+ elem_id="gallery",
136
+ )
137
+
138
+ with gr.Accordion('Advanced options', open=False):
139
+ num_steps = gr.Slider(label='Steps',
140
+ minimum=1,
141
+ maximum=100,
142
+ value=20,
143
+ step=1)
144
+ text_scale = gr.Slider(label='Text Guidance Scale',
145
+ minimum=0.1,
146
+ maximum=30.0,
147
+ value=7.5,
148
+ step=0.1)
149
+ seed = gr.Slider(label='Seed',
150
+ minimum=-1,
151
+ maximum=2147483647,
152
+ step=1,
153
+ randomize=True)
154
+
155
+ sketch_scale = gr.Slider(label='Sketch Guidance Scale',
156
+ minimum=0.0,
157
+ maximum=1.0,
158
+ value=1.0,
159
+ step=0.05)
160
+
161
+ inputs = [
162
+ sketch_image,
163
+ prompt,
164
+ num_steps,
165
+ text_scale,
166
+ sketch_scale,
167
+ seed
168
+ ]
169
+
170
+ mask_button.click(fn=set_mask, inputs=input_image, outputs=sketch_image)
171
+ run_button.click(fn=generate, inputs=inputs, outputs=output_image)
172
+ return demo
173
+
174
+ if __name__ == '__main__':
175
+ demo = create_demo()
176
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ xformers
3
+ transformers
4
+ scipy
5
+ ftfy
6
+ accelerate
7
+ controlnet_aux
src/ControlNetInpaint ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit f33c386fe81d226dc50d4344c509288a1bcca7a2