captcha2 / modeling_tjmg.py
jtrecenti's picture
update predict function
1483b7a
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from torchvision import transforms
from PIL import Image
from .configuration_tjmg import CaptchaConfig
# Modelo baseado na arquitetura do CaptchaCNN
class CaptchaModel(PreTrainedModel):
config_class = CaptchaConfig
model_type = "captcha"
def __init__(self, config):
super().__init__(config)
self.vocab = config.vocab
self.output_ndigits = config.output_ndigits
self.output_vocab_size = config.output_vocab_size
self.input_dim = config.input_dim
self.batchnorm0 = nn.BatchNorm2d(3)
self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
self.batchnorm1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
self.batchnorm2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3)
self.batchnorm3 = nn.BatchNorm2d(64)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
def calc_dim(x):
for _ in range(3):
x = (x - 2) // 2
return x
conv_h = calc_dim(config.input_dim[0])
conv_w = calc_dim(config.input_dim[1])
fc1_in_features = conv_h * conv_w * 64
self.fc1 = nn.Linear(fc1_in_features, 200)
self.batchnorm_dense = nn.BatchNorm1d(200)
self.fc2 = nn.Linear(200, self.output_vocab_size * self.output_ndigits)
def calc_dim_img_one(self, x):
return (x - 2) // 2
def calc_dim_conv(self, x):
# Aplica três vezes: floor((x - 2) / 2)
for _ in range(3):
x = self.calc_dim_img_one(x)
return x
def predict_captcha(self, file_path):
"""
Realiza a predição do captcha para uma imagem específica.
"""
# Carrega a imagem e aplica as transformações
transform = transforms.Compose([
transforms.Resize(self.input_dim),
transforms.ToTensor(),
])
image = Image.open(file_path).convert('RGB')
image = transform(image)
image = image.unsqueeze(0) # Adiciona uma dimensão para o batch
# Realiza a predição
with torch.no_grad():
logits = self.forward(image)
# Obtém a predição (índice da classe com maior probabilidade)
preds = torch.argmax(logits, dim=2)
predicted_label = "".join([self.vocab[i] for i in preds[0].tolist()])
return predicted_label
def forward(self, x):
# Passagem pela rede (observe que não usamos pipes, apenas chamadas sequenciais)
x = self.batchnorm0(x)
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=2)
x = self.batchnorm1(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=2)
x = self.batchnorm2(x)
x = self.conv3(x)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=2)
x = self.batchnorm3(x)
# Flatten: junta todas as dimensões, exceto o batch
x = torch.flatten(x, start_dim=1)
x = self.dropout1(x)
x = self.fc1(x)
x = F.relu(x)
x = self.batchnorm_dense(x)
x = self.dropout2(x)
x = self.fc2(x)
# Reestrutura para (batch_size, output_ndigits, output_vocab_size)
x = x.view(-1, self.output_ndigits, self.output_vocab_size)
return x