JuanLozada97 commited on
Commit
c6ccb48
1 Parent(s): d1bc201

first commit

Browse files
Files changed (6) hide show
  1. app.py +133 -0
  2. examples/img_demo.png +0 -0
  3. examples/truck.jpg +0 -0
  4. medsam_vit_b.pth +3 -0
  5. model.py +12 -0
  6. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
+ import base64
8
+ import json
9
+
10
+ from segment_anything import sam_model_registry, SamPredictor
11
+ from segment_anything.utils.onnx import SamOnnxModel
12
+
13
+ import torch.nn.functional as F
14
+
15
+ from model import create_sam_model
16
+
17
+ # 1.Setup variables
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ checkpoint = "medsam_vit_b.pth"
20
+ model_type = "vit_b"
21
+
22
+ # 2.Model preparation and load save weights
23
+ medsam_model = create_sam_model(model_type,checkpoint,device)
24
+
25
+ # 3.Predict fn
26
+ @torch.no_grad()
27
+ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
28
+ box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
29
+ if len(box_torch.shape) == 2:
30
+ box_torch = box_torch[:, None, :] # (B, 1, 4)
31
+
32
+ sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
33
+ points=None,
34
+ boxes=box_torch,
35
+ masks=None,
36
+ )
37
+ low_res_logits, _ = medsam_model.mask_decoder(
38
+ image_embeddings=img_embed, # (B, 256, 64, 64)
39
+ image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
40
+ sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
41
+ dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
42
+ multimask_output=False,
43
+ )
44
+
45
+ low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
46
+
47
+ low_res_pred = F.interpolate(
48
+ low_res_pred,
49
+ size=(H, W),
50
+ mode="bilinear",
51
+ align_corners=False,
52
+ ) # (1, 1, gt.shape)
53
+ low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256)
54
+ medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
55
+ return medsam_seg
56
+
57
+ def predict(img) -> Tuple[Dict, float]:
58
+ """Transforms and performs a prediction on img and returns prediction and time taken.
59
+ """
60
+ # Start the timer
61
+ start_time = timer()
62
+ # Transform the target image and add a batch dimension
63
+
64
+ img_np = np.array(img)
65
+ # Convierte de BGR a RGB si es necesario
66
+ if img_np.shape[-1] == 3: # Asegura que sea una imagen en color
67
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
68
+
69
+ if len(img_np.shape) == 2:
70
+ img_3c = np.repeat(img_np[:, :, None], 3, axis=-1)
71
+ else:
72
+ img_3c = img_np
73
+ H, W, _ = img_3c.shape
74
+ # %% image preprocessing
75
+ img_1024 = transform.resize(
76
+ img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
77
+ ).astype(np.uint8)
78
+ img_1024 = (img_1024 - img_1024.min()) / np.clip(
79
+ img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None
80
+ ) # normalize to [0, 1], (H, W, 3)
81
+ # convert the shape to (3, H, W)
82
+ img_1024_tensor = (
83
+ torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)
84
+ )
85
+
86
+ # Put model into evaluation mode and turn on inference mode
87
+ medsam_model.eval()
88
+ with torch.inference_mode():
89
+ image_embedding = medsam_model.image_encoder(img_1024_tensor) # (1, 256, 64, 64)
90
+ # define the inputbox
91
+ input_box = np.array([[425, 600, 700, 875]])
92
+ # transfer box_np t0 1024x1024 scale
93
+ box_1024 = input_box / np.array([W, H, W, H]) * 1024
94
+
95
+ medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W)
96
+ pred_time = round(timer() - start_time, 5)
97
+
98
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
99
+ ax[0].imshow(img_3c)
100
+ show_box(input_box[0], ax[0])
101
+ ax[0].set_title("Input Image and Bounding Box")
102
+ ax[1].imshow(img_3c)
103
+ show_mask(medsam_seg, ax[1])
104
+ show_box(input_box[0], ax[1])
105
+ ax[1].set_title("MedSAM Segmentation")
106
+ # Calculate the prediction time
107
+
108
+
109
+ # Return the prediction dictionary and prediction time
110
+ return fig, pred_time
111
+
112
+ # 4. Gradio app
113
+ # Create title, description and article strings
114
+ title = "MedSam"
115
+ description = "a specialized SAM model finely tuned for the segmentation of medical images. With this app, effortlessly extract image embeddings using the model's advanced mask decoder."
116
+ article = "Created at gradio-sam-predictor-image-embedding-generator.ipynb ."
117
+
118
+ # Create examples list from "examples/" directory
119
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
120
+
121
+ # Create the Gradio demo
122
+ demo = gr.Interface(fn=predict, # mapping function from input to output
123
+ inputs=gr.Image(type="pil"), # what are the inputs?
124
+ outputs=[gr.Plot(label="Predictions"), # what are the outputs?
125
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
126
+ examples=example_list,
127
+ title=title,
128
+ description=description,
129
+ article=article)
130
+
131
+ # Launch the demo!
132
+ demo.launch(debug=False, # print errors locally?
133
+ share=True) # generate a publically shareable URL?
examples/img_demo.png ADDED
examples/truck.jpg ADDED
medsam_vit_b.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9ef4acfee5f5a5d9737a32b5ce03f2cfc5d349289c6a2f8270d7d3f63eaf966
3
+ size 189792256
model.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ import matplotlib.pyplot as plt
5
+ from segment_anything import sam_model_registry, SamPredictor
6
+ from segment_anything.utils.onnx import SamOnnxModel
7
+ import torch.nn.functional as F
8
+
9
+ def create_sam_model(model_type, checkpoint, device: str = "cpu"):
10
+ medsam_model = sam_model_registry[model_type](checkpoint=checkpoint)
11
+ medsam_model = medsam_model.to(device)
12
+ return medsam_model
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.1.0
2
+ torchvision==0.16.0
3
+ gradio==3.50.2
4
+ 'git+https://github.com/facebookresearch/segment-anything.git'