Mehmet Batuhan Duman commited on
Commit
31179bc
·
1 Parent(s): 1c0fcfa

Changed scan func

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -194,7 +194,7 @@ else:
194
  # iface = gr.Interface(fn=gradio_process_image, inputs=inputs, outputs=outputs)
195
  # iface.launch()
196
 
197
- def scanmap(image_path, model, device, threshold=0.5):
198
  satellite_image = cv2.imread(image_path)
199
  satellite_image = satellite_image.astype(np.float32) / 255.0
200
 
@@ -203,6 +203,8 @@ def scanmap(image_path, model, device, threshold=0.5):
203
 
204
  height, width, channels = satellite_image.shape
205
 
 
 
206
 
207
  fig, ax = plt.subplots(1)
208
  ax.imshow(satellite_image)
@@ -213,7 +215,6 @@ def scanmap(image_path, model, device, threshold=0.5):
213
  for x in range(0, width - window_size[0] + 1, stride):
214
  cropped_window = satellite_image[y:y + window_size[1], x:x + window_size[0]]
215
  cropped_window_torch = torch.tensor(cropped_window.transpose(2, 0, 1), dtype=torch.float32).unsqueeze(0)
216
- cropped_window_torch = cropped_window_torch.to(device) # move data to the same device as model
217
 
218
  with torch.no_grad():
219
  probabilities = model(cropped_window_torch)
@@ -231,6 +232,7 @@ def scanmap(image_path, model, device, threshold=0.5):
231
 
232
  return output_path
233
 
 
234
  def process_image(input_image, model, threshold=0.5):
235
  start_time = time.time()
236
  ship_images = scanmap(input_image, model, threshold)
 
194
  # iface = gr.Interface(fn=gradio_process_image, inputs=inputs, outputs=outputs)
195
  # iface.launch()
196
 
197
+ def scanmap(image_path, model, threshold=0.5):
198
  satellite_image = cv2.imread(image_path)
199
  satellite_image = satellite_image.astype(np.float32) / 255.0
200
 
 
203
 
204
  height, width, channels = satellite_image.shape
205
 
206
+ # ensure model is in float32 precision
207
+ model.float()
208
 
209
  fig, ax = plt.subplots(1)
210
  ax.imshow(satellite_image)
 
215
  for x in range(0, width - window_size[0] + 1, stride):
216
  cropped_window = satellite_image[y:y + window_size[1], x:x + window_size[0]]
217
  cropped_window_torch = torch.tensor(cropped_window.transpose(2, 0, 1), dtype=torch.float32).unsqueeze(0)
 
218
 
219
  with torch.no_grad():
220
  probabilities = model(cropped_window_torch)
 
232
 
233
  return output_path
234
 
235
+
236
  def process_image(input_image, model, threshold=0.5):
237
  start_time = time.time()
238
  ship_images = scanmap(input_image, model, threshold)