JuanLozada97 commited on
Commit
af603ca
·
1 Parent(s): dc8e463

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +21 -5
  2. model.py +2 -5
app.py CHANGED
@@ -4,11 +4,10 @@ import torch
4
  import numpy as np
5
  import cv2
6
  import matplotlib.pyplot as plt
7
- import base64
8
- import json
9
-
10
- from segment_anything import sam_model_registry, SamPredictor
11
- from segment_anything.utils.onnx import SamOnnxModel
12
 
13
  import torch.nn.functional as F
14
 
@@ -23,6 +22,23 @@ model_type = "vit_b"
23
  medsam_model = create_sam_model(model_type,checkpoint,device)
24
 
25
  # 3.Predict fn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  @torch.no_grad()
27
  def medsam_inference(medsam_model, img_embed, box_1024, H, W):
28
  box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
 
4
  import numpy as np
5
  import cv2
6
  import matplotlib.pyplot as plt
7
+ from typing import Tuple, Dict
8
+ from timeit import default_timer as timer
9
+ from skimage import io, transform
10
+ import os
 
11
 
12
  import torch.nn.functional as F
13
 
 
22
  medsam_model = create_sam_model(model_type,checkpoint,device)
23
 
24
  # 3.Predict fn
25
+ def show_mask(mask, ax):
26
+ color = np.array([30/255, 144/255, 255/255, 0.6])
27
+ h, w = mask.shape[-2:]
28
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
29
+ ax.imshow(mask_image)
30
+
31
+ def show_points(coords, labels, ax, marker_size=375):
32
+ pos_points = coords[labels==1]
33
+ neg_points = coords[labels==0]
34
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
35
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
36
+
37
+ def show_box(box, ax):
38
+ x0, y0 = box[0], box[1]
39
+ w, h = box[2] - box[0], box[3] - box[1]
40
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
41
+
42
  @torch.no_grad()
43
  def medsam_inference(medsam_model, img_embed, box_1024, H, W):
44
  box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
model.py CHANGED
@@ -1,9 +1,6 @@
1
- import torch
2
- import numpy as np
3
- import cv2
4
  import matplotlib.pyplot as plt
5
- from segment_anything import sam_model_registry, SamPredictor
6
- from segment_anything.utils.onnx import SamOnnxModel
7
  import torch.nn.functional as F
8
 
9
  def create_sam_model(model_type, checkpoint, device: str = "cpu"):
 
1
+
 
 
2
  import matplotlib.pyplot as plt
3
+ from segment_anything import sam_model_registry
 
4
  import torch.nn.functional as F
5
 
6
  def create_sam_model(model_type, checkpoint, device: str = "cpu"):