Mehmet Batuhan Duman commited on
Commit
d8e79b6
·
1 Parent(s): 31179bc

Changed scan func

Browse files
Files changed (2) hide show
  1. .idea/workspace.xml +1 -1
  2. app.py +44 -68
.idea/workspace.xml CHANGED
@@ -65,7 +65,7 @@
65
  <workItem from="1683665300392" duration="7649000" />
66
  <workItem from="1683708398011" duration="1235000" />
67
  <workItem from="1684437905081" duration="110000" />
68
- <workItem from="1686602174110" duration="5060000" />
69
  </task>
70
  <servers />
71
  </component>
 
65
  <workItem from="1683665300392" duration="7649000" />
66
  <workItem from="1683708398011" duration="1235000" />
67
  <workItem from="1684437905081" duration="110000" />
68
+ <workItem from="1686602174110" duration="6264000" />
69
  </task>
70
  <servers />
71
  </component>
app.py CHANGED
@@ -15,7 +15,6 @@ import cv2
15
  import matplotlib.pyplot as plt
16
  import matplotlib.patches as patches
17
  from functools import partial
18
- import tempfile
19
 
20
  class Net2(nn.Module):
21
  def __init__(self):
@@ -123,8 +122,8 @@ class Net(nn.Module):
123
  model = None
124
  model_path = "models1.pth"
125
 
126
- # model2 = None
127
- # model2_path = "model4.pth"
128
 
129
  if os.path.exists(model_path):
130
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
@@ -143,130 +142,107 @@ else:
143
 
144
  # def process_image(input_image):
145
  # image = Image.fromarray(input_image).convert("RGB")
146
- #
147
  # start_time = time.time()
148
  # heatmap = scanmap(np.array(image), model)
149
  # elapsed_time = time.time() - start_time
150
  # heatmap_img = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255)).convert('RGB')
151
- #
152
  # heatmap_img = heatmap_img.resize(image.size)
153
- #
154
  # return image, heatmap_img, int(elapsed_time)
155
- #
156
- #
157
  # def scanmap(image_np, model):
158
  # image_np = image_np.astype(np.float32) / 255.0
159
- #
160
  # window_size = (80, 80)
161
  # stride = 10
162
- #
163
  # height, width, channels = image_np.shape
164
- #
165
  # probabilities_map = []
166
- #
167
  # for y in range(0, height - window_size[1] + 1, stride):
168
  # row_probabilities = []
169
  # for x in range(0, width - window_size[0] + 1, stride):
170
  # cropped_window = image_np[y:y + window_size[1], x:x + window_size[0]]
171
  # cropped_window_torch = transforms.ToTensor()(cropped_window).unsqueeze(0)
172
- #
173
  # with torch.no_grad():
174
  # probabilities = model(cropped_window_torch)
175
- #
176
  # row_probabilities.append(probabilities[0, 1].item())
177
- #
178
  # probabilities_map.append(row_probabilities)
179
- #
180
  # probabilities_map = np.array(probabilities_map)
181
  # return probabilities_map
182
- #
183
  # def gradio_process_image(input_image):
184
  # original, heatmap, elapsed_time = process_image(input_image)
185
  # return original, heatmap, f"Elapsed Time (seconds): {elapsed_time}"
186
- #
187
  # inputs = gr.Image(label="Upload Image")
188
  # outputs = [
189
  # gr.Image(label="Original Image"),
190
  # gr.Image(label="Heatmap"),
191
  # gr.Textbox(label="Elapsed Time")
192
  # ]
193
- #
194
  # iface = gr.Interface(fn=gradio_process_image, inputs=inputs, outputs=outputs)
195
  # iface.launch()
196
-
197
- def scanmap(image_path, model, threshold=0.5):
198
- satellite_image = cv2.imread(image_path)
199
- satellite_image = satellite_image.astype(np.float32) / 255.0
200
 
