p1atdev commited on
Commit
9844a09
•
1 Parent(s): 3ae75a5

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_mle.py +46 -0
  2. modeling_mle.py +413 -0
configuration_mle.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Any, List, Mapping, Optional
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.onnx import OnnxConfigWithPast, PatchingSpec
6
+ from transformers.utils import logging
7
+
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+
12
+ class MLEConfig(PretrainedConfig):
13
+ model_type = "mle"
14
+
15
+ def __init__(
16
+ self,
17
+ in_channels=1,
18
+ num_encoder_layers=[2, 3, 5, 7, 12],
19
+ num_decoder_layers=[7, 5, 3, 2, 2],
20
+ last_hidden_channels=16,
21
+ block_stride_size=4,
22
+ block_kernel_size=3,
23
+ block_patch_size=24,
24
+ upsample_ratio=2,
25
+ batch_norm_eps=1e-3,
26
+ hidden_act="leaky_relu",
27
+ negative_slope=0.2,
28
+ **kwargs,
29
+ ):
30
+ self.in_channels = in_channels
31
+ self.num_encoder_layers = num_encoder_layers
32
+ self.num_decoder_layers = num_decoder_layers
33
+ self.last_hidden_channels = last_hidden_channels
34
+
35
+ self.block_stride_size = block_stride_size
36
+ # if isinstance(block_kernel_size, int):
37
+ # self.block_kernel_size = (block_kernel_size, block_kernel_size)
38
+ self.block_kernel_size = block_kernel_size
39
+ self.block_patch_size = block_patch_size
40
+
41
+ self.upsample_ratio = upsample_ratio
42
+ self.batch_norm_eps = batch_norm_eps
43
+ self.hidden_act = hidden_act
44
+ self.negative_slope = negative_slope
45
+
46
+ super().__init__(**kwargs)
modeling_mle.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch MLE (Mnaga Line Extraction) model"""
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from transformers import PreTrainedModel
9
+ from transformers.modeling_outputs import ModelOutput, BaseModelOutput
10
+ from transformers.activations import ACT2FN
11
+
12
+ from .configuration_mle import MLEConfig
13
+
14
+
15
+ @dataclass
16
+ class MLEModelOutput(ModelOutput):
17
+ last_hidden_state: torch.FloatTensor | None = None
18
+
19
+
20
+ @dataclass
21
+ class MLEForAnimeLineExtractionOutput(ModelOutput):
22
+ last_hidden_state: torch.FloatTensor | None = None
23
+ pixel_values: torch.Tensor | None = None
24
+
25
+
26
+ class MLEBatchNorm(nn.Module):
27
+ def __init__(
28
+ self,
29
+ config: MLEConfig,
30
+ in_features: int,
31
+ ):
32
+ super().__init__()
33
+
34
+ self.norm = nn.BatchNorm2d(in_features, eps=config.batch_norm_eps)
35
+ # the original model uses leaky_relu
36
+ if config.hidden_act == "leaky_relu":
37
+ self.act_fn = nn.LeakyReLU(negative_slope=config.negative_slope)
38
+ else:
39
+ self.act_fn = ACT2FN[config.hidden_act]
40
+
41
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
42
+ hidden_states = self.norm(hidden_states)
43
+ hidden_states = self.act_fn(hidden_states)
44
+
45
+ return hidden_states
46
+
47
+
48
+ class MLEResBlock(nn.Module):
49
+ def __init__(
50
+ self,
51
+ config: MLEConfig,
52
+ in_channels: int,
53
+ out_channels: int,
54
+ stride_size: int,
55
+ ):
56
+ super().__init__()
57
+
58
+ self.norm1 = MLEBatchNorm(config, in_channels)
59
+ self.conv1 = nn.Conv2d(
60
+ in_channels,
61
+ out_channels,
62
+ config.block_kernel_size,
63
+ stride=stride_size,
64
+ padding=config.block_kernel_size // 2,
65
+ )
66
+
67
+ self.norm2 = MLEBatchNorm(config, out_channels)
68
+ self.conv2 = nn.Conv2d(
69
+ out_channels,
70
+ out_channels,
71
+ config.block_kernel_size,
72
+ stride=1,
73
+ padding=config.block_kernel_size // 2,
74
+ )
75
+
76
+ if in_channels != out_channels or stride_size != 1:
77
+ self.resize = nn.Conv2d(
78
+ in_channels,
79
+ out_channels,
80
+ kernel_size=1,
81
+ stride=stride_size,
82
+ )
83
+ else:
84
+ self.resize = None
85
+
86
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
87
+ output = self.norm1(hidden_states)
88
+ output = self.conv1(output)
89
+ output = self.norm2(output)
90
+ output = self.conv2(output)
91
+
92
+ if self.resize is not None:
93
+ resized_input = self.resize(hidden_states)
94
+ output += resized_input
95
+ else:
96
+ output += hidden_states
97
+
98
+ return output
99
+
100
+
101
+ class MLEEncoderLayer(nn.Module):
102
+ def __init__(
103
+ self,
104
+ config: MLEConfig,
105
+ in_features: int,
106
+ out_features: int,
107
+ num_layers: int,
108
+ stride_sizes: list[int],
109
+ ):
110
+ super().__init__()
111
+
112
+ self.blocks = nn.ModuleList(
113
+ [
114
+ MLEResBlock(
115
+ config,
116
+ in_channels=in_features if i == 0 else out_features,
117
+ out_channels=out_features,
118
+ stride_size=stride_sizes[i],
119
+ )
120
+ for i in range(num_layers)
121
+ ]
122
+ )
123
+
124
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
125
+ for block in self.blocks:
126
+ hidden_states = block(hidden_states)
127
+ return hidden_states
128
+
129
+
130
+ class MLEEncoder(nn.Module):
131
+ def __init__(
132
+ self,
133
+ config: MLEConfig,
134
+ ):
135
+ super().__init__()
136
+
137
+ self.layers = nn.ModuleList(
138
+ [
139
+ MLEEncoderLayer(
140
+ config,
141
+ in_features=(
142
+ config.in_channels
143
+ if i == 0
144
+ else config.in_channels
145
+ * config.block_patch_size
146
+ * (config.upsample_ratio ** (i - 1))
147
+ ),
148
+ out_features=config.in_channels
149
+ * config.block_patch_size
150
+ * (config.upsample_ratio**i),
151
+ num_layers=num_layers,
152
+ stride_sizes=(
153
+ [
154
+ 1 if i_layer < num_layers - 1 else 2
155
+ for i_layer in range(num_layers)
156
+ ]
157
+ if i > 0
158
+ else [1 for _ in range(num_layers)]
159
+ ),
160
+ )
161
+ for i, num_layers in enumerate(config.num_encoder_layers)
162
+ ]
163
+ )
164
+
165
+ def forward(
166
+ self, hidden_states: torch.Tensor
167
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]:
168
+ all_hidden_states: tuple[torch.Tensor, ...] = ()
169
+ for layer in self.layers:
170
+ hidden_states = layer(hidden_states)
171
+ all_hidden_states += (hidden_states,)
172
+ return hidden_states, all_hidden_states
173
+
174
+
175
+ class MLEUpsampleBlock(nn.Module):
176
+ def __init__(self, config: MLEConfig, in_features: int, out_features: int):
177
+ super().__init__()
178
+
179
+ self.norm = MLEBatchNorm(config, in_features=in_features)
180
+ self.conv = nn.Conv2d(
181
+ in_features,
182
+ out_features,
183
+ config.block_kernel_size,
184
+ stride=1,
185
+ padding=config.block_kernel_size // 2,
186
+ )
187
+ self.upsample = nn.Upsample(scale_factor=config.upsample_ratio)
188
+
189
+ def forward(self, hidden_states: torch.Tensor):
190
+ output = self.norm(hidden_states)
191
+ output = self.conv(output)
192
+ output = self.upsample(output)
193
+
194
+ return output
195
+
196
+
197
+ class MLEUpsampleResBlock(nn.Module):
198
+ def __init__(self, config: MLEConfig, in_features: int, out_features: int):
199
+ super().__init__()
200
+
201
+ self.upsample = MLEUpsampleBlock(
202
+ config, in_features=in_features, out_features=out_features
203
+ )
204
+
205
+ self.norm = MLEBatchNorm(config, in_features=out_features)
206
+ self.conv = nn.Conv2d(
207
+ out_features,
208
+ out_features,
209
+ config.block_kernel_size,
210
+ stride=1,
211
+ padding=config.block_kernel_size // 2,
212
+ )
213
+
214
+ if in_features != out_features:
215
+ self.resize = nn.Sequential(
216
+ nn.Conv2d(
217
+ in_features,
218
+ out_features,
219
+ kernel_size=1,
220
+ stride=1,
221
+ ),
222
+ nn.Upsample(scale_factor=config.upsample_ratio),
223
+ )
224
+ else:
225
+ self.resize = None
226
+
227
+ def forward(self, hidden_states: torch.Tensor):
228
+ output = self.upsample(hidden_states)
229
+ output = self.norm(output)
230
+ output = self.conv(output)
231
+
232
+ if self.resize is not None:
233
+ output += self.resize(hidden_states)
234
+
235
+ return output
236
+
237
+
238
+ class MLEDecoderLayer(nn.Module):
239
+ def __init__(
240
+ self,
241
+ config: MLEConfig,
242
+ in_features: int,
243
+ out_features: int,
244
+ num_layers: int,
245
+ ):
246
+ super().__init__()
247
+
248
+ self.blocks = nn.ModuleList(
249
+ [
250
+ (
251
+ MLEResBlock(
252
+ config,
253
+ in_channels=out_features,
254
+ out_channels=out_features,
255
+ stride_size=1,
256
+ )
257
+ if i > 0
258
+ else MLEUpsampleResBlock(
259
+ config,
260
+ in_features=in_features,
261
+ out_features=out_features,
262
+ )
263
+ )
264
+ for i in range(num_layers)
265
+ ]
266
+ )
267
+
268
+ def forward(
269
+ self, hidden_states: torch.Tensor, shortcut_states: torch.Tensor
270
+ ) -> torch.Tensor:
271
+ for block in self.blocks:
272
+ hidden_states = block(hidden_states)
273
+
274
+ hidden_states += shortcut_states
275
+
276
+ return hidden_states
277
+
278
+
279
+ class MLEDecoderHead(nn.Module):
280
+ def __init__(self, config: MLEConfig, num_layers: int):
281
+ super().__init__()
282
+
283
+ self.layer = MLEEncoderLayer(
284
+ config,
285
+ in_features=config.block_patch_size,
286
+ out_features=config.last_hidden_channels,
287
+ stride_sizes=[1 for _ in range(num_layers)],
288
+ num_layers=num_layers,
289
+ )
290
+ self.norm = MLEBatchNorm(config, in_features=config.last_hidden_channels)
291
+ self.conv = nn.Conv2d(
292
+ config.last_hidden_channels,
293
+ out_channels=1,
294
+ kernel_size=1,
295
+ stride=1,
296
+ )
297
+
298
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
299
+ hidden_states = self.layer(hidden_states)
300
+ hidden_states = self.norm(hidden_states)
301
+ pixel_values = self.conv(hidden_states)
302
+ return pixel_values
303
+
304
+
305
+ class MLEDecoder(nn.Module):
306
+ def __init__(
307
+ self,
308
+ config: MLEConfig,
309
+ ):
310
+ super().__init__()
311
+
312
+ encoder_output_channels = (
313
+ config.in_channels
314
+ * config.block_patch_size
315
+ * (config.upsample_ratio ** (len(config.num_encoder_layers) - 1))
316
+ )
317
+ upsample_ratio = config.upsample_ratio
318
+ num_decoder_layers = config.num_decoder_layers
319
+
320
+ self.layers = nn.ModuleList(
321
+ [
322
+ (
323
+ MLEDecoderLayer(
324
+ config,
325
+ in_features=encoder_output_channels // (upsample_ratio**i),
326
+ out_features=encoder_output_channels
327
+ // (upsample_ratio ** (i + 1)),
328
+ num_layers=num_layers,
329
+ )
330
+ if i < len(num_decoder_layers) - 1
331
+ else MLEDecoderHead(
332
+ config,
333
+ num_layers=num_layers,
334
+ )
335
+ )
336
+ for i, num_layers in enumerate(num_decoder_layers)
337
+ ]
338
+ )
339
+
340
+ def forward(
341
+ self,
342
+ last_hidden_states: torch.Tensor,
343
+ encoder_hidden_states: tuple[torch.Tensor, ...],
344
+ ) -> torch.Tensor:
345
+ hidden_states = last_hidden_states
346
+ num_encoder_hidden_states = len(encoder_hidden_states) # 5
347
+
348
+ for i, layer in enumerate(self.layers):
349
+ if i < len(self.layers) - 1:
350
+ hidden_states = layer(
351
+ hidden_states,
352
+ # 0, 1, 2, 3, 4
353
+ # ↓ ↓ ↓ ↓ ↓
354
+ # 8, 7, 6, 5, 5
355
+ encoder_hidden_states[num_encoder_hidden_states - 2 - i],
356
+ )
357
+ else:
358
+ # decoder head
359
+ hidden_states = layer(hidden_states)
360
+
361
+ return hidden_states
362
+
363
+
364
+ class MLEPretrainedModel(PreTrainedModel):
365
+ config_class = MLEConfig
366
+ base_model_prefix = "model"
367
+ supports_gradient_checkpointing = True
368
+
369
+
370
+ class MLEModel(MLEPretrainedModel):
371
+ def __init__(self, config: MLEConfig):
372
+ super().__init__(config)
373
+ self.config = config
374
+
375
+ self.encoder = MLEEncoder(config)
376
+ self.decoder = MLEDecoder(config)
377
+
378
+ # Initialize weights and apply final processing
379
+ self.post_init()
380
+
381
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
382
+ encoder_output, all_hidden_states = self.encoder(pixel_values)
383
+ decoder_output = self.decoder(encoder_output, all_hidden_states)
384
+
385
+ return decoder_output
386
+
387
+
388
+ class MLEForAnimeLineExtraction(MLEPretrainedModel):
389
+ def __init__(self, config: MLEConfig):
390
+ super().__init__(config)
391
+
392
+ self.model = MLEModel(config)
393
+
394
+ def postprocess(self, output_tensor: torch.Tensor, input_shape: torch.Size):
395
+ pixel_values = output_tensor[0, 0, :, :]
396
+ pixel_values = torch.clip(pixel_values, 0, 255)
397
+
398
+ pixel_values = pixel_values[0 : input_shape[2], 0 : input_shape[3]]
399
+ return pixel_values
400
+
401
+ def forward(
402
+ self, pixel_values: torch.Tensor, return_dict: bool = True
403
+ ) -> tuple[torch.Tensor, ...] | MLEForAnimeLineExtractionOutput:
404
+ model_output = self.model(pixel_values)
405
+
406
+ if not return_dict:
407
+ return (model_output, self.postprocess(model_output, pixel_values.shape))
408
+
409
+ else:
410
+ return MLEForAnimeLineExtractionOutput(
411
+ last_hidden_state=model_output,
412
+ pixel_values=self.postprocess(model_output, pixel_values.shape),
413
+ )