mfidabel commited on
Commit
6447c95
1 Parent(s): 6162149

Added Segmentation to the Space

Browse files
app.py CHANGED
@@ -1,119 +1,161 @@
1
- import gradio as gr
2
- import jax
3
  from PIL import Image
4
- from flax.jax_utils import replicate
5
- from flax.training.common_utils import shard
6
- from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
7
- import jax.numpy as jnp
8
  import numpy as np
 
 
9
  import gc
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
13
- "mfidabel/controlnet-segment-anything", dtype=jnp.float32
14
- )
15
 
16
- pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
17
- "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
18
- )
 
 
 
 
19
 
20
- # Add ControlNet params and Replicate
21
- params["controlnet"] = controlnet_params
22
- p_params = replicate(params)
23
 
24
  # Description
25
  title = "# 🧨 ControlNet on Segment Anything 🤗"
26
  description = """This is a demo on 🧨 ControlNet based on Meta's [Segment Anything Model](https://segment-anything.com/).
27
 
28
- Upload a Segment Anything Segmentation Map, write a prompt, and generate images 🤗 This demo is still a Work in Progress, so don't expect it to work well for now !!
29
-
30
- ⌛️ It takes about 30~ seconds to generate 4 samples, to get faster results, don't forget to reduce the Nº Samples to 1.
31
-
32
  You can obtain the Segmentation Map of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mfidabel/JAX_SPRINT_2023/blob/main/Segment_Anything_JAX_SPRINT.ipynb)
33
-
34
 
35
- A huge thanks goes out to @Google Cloud, for providing us with powerful TPUs that enabled us to train this model; and to the @HuggingFace Team for organizing the sprint.
36
-
37
  Check out our [Model Card 🧨](https://huggingface.co/mfidabel/controlnet-segment-anything)
 
38
  """
39
 
40
  about = """
41
-
42
-
43
  # 👨‍💻 About the model
44
-
45
  This [model](https://huggingface.co/mfidabel/controlnet-segment-anything) is based on the [ControlNet Model](https://huggingface.co/blog/controlnet), which allow us to generate Images using some sort of condition image. For this model, we selected the segmentation maps produced by Meta's new segmentation model called [Segment Anything Model](https://github.com/facebookresearch/segment-anything) as the condition image. We then trained the model to generate images based on the structure of the segmentation maps and the text prompts given.
 
46
 
47
-
48
  # 💾 About the dataset
49
-
50
  For the training, we generated a segmented dataset based on the [COYO-700M](https://huggingface.co/datasets/kakaobrain/coyo-700m) dataset. The dataset provided us with the images, and the text prompts. For the segmented images, we used [Segment Anything Model](https://github.com/facebookresearch/segment-anything). We then created 8k samples to train our model on, which isn't a lot, but as a team, we have been very busy with many other responsibilities and time constraints, which made it challenging to dedicate a lot of time to generating a larger dataset. Despite the constraints we faced, we have still managed to achieve some nice results 🙌
51
-
52
  You can check the generated datasets below ⬇️
53
  - [sam-coyo-2k](https://huggingface.co/datasets/mfidabel/sam-coyo-2k)
54
  - [sam-coyo-2.5k](https://huggingface.co/datasets/mfidabel/sam-coyo-2.5k)
55
  - [sam-coyo-3k](https://huggingface.co/datasets/mfidabel/sam-coyo-3k)
56
-
57
  """
58
 
59
- examples = [["contemporary living room of a house", "low quality", "examples/condition_image_1.png"],
60
- ["new york buildings, Vincent Van Gogh starry night ", "low quality, monochrome", "examples/condition_image_2.png"],
61
- ["contemporary living room, high quality, 4k, realistic", "low quality, monochrome, low res", "examples/condition_image_3.png"],
62
- ["internal stairs of a japanese house", "low quality, low res, people, kids", "examples/condition_image_4.png"],
63
- ["a photo of a girl taking notes", "low quality, low res, painting", "examples/condition_image_5.png"],
64
- ["painting of an hot air ballon flying over a valley, The Great Wave off Kanagawa style, blue and white colors", "low quality, low res", "examples/condition_image_6.png"],
65
- ["painting of families enjoying the sunset, The Garden of Earthly Delights style, joyful", "low quality, low res", "examples/condition_image_7.png"]]
 
 
 
 
 
 
 
66
 
67
  css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
68
 
69
  # Inference Function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def infer(prompts, negative_prompts, image, num_inference_steps = 50, seed = 4, num_samples = 4):
71
  try:
72
- rng = jax.random.PRNGKey(int(seed))
 
 
 
 
 
