Mehmet Batuhan Duman commited on
Commit
0a2fbf3
1 Parent(s): f6c66e9

Changed scan func

Browse files
Files changed (2) hide show
  1. .idea/workspace.xml +1 -1
  2. app.py +20 -16
.idea/workspace.xml CHANGED
@@ -65,7 +65,7 @@
65
  <workItem from="1683665300392" duration="7649000" />
66
  <workItem from="1683708398011" duration="1235000" />
67
  <workItem from="1684437905081" duration="110000" />
68
- <workItem from="1686602174110" duration="3739000" />
69
  </task>
70
  <servers />
71
  </component>
 
65
  <workItem from="1683665300392" duration="7649000" />
66
  <workItem from="1683708398011" duration="1235000" />
67
  <workItem from="1684437905081" duration="110000" />
68
+ <workItem from="1686602174110" duration="4678000" />
69
  </task>
70
  <servers />
71
  </component>
app.py CHANGED
@@ -192,7 +192,9 @@ else:
192
  #
193
  # iface = gr.Interface(fn=gradio_process_image, inputs=inputs, outputs=outputs)
194
  # iface.launch()
195
- def scanmap(satellite_image, model, threshold=0.5):
 
 
196
  satellite_image = satellite_image.astype(np.float32) / 255.0
197
 
198
  window_size = (80, 80)
@@ -200,6 +202,7 @@ def scanmap(satellite_image, model, threshold=0.5):
200
 
201
  height, width, channels = satellite_image.shape
202
 
 
203
  fig, ax = plt.subplots(1)
204
  ax.imshow(satellite_image)
205
 
@@ -209,19 +212,23 @@ def scanmap(satellite_image, model, threshold=0.5):
209
  for x in range(0, width - window_size[0] + 1, stride):
210
  cropped_window = satellite_image[y:y + window_size[1], x:x + window_size[0]]
211
  cropped_window_torch = torch.tensor(cropped_window.transpose(2, 0, 1), dtype=torch.float32).unsqueeze(0)
 
212
 
213
  with torch.no_grad():
214
  probabilities = model(cropped_window_torch)
215
 
 
216
  if probabilities[0, 1].item() > threshold:
217
- rect = patches.Rectangle((x, y), window_size[0], window_size[1], linewidth=1, edgecolor='r', facecolor='none')
 
218
  ax.add_patch(rect)
219
  ship_images.append(cropped_window)
220
 
221
- plt.show()
222
-
223
- return ship_images
224
 
 
225
 
226
  def process_image(input_image, model, threshold=0.5):
227
  start_time = time.time()
@@ -231,22 +238,19 @@ def process_image(input_image, model, threshold=0.5):
231
  return ship_images, int(elapsed_time)
232
 
233
 
234
- def gradio_process_image(input_image, model, threshold=0.5):
235
- ship_images, elapsed_time = process_image(input_image, model, threshold)
 
 
236
 
237
- if len(ship_images) > 0:
238
- # Convert first image to format compatible with Gradio
239
- output_image = Image.fromarray((ship_images[0] * 255).astype(np.uint8))
240
- else:
241
- output_image = None
242
 
243
  return output_image, f"Elapsed Time (seconds): {elapsed_time}"
244
 
245
-
246
- inputs = gr.Image(label="Upload Image")
247
  outputs = [
248
- gr.Image(label="Detected Ships"),
249
- gr.Textbox(label="Elapsed Time")
250
  ]
251
 
252
  # Use 0.5 as the threshold, but adjust according to your needs
 
192
  #
193
  # iface = gr.Interface(fn=gradio_process_image, inputs=inputs, outputs=outputs)
194
  # iface.launch()
195
+
196
+ def scanmap(image_path, model, device, threshold=0.5):
197
+ satellite_image = cv2.imread(image_path)
198
  satellite_image = satellite_image.astype(np.float32) / 255.0
199
 
200
  window_size = (80, 80)
 
202
 
203
  height, width, channels = satellite_image.shape
204
 
205
+
206
  fig, ax = plt.subplots(1)
207
  ax.imshow(satellite_image)
208
 
 
212
  for x in range(0, width - window_size[0] + 1, stride):
213
  cropped_window = satellite_image[y:y + window_size[1], x:x + window_size[0]]
214
  cropped_window_torch = torch.tensor(cropped_window.transpose(2, 0, 1), dtype=torch.float32).unsqueeze(0)
215
+ cropped_window_torch = cropped_window_torch.to(device) # move data to the same device as model
216
 
217
  with torch.no_grad():
218
  probabilities = model(cropped_window_torch)
219
 
220
+ # if probability is greater than threshold, draw a bounding box and add to ship_images
221
  if probabilities[0, 1].item() > threshold:
222
+ rect = patches.Rectangle((x, y), window_size[0], window_size[1], linewidth=1, edgecolor='r',
223
+ facecolor='none')
224
  ax.add_patch(rect)
225
  ship_images.append(cropped_window)
226
 
227
+ output_path = "output.png"
228
+ plt.savefig(output_path)
229
+ plt.close()
230
 
231
+ return output_path
232
 
233
  def process_image(input_image, model, threshold=0.5):
234
  start_time = time.time()
 
238
  return ship_images, int(elapsed_time)
239
 
240
 
241
+ def gradio_process_image(input_image_path, model, threshold=0.5):
242
+ start_time = time.time()
243
+ output_image_path = scanmap(input_image_path, model, threshold)
244
+ elapsed_time = time.time() - start_time
245
 
246
+ output_image = Image.open(output_image_path) if output_image_path else None
 
 
 
 
247
 
248
  return output_image, f"Elapsed Time (seconds): {elapsed_time}"
249
 
250
+ inputs = gr.inputs.Image(label="Upload Image")
 
251
  outputs = [
252
+ gr.outputs.Image(label="Detected Ships"),
253
+ gr.outputs.Textbox(label="Elapsed Time")
254
  ]
255
 
256
  # Use 0.5 as the threshold, but adjust according to your needs