Nunzio commited on
Commit
6a0b93e
·
1 Parent(s): bdf4b96

added files

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .venv/
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, torchvision
2
+ import torchvision.transforms.functional
3
+ from model.BiSeNet.build_bisenet import BiSeNet
4
+ import gradio as gr
5
+ from utils.imageHandling import hfImageToTensor, preprocessing
6
+
7
+
8
+
9
+
10
+
11
+ # %% prediction on an image
12
+
13
+ def predict(inputImage: torch.Tensor, model: BiSeNet) -> torch.Tensor:
14
+ """
15
+ Predict the segmentation mask for the input image using the provided model.
16
+
17
+ Args:
18
+ inputImage (torch.Tensor): The input image tensor.
19
+ model (BiSeNet): The BiSeNet model for segmentation.
20
+
21
+ Returns:
22
+ prediction (torch.Tensor): The predicted segmentation mask.
23
+ """
24
+ with torch.no_grad():
25
+ output = model(preprocessing(inputImage))
26
+ output = output[0] if isinstance(output, (tuple, list)) else output
27
+ return output[0].argmax(dim=0, keepdim=True)
28
+
29
+
30
+
31
+ # %% load model
32
+
33
+ def loadModel(model:str = 'bisenet', device: str = 'cpu')->BiSeNet:
34
+ """
35
+ Load the specified model and move it to the given device.
36
+
37
+ Args:
38
+ model (str): model to be loaded.
39
+ device (str): Device to load the model onto ('cpu' or 'cuda').
40
+
41
+ Returns:
42
+ model (BiSeNet): The loaded BiSeNet model.
43
+ """
44
+ match model.lower() if isinstance(model, str) else model:
45
+ case 'bisenet': model = loadBiSeNet(device)
46
+ case _: raise NotImplementedError(f"Model {model} is not implemented. Please choose 'bisenet' .")
47
+
48
+ return model
49
+
50
+
51
+ # BiSeNet model loading function
52
+ def loadBiSeNet(device: str = 'cpu') -> BiSeNet:
53
+ """
54
+ Load the BiSeNet model and move it to the specified device.
55
+
56
+ Args:
57
+ device (str): Device to load the model onto ('cpu' or 'cuda').
58
+
59
+ Returns:
60
+ model (BiSeNet): The loaded BiSeNet model.
61
+ """
62
+ model = BiSeNet(n_classes=19, context_path='resnet18').to(device)
63
+ model.load_state_dict(torch.load('./weights/BiSeNet/weightADV.pth', map_location=device))
64
+ model.eval()
65
+
66
+ return model
model/BiSeNet/build_bisenet.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .build_contextpath import build_contextpath
4
+ import warnings
5
+ warnings.filterwarnings(action='ignore')
6
+
7
+
8
+ class ConvBlock(torch.nn.Module):
9
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1):
10
+ super().__init__()
11
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
12
+ stride=stride, padding=padding, bias=False)
13
+ self.bn = nn.BatchNorm2d(out_channels)
14
+ self.relu = nn.ReLU()
15
+
16
+ def forward(self, input):
17
+ x = self.conv1(input)
18
+ return self.relu(self.bn(x))
19
+
20
+
21
+ class Spatial_path(torch.nn.Module):
22
+ def __init__(self):
23
+ super().__init__()
24
+ self.convblock1 = ConvBlock(in_channels=3, out_channels=64)
25
+ self.convblock2 = ConvBlock(in_channels=64, out_channels=128)
26
+ self.convblock3 = ConvBlock(in_channels=128, out_channels=256)
27
+
28
+ def forward(self, input):
29
+ x = self.convblock1(input)
30
+ x = self.convblock2(x)
31
+ x = self.convblock3(x)
32
+ return x
33
+
34
+
35
+ class AttentionRefinementModule(torch.nn.Module):
36
+ def __init__(self, in_channels, out_channels):
37
+ super().__init__()
38
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
39
+ self.bn = nn.BatchNorm2d(out_channels)
40
+ self.sigmoid = nn.Sigmoid()
41
+ self.in_channels = in_channels
42
+ self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
43
+
44
+ def forward(self, input):
45
+ # global average pooling
46
+ x = self.avgpool(input)
47
+ assert self.in_channels == x.size(1), 'in_channels and out_channels should all be {}'.format(x.size(1))
48
+ x = self.conv(x)
49
+ x = self.sigmoid(self.bn(x))
50
+ # x = self.sigmoid(x)
51
+ # channels of input and x should be same
52
+ x = torch.mul(input, x)
53
+ return x
54
+
55
+
56
+ class FeatureFusionModule(torch.nn.Module):
57
+ def __init__(self, num_classes, in_channels):
58
+ super().__init__()
59
+ # self.in_channels = input_1.channels + input_2.channels
60
+ # resnet101 3328 = 256(from spatial path) + 1024(from context path) + 2048(from context path)
61
+ # resnet18 1024 = 256(from spatial path) + 256(from context path) + 512(from context path)
62
+ self.in_channels = in_channels
63
+
64
+ self.convblock = ConvBlock(in_channels=self.in_channels, out_channels=num_classes, stride=1)
65
+ self.conv1 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
66
+ self.relu = nn.ReLU()
67
+ self.conv2 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
68
+ self.sigmoid = nn.Sigmoid()
69
+ self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
70
+
71
+ def forward(self, input_1, input_2):
72
+ x = torch.cat((input_1, input_2), dim=1)
73
+ assert self.in_channels == x.size(1), 'in_channels of ConvBlock should be {}'.format(x.size(1))
74
+ feature = self.convblock(x)
75
+ x = self.avgpool(feature)
76
+
77
+ x = self.relu(self.conv1(x))
78
+ x = self.sigmoid(self.conv2(x))
79
+ x = torch.mul(feature, x)
80
+ x = torch.add(x, feature)
81
+ return x
82
+
83
+
84
+ class BiSeNet(torch.nn.Module):
85
+ def __init__(self, num_classes, context_path):
86
+ super().__init__()
87
+ # build spatial path
88
+ self.saptial_path = Spatial_path()
89
+
90
+ # build context path
91
+ self.context_path = build_contextpath(name=context_path)
92
+
93
+ # build attention refinement module for resnet 101
94
+ if context_path == 'resnet101':
95
+ self.attention_refinement_module1 = AttentionRefinementModule(1024, 1024)
96
+ self.attention_refinement_module2 = AttentionRefinementModule(2048, 2048)
97
+ # supervision block
98
+ self.supervision1 = nn.Conv2d(in_channels=1024, out_channels=num_classes, kernel_size=1)
99
+ self.supervision2 = nn.Conv2d(in_channels=2048, out_channels=num_classes, kernel_size=1)
100
+ # build feature fusion module
101
+ self.feature_fusion_module = FeatureFusionModule(num_classes, 3328)
102
+
103
+ elif context_path == 'resnet18':
104
+ # build attention refinement module for resnet 18
105
+ self.attention_refinement_module1 = AttentionRefinementModule(256, 256)
106
+ self.attention_refinement_module2 = AttentionRefinementModule(512, 512)
107
+ # supervision block
108
+ self.supervision1 = nn.Conv2d(in_channels=256, out_channels=num_classes, kernel_size=1)
109
+ self.supervision2 = nn.Conv2d(in_channels=512, out_channels=num_classes, kernel_size=1)
110
+ # build feature fusion module
111
+ self.feature_fusion_module = FeatureFusionModule(num_classes, 1024)
112
+ else:
113
+ print('Error: unspport context_path network \n')
114
+
115
+ # build final convolution
116
+ self.conv = nn.Conv2d(in_channels=num_classes, out_channels=num_classes, kernel_size=1)
117
+
118
+ self.init_weight()
119
+
120
+ self.mul_lr = []
121
+ self.mul_lr.append(self.saptial_path)
122
+ self.mul_lr.append(self.attention_refinement_module1)
123
+ self.mul_lr.append(self.attention_refinement_module2)
124
+ self.mul_lr.append(self.supervision1)
125
+ self.mul_lr.append(self.supervision2)
126
+ self.mul_lr.append(self.feature_fusion_module)
127
+ self.mul_lr.append(self.conv)
128
+
129
+ def init_weight(self):
130
+ for name, m in self.named_modules():
131
+ if 'context_path' not in name:
132
+ if isinstance(m, nn.Conv2d):
133
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
134
+ elif isinstance(m, nn.BatchNorm2d):
135
+ m.eps = 1e-5
136
+ m.momentum = 0.1
137
+ nn.init.constant_(m.weight, 1)
138
+ nn.init.constant_(m.bias, 0)
139
+
140
+ def forward(self, input):
141
+ # output of spatial path
142
+ sx = self.saptial_path(input)
143
+
144
+ # output of context path
145
+ cx1, cx2, tail = self.context_path(input)
146
+ cx1 = self.attention_refinement_module1(cx1)
147
+ cx2 = self.attention_refinement_module2(cx2)
148
+ cx2 = torch.mul(cx2, tail)
149
+ # upsampling
150
+ cx1 = torch.nn.functional.interpolate(cx1, size=sx.size()[-2:], mode='bilinear')
151
+ cx2 = torch.nn.functional.interpolate(cx2, size=sx.size()[-2:], mode='bilinear')
152
+ cx = torch.cat((cx1, cx2), dim=1)
153
+
154
+ if self.training == True:
155
+ cx1_sup = self.supervision1(cx1)
156
+ cx2_sup = self.supervision2(cx2)
157
+ cx1_sup = torch.nn.functional.interpolate(cx1_sup, size=input.size()[-2:], mode='bilinear')
158
+ cx2_sup = torch.nn.functional.interpolate(cx2_sup, size=input.size()[-2:], mode='bilinear')
159
+
160
+ # output of feature fusion module
161
+ result = self.feature_fusion_module(sx, cx)
162
+
163
+ # upsampling
164
+ result = torch.nn.functional.interpolate(result, scale_factor=8, mode='bilinear')
165
+ result = self.conv(result)
166
+
167
+ if self.training == True:
168
+ return result, cx1_sup, cx2_sup
169
+
170
+ return result
model/BiSeNet/build_contextpath.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import models
3
+
4
+
5
+ class resnet18(torch.nn.Module):
6
+ def __init__(self, pretrained=True):
7
+ super().__init__()
8
+ self.features = models.resnet18(pretrained=pretrained)
9
+ self.conv1 = self.features.conv1
10
+ self.bn1 = self.features.bn1
11
+ self.relu = self.features.relu
12
+ self.maxpool1 = self.features.maxpool
13
+ self.layer1 = self.features.layer1
14
+ self.layer2 = self.features.layer2
15
+ self.layer3 = self.features.layer3
16
+ self.layer4 = self.features.layer4
17
+
18
+ def forward(self, input):
19
+ x = self.conv1(input)
20
+ x = self.relu(self.bn1(x))
21
+ x = self.maxpool1(x)
22
+ feature1 = self.layer1(x) # 1 / 4
23
+ feature2 = self.layer2(feature1) # 1 / 8
24
+ feature3 = self.layer3(feature2) # 1 / 16
25
+ feature4 = self.layer4(feature3) # 1 / 32
26
+ # global average pooling to build tail
27
+ tail = torch.mean(feature4, 3, keepdim=True)
28
+ tail = torch.mean(tail, 2, keepdim=True)
29
+ return feature3, feature4, tail
30
+
31
+
32
+ class resnet101(torch.nn.Module):
33
+ def __init__(self, pretrained=True):
34
+ super().__init__()
35
+ self.features = models.resnet101(pretrained=pretrained)
36
+ self.conv1 = self.features.conv1
37
+ self.bn1 = self.features.bn1
38
+ self.relu = self.features.relu
39
+ self.maxpool1 = self.features.maxpool
40
+ self.layer1 = self.features.layer1
41
+ self.layer2 = self.features.layer2
42
+ self.layer3 = self.features.layer3
43
+ self.layer4 = self.features.layer4
44
+
45
+ def forward(self, input):
46
+ x = self.conv1(input)
47
+ x = self.relu(self.bn1(x))
48
+ x = self.maxpool1(x)
49
+ feature1 = self.layer1(x) # 1 / 4
50
+ feature2 = self.layer2(feature1) # 1 / 8
51
+ feature3 = self.layer3(feature2) # 1 / 16
52
+ feature4 = self.layer4(feature3) # 1 / 32
53
+ # global average pooling to build tail
54
+ tail = torch.mean(feature4, 3, keepdim=True)
55
+ tail = torch.mean(tail, 2, keepdim=True)
56
+ return feature3, feature4, tail
57
+
58
+
59
+ def build_contextpath(name):
60
+ model = {
61
+ 'resnet18': resnet18(pretrained=True),
62
+ 'resnet101': resnet101(pretrained=True)
63
+ }
64
+ return model[name]
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
utils/imageHandling.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchvision
2
+
3
+ # %% image loading
4
+ def hfImageToTensor(image, width:int=1024, height:int=512)->torch.Tensor:
5
+ """
6
+ Convert an input image (PIL.Image or numpy array) from Hugging Face to a torch tensor
7
+ of shape (3, height, width) and type float32.
8
+
9
+ Args:
10
+ image: Input image (PIL.Image or numpy array).
11
+ width (int): Target width.
12
+ height (int): Target height.
13
+
14
+ Returns:
15
+ torch.Tensor: Image tensor of shape (3, height, width).
16
+ """
17
+ image = image if isinstance(image, torch.Tensor) else torchvision.transforms.functional.to_tensor(image)
18
+ return torchvision.transforms.functional.resize(image, [height, width])
19
+
20
+ # %% preprocessing
21
+ def preprocessing(image_tensor: torch.Tensor) -> torch.Tensor:
22
+ """
23
+ Standardize the image tensor and add batch dimension.
24
+
25
+ Args:
26
+ image_tensor (torch.Tensor): Image tensor of shape (3, H, W).
27
+
28
+ Returns:
29
+ torch.Tensor: Preprocessed tensor of shape (1, 3, H, W).
30
+ """
31
+ return torchvision.transforms.functional.normalize(
32
+ image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
33
+ ).unsqueeze(0)
utils2.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def print_mask(mask:torch.Tensor, numClasses:int=19)->None:
4
+ """
5
+ Visualizes the segmentation mask by mapping each class to a specific color.
6
+
7
+ Args:
8
+ mask (torch.Tensor): The segmentation mask to visualize.
9
+ numClasses (int, optional): Number of classes in the segmentation mask. Defaults to 19.
10
+ """
11
+ colors = [
12
+ (128, 64, 128), # 0: road
13
+ (244, 35, 232), # 1: sidewalk
14
+ (70, 70, 70), # 2: building
15
+ (102, 102, 156), # 3: wall
16
+ (190, 153, 153), # 4: fence
17
+ (153, 153, 153), # 5: pole
18
+ (250, 170, 30), # 6: traffic light
19
+ (220, 220, 0), # 7: traffic sign
20
+ (107, 142, 35), # 8: vegetation
21
+ (152, 251, 152), # 9: terrain
22
+ (70, 130, 180), # 10: sky
23
+ (220, 20, 60), # 11: person
24
+ (255, 0, 0), # 12: rider
25
+ (0, 0, 142), # 13: car
26
+ (0, 0, 70), # 14: truck
27
+ (0, 60, 100), # 15: bus
28
+ (0, 80, 100), # 16: train
29
+ (0, 0, 230), # 17: motorcycle
30
+ (119, 11, 32) # 18: bicycle
31
+ ]
32
+
33
+ new_mask = torch.zeros((mask.shape[0], mask.shape[1], 3),dtype=torch.uint8)
34
+ new_mask[mask == 255] = (0,0,0)
35
+ for i in range (numClasses):
36
+ new_mask[mask == i] = colors[i][:3]
37
+ return new_mask.permute(2,0,1)
weights/BiSeNet/weightADV.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:880db4160f20c87aecc13845ad691b1963fbce3d713b1dda1964457b9e0d8f0a
3
+ size 121015606