abhishek HF staff commited on
Commit
ec08a0d
1 Parent(s): 2c60cd7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from diffusers import StableDiffusionInpaintPipeline
5
+ from PIL import Image
6
+ from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
7
+ from diffusers import ControlNetModel
8
+ from diffusers import UniPCMultistepScheduler
9
+ from controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
10
+ import colorsys
11
+
12
+ sam_checkpoint = "weights/sam_vit_h_4b8939.pth"
13
+ model_type = "vit_h"
14
+ device = "cuda"
15
+
16
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
17
+ sam.to(device=device)
18
+ predictor = SamPredictor(sam)
19
+ mask_generator = SamAutomaticMaskGenerator(sam)
20
+
21
+ # pipe = StableDiffusionInpaintPipeline.from_pretrained(
22
+ # "stabilityai/stable-diffusion-2-inpainting",
23
+ # torch_dtype=torch.float16,
24
+ # )
25
+ # pipe = pipe.to("cuda")
26
+
27
+ controlnet = ControlNetModel.from_pretrained(
28
+ "lllyasviel/sd-controlnet-seg",
29
+ torch_dtype=torch.float16,
30
+ )
31
+ pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
32
+ "runwayml/stable-diffusion-inpainting",
33
+ controlnet=controlnet,
34
+ torch_dtype=torch.float16,
35
+ )
36
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
37
+ pipe.enable_model_cpu_offload()
38
+ pipe.enable_xformers_memory_efficient_attention()
39
+
40
+
41
+ with gr.Blocks() as demo:
42
+ selected_pixels = gr.State([])
43
+ with gr.Row():
44
+ input_img = gr.Image(label="Input")
45
+ mask_img = gr.Image(label="Mask")
46
+ seg_img = gr.Image(label="Segmentation")
47
+ output_img = gr.Image(label="Output")
48
+
49
+ with gr.Row():
50
+ prompt_text = gr.Textbox(lines=1, label="Prompt")
51
+ negative_prompt_text = gr.Textbox(lines=1, label="Negative Prompt")
52
+ is_background = gr.Checkbox(label="Background")
53
+
54
+ with gr.Row():
55
+ submit = gr.Button("Submit")
56
+ clear = gr.Button("Clear")
57
+
58
+ def generate_mask(image, bg, sel_pix, evt: gr.SelectData):
59
+ sel_pix.append(evt.index)
60
+ predictor.set_image(image)
61
+ input_point = np.array(sel_pix)
62
+ input_label = np.ones(input_point.shape[0])
63
+ mask, _, _ = predictor.predict(
64
+ point_coords=input_point,
65
+ point_labels=input_label,
66
+ multimask_output=False,
67
+ )
68
+ if bg:
69
+ mask = np.logical_not(mask)
70
+ mask = Image.fromarray(mask[0, :, :])
71
+ segs = mask_generator.generate(image)
72
+ boolean_masks = [s["segmentation"] for s in segs]
73
+ finseg = np.zeros((boolean_masks[0].shape[0], boolean_masks[0].shape[1], 3), dtype=np.uint8)
74
+ # Loop over the boolean masks and assign a unique color to each class
75
+ for class_id, boolean_mask in enumerate(boolean_masks):
76
+ hue = class_id * 1.0 / len(boolean_masks)
77
+ rgb = tuple(int(i * 255) for i in colorsys.hsv_to_rgb(hue, 1, 1))
78
+ rgb_mask = np.zeros((boolean_mask.shape[0], boolean_mask.shape[1], 3), dtype=np.uint8)
79
+ rgb_mask[:, :, 0] = boolean_mask * rgb[0]
80
+ rgb_mask[:, :, 1] = boolean_mask * rgb[1]
81
+ rgb_mask[:, :, 2] = boolean_mask * rgb[2]
82
+ finseg += rgb_mask
83
+
84
+ return mask, finseg
85
+
86
+ def inpaint(image, mask, seg_img, prompt, negative_prompt):
87
+ image = Image.fromarray(image)
88
+ mask = Image.fromarray(mask)
89
+ seg_img = Image.fromarray(seg_img)
90
+
91
+ image = image.resize((512, 512))
92
+ mask = mask.resize((512, 512))
93
+ seg_img = seg_img.resize((512, 512))
94
+
95
+ output = pipe(prompt, image, mask, seg_img, negative_prompt=negative_prompt).images[0]
96
+ return output
97
+
98
+ def _clear(sel_pix, img, mask, seg, out, prompt, neg_prompt, bg):
99
+ sel_pix = []
100
+ img = None
101
+ mask = None
102
+ seg = None
103
+ out = None
104
+ prompt = ""
105
+ neg_prompt = ""
106
+ bg = False
107
+ return img, mask, seg, out, prompt, neg_prompt, bg
108
+
109
+ input_img.select(
110
+ generate_mask,
111
+ [input_img, is_background, selected_pixels],
112
+ [mask_img, seg_img],
113
+ )
114
+ submit.click(
115
+ inpaint,
116
+ inputs=[input_img, mask_img, seg_img, prompt_text, negative_prompt_text],
117
+ outputs=[output_img],
118
+ )
119
+ clear.click(
120
+ _clear,
121
+ inputs=[
122
+ selected_pixels,
123
+ input_img,
124
+ mask_img,
125
+ seg_img,
126
+ output_img,
127
+ prompt_text,
128
+ negative_prompt_text,
129
+ is_background,
130
+ ],
131
+ outputs=[
132
+ input_img,
133
+ mask_img,
134
+ seg_img,
135
+ output_img,
136
+ prompt_text,
137
+ negative_prompt_text,
138
+ is_background,
139
+ ],
140
+ )
141
+
142
+ if __name__ == "__main__":
143
+ demo.queue(concurrency_count=50).launch()