201
  window_size = (80, 80)
202
  stride = 10
203
 
204
- height, width, channels = satellite_image.shape
205
-
206
- # ensure model is in float32 precision
207
- model.float()
208
 
209
  fig, ax = plt.subplots(1)
210
- ax.imshow(satellite_image)
211
-
212
- ship_images = []
213
 
214
  for y in range(0, height - window_size[1] + 1, stride):
215
  for x in range(0, width - window_size[0] + 1, stride):
216
- cropped_window = satellite_image[y:y + window_size[1], x:x + window_size[0]]
217
- cropped_window_torch = torch.tensor(cropped_window.transpose(2, 0, 1), dtype=torch.float32).unsqueeze(0)
218
 
219
  with torch.no_grad():
220
  probabilities = model(cropped_window_torch)
221
 
222
- # if probability is greater than threshold, draw a bounding box and add to ship_images
223
  if probabilities[0, 1].item() > threshold:
224
- rect = patches.Rectangle((x, y), window_size[0], window_size[1], linewidth=1, edgecolor='r',
225
- facecolor='none')
226
  ax.add_patch(rect)
227
- ship_images.append(cropped_window)
228
 
229
- output_path = "output.png"
230
- plt.savefig(output_path)
231
- plt.close()
 
232
 
233
- return output_path
234
 
 
 
235
 
236
- def process_image(input_image, model, threshold=0.5):
237
  start_time = time.time()
238
- ship_images = scanmap(input_image, model, threshold)
239
- elapsed_time = time.time() - start_time
240
-
241
- return ship_images, int(elapsed_time)
242
-
243
-
244
- def gradio_process_image(input_image, model, threshold=0.5):
245
- start_time = time.time()
246
-
247
- # save numpy array to a temporary file
248
- temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
249
- temp.close()
250
- cv2.imwrite(temp.name, cv2.cvtColor(input_image * 255, cv2.COLOR_RGB2BGR))
251
-
252
- # pass file path to scanmap
253
- output_image_path = scanmap(temp.name, model, threshold)
254
-
255
  elapsed_time = time.time() - start_time
256
 
257
- # delete temporary file after processing
258
- os.unlink(temp.name)
259
 
260
- return output_image_path, f"Elapsed Time (seconds): {elapsed_time}"
 
 
261
 
262
- inputs = gr.inputs.Image(label="Upload Image")
263
  outputs = [
264
- gr.outputs.Image(type='filepath', label="Detected Ships"),
265
- gr.outputs.Textbox(label="Elapsed Time")
 
266
  ]
267
 
268
- # Use 0.5 as the threshold, but adjust according to your needs
269
- gradio_process_image_partial = partial(gradio_process_image, model=model, threshold=0.5)
270
 
271
- iface = gr.Interface(fn=gradio_process_image_partial, inputs=inputs, outputs=outputs)
272
- iface.launch()
 
15
  import matplotlib.pyplot as plt
16
  import matplotlib.patches as patches
17
  from functools import partial
 
18
 
19
  class Net2(nn.Module):
20
  def __init__(self):
 
122
  model = None
123
  model_path = "models1.pth"
124
 
125
+ model2 = None
126
+ model2_path = "model4.pth"
127
 
128
  if os.path.exists(model_path):
129
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
 
142
 
143
  # def process_image(input_image):
144
  # image = Image.fromarray(input_image).convert("RGB")
145
+ #
146
  # start_time = time.time()
147
  # heatmap = scanmap(np.array(image), model)
148
  # elapsed_time = time.time() - start_time
149
  # heatmap_img = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255)).convert('RGB')
150
+ #
151
  # heatmap_img = heatmap_img.resize(image.size)
152
+ #
153
  # return image, heatmap_img, int(elapsed_time)
154
+ #
155
+ #
156
  # def scanmap(image_np, model):
157
  # image_np = image_np.astype(np.float32) / 255.0
