Johannes commited on
Commit
a5f6978
1 Parent(s): 1a93fb5

update generate mask method

Browse files
Files changed (1) hide show
  1. app.py +19 -19
app.py CHANGED
@@ -8,10 +8,10 @@ from flax.training.common_utils import shard
8
  from PIL import Image
9
  from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
10
  from diffusers import (
11
- UniPCMultistepScheduler,
12
  FlaxStableDiffusionControlNetPipeline,
13
  FlaxControlNetModel,
14
  )
 
15
 
16
  import colorsys
17
 
@@ -69,7 +69,7 @@ with gr.Blocks() as demo:
69
  submit = gr.Button("Submit")
70
  clear = gr.Button("Clear")
71
 
72
- def generate_mask(image, evt: gr.SelectData):
73
  predictor.set_image(image)
74
  input_point = np.array([120, 21])
75
  input_label = np.ones(input_point.shape[0])
@@ -82,26 +82,26 @@ with gr.Blocks() as demo:
82
  # clear torch cache
83
  torch.cuda.empty_cache()
84
  mask = Image.fromarray(mask[0, :, :])
85
- segs = mask_generator.generate(image)
86
- boolean_masks = [s["segmentation"] for s in segs]
87
- finseg = np.zeros(
88
- (boolean_masks[0].shape[0], boolean_masks[0].shape[1], 3), dtype=np.uint8
89
- )
90
- # Loop over the boolean masks and assign a unique color to each class
91
- for class_id, boolean_mask in enumerate(boolean_masks):
92
- hue = class_id * 1.0 / len(boolean_masks)
93
- rgb = tuple(int(i * 255) for i in colorsys.hsv_to_rgb(hue, 1, 1))
94
- rgb_mask = np.zeros(
95
- (boolean_mask.shape[0], boolean_mask.shape[1], 3), dtype=np.uint8
96
- )
97
- rgb_mask[:, :, 0] = boolean_mask * rgb[0]
98
- rgb_mask[:, :, 1] = boolean_mask * rgb[1]
99
- rgb_mask[:, :, 2] = boolean_mask * rgb[2]
100
- finseg += rgb_mask
101
 
102
  torch.cuda.empty_cache()
103
 
104
- return mask, finseg
105
 
106
  def infer(
107
  image, prompts, negative_prompts, num_inference_steps=50, seed=4, num_samples=4
 
8
  from PIL import Image
9
  from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
10
  from diffusers import (
 
11
  FlaxStableDiffusionControlNetPipeline,
12
  FlaxControlNetModel,
13
  )
14
+ from transformers import pipeline
15
 
16
  import colorsys
17
 
 
69
  submit = gr.Button("Submit")
70
  clear = gr.Button("Clear")
71
 
72
+ def generate_mask(image):
73
  predictor.set_image(image)
74
  input_point = np.array([120, 21])
75
  input_label = np.ones(input_point.shape[0])
 
82
  # clear torch cache
83
  torch.cuda.empty_cache()
84
  mask = Image.fromarray(mask[0, :, :])
85
+ # segs = mask_generator.generate(image)
86
+ # boolean_masks = [s["segmentation"] for s in segs]
87
+ # finseg = np.zeros(
88
+ # (boolean_masks[0].shape[0], boolean_masks[0].shape[1], 3), dtype=np.uint8
89
+ # )
90
+ # # Loop over the boolean masks and assign a unique color to each class
91
+ # for class_id, boolean_mask in enumerate(boolean_masks):
92
+ # hue = class_id * 1.0 / len(boolean_masks)
93
+ # rgb = tuple(int(i * 255) for i in colorsys.hsv_to_rgb(hue, 1, 1))
94
+ # rgb_mask = np.zeros(
95
+ # (boolean_mask.shape[0], boolean_mask.shape[1], 3), dtype=np.uint8
96
+ # )
97
+ # rgb_mask[:, :, 0] = boolean_mask * rgb[0]
98
+ # rgb_mask[:, :, 1] = boolean_mask * rgb[1]
99
+ # rgb_mask[:, :, 2] = boolean_mask * rgb[2]
100
+ # finseg += rgb_mask
101
 
102
  torch.cuda.empty_cache()
103
 
104
+ return mask
105
 
106
  def infer(
107
  image, prompts, negative_prompts, num_inference_steps=50, seed=4, num_samples=4