serdaryildiz commited on
Commit
9b6af3b
·
verified ·
1 Parent(s): a2071be

Upload 31 files

Browse files
Model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from Model.model import TTR
Model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (184 Bytes). View file
 
Model/__pycache__/attention.cpython-38.pyc ADDED
Binary file (4.32 kB). View file
 
Model/__pycache__/backbone.cpython-38.pyc ADDED
Binary file (9.63 kB). View file
 
Model/__pycache__/model.cpython-38.pyc ADDED
Binary file (1.28 kB). View file
 
Model/__pycache__/trainer.cpython-38.pyc ADDED
Binary file (6.75 kB). View file
 
Model/attention.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class PositionalEncoding(nn.Module):
7
+ r"""Inject some information about the relative or absolute position of the tokens
8
+ in the sequence. The positional encodings have the same dimension as
9
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
10
+ functions of different frequencies.
11
+ .. math::
12
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
13
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
14
+ \text{where pos is the word position and i is the embed idx)
15
+ Args:
16
+ d_model: the embed dim (required).
17
+ dropout: the dropout value (default=0.1).
18
+ max_len: the max. length of the incoming sequence (default=5000).
19
+ Examples:
20
+ >>> pos_encoder = PositionalEncoding(d_model)
21
+ """
22
+
23
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
24
+ super().__init__()
25
+ self.dropout = nn.Dropout(p=dropout)
26
+
27
+ pe = torch.zeros(max_len, d_model)
28
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
29
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
30
+ pe[:, 0::2] = torch.sin(position * div_term)
31
+ pe[:, 1::2] = torch.cos(position * div_term)
32
+ pe = pe.unsqueeze(0).transpose(0, 1)
33
+ self.register_buffer('pe', pe)
34
+
35
+ def forward(self, x):
36
+ r"""Inputs of forward function
37
+ Args:
38
+ x: the sequence fed to the positional encoder model (required).
39
+ Shape:
40
+ x: [sequence length, batch size, embed dim]
41
+ output: [sequence length, batch size, embed dim]
42
+ Examples:
43
+ >>> output = pos_encoder(x)
44
+ """
45
+
46
+ x = x + self.pe[:x.size(0), :]
47
+ return self.dropout(x)
48
+
49
+
50
+ def encoder_layer(in_c, out_c, k=3, s=2, p=1):
51
+ return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
52
+ nn.BatchNorm2d(out_c),
53
+ nn.ReLU(True))
54
+
55
+
56
+ def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None):
57
+ align_corners = None if mode == 'nearest' else True
58
+ return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor,
59
+ mode=mode, align_corners=align_corners),
60
+ nn.Conv2d(in_c, out_c, k, s, p),
61
+ nn.BatchNorm2d(out_c),
62
+ nn.ReLU(True))
63
+
64
+
65
+ class PositionAttention(nn.Module):
66
+ def __init__(self, max_length, in_channels=512, num_channels=64,
67
+ h=8, w=32, mode='nearest', **kwargs):
68
+ super().__init__()
69
+ self.max_length = max_length
70
+ self.k_encoder = nn.Sequential(
71
+ encoder_layer(in_channels, num_channels, s=(1, 2)),
72
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
73
+ encoder_layer(num_channels, num_channels, s=(2, 2)),
74
+ encoder_layer(num_channels, num_channels, s=(2, 2))
75
+ )
76
+ self.k_decoder = nn.Sequential(
77
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
78
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
79
+ decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
80
+ decoder_layer(num_channels, in_channels, size=(h, w), mode=mode)
81
+ )
82
+
83
+ self.pos_encoder = PositionalEncoding(in_channels, dropout=0., max_len=max_length)
84
+ self.project = nn.Linear(in_channels, in_channels)
85
+
86
+ def forward(self, x):
87
+ N, E, H, W = x.size()
88
+ k, v = x, x # (N, E, H, W)
89
+
90
+ # calculate key vector
91
+ features = []
92
+ for i in range(0, len(self.k_encoder)):
93
+ k = self.k_encoder[i](k)
94
+ features.append(k)
95
+ for i in range(0, len(self.k_decoder) - 1):
96
+ k = self.k_decoder[i](k)
97
+ k = k + features[len(self.k_decoder) - 2 - i]
98
+ k = self.k_decoder[-1](k)
99
+
100
+ # calculate query vector
101
+ zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E)
102
+ q = self.pos_encoder(zeros) # (T, N, E)
103
+ q = q.permute(1, 0, 2) # (N, T, E)
104
+ q = self.project(q) # (N, T, E)
105
+
106
+ # calculate attention
107
+ attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
108
+ attn_scores = attn_scores / (E ** 0.5)
109
+ attn_scores = torch.softmax(attn_scores, dim=-1)
110
+
111
+ v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
112
+ attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
113
+
114
+ return attn_vecs, attn_scores.view(N, -1, H, W)
Model/backbone.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from timm.models.layers import DropPath, trunc_normal_
4
+ from typing import Optional, Callable
5
+
6
+
7
+ class Mlp(nn.Module):
8
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):
9
+ super().__init__()
10
+ out_features = out_features or in_features
11
+ hidden_features = hidden_features or in_features
12
+ self.fc1 = nn.Linear(in_features, hidden_features)
13
+ self.act = act_layer()
14
+ self.fc2 = nn.Linear(hidden_features, out_features)
15
+ self.drop = nn.Dropout(drop)
16
+
17
+ def forward(self, x):
18
+ x = self.fc1(x)
19
+ x = self.act(x)
20
+ x = self.drop(x)
21
+ x = self.fc2(x)
22
+ x = self.drop(x)
23
+ return x
24
+
25
+
26
+ class VITBatchNorm(nn.Module):
27
+ def __init__(self, num_features):
28
+ super().__init__()
29
+ self.num_features = num_features
30
+ self.bn = nn.BatchNorm1d(num_features=num_features)
31
+
32
+ def forward(self, x):
33
+ return self.bn(x)
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(self,
38
+ dim: int,
39
+ num_heads: int = 8,
40
+ qkv_bias: bool = False,
41
+ qk_scale: Optional[None] = None,
42
+ attn_drop: float = 0.,
43
+ proj_drop: float = 0.):
44
+ super().__init__()
45
+ self.num_heads = num_heads
46
+ head_dim = dim // num_heads
47
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
48
+ self.scale = qk_scale or head_dim ** -0.5
49
+
50
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
51
+ self.attn_drop = nn.Dropout(attn_drop)
52
+ self.proj = nn.Linear(dim, dim)
53
+ self.proj_drop = nn.Dropout(proj_drop)
54
+
55
+ def forward(self, x):
56
+ with torch.cuda.amp.autocast(True):
57
+ batch_size, num_token, embed_dim = x.shape
58
+ # qkv is [3,batch_size,num_heads,num_token, embed_dim//num_heads]
59
+ qkv = self.qkv(x).reshape(
60
+ batch_size, num_token, 3, self.num_heads, embed_dim // self.num_heads).permute(2, 0, 3, 1, 4)
61
+ with torch.cuda.amp.autocast(False):
62
+ q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float()
63
+ attn = (q @ k.transpose(-2, -1)) * self.scale
64
+ attn = attn.softmax(dim=-1)
65
+ attn = self.attn_drop(attn)
66
+ x = (attn @ v).transpose(1, 2).reshape(batch_size, num_token, embed_dim)
67
+ with torch.cuda.amp.autocast(True):
68
+ x = self.proj(x)
69
+ x = self.proj_drop(x)
70
+ return x
71
+
72
+
73
+ class Block(nn.Module):
74
+
75
+ def __init__(self,
76
+ dim: int,
77
+ num_heads: int,
78
+ num_patches: int,
79
+ mlp_ratio: float = 4.,
80
+ qkv_bias: bool = False,
81
+ qk_scale: Optional[None] = None,
82
+ drop: float = 0.,
83
+ attn_drop: float = 0.,
84
+ drop_path: float = 0.,
85
+ act_layer: Callable = nn.ReLU6,
86
+ norm_layer: str = "ln",
87
+ patch_n: int = 144):
88
+ super().__init__()
89
+
90
+ if norm_layer == "bn":
91
+ self.norm1 = VITBatchNorm(num_features=num_patches)
92
+ self.norm2 = VITBatchNorm(num_features=num_patches)
93
+ elif norm_layer == "ln":
94
+ self.norm1 = nn.LayerNorm(dim)
95
+ self.norm2 = nn.LayerNorm(dim)
96
+
97
+ self.attn = Attention(
98
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
99
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
100
+ self.drop_path = DropPath(
101
+ drop_path) if drop_path > 0. else nn.Identity()
102
+ mlp_hidden_dim = int(dim * mlp_ratio)
103
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
104
+ act_layer=act_layer, drop=drop)
105
+ self.extra_gflops = (num_heads * patch_n * (dim // num_heads) * patch_n * 2) / (1000 ** 3)
106
+
107
+ def forward(self, x):
108
+ x = x + self.drop_path(self.attn(self.norm1(x)))
109
+ with torch.cuda.amp.autocast(True):
110
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
111
+ return x
112
+
113
+
114
+ class PatchEmbed(nn.Module):
115
+ def __init__(self, img_size=[108, 108], patch_size=[9, 9], in_channels=3, embed_dim=768):
116
+ super().__init__()
117
+ # img_size = to_2tuple(img_size)
118
+ # patch_size = to_2tuple(patch_size)
119
+ num_patches = (img_size[1] // patch_size[1]) * \
120
+ (img_size[0] // patch_size[0])
121
+ self.img_size = img_size
122
+ self.patch_size = patch_size
123
+ self.num_patches = num_patches
124
+ self.proj = nn.Conv2d(in_channels, embed_dim,
125
+ kernel_size=patch_size, stride=patch_size)
126
+
127
+ def forward(self, x):
128
+ batch_size, channels, height, width = x.shape
129
+ assert height == self.img_size[0] and width == self.img_size[1], \
130
+ f"Input image size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
131
+ x = self.proj(x).flatten(2).transpose(1, 2)
132
+ return x
133
+
134
+
135
+ class VisionTransformer(nn.Module):
136
+ """ Vision Transformer with support for patch or hybrid CNN input stage
137
+ """
138
+
139
+ def __init__(self,
140
+ img_size: list = [112, 112],
141
+ patch_size: list = [16, 16],
142
+ in_channels: int = 3,
143
+ num_classes: int = 1000,
144
+ embed_dim: int = 768,
145
+ depth: int = 12,
146
+ num_heads: int = 12,
147
+ mlp_ratio: float = 4.,
148
+ qkv_bias: bool = False,
149
+ qk_scale: Optional[None] = None,
150
+ drop_rate: float = 0.,
151
+ attn_drop_rate: float = 0.,
152
+ drop_path_rate: float = 0.,
153
+ hybrid_backbone: Optional[None] = None,
154
+ norm_layer: str = "ln",
155
+ mask_ratio=0.1,
156
+ using_checkpoint=False,
157
+ ):
158
+ super().__init__()
159
+ self.num_classes = num_classes
160
+ # num_features for consistency with other models
161
+ self.num_features = self.embed_dim = embed_dim
162
+
163
+ if hybrid_backbone is not None:
164
+ raise ValueError
165
+ else:
166
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_channels=in_channels,
167
+ embed_dim=embed_dim)
168
+ self.mask_ratio = mask_ratio
169
+ self.using_checkpoint = using_checkpoint
170
+ num_patches = self.patch_embed.num_patches
171
+ self.num_patches = num_patches
172
+
173
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
174
+ self.pos_drop = nn.Dropout(p=drop_rate)
175
+
176
+ # stochastic depth decay rule
177
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
178
+
179
+ self.patchWSize = img_size[1] // patch_size[1]
180
+ self.patchHSize = img_size[0] // patch_size[0]
181
+ patch_n = self.patchWSize * self.patchHSize
182
+
183
+ self.blocks = nn.ModuleList(
184
+ [
185
+ Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
186
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
187
+ num_patches=num_patches, patch_n=patch_n)
188
+ for i in range(depth)]
189
+ )
190
+ self.extra_gflops = 0.0
191
+ for _block in self.blocks:
192
+ self.extra_gflops += _block.extra_gflops
193
+
194
+ if norm_layer == "ln":
195
+ self.norm = nn.LayerNorm(embed_dim)
196
+ elif norm_layer == "bn":
197
+ self.norm = VITBatchNorm(self.num_patches)
198
+
199
+ # features head
200
+ self.feature = nn.Sequential(
201
+ nn.Linear(in_features=embed_dim * num_patches, out_features=embed_dim, bias=False),
202
+ nn.BatchNorm1d(num_features=embed_dim, eps=2e-5),
203
+ nn.Linear(in_features=embed_dim, out_features=num_classes, bias=False),
204
+ nn.BatchNorm1d(num_features=num_classes, eps=2e-5)
205
+ )
206
+
207
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
208
+ torch.nn.init.normal_(self.mask_token, std=.02)
209
+ trunc_normal_(self.pos_embed, std=.02)
210
+ # trunc_normal_(self.cls_token, std=.02)
211
+ self.apply(self._init_weights)
212
+
213
+ def _init_weights(self, m):
214
+ if isinstance(m, nn.Linear):
215
+ trunc_normal_(m.weight, std=.02)
216
+ if isinstance(m, nn.Linear) and m.bias is not None:
217
+ nn.init.constant_(m.bias, 0)
218
+ elif isinstance(m, nn.LayerNorm):
219
+ nn.init.constant_(m.bias, 0)
220
+ nn.init.constant_(m.weight, 1.0)
221
+
222
+ @torch.jit.ignore
223
+ def no_weight_decay(self):
224
+ return {'pos_embed', 'cls_token'}
225
+
226
+ def get_classifier(self):
227
+ return self.head
228
+
229
+ def random_masking(self, x, mask_ratio=0.1):
230
+ """
231
+ Perform per-sample random masking by per-sample shuffling.
232
+ Per-sample shuffling is done by argsort random noise.
233
+ x: [N, L, D], sequence
234
+ """
235
+ N, L, D = x.size() # batch, length, dim
236
+ len_keep = int(L * (1 - mask_ratio))
237
+
238
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
239
+
240
+ # sort noise for each sample
241
+ # ascend: small is keep, large is remove
242
+ ids_shuffle = torch.argsort(noise, dim=1)
243
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
244
+
245
+ # keep the first subset
246
+ ids_keep = ids_shuffle[:, :len_keep]
247
+ x_masked = torch.gather(
248
+ x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
249
+
250
+ # generate the binary mask: 0 is keep, 1 is remove
251
+ mask = torch.ones([N, L], device=x.device)
252
+ mask[:, :len_keep] = 0
253
+ # unshuffle to get the binary mask
254
+ mask = torch.gather(mask, dim=1, index=ids_restore)
255
+
256
+ return x_masked, mask, ids_restore
257
+
258
+ def forward_features(self, x):
259
+ B = x.shape[0]
260
+ x = self.patch_embed(x)
261
+ x = x + self.pos_embed
262
+
263
+ if self.training and self.mask_ratio > 0:
264
+ x, _, ids_restore = self.random_masking(x, mask_ratio=self.mask_ratio)
265
+
266
+ for func in self.blocks:
267
+ if self.using_checkpoint and self.training:
268
+ from torch.utils.checkpoint import checkpoint
269
+ x = checkpoint(func, x)
270
+ else:
271
+ x = func(x)
272
+ x = self.norm(x.float())
273
+
274
+ if self.training and self.mask_ratio > 0:
275
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
276
+ x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token
277
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
278
+ x = x_
279
+ return x.permute(0, 2, 1).view(B, self.embed_dim, self.patchHSize, self.patchWSize)
280
+
281
+ def forward(self, x):
282
+ x = self.forward_features(x)
283
+ return x
284
+
285
+
286
+ if __name__ == '__main__':
287
+ img = torch.zeros((1, 3, 32, 128))
288
+ model = VisionTransformer(img_size=[32, 128],
289
+ patch_size=[4, 4],
290
+ in_channels=3,
291
+ embed_dim=512,
292
+ num_heads=8)
293
+ model.eval()
294
+ features = model(img)
295
+ print(features.size())
Model/model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from Model.attention import PositionAttention
4
+ from Model.backbone import VisionTransformer
5
+
6
+
7
+ class TTR(nn.Module):
8
+ def __init__(self, args: dict):
9
+ super().__init__()
10
+ self.args = args
11
+
12
+ self.backbone = VisionTransformer(img_size=args["img_size"],
13
+ patch_size=args["patch_size"],
14
+ in_channels=3,
15
+ embed_dim=args["embed_dim"],
16
+ num_heads=args["num_heads"],
17
+ mask_ratio=args["mask_ratio"])
18
+
19
+ self.positionAttention = PositionAttention(max_length=26,
20
+ in_channels=args["embed_dim"],
21
+ num_channels=args["position_attention_hidden"],
22
+ h=args["img_size"][0] // args["patch_size"][0],
23
+ w=args["img_size"][1] // args["patch_size"][1],
24
+ mode='nearest')
25
+ self.cls = nn.Linear(args["embed_dim"], 43)
26
+ return
27
+
28
+ def forward(self, image):
29
+ features = self.backbone(image)
30
+ attn_vecs, attn_scores = self.positionAttention(features)
31
+ logits = self.cls(attn_vecs)
32
+ return logits
Model/trainer.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import tqdm
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from torch.utils.data import DataLoader, RandomSampler, Dataset
9
+
10
+ from metrics import getAcc
11
+ from torch.cuda.amp import autocast, GradScaler
12
+
13
+
14
+ class Trainer:
15
+
16
+ def __init__(self, args, tb_logger, logger):
17
+ self.args = args
18
+
19
+ self.gpu = torch.device(args.gpu)
20
+ self.model = None
21
+ self.it = 0
22
+ self.best_eval_acc, self.best_it = 0.0, 0
23
+
24
+ # init dataset
25
+ self.trainDataset = None
26
+ self.trainDataloader = None
27
+ self.evalDataset = None
28
+ self.evalDataloader = None
29
+
30
+ # optimizer and scheduler
31
+ self.scheduler = None
32
+ self.optimizer = None
33
+
34
+ # loss
35
+ self.loss_fn = None
36
+ self.weight = None
37
+ self.setLoss(args.loss)
38
+ self.ignore_index = args.model["letter_size"]
39
+
40
+ # gradient clipping
41
+ if args.clip_grad is not None:
42
+ self.clip_grad = True
43
+ self.clip_value = args.clip_grad
44
+ else:
45
+ self.clip_grad = False
46
+
47
+ if hasattr(args, "label_smoothing") and args.label_smoothing is not None:
48
+ self.label_smoothing = float(args.label_smoothing)
49
+ else:
50
+ self.label_smoothing = 0.0
51
+
52
+ # logging
53
+ if tb_logger is not None:
54
+ self.tb_log = tb_logger
55
+ self.print_fn = print if logger is None else logger.info
56
+
57
+ return
58
+
59
+ def train(self):
60
+ """
61
+ Train The Model
62
+ """
63
+ self.model.train()
64
+
65
+ # for gpu profiling
66
+ start_batch = torch.cuda.Event(enable_timing=True)
67
+ end_batch = torch.cuda.Event(enable_timing=True)
68
+ start_run = torch.cuda.Event(enable_timing=True)
69
+ end_run = torch.cuda.Event(enable_timing=True)
70
+
71
+ scaler = GradScaler()
72
+
73
+ start_batch.record()
74
+ # eval for once
75
+ if self.args.resume:
76
+ eval_dict = self.evaluate()
77
+ print(eval_dict)
78
+
79
+ tbar = tqdm.tqdm(total=len(self.trainDataloader), colour='BLUE')
80
+
81
+ for samples, targets, _ in self.trainDataloader:
82
+ tbar.update(1)
83
+ self.it += 1
84
+
85
+ end_batch.record()
86
+ torch.cuda.synchronize()
87
+ start_run.record()
88
+
89
+ samples, targets = samples.to(self.gpu), targets.to(self.gpu).long()
90
+
91
+ with autocast():
92
+ logits = self.model(samples)
93
+ loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(),
94
+ ignore_index=self.ignore_index,
95
+ label_smoothing=self.label_smoothing)
96
+
97
+ scaler.scale(loss).backward()
98
+
99
+ if self.clip_grad:
100
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_value)
101
+
102
+ scaler.step(self.optimizer)
103
+ scaler.update()
104
+ if self.scheduler is not None:
105
+ self.scheduler.step()
106
+ self.model.zero_grad()
107
+
108
+ end_run.record()
109
+ torch.cuda.synchronize()
110
+
111
+ # tensorboard_dict update
112
+ tb_dict = {}
113
+ tb_dict['train/loss'] = loss.detach().cpu().item()
114
+
115
+ tb_dict['lr'] = self.optimizer.param_groups[0]['lr']
116
+ tb_dict['GPU/prefecth_time'] = start_batch.elapsed_time(end_batch) / 1000.
117
+ tb_dict['GPU/run_time'] = start_run.elapsed_time(end_run) / 1000.
118
+
119
+ if self.it % self.args.num_eval_iter == 0:
120
+ eval_dict = self.evaluate()
121
+ tb_dict.update(eval_dict)
122
+ save_path = self.args.save_path
123
+ if tb_dict['Word/Acc'] > self.best_eval_acc:
124
+ self.best_eval_acc = tb_dict['Word/Acc']
125
+ self.best_it = self.it
126
+
127
+ self.print_fn(
128
+ f"\n {self.it} iteration, {tb_dict}, \n BEST_EVAL_ACC: {self.best_eval_acc}, at {self.best_it} iters")
129
+ self.print_fn(
130
+ f" {self.it} iteration, ACC: {tb_dict['Word/Acc']}\n")
131
+ if self.it == self.best_it:
132
+ self.save_model('model_best.pth', save_path)
133
+
134
+ if self.tb_log is not None:
135
+ self.tb_log.update(tb_dict, self.it)
136
+ del tb_dict
137
+ start_batch.record()
138
+
139
+ eval_dict = self.evaluate()
140
+ eval_dict.update({'eval/best_acc': self.best_eval_acc, 'eval/best_it': self.best_it})
141
+ return eval_dict
142
+
143
+ @torch.no_grad()
144
+ def evaluate(self, model: nn.Module = None, evalDataset: Dataset = None):
145
+ self.print_fn("\n Evaluation!!!")
146
+
147
+ if model is None:
148
+ model = self.model
149
+ if evalDataset is not None:
150
+ evalDataloader = DataLoader(evalDataset, self.args.eval_batch_size, shuffle=False, num_workers=0)
151
+ else:
152
+ evalDataloader = self.evalDataloader
153
+
154
+ eval_dict = {}
155
+
156
+ model.eval()
157
+
158
+ preds_arr = None
159
+ targets_arr = None
160
+ lengths_arr = None
161
+ for samples, targets, lengths in evalDataloader:
162
+ samples, targets = samples.to(self.gpu), targets.to(self.gpu)
163
+
164
+ outputs = model(samples)
165
+
166
+ preds = torch.max(outputs, dim=2)[1]
167
+
168
+ if preds_arr is None:
169
+ preds_arr = preds.detach().cpu()
170
+ targets_arr = targets.detach().cpu()
171
+ lengths_arr = lengths.detach().cpu()
172
+ else:
173
+ preds_arr = torch.concat((preds_arr, preds.detach().cpu()))
174
+ targets_arr = torch.concat((targets_arr, targets.detach().cpu()))
175
+ lengths_arr = torch.concat((lengths_arr, lengths.detach().cpu()))
176
+
177
+ wordAcc, charAcc = getAcc(preds_arr, targets_arr, lengths_arr)
178
+ eval_dict.update({"Word/Acc": wordAcc,
179
+ "Char/Acc": charAcc})
180
+ model.train()
181
+ return eval_dict
182
+
183
+ def save_model(self, save_name, save_path):
184
+ save_filename = os.path.join(save_path, save_name)
185
+ self.model.eval()
186
+ save_dict = {"model": self.model.state_dict(),
187
+ 'optimizer': self.optimizer.state_dict(),
188
+ 'scheduler': self.scheduler.state_dict() if self.scheduler is not None else None,
189
+ 'it': self.it}
190
+ torch.save(save_dict, save_filename)
191
+ self.model.train()
192
+ self.print_fn(f"model saved: {save_filename}\n")
193
+
194
+ def save_baseLearner(self, save_name, save_path, trainIndexes):
195
+ save_filename = os.path.join(save_path, save_name)
196
+ self.model.eval()
197
+ save_dict = {"model": self.model.state_dict(),
198
+ 'optimizer': self.optimizer.state_dict(),
199
+ 'scheduler': self.scheduler.state_dict() if self.scheduler is not None else None,
200
+ 'trainIndexes': trainIndexes,
201
+ 'it': self.it}
202
+ torch.save(save_dict, save_filename)
203
+ self.model.train()
204
+ self.print_fn(f"model saved: {save_filename}\n")
205
+
206
+ def load_model(self, load_dir, load_name):
207
+ """
208
+ load saved model a
209
+ :param load_dir: directory of loading model
210
+ :param load_name: model name
211
+ """
212
+ load_path = os.path.join(load_dir, load_name)
213
+ checkpoint = torch.load(load_path)
214
+ self.model.load_state_dict(checkpoint['model'])
215
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
216
+ if checkpoint['scheduler'] is not None:
217
+ self.scheduler.load_state_dict(checkpoint['scheduler'])
218
+ self.it = checkpoint['it']
219
+ self.print_fn(f'model loaded from {load_path}')
220
+
221
+ def set_optimizer(self, optimizer, scheduler=None):
222
+ """
223
+ set optimizer and scheduler
224
+ :param optimizer: optimizer
225
+ :param scheduler: scheduler
226
+ """
227
+ self.optimizer = optimizer
228
+ self.scheduler = scheduler
229
+
230
+ def setModel(self, model):
231
+ """
232
+ set model
233
+ :param model: model
234
+ """
235
+ self.model = model.cuda(self.gpu)
236
+
237
+ def setDatasets(self, trainDataset, evalDataset):
238
+ """
239
+ set train and evaluation datasets and dataloaders
240
+ :param trainDataset: train dataset
241
+ :param evalDataset: evaluation dataset
242
+ """
243
+ self.print_fn(f"\n Num Train Labeled Sample : {len(trainDataset)}\n Num Val Sample : {len(evalDataset)}")
244
+ self.trainDataset = trainDataset
245
+ self.evalDataset = evalDataset
246
+
247
+ self.trainDataloader = DataLoader(trainDataset, batch_size=self.args.batch_size,
248
+ sampler=RandomSampler(data_source=trainDataset,
249
+ replacement=True,
250
+ num_samples=self.args.iter * self.args.batch_size),
251
+ num_workers=self.args.num_workers, drop_last=True, pin_memory=True)
252
+
253
+ self.evalDataloader = DataLoader(evalDataset, self.args.eval_batch_size, shuffle=False, num_workers=0,
254
+ pin_memory=True)
255
+
256
+ def setLoss(self, loss_function: dict):
257
+ """
258
+ set loss function
259
+ :param loss_function: loss function arguments
260
+ """
261
+ if loss_function["name"] == 'CrossEntropyLoss':
262
+ self.loss_fn = nn.CrossEntropyLoss(label_smoothing=loss_function["label_smoothing"]).cuda(self.gpu)
263
+ else:
264
+ raise Exception(f"Unknown Loss Function : {loss_function}")
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ import gradio as gr
4
+
5
+ from Model import TTR
6
+ from dataset.charMapper import CharMapper
7
+
8
+ # arguments
9
+ model_path = "./experiments/real_train/model_best.pth"
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+
13
+ def getTransforms():
14
+ return transforms.Compose([
15
+ transforms.Resize((32, 128), transforms.InterpolationMode.BICUBIC),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize(0.5, 0.5)
18
+ ])
19
+
20
+
21
+ mapper = CharMapper()
22
+ model = TTR({"img_size": [32, 128],
23
+ "patch_size": [4, 4],
24
+ "embed_dim": 512,
25
+ "num_heads": 8,
26
+ "position_attention_hidden": 64,
27
+ "mask_ratio": 0.0
28
+ })
29
+ model.load_state_dict(torch.load(model_path)["model"])
30
+ model.eval()
31
+ model = model.to("cuda:0")
32
+
33
+ preprocess = getTransforms()
34
+
35
+
36
+ def inference(raw_image):
37
+ batch = preprocess(raw_image).unsqueeze(0).to(device)
38
+ outputs = model(batch)
39
+ preds = torch.max(outputs, dim=2)[1]
40
+ pred_text = mapper.reverseMapper(preds[0])
41
+ return pred_text
42
+
43
+
44
+ inputs = [gr.Image(type='pil', interactive=True, )]
45
+ outputs = gr.components.Textbox(label="Caption")
46
+ title = "MViT-TR"
47
+ paper_link = "https://www.sciencedirect.com/science/article/pii/S2215098624002672"
48
+ github_link = "https://github.com/serdaryildiz/MViT-TR"
49
+ description = f"<p style='text-align: center'><a href='{github_link}' target='_blank'>MViT-TR</a> : Masked Vision Transformer for Text Recognition"
50
+ examples = [
51
+ ["fig/0.jpg"],
52
+ ["fig/145.jpg"],
53
+ ["fig/195.jpg"],
54
+ ["fig/270.jpg"],
55
+ ]
56
+ article = f"<p style='text-align: center'><a href='{paper_link}' target='_blank'>Paper</a> | <a href='{github_link}' target='_blank'>Github Repo</a></p>"
57
+ css = ".output-image, .input-image, .image-preview {height: 600px !important}"
58
+
59
+ iface = gr.Interface(fn=inference,
60
+ inputs=inputs,
61
+ outputs=outputs,
62
+ title=title,
63
+ description=description,
64
+ examples=examples,
65
+ article=article,
66
+ css=css)
67
+ iface.launch()
dataset/TurkishSceneTextDataset.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+
8
+ from dataset.augmentations import CVGeometry, CVDeterioration, CVColorJitter
9
+ from dataset.charMapper import CharMapper
10
+
11
+
12
+ class TurkishSceneTextDataset(Dataset):
13
+ def __init__(self, args: dict, train=True):
14
+ self.root = args["root"]
15
+ self.train = train
16
+
17
+ txt_path = os.path.join(self.root, "train.txt") if self.train else os.path.join(self.root, "test.txt")
18
+ with open(txt_path, "r") as fp:
19
+ lines = fp.readlines()
20
+
21
+ self.gt = {}
22
+ for l in lines:
23
+ img_name, label = l.strip().split('\t')
24
+ self.gt[img_name] = label
25
+
26
+ self.root = os.path.join(self.root, "train") if self.train else os.path.join(self.root, "test")
27
+ self.imgPaths = [os.path.join(self.root, p) for p in self.gt.keys()]
28
+
29
+ self.transforms = self._getTransforms()
30
+ self.mapper = CharMapper(letters=args["letters"], maxLength=args["maxLength"])
31
+
32
+ def __getitem__(self, item):
33
+ imgPath = self.imgPaths[item]
34
+ image = Image.open(imgPath)
35
+ image = self.transforms(image)
36
+
37
+ label = self.gt[os.path.basename(imgPath)]
38
+
39
+ label, length = self.mapper(label, return_length=True)
40
+ return image, label, torch.tensor(length)
41
+
42
+ def __len__(self):
43
+ return len(self.imgPaths)
44
+
45
+ def _getTransforms(self):
46
+ if self.train:
47
+ return transforms.Compose([
48
+ CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
49
+ CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
50
+ CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25),
51
+ transforms.Resize((32, 128), transforms.InterpolationMode.BICUBIC),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize(0.5, 0.5)
54
+ ])
55
+ else:
56
+ return transforms.Compose([
57
+ transforms.Resize((32, 128), transforms.InterpolationMode.BICUBIC),
58
+ transforms.ToTensor(),
59
+ transforms.Normalize(0.5, 0.5)
60
+ ])
61
+
62
+
63
+ if __name__ == '__main__':
64
+ args = {
65
+ "name": "TurkishSceneTextDataset",
66
+ "root": "../data/TS-TR",
67
+ "letters": "0123456789abcçdefgğhıijklmnoöpqrsştuüvwxyz",
68
+ "maxLength": "25"
69
+ }
70
+ dataset = TurkishSceneTextDataset(args, False)
dataset/__init__.py ADDED
File without changes
dataset/__pycache__/TurkishSceneTextDataset.cpython-38.pyc ADDED
Binary file (2.75 kB). View file
 
dataset/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (147 Bytes). View file
 
dataset/__pycache__/augmentations.cpython-38.pyc ADDED
Binary file (12.9 kB). View file
 
dataset/__pycache__/charMapper.cpython-38.pyc ADDED
Binary file (2.12 kB). View file
 
dataset/__pycache__/strit.cpython-38.pyc ADDED
Binary file (1.72 kB). View file
 
dataset/__pycache__/syntheticTurkishStyleText.cpython-38.pyc ADDED
Binary file (2.37 kB). View file
 
dataset/augmentations.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source : https://github.com/FangShancheng/ABINet/blob/main/transforms.py
3
+ """
4
+
5
+ import math
6
+ import numbers
7
+ import random
8
+
9
+ import cv2
10
+ import numpy as np
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from torchvision.transforms import Compose
14
+
15
+
16
+ def sample_asym(magnitude, size=None):
17
+ return np.random.beta(1, 4, size) * magnitude
18
+
19
+
20
+ def sample_sym(magnitude, size=None):
21
+ return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
22
+
23
+
24
+ def sample_uniform(low, high, size=None):
25
+ return np.random.uniform(low, high, size=size)
26
+
27
+
28
+ def get_interpolation(type='random'):
29
+ if type == 'random':
30
+ choice = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA]
31
+ interpolation = choice[random.randint(0, len(choice) - 1)]
32
+ elif type == 'nearest':
33
+ interpolation = cv2.INTER_NEAREST
34
+ elif type == 'linear':
35
+ interpolation = cv2.INTER_LINEAR
36
+ elif type == 'cubic':
37
+ interpolation = cv2.INTER_CUBIC
38
+ elif type == 'area':
39
+ interpolation = cv2.INTER_AREA
40
+ else:
41
+ raise TypeError('Interpolation types only nearest, linear, cubic, area are supported!')
42
+ return interpolation
43
+
44
+
45
+ class CVRandomRotation(object):
46
+ def __init__(self, degrees=15):
47
+ assert isinstance(degrees, numbers.Number), "degree should be a single number."
48
+ assert degrees >= 0, "degree must be positive."
49
+ self.degrees = degrees
50
+
51
+ @staticmethod
52
+ def get_params(degrees):
53
+ return sample_sym(degrees)
54
+
55
+ def __call__(self, img):
56
+ angle = self.get_params(self.degrees)
57
+ src_h, src_w = img.shape[:2]
58
+ M = cv2.getRotationMatrix2D(center=(src_w / 2, src_h / 2), angle=angle, scale=1.0)
59
+ abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1])
60
+ dst_w = int(src_h * abs_sin + src_w * abs_cos)
61
+ dst_h = int(src_h * abs_cos + src_w * abs_sin)
62
+ M[0, 2] += (dst_w - src_w) / 2
63
+ M[1, 2] += (dst_h - src_h) / 2
64
+
65
+ flags = get_interpolation()
66
+ return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE)
67
+
68
+
69
+ class CVRandomAffine(object):
70
+ def __init__(self, degrees, translate=None, scale=None, shear=None):
71
+ assert isinstance(degrees, numbers.Number), "degree should be a single number."
72
+ assert degrees >= 0, "degree must be positive."
73
+ self.degrees = degrees
74
+
75
+ if translate is not None:
76
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
77
+ "translate should be a list or tuple and it must be of length 2."
78
+ for t in translate:
79
+ if not (0.0 <= t <= 1.0):
80
+ raise ValueError("translation values should be between 0 and 1")
81
+ self.translate = translate
82
+
83
+ if scale is not None:
84
+ assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
85
+ "scale should be a list or tuple and it must be of length 2."
86
+ for s in scale:
87
+ if s <= 0:
88
+ raise ValueError("scale values should be positive")
89
+ self.scale = scale
90
+
91
+ if shear is not None:
92
+ if isinstance(shear, numbers.Number):
93
+ if shear < 0:
94
+ raise ValueError("If shear is a single number, it must be positive.")
95
+ self.shear = [shear]
96
+ else:
97
+ assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
98
+ "shear should be a list or tuple and it must be of length 2."
99
+ self.shear = shear
100
+ else:
101
+ self.shear = shear
102
+
103
+ def _get_inverse_affine_matrix(self, center, angle, translate, scale, shear):
104
+ # https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717
105
+ from numpy import sin, cos, tan
106
+
107
+ if isinstance(shear, numbers.Number):
108
+ shear = [shear, 0]
109
+
110
+ if not isinstance(shear, (tuple, list)) and len(shear) == 2:
111
+ raise ValueError(
112
+ "Shear should be a single value or a tuple/list containing " +
113
+ "two values. Got {}".format(shear))
114
+
115
+ rot = math.radians(angle)
116
+ sx, sy = [math.radians(s) for s in shear]
117
+
118
+ cx, cy = center
119
+ tx, ty = translate
120
+
121
+ # RSS without scaling
122
+ a = cos(rot - sy) / cos(sy)
123
+ b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
124
+ c = sin(rot - sy) / cos(sy)
125
+ d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
126
+
127
+ # Inverted rotation matrix with scale and shear
128
+ # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
129
+ M = [d, -b, 0,
130
+ -c, a, 0]
131
+ M = [x / scale for x in M]
132
+
133
+ # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
134
+ M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
135
+ M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
136
+
137
+ # Apply center translation: C * RSS^-1 * C^-1 * T^-1
138
+ M[2] += cx
139
+ M[5] += cy
140
+ return M
141
+
142
+ @staticmethod
143
+ def get_params(degrees, translate, scale_ranges, shears, height):
144
+ angle = sample_sym(degrees)
145
+ if translate is not None:
146
+ max_dx = translate[0] * height
147
+ max_dy = translate[1] * height
148
+ translations = (np.round(sample_sym(max_dx)), np.round(sample_sym(max_dy)))
149
+ else:
150
+ translations = (0, 0)
151
+
152
+ if scale_ranges is not None:
153
+ scale = sample_uniform(scale_ranges[0], scale_ranges[1])
154
+ else:
155
+ scale = 1.0
156
+
157
+ if shears is not None:
158
+ if len(shears) == 1:
159
+ shear = [sample_sym(shears[0]), 0.]
160
+ elif len(shears) == 2:
161
+ shear = [sample_sym(shears[0]), sample_sym(shears[1])]
162
+ else:
163
+ shear = 0.0
164
+
165
+ return angle, translations, scale, shear
166
+
167
+ def __call__(self, img):
168
+ src_h, src_w = img.shape[:2]
169
+ angle, translate, scale, shear = self.get_params(
170
+ self.degrees, self.translate, self.scale, self.shear, src_h)
171
+
172
+ M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle, (0, 0), scale, shear)
173
+ M = np.array(M).reshape(2, 3)
174
+
175
+ startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1), (0, src_h - 1)]
176
+ project = lambda x, y, a, b, c: int(a * x + b * y + c)
177
+ endpoints = [(project(x, y, *M[0]), project(x, y, *M[1])) for x, y in startpoints]
178
+
179
+ rect = cv2.minAreaRect(np.array(endpoints))
180
+ bbox = cv2.boxPoints(rect).astype(dtype=np.int)
181
+ max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
182
+ min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
183
+
184
+ dst_w = int(max_x - min_x)
185
+ dst_h = int(max_y - min_y)
186
+ M[0, 2] += (dst_w - src_w) / 2
187
+ M[1, 2] += (dst_h - src_h) / 2
188
+
189
+ # add translate
190
+ dst_w += int(abs(translate[0]))
191
+ dst_h += int(abs(translate[1]))
192
+ if translate[0] < 0: M[0, 2] += abs(translate[0])
193
+ if translate[1] < 0: M[1, 2] += abs(translate[1])
194
+
195
+ flags = get_interpolation()
196
+ return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE)
197
+
198
+
199
+ class CVRandomPerspective(object):
200
+ def __init__(self, distortion=0.5):
201
+ self.distortion = distortion
202
+
203
+ def get_params(self, width, height, distortion):
204
+ offset_h = sample_asym(distortion * height / 2, size=4).astype(dtype=np.int)
205
+ offset_w = sample_asym(distortion * width / 2, size=4).astype(dtype=np.int)
206
+ topleft = (offset_w[0], offset_h[0])
207
+ topright = (width - 1 - offset_w[1], offset_h[1])
208
+ botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
209
+ botleft = (offset_w[3], height - 1 - offset_h[3])
210
+
211
+ startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
212
+ endpoints = [topleft, topright, botright, botleft]
213
+ return np.array(startpoints, dtype=np.float32), np.array(endpoints, dtype=np.float32)
214
+
215
+ def __call__(self, img):
216
+ height, width = img.shape[:2]
217
+ startpoints, endpoints = self.get_params(width, height, self.distortion)
218
+ M = cv2.getPerspectiveTransform(startpoints, endpoints)
219
+
220
+ # TODO: more robust way to crop image
221
+ rect = cv2.minAreaRect(endpoints)
222
+ bbox = cv2.boxPoints(rect).astype(dtype=np.int)
223
+ max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
224
+ min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
225
+ min_x, min_y = max(min_x, 0), max(min_y, 0)
226
+
227
+ flags = get_interpolation()
228
+ img = cv2.warpPerspective(img, M, (max_x, max_y), flags=flags, borderMode=cv2.BORDER_REPLICATE)
229
+ img = img[min_y:, min_x:]
230
+ return img
231
+
232
+
233
+ class CVRescale(object):
234
+
235
+ def __init__(self, factor=4, base_size=(128, 512)):
236
+ """ Define image scales using gaussian pyramid and rescale image to target scale.
237
+
238
+ Args:
239
+ factor: the decayed factor from base size, factor=4 keeps target scale by default.
240
+ base_size: base size the build the bottom layer of pyramid
241
+ """
242
+ if isinstance(factor, numbers.Number):
243
+ self.factor = round(sample_uniform(0, factor))
244
+ elif isinstance(factor, (tuple, list)) and len(factor) == 2:
245
+ self.factor = round(sample_uniform(factor[0], factor[1]))
246
+ else:
247
+ raise Exception('factor must be number or list with length 2')
248
+ # assert factor is valid
249
+ self.base_h, self.base_w = base_size[:2]
250
+
251
+ def __call__(self, img):
252
+ if self.factor == 0: return img
253
+ src_h, src_w = img.shape[:2]
254
+ cur_w, cur_h = self.base_w, self.base_h
255
+ scale_img = cv2.resize(img, (cur_w, cur_h), interpolation=get_interpolation())
256
+ for _ in range(self.factor):
257
+ scale_img = cv2.pyrDown(scale_img)
258
+ scale_img = cv2.resize(scale_img, (src_w, src_h), interpolation=get_interpolation())
259
+ return scale_img
260
+
261
+
262
+ class CVGaussianNoise(object):
263
+ def __init__(self, mean=0, var=20):
264
+ self.mean = mean
265
+ if isinstance(var, numbers.Number):
266
+ self.var = max(int(sample_asym(var)), 1)
267
+ elif isinstance(var, (tuple, list)) and len(var) == 2:
268
+ self.var = int(sample_uniform(var[0], var[1]))
269
+ else:
270
+ raise Exception('degree must be number or list with length 2')
271
+
272
+ def __call__(self, img):
273
+ noise = np.random.normal(self.mean, self.var ** 0.5, img.shape)
274
+ img = np.clip(img + noise, 0, 255).astype(np.uint8)
275
+ return img
276
+
277
+
278
+ class CVMotionBlur(object):
279
+ def __init__(self, degrees=12, angle=90):
280
+ if isinstance(degrees, numbers.Number):
281
+ self.degree = max(int(sample_asym(degrees)), 1)
282
+ elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
283
+ self.degree = int(sample_uniform(degrees[0], degrees[1]))
284
+ else:
285
+ raise Exception('degree must be number or list with length 2')
286
+ self.angle = sample_uniform(-angle, angle)
287
+
288
+ def __call__(self, img):
289
+ M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2), self.angle, 1)
290
+ motion_blur_kernel = np.zeros((self.degree, self.degree))
291
+ motion_blur_kernel[self.degree // 2, :] = 1
292
+ motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (self.degree, self.degree))
293
+ motion_blur_kernel = motion_blur_kernel / self.degree
294
+ img = cv2.filter2D(img, -1, motion_blur_kernel)
295
+ img = np.clip(img, 0, 255).astype(np.uint8)
296
+ return img
297
+
298
+
299
+ class CVGeometry(object):
300
+ def __init__(self, degrees=15, translate=(0.3, 0.3), scale=(0.5, 2.),
301
+ shear=(45, 15), distortion=0.5, p=0.5):
302
+ self.p = p
303
+ type_p = random.random()
304
+ if type_p < 0.33:
305
+ self.transforms = CVRandomRotation(degrees=degrees)
306
+ elif type_p < 0.66:
307
+ self.transforms = CVRandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear)
308
+ else:
309
+ self.transforms = CVRandomPerspective(distortion=distortion)
310
+
311
+ def __call__(self, img):
312
+ if random.random() < self.p:
313
+ img = np.array(img)
314
+ return Image.fromarray(self.transforms(img))
315
+ else:
316
+ return img
317
+
318
+
319
+ class CVDeterioration(object):
320
+ def __init__(self, var, degrees, factor, p=0.5):
321
+ self.p = p
322
+ transforms = []
323
+ if var is not None:
324
+ transforms.append(CVGaussianNoise(var=var))
325
+ if degrees is not None:
326
+ transforms.append(CVMotionBlur(degrees=degrees))
327
+ if factor is not None:
328
+ transforms.append(CVRescale(factor=factor))
329
+
330
+ random.shuffle(transforms)
331
+ transforms = Compose(transforms)
332
+ self.transforms = transforms
333
+
334
+ def __call__(self, img):
335
+ if random.random() < self.p:
336
+ img = np.array(img)
337
+ return Image.fromarray(self.transforms(img))
338
+ else:
339
+ return img
340
+
341
+
342
+ class CVColorJitter(object):
343
+ def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.5):
344
+ self.p = p
345
+ self.transforms = transforms.ColorJitter(brightness=brightness, contrast=contrast,
346
+ saturation=saturation, hue=hue)
347
+
348
+ def __call__(self, img):
349
+ if random.random() < self.p:
350
+ return self.transforms(img)
351
+ else:
352
+ return img
dataset/charMapper.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+
4
+
5
+ class CharMapper:
6
+ lower2upper = {
7
+ ord(u"i"): u"İ",
8
+ ord(u"ı"): u"I"
9
+ }
10
+
11
+ upper2lower = {
12
+ ord(u"İ"): u"i",
13
+ ord(u"I"): u"ı"
14
+ }
15
+
16
+ def __init__(self, letters: str = "0123456789abcçdefgğhıijklmnoöpqrsştuüvwxyz", maxLength: int = 25):
17
+ self.letters = letters
18
+ self.maxLength = maxLength
19
+ self.map = {"[END]": 0}
20
+ self.reverseMap = {0: "[END]"}
21
+ i = 1
22
+ for l in self.letters:
23
+ self.map[l] = i
24
+ self.reverseMap[i] = l
25
+ i += 1
26
+ self.map["[PAD]"] = i
27
+ self.reverseMap[i] = "[PAD]"
28
+ return
29
+
30
+ def __call__(self, text: str, return_length=False):
31
+ text = self.text2label(text)
32
+ length = len(text) + 1
33
+ mappedText = torch.tensor([self.map[l] for l in text] + [self.map["[END]"]])
34
+ text = torch.ones((self.maxLength + 1,)) * self.map["[PAD]"]
35
+ text[:len(mappedText)] = mappedText
36
+ if return_length:
37
+ return text, length
38
+ else:
39
+ return text
40
+
41
+ def reverseMapper(self, label: torch.tensor):
42
+ label = label.cpu()
43
+ text = "".join([self.reverseMap[l] for l in label.numpy()])
44
+ return text.split("[END]")[0]
45
+
46
+ def text2label(self, text):
47
+ text = re.sub('[^0-9a-zA-ZğüşöçıİĞÜŞÖÇ]+', '', text)
48
+ text = text.translate(self.upper2lower).lower()
49
+ return text
50
+
51
+
52
+ if __name__ == '__main__':
53
+ mapper = CharMapper()
54
+ mapped = mapper("!MA-PİŞ$Z")
55
+ print(mapped)
dataset/strit.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from PIL import Image
5
+
6
+ from torch.utils.data import Dataset
7
+ from torchvision import transforms
8
+ from dataset.charMapper import CharMapper
9
+
10
+
11
+ class STRIT(Dataset):
12
+ def __init__(self, args: dict):
13
+ self.root = args["root"]
14
+ self.imgPaths = [os.path.join(self.root, p) for p in os.listdir(self.root)]
15
+ self.transforms = self._getTransforms()
16
+
17
+ self.mapper = CharMapper(letters=args["letters"], maxLength=args["maxLength"])
18
+
19
+ def __getitem__(self, item):
20
+ imgPath = self.imgPaths[item]
21
+ image = Image.open(imgPath)
22
+ image = self.transforms(image)
23
+ label = imgPath.split('_')[-1].split('.')[0]
24
+ length = torch.tensor(len(label) + 1)
25
+ label = self.mapper(label)
26
+ return image, label, length
27
+
28
+ def __len__(self):
29
+ return len(self.imgPaths)
30
+
31
+ @staticmethod
32
+ def _getTransforms():
33
+ return transforms.Compose([
34
+ transforms.Resize((32, 128), transforms.InterpolationMode.BICUBIC),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(0.5, 0.5)
37
+ ])
dataset/syntheticTurkishStyleText.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import lmdb
3
+ import six
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from torchvision import transforms
8
+
9
+ from dataset.augmentations import CVGeometry, CVDeterioration, CVColorJitter
10
+ from dataset.charMapper import CharMapper
11
+
12
+
13
+ class SyntheticTurkishStyleText(Dataset):
14
+ def __init__(self, args: dict):
15
+ self.root = args["root"]
16
+ self.args = args
17
+ self.transforms = self._getTransforms()
18
+ self.mapper = CharMapper(letters=args["letters"], maxLength=args["maxLength"])
19
+
20
+ self.env = lmdb.open(self.root, readonly=True, lock=False, readahead=False, meminit=False)
21
+
22
+ with self.env.begin(write=False) as txn:
23
+ self.keys = list(txn.cursor().iternext(values=False))
24
+
25
+ def __getitem__(self, index):
26
+ key = self.keys[index]
27
+ with self.env.begin(write=False) as txn:
28
+ label = key.decode().split("-*-")[0]
29
+ imgbuf = txn.get(key)
30
+ buf = six.BytesIO()
31
+ buf.write(imgbuf)
32
+ buf.seek(0)
33
+ image = Image.open(buf).convert("RGB")
34
+ image = self.transforms(image)
35
+ length = torch.tensor(len(label) + 1)
36
+ label = self.mapper(label)
37
+ return image, label, length
38
+
39
+ def __len__(self):
40
+ return len(self.keys)
41
+
42
+ @staticmethod
43
+ def _getTransforms():
44
+ return transforms.Compose([
45
+ CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
46
+ CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
47
+ CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25),
48
+ transforms.Resize((32, 128), transforms.InterpolationMode.BICUBIC),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(0.5, 0.5)
51
+ ])
fig/0.jpg ADDED
fig/145.jpg ADDED
fig/195.jpg ADDED
fig/270.jpg ADDED
fig/MViT-TR-arch.png ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==1.12.1
2
+ torchvision==0.12.1
3
+ opencv-python==4.6.0.66
4
+ transformers==4.27.3
5
+ ftfy==6.1.1
6
+ gradio==3.48.0
7
+ gdown==4.6.0