Andrey commited on
Commit
b683920
1 Parent(s): 5182668

Initial commit.

Browse files
.flake8 ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [flake8]
2
+ ignore = I001,I002,I004,I005,I101,I201,C101,C403,C901,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D210,D205,D400,T001,W504,D202,E203,W503,B006,T002,T100,P103,C408,F841
3
+ max-line-length = 120
4
+ exclude = outputs/*
5
+ max-complexity = 10
.gitattributes CHANGED
@@ -1,2 +1,3 @@
1
  # Auto detect text files and perform LF normalization
2
  * text=auto
 
 
1
  # Auto detect text files and perform LF normalization
2
  * text=auto
3
+ *.pth filter=lfs diff=lfs merge=lfs -text
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.3.0
4
+ hooks:
5
+ - id: check-yaml
6
+ - id: end-of-file-fixer
7
+ - id: trailing-whitespace
8
+ - repo: https://github.com/psf/black
9
+ rev: '22.12.0'
10
+ hooks:
11
+ - id: black
12
+ args: [--config=pyproject.toml]
13
+ - repo: https://github.com/pre-commit/mirrors-mypy
14
+ rev: 586b4f0
15
+ hooks:
16
+ - id: mypy
17
+ args: [--ignore-missing-imports, --warn-no-return, --warn-redundant-casts, --disallow-incomplete-defs, --no-namespace-packages ]
18
+ - repo: https://gitlab.com/pycqa/flake8
19
+ rev: '5.0.4'
20
+ hooks:
21
+ - id: flake8
22
+ additional_dependencies: [
23
+ 'flake8-bugbear==22.8.23',
24
+ 'flake8-coding==1.3.2',
25
+ 'flake8-comprehensions==3.10.0',
26
+ 'flake8-debugger==4.1.2',
27
+ 'flake8-deprecated==1.3',
28
+ 'flake8-docstrings==1.6.0',
29
+ 'flake8-isort==4.2.0',
30
+ 'flake8-pep3101==1.3.0',
31
+ 'flake8-polyfill==1.0.2',
32
+ 'flake8-print==5.0.0',
33
+ 'flake8-quotes==3.3.1',
34
+ 'flake8-string-format==0.3.0',
35
+ ]
README.md CHANGED
@@ -1,2 +1,17 @@
1
  # digit-draw-detect
2
  An app for handwritten digit detection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # digit-draw-detect
2
  An app for handwritten digit detection
3
+
4
+ steps:
5
+ * use git lfs for the model +
6
+ * write better code +
7
+ * convert model to onnx or some other format?
8
+ * deploy bare working app, without nice things
9
+ * make better design
10
+ * think about descriptions on the site
11
+
12
+ On using git lfs:
13
+ ```shell
14
+ git lfs install
15
+ git lfs track "*.psd"
16
+ git add .gitattributes
17
+ ```
config.toml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Everything in this section will be available as an environment variable
2
+ db_username = "Jane"
3
+ db_password = "12345qwerty"
4
+
5
+ AWS_ACCESS_KEY_ID = 'AKIAI4JDKBYRCHGT77VQ'
6
+ AWS_SECRET_ACCESS_KEY = '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'
7
+ S3_BUCKET = 'digitdrawdetect'
8
+ S3_BUCKET_NAME = 'digitdrawdetect'
model_files/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:571e937122d5ccafe496d1cc71cea5c0661d385b5a7db4ec977ac8ae5da40680
3
+ size 246698572
mypy.ini ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import numpy
2
+
3
+ [mypy]
4
+ python_version = 3.10
5
+ plugins = numpy.typing.mypy_plugin
pyproject.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line-length = 119
3
+ skip-string-normalization = true
4
+ target-version = ['py39', 'py310']
5
+ include = '\.pyi?$'
6
+ exclude = '''
7
+ /(
8
+ \.eggs
9
+ | \.git
10
+ | \.hg
11
+ | \.mypy_cache
12
+ | \.tox
13
+ | \.venv
14
+ | _build
15
+ | buck-out
16
+ | build
17
+ | dist
18
+ | outputs
19
+ )/
20
+ '''
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations==1.3.0
2
+ matplotlib==3.6.1
3
+ numpy==1.23.4
4
+ omegaconf==2.2.1
5
+ opencv_python==4.6.0.66
6
+ pandas==1.5.1
7
+ Pillow==9.2.0
8
+ rich==12.6.0
9
+ streamlit==1.16.0
10
+ streamlit_drawable_canvas==0.9.2
11
+ tomli==2.0.1
12
+ torch==1.12.1
13
+ torchvision
14
+ # need to define some pytorch
15
+ https://download.pytorch.org/whl/cpu/torch-1.13.1%2Bcpu-cp310-cp310-linux_x86_64.whl
src/ml_utils.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import logging
3
+
4
+ import albumentations as A
5
+ import streamlit as st
6
+ import torch
7
+ from albumentations import pytorch
8
+
9
+ from src.model_architecture import Net
10
+
11
+ anchors = torch.tensor(
12
+ [
13
+ [[0.2800, 0.2200], [0.3800, 0.4800], [0.9000, 0.7800]],
14
+ [[0.0700, 0.1500], [0.1500, 0.1100], [0.1400, 0.2900]],
15
+ [[0.0200, 0.0300], [0.0400, 0.0700], [0.0800, 0.0600]],
16
+ ]
17
+ )
18
+
19
+ transforms = A.Compose(
20
+ [
21
+ A.Resize(always_apply=False, p=1, height=192, width=192, interpolation=1),
22
+ A.Normalize(),
23
+ pytorch.transforms.ToTensorV2(),
24
+ ]
25
+ )
26
+
27
+
28
+ def cells_to_bboxes(predictions: torch.tensor, anchors: torch.tensor, s: int, is_preds: bool = True) -> torch.tensor:
29
+ """
30
+ Scale the predictions coming from the model_files to
31
+ be relative to the entire image such that they for example later
32
+ can be plotted or.
33
+ Args:
34
+ predictions: tensor of size (N, 3, S, S, num_classes+5)
35
+ anchors: the anchors used for the predictions
36
+ s: the number of cells the image is divided in on the width (and height)
37
+ is_preds: whether the input is predictions or the true bounding boxes
38
+ Returns:
39
+ converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
40
+ object score, bounding box coordinates
41
+ """
42
+ batch_size = predictions.shape[0]
43
+ num_anchors = len(anchors)
44
+ box_predictions = predictions[..., 1:5]
45
+ if is_preds:
46
+ anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
47
+ box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
48
+ box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
49
+ scores = torch.sigmoid(predictions[..., 0:1])
50
+ best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
51
+ else:
52
+ scores = predictions[..., 0:1]
53
+ best_class = predictions[..., 5:6]
54
+
55
+ cell_indices = torch.arange(s).repeat(predictions.shape[0], 3, s, 1).unsqueeze(-1).to(predictions.device)
56
+ x = 1 / s * (box_predictions[..., 0:1] + cell_indices)
57
+ y = 1 / s * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
58
+ w_h = 1 / s * box_predictions[..., 2:4]
59
+ converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(batch_size, num_anchors * s * s, 6)
60
+ return converted_bboxes.tolist()
61
+
62
+
63
+ def non_max_suppression(
64
+ bboxes: List[List], iou_threshold: float, threshold: float, box_format: str = 'corners'
65
+ ) -> List[List]:
66
+ """
67
+ Apply nms to the bboxes.
68
+
69
+ Video explanation of this function:
70
+ https://youtu.be/YDkjWEN8jNA
71
+ Does Non Max Suppression given bboxes
72
+ Args:
73
+ bboxes (list): list of lists containing all bboxes with each bboxes
74
+ specified as [class_pred, prob_score, x1, y1, x2, y2]
75
+ iou_threshold (float): threshold where predicted bboxes is correct
76
+ threshold (float): threshold to remove predicted bboxes (independent of IoU)
77
+ box_format (str): 'midpoint' or 'corners' used to specify bboxes
78
+ Returns:
79
+ list: bboxes after performing NMS given a specific IoU threshold
80
+ """
81
+
82
+ assert type(bboxes) == list
83
+
84
+ bboxes = [box for box in bboxes if box[1] > threshold]
85
+ bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
86
+ bboxes_after_nms = []
87
+
88
+ while bboxes:
89
+ chosen_box = bboxes.pop(0)
90
+
91
+ bboxes = [
92
+ box
93
+ for box in bboxes
94
+ if box[0] != chosen_box[0]
95
+ or intersection_over_union(
96
+ torch.tensor(chosen_box[2:]),
97
+ torch.tensor(box[2:]),
98
+ box_format=box_format,
99
+ )
100
+ < iou_threshold
101
+ ]
102
+
103
+ bboxes_after_nms.append(chosen_box)
104
+
105
+ return bboxes_after_nms
106
+
107
+
108
+ def intersection_over_union(
109
+ boxes_preds: torch.tensor, boxes_labels: torch.tensor, box_format: str = 'midpoint'
110
+ ) -> torch.tensor:
111
+ """
112
+ Calculate iou.
113
+
114
+ Video explanation of this function:
115
+ https://youtu.be/XXYG5ZWtjj0
116
+ This function calculates intersection over union (iou) given pred boxes
117
+ and target boxes.
118
+ Args:
119
+ boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
120
+ boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
121
+ box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
122
+ Returns:
123
+ tensor: Intersection over union for all examples
124
+ """
125
+
126
+ if box_format == 'midpoint':
127
+ box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
128
+ box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
129
+ box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
130
+ box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
131
+ box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
132
+ box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
133
+ box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
134
+ box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
135
+
136
+ if box_format == 'corners':
137
+ box1_x1 = boxes_preds[..., 0:1]
138
+ box1_y1 = boxes_preds[..., 1:2]
139
+ box1_x2 = boxes_preds[..., 2:3]
140
+ box1_y2 = boxes_preds[..., 3:4]
141
+ box2_x1 = boxes_labels[..., 0:1]
142
+ box2_y1 = boxes_labels[..., 1:2]
143
+ box2_x2 = boxes_labels[..., 2:3]
144
+ box2_y2 = boxes_labels[..., 3:4]
145
+
146
+ x1 = torch.max(box1_x1, box2_x1)
147
+ y1 = torch.max(box1_y1, box2_y1)
148
+ x2 = torch.min(box1_x2, box2_x2)
149
+ y2 = torch.min(box1_y2, box2_y2)
150
+
151
+ intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
152
+ box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
153
+ box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
154
+
155
+ return intersection / (box1_area + box2_area - intersection + 1e-6)
156
+
157
+
158
+ def predict(
159
+ model: torch.nn.Module, image: torch.tensor, iou_threshold: float = 1.0, threshold: float = 0.05
160
+ ) -> List[List]:
161
+ """
162
+ Apply the model_files to the predictions and to postprocessing
163
+ Args:
164
+ model: a trained pytorch model_files.
165
+ image: image as a torch tensor
166
+ iou_threshold: a threshold for intersection_over_union function
167
+ threshold: a threshold for bbox probability
168
+
169
+ Returns:
170
+ predicted bboxes
171
+
172
+ """
173
+ # apply model_files. add a dimension to imitate a batch size of 1
174
+ logits = model(image[None, :])
175
+ logging.info('predicted')
176
+
177
+ # postprocess. In fact, we could remove indexing with idx here, as there is a single image.
178
+ # But I prefer to keep it so that this code could be easier changed for cases with batch size > 1
179
+ bboxes: List[List] = [[] for _ in range(1)]
180
+ for i in range(3):
181
+ S = logits[i].shape[2]
182
+ # it could be better to initialize anchors inside the function, but I don't want to do it for every prediction.
183
+ anchor = anchors[i] * S
184
+ boxes_scale_i = cells_to_bboxes(logits[i], anchor, s=S, is_preds=True)
185
+ for idx, (box) in enumerate(boxes_scale_i):
186
+ bboxes[idx] += box
187
+ logging.info('Starting nms')
188
+ nms_boxes = non_max_suppression(
189
+ bboxes[idx],
190
+ iou_threshold=iou_threshold,
191
+ threshold=threshold,
192
+ box_format='midpoint',
193
+ )
194
+
195
+ return nms_boxes
196
+
197
+
198
+ @st.cache
199
+ def get_model():
200
+
201
+ model_name = 'model_files/best_model.pth'
202
+
203
+ model = Net()
204
+ model.load_state_dict(torch.load(model_name))
205
+ model.eval()
206
+
207
+ return model
src/model_architecture.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class CNNBlock(nn.Module):
6
+ def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
7
+ super().__init__()
8
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
9
+ self.bn = nn.BatchNorm2d(out_channels)
10
+ self.leaky = nn.LeakyReLU(0.1)
11
+ self.use_bn_act = bn_act
12
+
13
+ def forward(self, x):
14
+ if self.use_bn_act:
15
+ return self.leaky(self.bn(self.conv(x)))
16
+ else:
17
+ return self.conv(x)
18
+
19
+
20
+ class ResidualBlock(nn.Module):
21
+ def __init__(self, channels, use_residual=True, num_repeats=1):
22
+ super().__init__()
23
+ self.layers = nn.ModuleList()
24
+ for _ in range(num_repeats):
25
+ self.layers += [
26
+ nn.Sequential(
27
+ CNNBlock(channels, channels // 2, kernel_size=1),
28
+ CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
29
+ )
30
+ ]
31
+
32
+ self.use_residual = use_residual
33
+ self.num_repeats = num_repeats
34
+
35
+ def forward(self, x):
36
+ for layer in self.layers:
37
+ if self.use_residual:
38
+ x = x + layer(x)
39
+ else:
40
+ x = layer(x)
41
+
42
+ return x
43
+
44
+
45
+ class ScalePrediction(nn.Module):
46
+ def __init__(self, in_channels, num_classes):
47
+ super().__init__()
48
+ self.pred = nn.Sequential(
49
+ CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
50
+ CNNBlock(2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1),
51
+ )
52
+ self.num_classes = num_classes
53
+
54
+ def forward(self, x):
55
+ return self.pred(x).reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]).permute(0, 1, 3, 4, 2)
56
+
57
+
58
+ class Net(nn.Module):
59
+ def __init__(self):
60
+ super().__init__()
61
+ self.num_classes = 12
62
+ self.in_channels = 3
63
+ # self.config = cfg.model_files.params.config
64
+ # self.config = [i if i[0] != '(' else literal_eval(i) for i in self.config]
65
+ self.config = [
66
+ (32, 3, 1),
67
+ (64, 3, 2),
68
+ ['B', 1],
69
+ (128, 3, 2),
70
+ ['B', 2],
71
+ (256, 3, 2),
72
+ ['B', 8],
73
+ (512, 3, 2),
74
+ ['B', 8],
75
+ (1024, 3, 2),
76
+ ['B', 4],
77
+ (512, 1, 1),
78
+ (1024, 3, 1),
79
+ 'S',
80
+ (256, 1, 1),
81
+ 'U',
82
+ (256, 1, 1),
83
+ (512, 3, 1),
84
+ 'S',
85
+ (128, 1, 1),
86
+ 'U',
87
+ (128, 1, 1),
88
+ (256, 3, 1),
89
+ 'S',
90
+ ]
91
+ # print('self.config', self.config)
92
+ self.layers = self._create_conv_layers()
93
+
94
+ def forward(self, x):
95
+ outputs = [] # for each scale
96
+ route_connections = []
97
+ for layer in self.layers:
98
+ if isinstance(layer, ScalePrediction):
99
+ outputs.append(layer(x))
100
+ continue
101
+ # print(layer, x.shape)
102
+ x = layer(x)
103
+
104
+ if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
105
+ route_connections.append(x)
106
+
107
+ elif isinstance(layer, nn.Upsample):
108
+ x = torch.cat([x, route_connections[-1]], dim=1)
109
+ route_connections.pop()
110
+
111
+ return outputs
112
+
113
+ def _create_conv_layers(self):
114
+ layers = nn.ModuleList()
115
+ in_channels = self.in_channels
116
+
117
+ for module in self.config:
118
+ # print(module, type(module))
119
+ if isinstance(module, tuple):
120
+ out_channels, kernel_size, stride = module
121
+ layers.append(
122
+ CNNBlock(
123
+ in_channels,
124
+ out_channels,
125
+ kernel_size=kernel_size,
126
+ stride=stride,
127
+ padding=1 if kernel_size == 3 else 0,
128
+ )
129
+ )
130
+ in_channels = out_channels
131
+
132
+ elif isinstance(module, list):
133
+ num_repeats = module[1]
134
+ layers.append(
135
+ ResidualBlock(
136
+ in_channels,
137
+ num_repeats=num_repeats,
138
+ )
139
+ )
140
+
141
+ elif isinstance(module, str):
142
+ if module == 'S':
143
+ layers += [
144
+ ResidualBlock(in_channels, use_residual=False, num_repeats=1),
145
+ CNNBlock(in_channels, in_channels // 2, kernel_size=1),
146
+ ScalePrediction(in_channels // 2, num_classes=self.num_classes),
147
+ ]
148
+ in_channels = in_channels // 2
149
+
150
+ elif module == 'U':
151
+ layers.append(
152
+ nn.Upsample(scale_factor=2),
153
+ )
154
+ in_channels = in_channels * 3
155
+
156
+ return layers
src/utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+
3
+ import matplotlib
4
+ import matplotlib.patches as patches
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import tomli as tomllib
8
+
9
+
10
+ def plot_img_with_rects(
11
+ img: np.array, boxes: List[List], threshold: float = 0.5, coef: int = 400
12
+ ) -> matplotlib.figure.Figure:
13
+ """
14
+ Plot image with rectangles.
15
+
16
+ Args:
17
+ img: image as a numpy array
18
+ boxes: the list of the bboxes
19
+ threshold: threshold for bbox probability
20
+ coef: coefficient to multiply images. Can be changed when the original image is a different size
21
+
22
+ Returns:
23
+ image with bboxes
24
+ """
25
+ fig, ax = plt.subplots(1, figsize=(4, 4))
26
+
27
+ # Display the image
28
+ ax.imshow(img)
29
+
30
+ # Create a Rectangle patch
31
+ for _, rect in enumerate([b for b in boxes if b[1] > threshold]):
32
+ label, _, xc, yc, w, h = rect
33
+ xc, yc, w, h = xc * coef, yc * coef, w * coef, h * coef
34
+ # the coordinates from center-based to left top corner
35
+ x = xc - w / 2
36
+ y = yc - h / 2
37
+ label = int(label)
38
+ label = label if label != 10 else 'penis'
39
+ label = label if label != 11 else 'junk'
40
+ rect = [x, y, x + w, y + h]
41
+
42
+ rect_ = patches.Rectangle(
43
+ (rect[0], rect[1]), rect[2] - rect[0], rect[3] - rect[1], linewidth=2, edgecolor='blue', facecolor='none'
44
+ )
45
+ plt.text(rect[2], rect[1], f'{label}', color='blue')
46
+ # Add the patch to the Axes
47
+ ax.add_patch(rect_)
48
+ return fig
49
+
50
+
51
+ def get_config() -> Dict:
52
+ """
53
+ Get dict from config.
54
+
55
+ Returns:
56
+ config
57
+ """
58
+ with open('config.toml', 'rb') as f:
59
+ config = tomllib.load(f)
60
+
61
+ return config
st_app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import streamlit as st
5
+ import tomli as tomllib
6
+ from PIL import Image
7
+ from streamlit_drawable_canvas import st_canvas
8
+
9
+ from src.ml_utils import predict, get_model, transforms
10
+ from src.utils import plot_img_with_rects, get_config
11
+
12
+ logging.info('Starting')
13
+
14
+ col1, col2 = st.columns(2)
15
+
16
+ with col1:
17
+ # Create a canvas component
18
+ canvas_result = st_canvas(
19
+ fill_color='#fff',
20
+ stroke_width=5,
21
+ stroke_color='#000',
22
+ background_color='#fff',
23
+ update_streamlit=True,
24
+ height=400,
25
+ width=400,
26
+ drawing_mode='freedraw',
27
+ key='canvas',
28
+ )
29
+ with col2:
30
+ data = get_config()
31
+ logging.info('canvas ready')
32
+ if canvas_result.image_data is not None:
33
+ # convert a drawn image into numpy array with RGB from a canvas image with RGBA
34
+ img = np.array(Image.fromarray(np.uint8(canvas_result.image_data)).convert('RGB'))
35
+ image = transforms(image=img)['image']
36
+ logging.info('image augmented')
37
+ model = get_model()
38
+ logging.info('model ready')
39
+ pred = predict(model, image)
40
+ logging.info('prediction done')
41
+ threshold = st.slider('Bbox probability slider', min_value=0.0, max_value=1.0, value=0.5)
42
+
43
+ fig = plot_img_with_rects(image.permute(1, 2, 0).numpy(), pred, threshold, coef=192)
44
+ fig.savefig('figure_name1.png')
45
+ image = Image.open('figure_name1.png')
46
+ st.image(image)