from typing import Any, Mapping from .configuration_arabichar import ArabiCharModelConfig from transformers import PreTrainedModel import torch import torch.nn as nn class ArabiCharModel(nn.Module): def __init__(self, config): super(ArabiCharModel, self).__init__() self.conv1 = nn.Conv2d(1, config.conv1_channels, kernel_size=5, padding=4) self.conv2 = nn.Conv2d(config.conv1_channels, config.conv1_channels, kernel_size=5) self.conv3 = nn.Conv2d(config.conv1_channels, config.conv1_channels, kernel_size=5) self.pool1 = nn.MaxPool2d(2) self.bn1 = nn.BatchNorm2d(config.conv1_channels) self.conv4 = nn.Conv2d(config.conv1_channels, config.conv2_channels, kernel_size=5, padding=4) self.conv5 = nn.Conv2d(config.conv2_channels, config.conv2_channels, kernel_size=5) self.conv6 = nn.Conv2d(config.conv2_channels, config.conv2_channels, kernel_size=5) self.pool2 = nn.MaxPool2d(2) self.bn2 = nn.BatchNorm2d(config.conv2_channels) self.fc1 = nn.Linear(config.conv2_channels * 5 * 5, config.fc1_units) self.fc2 = nn.Linear(config.fc1_units, config.fc1_units) self.dropout = nn.Dropout(config.dropout_prob) self.fc3 = nn.Linear(config.fc1_units, config.num_classes) def forward(self, x): x = torch.relu(self.conv1(x)) x = torch.relu(self.conv2(x)) x = torch.relu(self.conv3(x)) x = self.pool1(x) x = self.bn1(x) x = torch.relu(self.conv4(x)) x = torch.relu(self.conv5(x)) x = torch.relu(self.conv6(x)) x = self.pool2(x) x = self.bn2(x) x = x.view(x.size(0), -1) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.dropout(x) return torch.softmax(self.fc3(x), dim=1) class ArabiCharModelForImageClassification(PreTrainedModel): config_class = ArabiCharModelConfig def __init__(self, config): super().__init__(config) self.model = ArabiCharModel(config) def forward(self, tensor, labels=None): logits = self.model(tensor) if labels is not None: loss = torch.nn.cross_entropy(logits, labels) return {"loss": loss, "logits": logits} return {"logits": logits} def load_state_dict(self, model_name): self.model.load_state_dict(torch.load(model_name))