dennistrujillo commited on
Commit
5bb4fec
·
verified ·
1 Parent(s): 201e3ec

fixed scaling

Browse files
Files changed (1) hide show
  1. app.py +6 -10
app.py CHANGED
@@ -74,30 +74,26 @@ def process_images(file, x_min, y_min, x_max, y_max):
74
  image, H, W = load_image(file)
75
  image_resized = transform.resize(image, (1024, 1024), anti_aliasing=True)
76
  image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
77
-
78
- # Check if CUDA is available, and set the device accordingly
79
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
80
 
81
- # Define the checkpoint path
82
  model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
83
-
84
- # Create the model instance and load the checkpoint
85
  medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
86
  medsam_model = medsam_model.to(device)
87
  medsam_model.eval()
88
 
89
- # Convert image to tensor and move to the correct device
90
- image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
91
-
92
  # Generate image embedding
93
  with torch.no_grad():
94
  img_embed = medsam_model.image_encoder(image_tensor)
95
 
96
- # Calculate resized box coordinates and perform inference
97
  scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
98
  box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
 
 
99
  mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)
100
 
 
101
  visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
102
  return visualization.getvalue()
103
 
 
74
  image, H, W = load_image(file)
75
  image_resized = transform.resize(image, (1024, 1024), anti_aliasing=True)
76
  image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
77
+ image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
 
 
78
 
79
+ # Initialize the MedSAM model and set the device
80
  model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
 
 
81
  medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
82
  medsam_model = medsam_model.to(device)
83
  medsam_model.eval()
84
 
 
 
 
85
  # Generate image embedding
86
  with torch.no_grad():
87
  img_embed = medsam_model.image_encoder(image_tensor)
88
 
89
+ # Calculate resized box coordinates
90
  scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
91
  box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
92
+
93
+ # Perform inference
94
  mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)
95
 
96
+ # Visualization
97
  visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
98
  return visualization.getvalue()
99