| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
| from torchvision import transforms |
| from PIL import Image |
| from .configuration_tjmg import CaptchaConfig |
|
|
| |
| class CaptchaModel(PreTrainedModel): |
| config_class = CaptchaConfig |
| model_type = "captcha" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.vocab = config.vocab |
| self.output_ndigits = config.output_ndigits |
| self.output_vocab_size = config.output_vocab_size |
| self.input_dim = config.input_dim |
|
|
| self.batchnorm0 = nn.BatchNorm2d(3) |
| self.conv1 = nn.Conv2d(3, 32, kernel_size=3) |
| self.batchnorm1 = nn.BatchNorm2d(32) |
| self.conv2 = nn.Conv2d(32, 64, kernel_size=3) |
| self.batchnorm2 = nn.BatchNorm2d(64) |
| self.conv3 = nn.Conv2d(64, 64, kernel_size=3) |
| self.batchnorm3 = nn.BatchNorm2d(64) |
| self.dropout1 = nn.Dropout(0.25) |
| self.dropout2 = nn.Dropout(0.5) |
|
|
| def calc_dim(x): |
| for _ in range(3): |
| x = (x - 2) // 2 |
| return x |
|
|
| conv_h = calc_dim(config.input_dim[0]) |
| conv_w = calc_dim(config.input_dim[1]) |
| fc1_in_features = conv_h * conv_w * 64 |
|
|
| self.fc1 = nn.Linear(fc1_in_features, 200) |
| self.batchnorm_dense = nn.BatchNorm1d(200) |
| self.fc2 = nn.Linear(200, self.output_vocab_size * self.output_ndigits) |
| |
| def calc_dim_img_one(self, x): |
| return (x - 2) // 2 |
|
|
| def calc_dim_conv(self, x): |
| |
| for _ in range(3): |
| x = self.calc_dim_img_one(x) |
| return x |
| |
| def predict_captcha(self, file_path): |
| """ |
| Realiza a predição do captcha para uma imagem específica. |
| """ |
| |
| transform = transforms.Compose([ |
| transforms.Resize(self.input_dim), |
| transforms.ToTensor(), |
| ]) |
| image = Image.open(file_path).convert('RGB') |
| image = transform(image) |
| image = image.unsqueeze(0) |
|
|
| |
| with torch.no_grad(): |
| logits = self.forward(image) |
|
|
| |
| preds = torch.argmax(logits, dim=2) |
| predicted_label = "".join([self.vocab[i] for i in preds[0].tolist()]) |
|
|
| return predicted_label |
|
|
| def forward(self, x): |
| |
| x = self.batchnorm0(x) |
| x = self.conv1(x) |
| x = F.relu(x) |
| x = F.max_pool2d(x, kernel_size=2) |
| x = self.batchnorm1(x) |
|
|
| x = self.conv2(x) |
| x = F.relu(x) |
| x = F.max_pool2d(x, kernel_size=2) |
| x = self.batchnorm2(x) |
|
|
| x = self.conv3(x) |
| x = F.relu(x) |
| x = F.max_pool2d(x, kernel_size=2) |
| x = self.batchnorm3(x) |
|
|
| |
| x = torch.flatten(x, start_dim=1) |
| x = self.dropout1(x) |
| x = self.fc1(x) |
| x = F.relu(x) |
| x = self.batchnorm_dense(x) |
| x = self.dropout2(x) |
| x = self.fc2(x) |
| |
| x = x.view(-1, self.output_ndigits, self.output_vocab_size) |
| return x |