Johannes commited on
Commit
54c5ead
1 Parent(s): a5f6978

test transformers sam

Browse files
Files changed (2) hide show
  1. app.py +28 -17
  2. requirements.txt +1 -1
app.py CHANGED
@@ -17,14 +17,16 @@ import colorsys
17
 
18
  sam_checkpoint = "sam_vit_h_4b8939.pth"
19
  model_type = "vit_h"
20
- device = "cpu"
21
 
22
 
23
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
24
- sam.to(device=device)
25
- predictor = SamPredictor(sam)
26
- mask_generator = SamAutomaticMaskGenerator(sam)
27
 
 
 
28
 
29
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
30
  "SAMControlNet/sd-controlnet-sam-seg", dtype=jnp.float32
@@ -70,18 +72,27 @@ with gr.Blocks() as demo:
70
  clear = gr.Button("Clear")
71
 
72
  def generate_mask(image):
73
- predictor.set_image(image)
74
- input_point = np.array([120, 21])
75
- input_label = np.ones(input_point.shape[0])
76
- mask, _, _ = predictor.predict(
77
- point_coords=input_point,
78
- point_labels=input_label,
79
- multimask_output=False,
80
- )
 
 
 
 
 
 
 
 
 
81
 
82
  # clear torch cache
83
- torch.cuda.empty_cache()
84
- mask = Image.fromarray(mask[0, :, :])
85
  # segs = mask_generator.generate(image)
86
  # boolean_masks = [s["segmentation"] for s in segs]
87
  # finseg = np.zeros(
@@ -99,9 +110,9 @@ with gr.Blocks() as demo:
99
  # rgb_mask[:, :, 2] = boolean_mask * rgb[2]
100
  # finseg += rgb_mask
101
 
102
- torch.cuda.empty_cache()
103
 
104
- return mask
105
 
106
  def infer(
107
  image, prompts, negative_prompts, num_inference_steps=50, seed=4, num_samples=4
 
17
 
18
  sam_checkpoint = "sam_vit_h_4b8939.pth"
19
  model_type = "vit_h"
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
 
23
+ #sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
24
+ #sam.to(device=device)
25
+ #predictor = SamPredictor(sam)
26
+ #mask_generator = SamAutomaticMaskGenerator(sam)
27
 
28
+ generator = pipeline(model="facebook/sam-vit-base", task="mask-generation", points_per_batch=256)
29
+ #image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
30
 
31
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
32
  "SAMControlNet/sd-controlnet-sam-seg", dtype=jnp.float32
 
72
  clear = gr.Button("Clear")
73
 
74
  def generate_mask(image):
75
+ outputs = generator(image, points_per_batch=256)
76
+
77
+ for mask in outputs["masks"]:
78
+ color = np.concatenate([np.random.random(3), np.array([1.0])], axis=0)
79
+ h, w = mask.shape[-2:]
80
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
81
+
82
+ return mask_image
83
+
84
+ # predictor.set_image(image)
85
+ # input_point = np.array([120, 21])
86
+ # input_label = np.ones(input_point.shape[0])
87
+ # mask, _, _ = predictor.predict(
88
+ # point_coords=input_point,
89
+ # point_labels=input_label,
90
+ # multimask_output=False,
91
+ # )
92
 
93
  # clear torch cache
94
+ # torch.cuda.empty_cache()
95
+ # mask = Image.fromarray(mask[0, :, :])
96
  # segs = mask_generator.generate(image)
97
  # boolean_masks = [s["segmentation"] for s in segs]
98
  # finseg = np.zeros(
 
110
  # rgb_mask[:, :, 2] = boolean_mask * rgb[2]
111
  # finseg += rgb_mask
112
 
113
+ # torch.cuda.empty_cache()
114
 
115
+ # return mask
116
 
117
  def infer(
118
  image, prompts, negative_prompts, num_inference_steps=50, seed=4, num_samples=4
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  torch
2
  torchvision
3
  git+https://github.com/facebookresearch/segment-anything.git
4
- transformers
5
  flax
6
  jax[cuda11_pip]
7
  -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
 
1
  torch
2
  torchvision
3
  git+https://github.com/facebookresearch/segment-anything.git
4
+ git+https://github.com/huggingface/transformers@main
5
  flax
6
  jax[cuda11_pip]
7
  -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html