73
  num_inference_steps = int(num_inference_steps)
74
- image = Image.fromarray(image, mode="RGB")
75
- num_samples = max(jax.device_count(), int(num_samples))
76
- p_rng = jax.random.split(rng, jax.device_count())
77
-
78
- prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
79
- negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
80
- processed_image = pipe.prepare_image_inputs([image] * num_samples)
81
-
82
- prompt_ids = shard(prompt_ids)
83
- negative_prompt_ids = shard(negative_prompt_ids)
84
- processed_image = shard(processed_image)
85
 
86
- output = pipe(
87
- prompt_ids=prompt_ids,
88
- image=processed_image,
89
- params=p_params,
90
- prng_seed=p_rng,
91
- num_inference_steps=num_inference_steps,
92
- neg_prompt_ids=negative_prompt_ids,
93
- jit=True,
94
- ).images
95
-
96
- del negative_prompt_ids
97
- del processed_image
98
- del prompt_ids
99
-
100
- output = output.reshape((num_samples,) + output.shape[-3:])
101
- final_image = [np.array(x*255, dtype=np.uint8) for x in output]
102
- print(output.shape)
103
  del output
104
 
105
  except Exception as e:
106
  print("Error: " + str(e))
107
- final_image = [np.zeros((512, 512, 3), dtype=np.uint8)] * num_samples
108
  finally:
109
  gc.collect()
110
- return final_image
 
111
 
112
 
113
- default_example = examples[5]
114
-
115
  cond_img = gr.Image(label="Input", shape=(512, 512), value=default_example[2])\
116
- .style(height=200)
 
 
 
117
 
118
  output = gr.Gallery(label="Generated images")\
119
  .style(height=200, rows=[2], columns=[2], object_fit="contain")
@@ -132,17 +174,16 @@ with gr.Blocks(css=css) as demo:
132
 
133
  with gr.Column():
134
  # Examples
135
- gr.Markdown("Try some of the examples below ⬇️")
136
- gr.Examples(examples=examples,
137
- inputs=[prompt, negative_prompt, cond_img],
138
- outputs=output,
139
- fn=infer,
140
- examples_per_page=4)
141
 
142
  # Images
143
  with gr.Row(variant="panel"):
144
- with gr.Column(scale=2):
145
  cond_img.render()
 
 
 
 
146
  with gr.Column(scale=1):
147
  output.render()
148
 
@@ -158,15 +199,29 @@ with gr.Blocks(css=css) as demo:
158
  seed = gr.Slider(0, 1024, 4, step=1, label="Seed")
159
  num_samples = gr.Slider(1, 4, 4, step=1, label="Nº Samples")
160
 
161
- submit = gr.Button("Generate")
 
162
  # TODO: Download Button
163
 
164
  with gr.Row():
165
- gr.Markdown(about, elem_classes="about")
 
 
 
 
 
 
 
 
 
166
 
167
  submit.click(infer,
168
  inputs=[prompt, negative_prompt, cond_img, num_steps, seed, num_samples],
169
- outputs = output)
 
 
 
 
170
 
171
  demo.queue()
172
  demo.launch()
 
1
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
2
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
3
  from PIL import Image
4
+ import gradio as gr
 
 
 
5
  import numpy as np
6
+ import requests
7
+ import torch
8
  import gc
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ # Download and Create SAM Model
13
+
14
+ print("[Downloading SAM Weights]")
15
+ SAM_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
16
+
17
+ r = requests.get(SAM_URL, allow_redirects=True)
18
+
19
+ print("[Writing SAM Weights]")
20
+
21
+ with open("./sam_vit_h_4b8939.pth", "wb") as sam_weights:
22
+ sam_weights.write(r.content)
23
+
24
+ del r
25
+ gc.collect()
26
+
27
+ sam = sam_model_registry["vit_h"](checkpoint="./sam_vit_h_4b8939.pth").to(device)
28
+
29
+ mask_generator = SamAutomaticMaskGenerator(sam)
30
+ gc.collect()
31
+
32
+ # Create ControlNet Pipeline
33
 
34
+ print("Creating ControlNet Pipeline")
 
 
35
 
36
+ controlnet = ControlNetModel.from_pretrained(
37
+ "mfidabel/controlnet-segment-anything", torch_dtype=torch.float16
38
+ ).to(device)
39
+
40
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
41
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, safety_check=None
42
+ ).to(device)
43
 
 
 
 
44
 
45
  # Description
46
  title = "# 🧨 ControlNet on Segment Anything 🤗"
