yasinelh commited on
Commit
9117f01
1 Parent(s): 3aba32d

Upload 7 files

Browse files
Files changed (7) hide show
  1. 01_test.tif +0 -0
  2. 02_test.tif +0 -0
  3. 03_test.tif +0 -0
  4. 04_test.tif +0 -0
  5. app (1).py +93 -0
  6. models (1).py +210 -0
  7. unet_demo.ipynb +445 -0
01_test.tif ADDED
02_test.tif ADDED
03_test.tif ADDED
04_test.tif ADDED
app (1).py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import cv2
4
+ import numpy as np
5
+ import time
6
+ import models
7
+ import torch
8
+
9
+ from torchvision import transforms
10
+ from torchvision import transforms
11
+
12
+ def load_model(path, model):
13
+ model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
14
+ return model
15
+
16
+ def predict(img):
17
+ model = models.unet(3, 1)
18
+ model = load_model('model.pth',model)
19
+
20
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
21
+ img = cv2.resize(img, (512, 512))
22
+ convert_tensor = transforms.ToTensor()
23
+ img = convert_tensor(img).float()
24
+ img = normalize(img)
25
+ img = torch.unsqueeze(img, dim=0)
26
+
27
+ output = model(img)
28
+ result = torch.sigmoid(output)
29
+
30
+ threshold = 0.5
31
+ result = (result >= threshold).float()
32
+ prediction = result[0].cpu() # Move tensor to CPU if it's on GPU
33
+ # Convert tensor to a numpy array
34
+ prediction_array = prediction.numpy()
35
+ # Rescale values to the range [0, 255]
36
+ prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)
37
+ cv2.imwrite("test.png",prediction_array)
38
+ return prediction_array
39
+
40
+ def predicjt(img):
41
+ model1 = models.SAunet(3, 1)
42
+ model1 = load_model('saunet.pth',model1)
43
+
44
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
45
+ img = cv2.resize(img, (512, 512))
46
+ convert_tensor = transforms.ToTensor()
47
+ img = convert_tensor(img).float()
48
+ img = normalize(img)
49
+ img = torch.unsqueeze(img, dim=0)
50
+
51
+ output = model1(img)
52
+ result = torch.sigmoid(output)
53
+
54
+ threshold = 0.5
55
+ result = (result >= threshold).float()
56
+ prediction = result[0].cpu() # Move tensor to CPU if it's on GPU
57
+ # Convert tensor to a numpy array
58
+ prediction_array = prediction.numpy()
59
+ # Rescale values to the range [0, 255]
60
+ prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)
61
+ cv2.imwrite("test1.png",prediction_array)
62
+ return prediction_array
63
+ def main():
64
+ st.title("Image Segmentation Demo")
65
+
66
+ # Predefined list of image names
67
+ image_names = ["01_test.tif", "02_test.tif", "03_test.tif"]
68
+
69
+ # Create a selection box for the images
70
+ selected_image_name = st.selectbox("Select an Image", image_names)
71
+
72
+ # Load the selected image
73
+ selected_image = cv2.imread(selected_image_name)
74
+
75
+ # Display the selected image
76
+ st.image(selected_image, channels="RGB")
77
+
78
+ # Create a button for segmentation
79
+ if st.button("Segment"):
80
+ # Perform segmentation on the selected image
81
+ segmented_image = predict(selected_image)
82
+ segmented_image1 = predicjt(selected_image)
83
+
84
+
85
+ # Display the segmented image
86
+ st.image(segmented_image, channels="RGB",caption='U-Net segmentation')
87
+ st.image(segmented_image1, channels="RGB",caption='Spatial Attention U-Net segmentation ')
88
+
89
+ # Function to perform segmentation on the selected image
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
models (1).py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+
6
+
7
+ class DropBlock(nn.Module):
8
+ def __init__(self, block_size: int = 5, p: float = 0.1):
9
+ super().__init__()
10
+ self.block_size = block_size
11
+ self.p = p
12
+
13
+ def calculate_gamma(self, x: Tensor) -> float:
14
+
15
+
16
+ invalid = (1 - self.p) / (self.block_size ** 2)
17
+ valid = (x.shape[-1] ** 2) / ((x.shape[-1] - self.block_size + 1) ** 2)
18
+ return invalid * valid
19
+
20
+ def forward(self, x: Tensor) -> Tensor:
21
+ N, C, H, W = x.size()
22
+ if self.training:
23
+ gamma = self.calculate_gamma(x)
24
+ mask_shape = (N, C, H - self.block_size + 1, W - self.block_size + 1)
25
+ mask = torch.bernoulli(torch.full(mask_shape, gamma, device=x.device))
26
+ mask = F.pad(mask, [self.block_size // 2] * 4, value=0)
27
+ mask_block = 1 - F.max_pool2d(
28
+ mask,
29
+ kernel_size=(self.block_size, self.block_size),
30
+ stride=(1, 1),
31
+ padding=(self.block_size // 2, self.block_size // 2),
32
+ )
33
+ x = mask_block * x * (mask_block.numel() / mask_block.sum())
34
+ return x
35
+
36
+
37
+ class double_conv(nn.Module):
38
+ def __init__(self,intc,outc):
39
+ super().__init__()
40
+ self.conv1=nn.Conv2d(intc,outc,kernel_size=3,padding=1,stride=1)
41
+ self.drop1= DropBlock(7, 0.9)
42
+ self.bn1=nn.BatchNorm2d(outc)
43
+ self.relu1=nn.ReLU()
44
+ self.conv2=nn.Conv2d(outc,outc,kernel_size=3,padding=1,stride=1)
45
+ self.drop2=DropBlock(7, 0.9)
46
+ self.bn2=nn.BatchNorm2d(outc)
47
+ self.relu2=nn.ReLU()
48
+
49
+ def forward(self,input):
50
+ x=self.relu1(self.bn1(self.drop1(self.conv1(input))))
51
+ x=self.relu2(self.bn2(self.drop2(self.conv2(x))))
52
+
53
+ return x
54
+ class upconv(nn.Module):
55
+ def __init__(self,intc,outc) -> None:
56
+ super().__init__()
57
+ self.up=nn.ConvTranspose2d(intc, outc, kernel_size=2, stride=2, padding=0)
58
+ # self.relu=nn.ReLU()
59
+
60
+ def forward(self,input):
61
+ x=self.up(input)
62
+ #x=self.relu(self.up(input))
63
+ return x
64
+ class unet(nn.Module):
65
+ def __init__(self,int,out) -> None:
66
+ 'int: represent the number of image channels'
67
+ 'out: number of the desired final channels'
68
+
69
+ super().__init__()
70
+ 'encoder'
71
+ self.convlayer1=double_conv(int,64)
72
+ self.down1=nn.MaxPool2d((2, 2))
73
+ self.convlayer2=double_conv(64,128)
74
+ self.down2=nn.MaxPool2d((2, 2))
75
+ self.convlayer3=double_conv(128,256)
76
+ self.down3=nn.MaxPool2d((2, 2))
77
+ self.convlayer4=double_conv(256,512)
78
+ self.down4=nn.MaxPool2d((2, 2))
79
+
80
+ 'bridge'
81
+ self.bridge=double_conv(512,1024)
82
+ 'decoder'
83
+ self.up1=upconv(1024,512)
84
+ self.convlayer5=double_conv(1024,512)
85
+ self.up2=upconv(512,256)
86
+ self.convlayer6=double_conv(512,256)
87
+ self.up3=upconv(256,128)
88
+ self.convlayer7=double_conv(256,128)
89
+ self.up4=upconv(128,64)
90
+ self.convlayer8=double_conv(128,64)
91
+ 'output'
92
+ self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0)
93
+ self.sig=nn.Sigmoid()
94
+ def forward(self,input):
95
+ 'encoder'
96
+ l1=self.convlayer1(input)
97
+ d1=self.down1(l1)
98
+ l2=self.convlayer2(d1)
99
+ d2=self.down2(l2)
100
+ l3=self.convlayer3(d2)
101
+ d3=self.down3(l3)
102
+ l4=self.convlayer4(d3)
103
+ d4=self.down4(l4)
104
+ 'bridge'
105
+ bridge=self.bridge(d4)
106
+ 'decoder'
107
+ up1=self.up1(bridge)
108
+ up1 = torch.cat([up1, l4], axis=1)
109
+ l5=self.convlayer5(up1)
110
+
111
+ up2=self.up2(l5)
112
+ up2 = torch.cat([up2, l3], axis=1)
113
+ l6=self.convlayer6(up2)
114
+
115
+ up3=self.up3(l6)
116
+ up3= torch.cat([up3, l2], axis=1)
117
+ l7=self.convlayer7(up3)
118
+
119
+ up4=self.up4(l7)
120
+ up4 = torch.cat([up4, l1], axis=1)
121
+ l8=self.convlayer8(up4)
122
+ out=self.outputs(l8)
123
+
124
+ #out=self.sig(self.outputs(l8))
125
+ return out
126
+ class spatialAttention(nn.Module):
127
+ def __init__(self) -> None:
128
+ super().__init__()
129
+
130
+ self.conv77=nn.Conv2d(2,1,kernel_size=7,padding=3)
131
+ self.sig=nn.Sigmoid()
132
+ def forward(self,input):
133
+ x=torch.cat( (torch.max(input,1)[0].unsqueeze(1), torch.mean(input,1).unsqueeze(1)), dim=1 )
134
+ x=self.sig(self.conv77(x))
135
+ x=input*x
136
+ return x
137
+ class SAunet(nn.Module):
138
+ def __init__(self,int,out) -> None:
139
+ 'int: represent the number of image channels'
140
+ 'out: number of the desired final channels'
141
+
142
+ super().__init__()
143
+ 'encoder'
144
+ self.convlayer1=double_conv(int,64)
145
+ self.down1=nn.MaxPool2d((2, 2))
146
+ self.convlayer2=double_conv(64,128)
147
+ self.down2=nn.MaxPool2d((2, 2))
148
+ self.convlayer3=double_conv(128,256)
149
+ self.down3=nn.MaxPool2d((2, 2))
150
+ self.convlayer4=double_conv(256,512)
151
+ self.down4=nn.MaxPool2d((2, 2))
152
+
153
+ 'bridge'
154
+ self.attmodule=spatialAttention()
155
+ self.bridge1=nn.Conv2d(512,1024,kernel_size=3,stride=1,padding=1)
156
+ self.bn1=nn.BatchNorm2d(1024)
157
+ self.act1=nn.ReLU()
158
+ self.bridge2=nn.Conv2d(1024,1024,kernel_size=3,stride=1,padding=1)
159
+ self.bn2=nn.BatchNorm2d(1024)
160
+ self.act2=nn.ReLU()
161
+ 'decoder'
162
+ self.up1=upconv(1024,512)
163
+ self.convlayer5=double_conv(1024,512)
164
+ self.up2=upconv(512,256)
165
+ self.convlayer6=double_conv(512,256)
166
+ self.up3=upconv(256,128)
167
+ self.convlayer7=double_conv(256,128)
168
+ self.up4=upconv(128,64)
169
+ self.convlayer8=double_conv(128,64)
170
+ 'output'
171
+ self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0)
172
+ self.sig=nn.Sigmoid()
173
+ def forward(self,input):
174
+ 'encoder'
175
+ l1=self.convlayer1(input)
176
+ d1=self.down1(l1)
177
+ l2=self.convlayer2(d1)
178
+ d2=self.down2(l2)
179
+ l3=self.convlayer3(d2)
180
+ d3=self.down3(l3)
181
+ l4=self.convlayer4(d3)
182
+ d4=self.down4(l4)
183
+ 'bridge'
184
+ b1=self.act1(self.bn1(self.bridge1(d4)))
185
+ att=self.attmodule(b1)
186
+ b2=self.act2(self.bn2(self.bridge2(att)))
187
+ 'decoder'
188
+ up1=self.up1(b2)
189
+ up1 = torch.cat([up1, l4], axis=1)
190
+ l5=self.convlayer5(up1)
191
+
192
+ up2=self.up2(l5)
193
+ up2 = torch.cat([up2, l3], axis=1)
194
+ l6=self.convlayer6(up2)
195
+
196
+ up3=self.up3(l6)
197
+ up3= torch.cat([up3, l2], axis=1)
198
+ l7=self.convlayer7(up3)
199
+
200
+ up4=self.up4(l7)
201
+ up4 = torch.cat([up4, l1], axis=1)
202
+ l8=self.convlayer8(up4)
203
+ out=self.outputs(l8)
204
+
205
+ #out=self.sig(self.outputs(l8))
206
+ return out
207
+
208
+
209
+
210
+
unet_demo.ipynb ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "install streamlit"
21
+ ],
22
+ "metadata": {
23
+ "id": "0Zvkx3gudK6C"
24
+ }
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "metadata": {
30
+ "id": "rIoHYPsIc_JX"
31
+ },
32
+ "outputs": [],
33
+ "source": [
34
+ "!pip install streamlit -q"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {
41
+ "colab": {
42
+ "base_uri": "https://localhost:8080/"
43
+ },
44
+ "id": "yi09eoT-JgS8",
45
+ "outputId": "24656b94-f2b7-4eb1-c900-e2e3028a5ff6"
46
+ },
47
+ "outputs": [
48
+ {
49
+ "output_type": "stream",
50
+ "name": "stdout",
51
+ "text": [
52
+ "Overwriting models.py\n"
53
+ ]
54
+ }
55
+ ],
56
+ "source": [
57
+ "%%writefile models.py\n",
58
+ "import torch\n",
59
+ "import torch.nn as nn\n",
60
+ "import torch.nn.functional as F\n",
61
+ "from torch import Tensor\n",
62
+ "\n",
63
+ "\n",
64
+ "class DropBlock(nn.Module):\n",
65
+ " def __init__(self, block_size: int = 5, p: float = 0.1):\n",
66
+ " super().__init__()\n",
67
+ " self.block_size = block_size\n",
68
+ " self.p = p\n",
69
+ "\n",
70
+ " def calculate_gamma(self, x: Tensor) -> float:\n",
71
+ "\n",
72
+ "\n",
73
+ " invalid = (1 - self.p) / (self.block_size ** 2)\n",
74
+ " valid = (x.shape[-1] ** 2) / ((x.shape[-1] - self.block_size + 1) ** 2)\n",
75
+ " return invalid * valid\n",
76
+ "\n",
77
+ " def forward(self, x: Tensor) -> Tensor:\n",
78
+ " N, C, H, W = x.size()\n",
79
+ " if self.training:\n",
80
+ " gamma = self.calculate_gamma(x)\n",
81
+ " mask_shape = (N, C, H - self.block_size + 1, W - self.block_size + 1)\n",
82
+ " mask = torch.bernoulli(torch.full(mask_shape, gamma, device=x.device))\n",
83
+ " mask = F.pad(mask, [self.block_size // 2] * 4, value=0)\n",
84
+ " mask_block = 1 - F.max_pool2d(\n",
85
+ " mask,\n",
86
+ " kernel_size=(self.block_size, self.block_size),\n",
87
+ " stride=(1, 1),\n",
88
+ " padding=(self.block_size // 2, self.block_size // 2),\n",
89
+ " )\n",
90
+ " x = mask_block * x * (mask_block.numel() / mask_block.sum())\n",
91
+ " return x\n",
92
+ "\n",
93
+ "\n",
94
+ "class double_conv(nn.Module):\n",
95
+ " def __init__(self,intc,outc):\n",
96
+ " super().__init__()\n",
97
+ " self.conv1=nn.Conv2d(intc,outc,kernel_size=3,padding=1,stride=1)\n",
98
+ " self.drop1= DropBlock(7, 0.9)\n",
99
+ " self.bn1=nn.BatchNorm2d(outc)\n",
100
+ " self.relu1=nn.ReLU()\n",
101
+ " self.conv2=nn.Conv2d(outc,outc,kernel_size=3,padding=1,stride=1)\n",
102
+ " self.drop2=DropBlock(7, 0.9)\n",
103
+ " self.bn2=nn.BatchNorm2d(outc)\n",
104
+ " self.relu2=nn.ReLU()\n",
105
+ "\n",
106
+ " def forward(self,input):\n",
107
+ " x=self.relu1(self.bn1(self.drop1(self.conv1(input))))\n",
108
+ " x=self.relu2(self.bn2(self.drop2(self.conv2(x))))\n",
109
+ "\n",
110
+ " return x\n",
111
+ "class upconv(nn.Module):\n",
112
+ " def __init__(self,intc,outc) -> None:\n",
113
+ " super().__init__()\n",
114
+ " self.up=nn.ConvTranspose2d(intc, outc, kernel_size=2, stride=2, padding=0)\n",
115
+ " # self.relu=nn.ReLU()\n",
116
+ "\n",
117
+ " def forward(self,input):\n",
118
+ " x=self.up(input)\n",
119
+ " #x=self.relu(self.up(input))\n",
120
+ " return x\n",
121
+ "class unet(nn.Module):\n",
122
+ " def __init__(self,int,out) -> None:\n",
123
+ " 'int: represent the number of image channels'\n",
124
+ " 'out: number of the desired final channels'\n",
125
+ "\n",
126
+ " super().__init__()\n",
127
+ " 'encoder'\n",
128
+ " self.convlayer1=double_conv(int,64)\n",
129
+ " self.down1=nn.MaxPool2d((2, 2))\n",
130
+ " self.convlayer2=double_conv(64,128)\n",
131
+ " self.down2=nn.MaxPool2d((2, 2))\n",
132
+ " self.convlayer3=double_conv(128,256)\n",
133
+ " self.down3=nn.MaxPool2d((2, 2))\n",
134
+ " self.convlayer4=double_conv(256,512)\n",
135
+ " self.down4=nn.MaxPool2d((2, 2))\n",
136
+ "\n",
137
+ " 'bridge'\n",
138
+ " self.bridge=double_conv(512,1024)\n",
139
+ " 'decoder'\n",
140
+ " self.up1=upconv(1024,512)\n",
141
+ " self.convlayer5=double_conv(1024,512)\n",
142
+ " self.up2=upconv(512,256)\n",
143
+ " self.convlayer6=double_conv(512,256)\n",
144
+ " self.up3=upconv(256,128)\n",
145
+ " self.convlayer7=double_conv(256,128)\n",
146
+ " self.up4=upconv(128,64)\n",
147
+ " self.convlayer8=double_conv(128,64)\n",
148
+ " 'output'\n",
149
+ " self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0)\n",
150
+ " self.sig=nn.Sigmoid()\n",
151
+ " def forward(self,input):\n",
152
+ " 'encoder'\n",
153
+ " l1=self.convlayer1(input)\n",
154
+ " d1=self.down1(l1)\n",
155
+ " l2=self.convlayer2(d1)\n",
156
+ " d2=self.down2(l2)\n",
157
+ " l3=self.convlayer3(d2)\n",
158
+ " d3=self.down3(l3)\n",
159
+ " l4=self.convlayer4(d3)\n",
160
+ " d4=self.down4(l4)\n",
161
+ " 'bridge'\n",
162
+ " bridge=self.bridge(d4)\n",
163
+ " 'decoder'\n",
164
+ " up1=self.up1(bridge)\n",
165
+ " up1 = torch.cat([up1, l4], axis=1)\n",
166
+ " l5=self.convlayer5(up1)\n",
167
+ "\n",
168
+ " up2=self.up2(l5)\n",
169
+ " up2 = torch.cat([up2, l3], axis=1)\n",
170
+ " l6=self.convlayer6(up2)\n",
171
+ "\n",
172
+ " up3=self.up3(l6)\n",
173
+ " up3= torch.cat([up3, l2], axis=1)\n",
174
+ " l7=self.convlayer7(up3)\n",
175
+ "\n",
176
+ " up4=self.up4(l7)\n",
177
+ " up4 = torch.cat([up4, l1], axis=1)\n",
178
+ " l8=self.convlayer8(up4)\n",
179
+ " out=self.outputs(l8)\n",
180
+ "\n",
181
+ " #out=self.sig(self.outputs(l8))\n",
182
+ " return out\n",
183
+ "class spatialAttention(nn.Module):\n",
184
+ " def __init__(self) -> None:\n",
185
+ " super().__init__()\n",
186
+ "\n",
187
+ " self.conv77=nn.Conv2d(2,1,kernel_size=7,padding=3)\n",
188
+ " self.sig=nn.Sigmoid()\n",
189
+ " def forward(self,input):\n",
190
+ " x=torch.cat( (torch.max(input,1)[0].unsqueeze(1), torch.mean(input,1).unsqueeze(1)), dim=1 )\n",
191
+ " x=self.sig(self.conv77(x))\n",
192
+ " x=input*x\n",
193
+ " return x\n",
194
+ "class SAunet(nn.Module):\n",
195
+ " def __init__(self,int,out) -> None:\n",
196
+ " 'int: represent the number of image channels'\n",
197
+ " 'out: number of the desired final channels'\n",
198
+ "\n",
199
+ " super().__init__()\n",
200
+ " 'encoder'\n",
201
+ " self.convlayer1=double_conv(int,64)\n",
202
+ " self.down1=nn.MaxPool2d((2, 2))\n",
203
+ " self.convlayer2=double_conv(64,128)\n",
204
+ " self.down2=nn.MaxPool2d((2, 2))\n",
205
+ " self.convlayer3=double_conv(128,256)\n",
206
+ " self.down3=nn.MaxPool2d((2, 2))\n",
207
+ " self.convlayer4=double_conv(256,512)\n",
208
+ " self.down4=nn.MaxPool2d((2, 2))\n",
209
+ "\n",
210
+ " 'bridge'\n",
211
+ " self.attmodule=spatialAttention()\n",
212
+ " self.bridge1=nn.Conv2d(512,1024,kernel_size=3,stride=1,padding=1)\n",
213
+ " self.bn1=nn.BatchNorm2d(1024)\n",
214
+ " self.act1=nn.ReLU()\n",
215
+ " self.bridge2=nn.Conv2d(1024,1024,kernel_size=3,stride=1,padding=1)\n",
216
+ " self.bn2=nn.BatchNorm2d(1024)\n",
217
+ " self.act2=nn.ReLU()\n",
218
+ " 'decoder'\n",
219
+ " self.up1=upconv(1024,512)\n",
220
+ " self.convlayer5=double_conv(1024,512)\n",
221
+ " self.up2=upconv(512,256)\n",
222
+ " self.convlayer6=double_conv(512,256)\n",
223
+ " self.up3=upconv(256,128)\n",
224
+ " self.convlayer7=double_conv(256,128)\n",
225
+ " self.up4=upconv(128,64)\n",
226
+ " self.convlayer8=double_conv(128,64)\n",
227
+ " 'output'\n",
228
+ " self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0)\n",
229
+ " self.sig=nn.Sigmoid()\n",
230
+ " def forward(self,input):\n",
231
+ " 'encoder'\n",
232
+ " l1=self.convlayer1(input)\n",
233
+ " d1=self.down1(l1)\n",
234
+ " l2=self.convlayer2(d1)\n",
235
+ " d2=self.down2(l2)\n",
236
+ " l3=self.convlayer3(d2)\n",
237
+ " d3=self.down3(l3)\n",
238
+ " l4=self.convlayer4(d3)\n",
239
+ " d4=self.down4(l4)\n",
240
+ " 'bridge'\n",
241
+ " b1=self.act1(self.bn1(self.bridge1(d4)))\n",
242
+ " att=self.attmodule(b1)\n",
243
+ " b2=self.act2(self.bn2(self.bridge2(att)))\n",
244
+ " 'decoder'\n",
245
+ " up1=self.up1(b2)\n",
246
+ " up1 = torch.cat([up1, l4], axis=1)\n",
247
+ " l5=self.convlayer5(up1)\n",
248
+ "\n",
249
+ " up2=self.up2(l5)\n",
250
+ " up2 = torch.cat([up2, l3], axis=1)\n",
251
+ " l6=self.convlayer6(up2)\n",
252
+ "\n",
253
+ " up3=self.up3(l6)\n",
254
+ " up3= torch.cat([up3, l2], axis=1)\n",
255
+ " l7=self.convlayer7(up3)\n",
256
+ "\n",
257
+ " up4=self.up4(l7)\n",
258
+ " up4 = torch.cat([up4, l1], axis=1)\n",
259
+ " l8=self.convlayer8(up4)\n",
260
+ " out=self.outputs(l8)\n",
261
+ "\n",
262
+ " #out=self.sig(self.outputs(l8))\n",
263
+ " return out\n",
264
+ "\n",
265
+ "\n",
266
+ "\n",
267
+ "\n"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "source": [],
273
+ "metadata": {
274
+ "id": "VfBYYfhlejB2"
275
+ },
276
+ "execution_count": null,
277
+ "outputs": []
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "source": [
282
+ "%%writefile app.py\n",
283
+ "import streamlit as st\n",
284
+ "from PIL import Image\n",
285
+ "import cv2\n",
286
+ "import numpy as np\n",
287
+ "import time\n",
288
+ "import models\n",
289
+ "import torch\n",
290
+ "\n",
291
+ "from torchvision import transforms\n",
292
+ "from torchvision import transforms\n",
293
+ "\n",
294
+ "def load_model(path, model):\n",
295
+ " model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))\n",
296
+ " return model\n",
297
+ "\n",
298
+ "def predict(img):\n",
299
+ " model = models.unet(3, 1)\n",
300
+ " model = load_model('model.pth',model)\n",
301
+ "\n",
302
+ " normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])\n",
303
+ " img = cv2.resize(img, (512, 512))\n",
304
+ " convert_tensor = transforms.ToTensor()\n",
305
+ " img = convert_tensor(img).float()\n",
306
+ " img = normalize(img)\n",
307
+ " img = torch.unsqueeze(img, dim=0)\n",
308
+ "\n",
309
+ " output = model(img)\n",
310
+ " result = torch.sigmoid(output)\n",
311
+ "\n",
312
+ " threshold = 0.5\n",
313
+ " result = (result >= threshold).float()\n",
314
+ " prediction = result[0].cpu() # Move tensor to CPU if it's on GPU\n",
315
+ " # Convert tensor to a numpy array\n",
316
+ " prediction_array = prediction.numpy()\n",
317
+ " # Rescale values to the range [0, 255]\n",
318
+ " prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)\n",
319
+ " cv2.imwrite(\"test.png\",prediction_array)\n",
320
+ " return prediction_array\n",
321
+ "\n",
322
+ "def predicjt(img):\n",
323
+ " model1 = models.SAunet(3, 1)\n",
324
+ " model1 = load_model('saunet.pth',model1)\n",
325
+ "\n",
326
+ " normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])\n",
327
+ " img = cv2.resize(img, (512, 512))\n",
328
+ " convert_tensor = transforms.ToTensor()\n",
329
+ " img = convert_tensor(img).float()\n",
330
+ " img = normalize(img)\n",
331
+ " img = torch.unsqueeze(img, dim=0)\n",
332
+ "\n",
333
+ " output = model1(img)\n",
334
+ " result = torch.sigmoid(output)\n",
335
+ "\n",
336
+ " threshold = 0.5\n",
337
+ " result = (result >= threshold).float()\n",
338
+ " prediction = result[0].cpu() # Move tensor to CPU if it's on GPU\n",
339
+ " # Convert tensor to a numpy array\n",
340
+ " prediction_array = prediction.numpy()\n",
341
+ " # Rescale values to the range [0, 255]\n",
342
+ " prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)\n",
343
+ " cv2.imwrite(\"test1.png\",prediction_array)\n",
344
+ " return prediction_array\n",
345
+ "def main():\n",
346
+ " st.title(\"Image Segmentation Demo\")\n",
347
+ "\n",
348
+ " # Predefined list of image names\n",
349
+ " image_names = [\"01_test.tif\", \"02_test.tif\", \"03_test.tif\"]\n",
350
+ "\n",
351
+ " # Create a selection box for the images\n",
352
+ " selected_image_name = st.selectbox(\"Select an Image\", image_names)\n",
353
+ "\n",
354
+ " # Load the selected image\n",
355
+ " selected_image = cv2.imread(selected_image_name)\n",
356
+ "\n",
357
+ " # Display the selected image\n",
358
+ " st.image(selected_image, channels=\"RGB\")\n",
359
+ "\n",
360
+ " # Create a button for segmentation\n",
361
+ " if st.button(\"Segment\"):\n",
362
+ " # Perform segmentation on the selected image\n",
363
+ " segmented_image = predict(selected_image)\n",
364
+ " segmented_image1 = predicjt(selected_image)\n",
365
+ "\n",
366
+ "\n",
367
+ " # Display the segmented image\n",
368
+ " st.image(segmented_image, channels=\"RGB\",caption='U-Net segmentation')\n",
369
+ " st.image(segmented_image1, channels=\"RGB\",caption='Spatial Attention U-Net segmentation ')\n",
370
+ "\n",
371
+ "# Function to perform segmentation on the selected image\n",
372
+ "\n",
373
+ "\n",
374
+ "if __name__ == \"__main__\":\n",
375
+ " main()\n"
376
+ ],
377
+ "metadata": {
378
+ "colab": {
379
+ "base_uri": "https://localhost:8080/"
380
+ },
381
+ "id": "v_1SyQwJ32Cy",
382
+ "outputId": "b88d7f6d-8f25-442a-8c3f-f7e2b1cb7691"
383
+ },
384
+ "execution_count": null,
385
+ "outputs": [
386
+ {
387
+ "output_type": "stream",
388
+ "name": "stdout",
389
+ "text": [
390
+ "Writing app.py\n"
391
+ ]
392
+ }
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "markdown",
397
+ "source": [
398
+ "use this ip"
399
+ ],
400
+ "metadata": {
401
+ "id": "Rkk12rLMdZeb"
402
+ }
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "source": [
407
+ "!wget -q -O - ipv4.icanhazip.com"
408
+ ],
409
+ "metadata": {
410
+ "id": "CfVannfVdJFr"
411
+ },
412
+ "execution_count": null,
413
+ "outputs": []
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "source": [],
418
+ "metadata": {
419
+ "id": "Z2t-PBADddGS"
420
+ },
421
+ "execution_count": null,
422
+ "outputs": []
423
+ },
424
+ {
425
+ "cell_type": "code",
426
+ "source": [
427
+ "!streamlit run app.py & npx localtunnel --port 8501"
428
+ ],
429
+ "metadata": {
430
+ "id": "hI5bMKCQdVve"
431
+ },
432
+ "execution_count": null,
433
+ "outputs": []
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "source": [],
438
+ "metadata": {
439
+ "id": "69mNAs6EdVtU"
440
+ },
441
+ "execution_count": null,
442
+ "outputs": []
443
+ }
444
+ ]
445
+ }