Spaces:
Sleeping
Sleeping
fixed scaling
Browse files
app.py
CHANGED
@@ -74,30 +74,26 @@ def process_images(file, x_min, y_min, x_max, y_max):
|
|
74 |
image, H, W = load_image(file)
|
75 |
image_resized = transform.resize(image, (1024, 1024), anti_aliasing=True)
|
76 |
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
|
77 |
-
|
78 |
-
# Check if CUDA is available, and set the device accordingly
|
79 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
80 |
|
81 |
-
#
|
82 |
model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
|
83 |
-
|
84 |
-
# Create the model instance and load the checkpoint
|
85 |
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
|
86 |
medsam_model = medsam_model.to(device)
|
87 |
medsam_model.eval()
|
88 |
|
89 |
-
# Convert image to tensor and move to the correct device
|
90 |
-
image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
|
91 |
-
|
92 |
# Generate image embedding
|
93 |
with torch.no_grad():
|
94 |
img_embed = medsam_model.image_encoder(image_tensor)
|
95 |
|
96 |
-
# Calculate resized box coordinates
|
97 |
scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
|
98 |
box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
|
|
|
|
|
99 |
mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)
|
100 |
|
|
|
101 |
visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
|
102 |
return visualization.getvalue()
|
103 |
|
|
|
74 |
image, H, W = load_image(file)
|
75 |
image_resized = transform.resize(image, (1024, 1024), anti_aliasing=True)
|
76 |
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
|
77 |
+
image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
|
|
|
|
|
78 |
|
79 |
+
# Initialize the MedSAM model and set the device
|
80 |
model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
|
|
|
|
|
81 |
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
|
82 |
medsam_model = medsam_model.to(device)
|
83 |
medsam_model.eval()
|
84 |
|
|
|
|
|
|
|
85 |
# Generate image embedding
|
86 |
with torch.no_grad():
|
87 |
img_embed = medsam_model.image_encoder(image_tensor)
|
88 |
|
89 |
+
# Calculate resized box coordinates
|
90 |
scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
|
91 |
box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
|
92 |
+
|
93 |
+
# Perform inference
|
94 |
mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)
|
95 |
|
96 |
+
# Visualization
|
97 |
visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
|
98 |
return visualization.getvalue()
|
99 |
|