47
  description = """This is a demo on 🧨 ControlNet based on Meta's [Segment Anything Model](https://segment-anything.com/).
48
 
49
+ Upload an Image, Segment it with Segment Anything, write a prompt, and generate images 🤗
50
+
51
+ ⌛️ It takes about 20~ seconds to generate 4 samples, to get faster results, don't forget to reduce the Nº Samples to 1.
52
+
53
  You can obtain the Segmentation Map of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mfidabel/JAX_SPRINT_2023/blob/main/Segment_Anything_JAX_SPRINT.ipynb)
 
54
 
55
+ A huge thanks goes out to @GoogleCloud, for providing us with powerful TPUs that enabled us to train this model; and to the @HuggingFace Team for organizing the sprint.
56
+
57
  Check out our [Model Card 🧨](https://huggingface.co/mfidabel/controlnet-segment-anything)
58
+
59
  """
60
 
61
  about = """
 
 
62
  # 👨‍💻 About the model
63
+
64
  This [model](https://huggingface.co/mfidabel/controlnet-segment-anything) is based on the [ControlNet Model](https://huggingface.co/blog/controlnet), which allow us to generate Images using some sort of condition image. For this model, we selected the segmentation maps produced by Meta's new segmentation model called [Segment Anything Model](https://github.com/facebookresearch/segment-anything) as the condition image. We then trained the model to generate images based on the structure of the segmentation maps and the text prompts given.
65
+
66
 
67
+
68
  # 💾 About the dataset
69
+
70
  For the training, we generated a segmented dataset based on the [COYO-700M](https://huggingface.co/datasets/kakaobrain/coyo-700m) dataset. The dataset provided us with the images, and the text prompts. For the segmented images, we used [Segment Anything Model](https://github.com/facebookresearch/segment-anything). We then created 8k samples to train our model on, which isn't a lot, but as a team, we have been very busy with many other responsibilities and time constraints, which made it challenging to dedicate a lot of time to generating a larger dataset. Despite the constraints we faced, we have still managed to achieve some nice results 🙌
71
+
72
  You can check the generated datasets below ⬇️
73
  - [sam-coyo-2k](https://huggingface.co/datasets/mfidabel/sam-coyo-2k)
74
  - [sam-coyo-2.5k](https://huggingface.co/datasets/mfidabel/sam-coyo-2.5k)
75
  - [sam-coyo-3k](https://huggingface.co/datasets/mfidabel/sam-coyo-3k)
76
+
77
  """
78
 
79
+ gif_html = """ <img src="https://github.com/mfidabel/JAX_SPRINT_2023/blob/8632f0fde7388d7a4fc57225c96ef3b8411b3648/EX_1.gif?raw=true" alt= “” height="50%" class="about"> """
80
+
81
+ examples = [["photo of a futuristic dining table, high quality, tricolor", "low quality, deformed, blurry, points", "examples/condition_image_1.jpeg"],
82
+ ["a monochrome photo of henry cavil using a shirt, high quality", "low quality, low res, deformed", "examples/condition_image_2.jpeg"],
83
+ ["photo of a japanese living room, high quality, coherent", "low quality, colors, saturation, extreme brightness, blurry, low res", "examples/condition_image_3.jpeg"],
84
+ ["living room, detailed, high quality", "low quality, low resolution, render, oversaturated, low contrast", "examples/condition_image_4.jpeg"],
85
+ ["painting of the bodiam castle, Vicent Van Gogh style, Starry Night", "low quality, low resolution, render, oversaturated, low contrast", "examples/condition_image_5.jpeg"],
86
+ ["painting of food, olive oil can, purple wine, green cabbage, chili peppers, pablo picasso style, high quality", "low quality, low resolution, render, oversaturated, low contrast, realistic", "examples/condition_image_6.jpeg"],
87
+ ["Katsushika Hokusai painting of mountains, a sky and desert landscape, The Great Wave off Kanagawa style, colorful",
88
+ "low quality, low resolution, render, oversaturated, low contrast, realistic", "examples/condition_image_7.jpeg"]]
89
+
90
+ default_example = examples[4]
91
+
92
+ examples = examples[::-1]
93
 
94
  css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
95
 
96
  # Inference Function