158
+ #
159
  # window_size = (80, 80)
160
  # stride = 10
161
+ #
162
  # height, width, channels = image_np.shape
163
+ #
164
  # probabilities_map = []
165
+ #
166
  # for y in range(0, height - window_size[1] + 1, stride):
167
  # row_probabilities = []
168
  # for x in range(0, width - window_size[0] + 1, stride):
169
  # cropped_window = image_np[y:y + window_size[1], x:x + window_size[0]]
170
  # cropped_window_torch = transforms.ToTensor()(cropped_window).unsqueeze(0)
171
+ #
172
  # with torch.no_grad():
173
  # probabilities = model(cropped_window_torch)
174
+ #
175
  # row_probabilities.append(probabilities[0, 1].item())
176
+ #
177
  # probabilities_map.append(row_probabilities)
178
+ #
179
  # probabilities_map = np.array(probabilities_map)
180
  # return probabilities_map
181
+ #
182
  # def gradio_process_image(input_image):
183
  # original, heatmap, elapsed_time = process_image(input_image)
184
  # return original, heatmap, f"Elapsed Time (seconds): {elapsed_time}"
185
+ #
186
  # inputs = gr.Image(label="Upload Image")
187
  # outputs = [
188
  # gr.Image(label="Original Image"),
189
  # gr.Image(label="Heatmap"),
190
  # gr.Textbox(label="Elapsed Time")
191
  # ]
192
+ #
193
  # iface = gr.Interface(fn=gradio_process_image, inputs=inputs, outputs=outputs)
194
  # iface.launch()
195
+ def scanmap(image_np, model, threshold=0.5):
196
+ image_np = image_np.astype(np.float32) / 255.0
 
 
197
 
198
  window_size = (80, 80)
199
  stride = 10
200
 
201
+ height, width, channels = image_np.shape
 
 
 
202
 
203
  fig, ax = plt.subplots(1)
204
+ ax.imshow(image_np)
 
 
205
 
206
  for y in range(0, height - window_size[1] + 1, stride):
207
  for x in range(0, width - window_size[0] + 1, stride):
208
+ cropped_window = image_np[y:y + window_size[1], x:x + window_size[0]]
209
+ cropped_window_torch = transforms.ToTensor()(cropped_window).unsqueeze(0)
210
 
211
  with torch.no_grad():
212
  probabilities = model(cropped_window_torch)
213
 
214
+ # if probability is greater than threshold, draw a bounding box
215
  if probabilities[0, 1].item() > threshold:
216
+ rect = patches.Rectangle((x, y), window_size[0], window_size[1], linewidth=1, edgecolor='r', facecolor='none')
 
217
  ax.add_patch(rect)
 
218
 
219
+ # Save the image to a byte buffer
220
+ buf = io.BytesIO()
221
+ plt.savefig(buf, format='png')
222
+ buf.seek(0)
223
 
224
+ return Image.open(buf) # return PIL Image
225
 
226
+ def process_image(input_image):
227
+ image = Image.fromarray(input_image).convert("RGB")
228
 
 
229
  start_time = time.time()
230
+ detected_ships_image = scanmap(np.array(image), model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  elapsed_time = time.time() - start_time
232
 
233
+ return image, detected_ships_image, int(elapsed_time)
 
234
 
235
+ def gradio_process_image(input_image):
236
+ original, detected_ships_image, elapsed_time = process_image(input_image)
237
+ return original, detected_ships_image, f"Elapsed Time (seconds): {elapsed_time}"
238
 
239
+ inputs = gr.Image(label="Upload Image")
240
  outputs = [
241
+ gr.Image(label="Original Image"),
242
+ gr.Image(label="Heatmap"),
243
+ gr.Textbox(label="Elapsed Time")
244
  ]
245
 
246
+ iface = gr.Interface(fn=gradio_process_image, inputs=inputs, outputs=outputs)
247
+ iface.launch()
248