Mehmet Batuhan Duman commited on
Commit
7943976
1 Parent(s): 40d71ff
Files changed (2) hide show
  1. app.py +99 -12
  2. model4.pth +3 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import cv2
2
  import numpy as np
3
  import gradio as gr
 
4
  from PIL import Image, ImageOps
5
  import matplotlib.pyplot as plt
6
  import torch
@@ -8,11 +9,63 @@ import torch.nn as nn
8
  import torch.nn.functional as F
9
  from torchvision import transforms
10
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Add your model classes (Net and Net2) here.
13
 
14
- # Loading model
15
- model = None
16
  model2 = None
17
  model2_path = "model4.pth"
18
 
@@ -30,23 +83,57 @@ if os.path.exists(model2_path):
30
  else:
31
  print("Model file not found at", model2_path)
32
 
33
- # Add the scanmap function here.
34
 
35
- def process_image(image: Image.Image):
36
- image_np = np.array(image)
 
37
  start_time = time.time()
38
- heatmap = scanmap(image_np, model)
39
  elapsed_time = time.time() - start_time
40
  heatmap_img = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255)).convert('RGB')
 
41
  heatmap_img = heatmap_img.resize(image.size)
42
 
43
- return heatmap_img, elapsed_time
44
 
45
- inputs = gr.inputs.Image(label="Upload Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  outputs = [
47
- gr.outputs.Image(label="Heatmap"),
48
- gr.outputs.Textbox(label="Elapsed Time (seconds)")
 
49
  ]
50
 
51
- iface = gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="ShipNet Heatmap")
52
  iface.launch()
 
 
 
1
  import cv2
2
  import numpy as np
3
  import gradio as gr
4
+ from gradio import Interface, Input, Output, Image
5
  from PIL import Image, ImageOps
6
  import matplotlib.pyplot as plt
7
  import torch
 
9
  import torch.nn.functional as F
10
  from torchvision import transforms
11
  import os
12
+ import time
13
+ import io
14
+ import base64
15
+
16
+ class Net2(nn.Module):
17
+ def __init__(self):
18
+ super(Net2, self).__init__()
19
+ self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
20
+ self.bn1 = nn.BatchNorm2d(64)
21
+ self.pool1 = nn.MaxPool2d(2, 2)
22
+ self.dropout1 = nn.Dropout(0.25)
23
+
24
+ self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
25
+ self.bn2 = nn.BatchNorm2d(64)
26
+ self.pool2 = nn.MaxPool2d(2, 2)
27
+ self.dropout2 = nn.Dropout(0.25)
28
+
29
+ self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
30
+ self.bn3 = nn.BatchNorm2d(64)
31
+ self.pool3 = nn.MaxPool2d(2, 2)
32
+ self.dropout3 = nn.Dropout(0.25)
33
+
34
+ self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
35
+ self.bn4 = nn.BatchNorm2d(64)
36
+ self.pool4 = nn.MaxPool2d(2, 2)
37
+ self.dropout4 = nn.Dropout(0.25)
38
+
39
+ self.flatten = nn.Flatten()
40
+
41
+ self.fc1 = nn.Linear(64 * 5 * 5, 200)
42
+ self.fc2 = nn.Linear(200, 150)
43
+ self.fc3 = nn.Linear(150, 2)
44
+
45
+ def forward(self, x):
46
+ x = F.relu(self.bn1(self.conv1(x)))
47
+ x = self.pool1(x)
48
+ x = self.dropout1(x)
49
+
50
+ x = F.relu(self.bn2(self.conv2(x)))
51
+ x = self.pool2(x)
52
+ x = self.dropout2(x)
53
+
54
+ x = F.relu(self.bn3(self.conv3(x)))
55
+ x = self.pool3(x)
56
+ x = self.dropout3(x)
57
+
58
+ x = F.relu(self.bn4(self.conv4(x)))
59
+ x = self.pool4(x)
60
+ x = self.dropout4(x)
61
+
62
+ x = self.flatten(x)
63
+ x = F.relu(self.fc1(x))
64
+ x = F.relu(self.fc2(x))
65
+ x = F.softmax(self.fc3(x), dim=1)
66
+ return x
67
 
 
68
 
 
 
69
  model2 = None
70
  model2_path = "model4.pth"
71
 
 
83
  else:
84
  print("Model file not found at", model2_path)
85
 
 
86
 
87
+ def process_image(input_image):
88
+ image = Image.open(io.BytesIO(input_image)).convert("RGB")
89
+
90
  start_time = time.time()
91
+ heatmap = scanmap(np.array(image), model)
92
  elapsed_time = time.time() - start_time
93
  heatmap_img = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255)).convert('RGB')
94
+
95
  heatmap_img = heatmap_img.resize(image.size)
96
 
97
+ return image, heatmap_img, int(elapsed_time)
98
 
99
+ def scanmap(image_np, model):
100
+ image_np = image_np.astype(np.float32) / 255.0
101
+
102
+ window_size = (80, 80)
103
+ stride = 10
104
+
105
+ height, width, channels = image_np.shape
106
+
107
+ probabilities_map = []
108
+
109
+ for y in range(0, height - window_size[1] + 1, stride):
110
+ row_probabilities = []
111
+ for x in range(0, width - window_size[0] + 1, stride):
112
+ cropped_window = image_np[y:y + window_size[1], x:x + window_size[0]]
113
+ cropped_window_torch = transforms.ToTensor()(cropped_window).unsqueeze(0)
114
+
115
+ with torch.no_grad():
116
+ probabilities = model(cropped_window_torch)
117
+
118
+ row_probabilities.append(probabilities[0, 1].item())
119
+
120
+ probabilities_map.append(row_probabilities)
121
+
122
+ probabilities_map = np.array(probabilities_map)
123
+ return probabilities_map
124
+
125
+ def gradio_process_image(input_image):
126
+ original, heatmap, elapsed_time = process_image(input_image)
127
+ return original, heatmap, f"Elapsed Time (seconds): {elapsed_time}"
128
+
129
+ inputs = gr.Image(label="Upload Image")
130
  outputs = [
131
+ gr.Image(label="Original Image"),
132
+ gr.Image(label="Heatmap"),
133
+ gr.Textbox(label="Elapsed Time")
134
  ]
135
 
136
+ iface = gr.Interface(fn=gradio_process_image, inputs=inputs, outputs=outputs)
137
  iface.launch()
138
+
139
+
model4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90d2dcf5a7c630275f3238f399b59b5de6da5688bc9dd10c95476cffc675e342
3
+ size 1867947