| import math |
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import SequenceClassifierOutput |
| from .configuration_captcha import CaptchaConfig |
|
|
| class PositionalEncoding(nn.Module): |
| def __init__(self, d_model, max_len=500): |
| super().__init__() |
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) |
| ) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| pe = pe.unsqueeze(0) |
| self.register_buffer("pe", pe) |
|
|
| def forward(self, x): |
| return x + self.pe[:, : x.size(1)] |
|
|
| class CaptchaConvolutionalTransformer(PreTrainedModel): |
| config_class = CaptchaConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| |
| self.conv = nn.Sequential( |
| nn.Conv2d(1, 32, kernel_size=3, padding=1), |
| nn.BatchNorm2d(32), |
| nn.SiLU(), |
| nn.MaxPool2d(2, 2), |
|
|
| nn.Conv2d(32, 64, kernel_size=3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.SiLU(), |
| nn.MaxPool2d(2, 2), |
|
|
| nn.Conv2d(64, 128, kernel_size=3, padding=1), |
| nn.BatchNorm2d(128), |
| nn.SiLU(), |
| nn.MaxPool2d(kernel_size=(2, 1)), |
|
|
| nn.Conv2d(128, 256, kernel_size=3, padding=1), |
| nn.BatchNorm2d(256), |
| nn.SiLU(), |
| ) |
|
|
| |
| self.positional_encoding = PositionalEncoding(config.d_model) |
|
|
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=config.d_model, |
| nhead=config.nhead, |
| dim_feedforward=config.dim_feedforward, |
| dropout=config.dropout, |
| activation="gelu", |
| batch_first=True, |
| norm_first=True, |
| ) |
|
|
| self.transformer = nn.TransformerEncoder( |
| encoder_layer, |
| num_layers=config.num_layers, |
| ) |
|
|
| |
| self.classifier = nn.Linear(config.d_model, config.num_chars) |
|
|
| |
| self.post_init() |
|
|
| def forward(self, pixel_values, labels=None): |
| """ |
| pixel_values: (batch, 1, H, W) |
| """ |
| |
| x = self.conv(pixel_values) |
|
|
| |
| x = x.permute(0, 3, 1, 2) |
| b, t, c, h = x.size() |
| |
| |
| x = x.reshape(b, t, c * h) |
|
|
| |
| x = self.positional_encoding(x) |
| x = self.transformer(x) |
| |
| |
| logits = self.classifier(x) |
|
|
| |
| return SequenceClassifierOutput(logits=logits) |