97
+ def show_anns(anns):
98
+ if len(anns) == 0:
99
+ return
100
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
101
+ h, w = anns[0]['segmentation'].shape
102
+ final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
103
+ for ann in sorted_anns:
104
+ m = ann['segmentation']
105
+ img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
106
+ for i in range(3):
107
+ img[:,:,i] = np.random.randint(255, dtype=np.uint8)
108
+ final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m*255)))
109
+
110
+ return final_img
111
+
112
+ def segment_image(image, seed = 0):
113
+ # Generate Masks
114
+ np.random.seed(int(seed))
115
+ masks = mask_generator.generate(image)
116
+ torch.cuda.empty_cache()
117
+ # Create map
118
+ map = show_anns(masks)
119
+ del masks
120
+ gc.collect()
121
+ torch.cuda.empty_cache()
122
+ return map
123
+
124
  def infer(prompts, negative_prompts, image, num_inference_steps = 50, seed = 4, num_samples = 4):
125
  try:
126
+ # Segment Image
127
+ print("Segmenting Everything")
128
+ segmented_map = segment_image(image, seed)
129
+ yield segmented_map, [Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))] * num_samples
130
+ # Generate
131
+ rng = torch.Generator(device="cpu").manual_seed(seed)
132
  num_inference_steps = int(num_inference_steps)
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ print(f"Generating Prompt: {prompts} \nNegative Prompt: {negative_prompts} \nSamples:{num_samples}")
135
+ output = pipe([prompts] * num_samples,
136
+ [segmented_map] * num_samples,
137
+ negative_prompt = [negative_prompts] * num_samples,
138
+ generator = rng,
139
+ num_inference_steps = num_inference_steps)
140
+
141
+
142
+ final_image = output.images
 
 
 
 
 
 
 
 
143
  del output
144
 
145
  except Exception as e:
146
  print("Error: " + str(e))
147
+ final_image = segmented_map = [np.zeros((512, 512, 3), dtype=np.uint8)] * num_samples
148
  finally:
149
  gc.collect()
150
+ torch.cuda.empty_cache()
151
+ yield segmented_map, final_image
152
 
153
 
 
 
154
  cond_img = gr.Image(label="Input", shape=(512, 512), value=default_example[2])\
155
+ .style(height=400)
156
+
157
+ segm_img = gr.Image(label="Segmented Image", shape=(512, 512), interactive=False)\
158
+ .style(height=400)
159
 
160
  output = gr.Gallery(label="Generated images")\
161
  .style(height=200, rows=[2], columns=[2], object_fit="contain")
 
174
 
175
  with gr.Column():
176
  # Examples
177
+ gr.Markdown(gif_html)
 
 
 
 
 
178
 
179
  # Images
180
  with gr.Row(variant="panel"):
181
+ with gr.Column(scale=1):
182
  cond_img.render()
183
+
184
+ with gr.Column(scale=1):
185
+ segm_img.render()
186
+
187
  with gr.Column(scale=1):
188
  output.render()
189
 
 
199
  seed = gr.Slider(0, 1024, 4, step=1, label="Seed")
200
  num_samples = gr.Slider(1, 4, 4, step=1, label="Nº Samples")
201
 
202
+ segment_btn = gr.Button("Segment")
203
+ submit = gr.Button("Segment & Generate Images")
204
  # TODO: Download Button
205
 
206
  with gr.Row():
207
+ with gr.Column():
208
+ gr.Markdown("Try some of the examples below ⬇️")
209
+ gr.Examples(examples=examples,
210
+ inputs=[prompt, negative_prompt, cond_img],
211
+ outputs=output,
212
+ fn=infer,
213
+ examples_per_page=4)
214
+
215
+ with gr.Column():
216
+ gr.Markdown(about, elem_classes="about")
217
 
218
  submit.click(infer,
219
  inputs=[prompt, negative_prompt, cond_img, num_steps, seed, num_samples],
220
+ outputs = [segm_img, output])
221
+
222
+ segment_btn.click(segment_image,
223
+ inputs=[cond_img, seed],
224
+ outputs=segm_img)
225
 
226
  demo.queue()
227
  demo.launch()
examples/condition_image_1.jpeg ADDED
examples/condition_image_2.jpeg ADDED
examples/condition_image_3.jpeg ADDED
examples/condition_image_4.jpeg ADDED
examples/condition_image_5.jpeg ADDED
examples/condition_image_6.jpeg ADDED
examples/condition_image_7.jpeg ADDED
requirements.txt CHANGED
@@ -5,4 +5,7 @@ jax[cuda11_pip]
5
  jaxlib
6
  git+https://github.com/huggingface/diffusers@main
7
  opencv-python
8
- torch
 
 
 
 
5
  jaxlib
6
  git+https://github.com/huggingface/diffusers@main
7
  opencv-python
8
+ torch
9
+ torchvision
10
+ git+https://github.com/facebookresearch/segment-anything.git
11
+ accelerate