JuanLozada97 commited on
Commit
28833ba
1 Parent(s): 70ee8a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -10
app.py CHANGED
@@ -72,7 +72,7 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
72
  medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
73
  return medsam_seg
74
 
75
- def predict(img,x1,y1,x2,y2) -> Tuple[Dict, float]:
76
  """Transforms and performs a prediction on img and returns prediction and time taken.
77
  """
78
  # Start the timer
@@ -106,11 +106,9 @@ def predict(img,x1,y1,x2,y2) -> Tuple[Dict, float]:
106
  with torch.inference_mode():
107
  image_embedding = medsam_model.image_encoder(img_1024_tensor) # (1, 256, 64, 64)
108
  # define the inputbox
109
- input_box = np.array([[x1,y1,x2,y2]])
110
- input_box = np.nan_to_num(input_box, nan=0)
111
  # transfer box_np t0 1024x1024 scale
112
- scaling_factor = 1/np.array([W, H, W, H])
113
- box_1024 = input_box.astype(int) * scaling_factor * 1024
114
 
115
  medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W)
116
  pred_time = round(timer() - start_time, 5)
@@ -143,11 +141,7 @@ example_list = [["examples/" + example] for example in os.listdir("examples")]
143
 
144
  # Create the Gradio demo
145
  demo = gr.Interface(fn=predict, # mapping function from input to output
146
- inputs=[gr.Image(type="pil"),
147
- gr.Slider(0, 512, randomize=True,step=1, label="X1",info="top-left point"),
148
- gr.Slider(0, 512, randomize=True,step=1, label="Y1"),
149
- gr.Slider(0, 512, randomize=True,step=1, label="X2",info="bottom-right point"),
150
- gr.Slider(0, 512, randomize=True,step=1, label="Y2"),], # what are the inputs?
151
  outputs=[gr.Plot(label="Predictions"), # what are the outputs?
152
  gr.Number(label="Prediction time (s)"),
153
  gr.JSON(label="Embedding Image")], # our fn has two outputs, therefore we have two outputs
 
72
  medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
73
  return medsam_seg
74
 
75
+ def predict(img) -> Tuple[Dict, float]:
76
  """Transforms and performs a prediction on img and returns prediction and time taken.
77
  """
78
  # Start the timer
 
106
  with torch.inference_mode():
107
  image_embedding = medsam_model.image_encoder(img_1024_tensor) # (1, 256, 64, 64)
108
  # define the inputbox
109
+ input_box = np.array([[125, 275, 190, 350]])
 
110
  # transfer box_np t0 1024x1024 scale
111
+ box_1024 = input_box.astype(int) / np.array([W, H, W, H])* 1024
 
112
 
113
  medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W)
114
  pred_time = round(timer() - start_time, 5)
 
141
 
142
  # Create the Gradio demo
143
  demo = gr.Interface(fn=predict, # mapping function from input to output
144
+ inputs=gr.Image(type="pil"), # what are the inputs?
 
 
 
 
145
  outputs=[gr.Plot(label="Predictions"), # what are the outputs?
146
  gr.Number(label="Prediction time (s)"),
147
  gr.JSON(label="Embedding Image")], # our fn has two outputs, therefore we have two outputs