File size: 2,416 Bytes
e4147e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Any, Mapping
from .configuration_arabichar import ArabiCharModelConfig
from transformers import PreTrainedModel
import torch
import torch.nn as nn

class ArabiCharModel(nn.Module):
    def __init__(self, config):
        super(ArabiCharModel, self).__init__()

        self.conv1 = nn.Conv2d(1, config.conv1_channels, kernel_size=5, padding=4)
        self.conv2 = nn.Conv2d(config.conv1_channels, config.conv1_channels, kernel_size=5)
        self.conv3 = nn.Conv2d(config.conv1_channels, config.conv1_channels, kernel_size=5)
        self.pool1 = nn.MaxPool2d(2)
        self.bn1 = nn.BatchNorm2d(config.conv1_channels)

        self.conv4 = nn.Conv2d(config.conv1_channels, config.conv2_channels, kernel_size=5, padding=4)
        self.conv5 = nn.Conv2d(config.conv2_channels, config.conv2_channels, kernel_size=5)
        self.conv6 = nn.Conv2d(config.conv2_channels, config.conv2_channels, kernel_size=5)
        self.pool2 = nn.MaxPool2d(2)
        self.bn2 = nn.BatchNorm2d(config.conv2_channels)

        self.fc1 = nn.Linear(config.conv2_channels * 5 * 5, config.fc1_units)
        self.fc2 = nn.Linear(config.fc1_units, config.fc1_units)
        self.dropout = nn.Dropout(config.dropout_prob)
        self.fc3 = nn.Linear(config.fc1_units, config.num_classes)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = self.pool1(x)
        x = self.bn1(x)

        x = torch.relu(self.conv4(x))
        x = torch.relu(self.conv5(x))
        x = torch.relu(self.conv6(x))
        x = self.pool2(x)
        x = self.bn2(x)

        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.dropout(x)
        return torch.softmax(self.fc3(x), dim=1)

class ArabiCharModelForImageClassification(PreTrainedModel):
    config_class = ArabiCharModelConfig
    def __init__(self, config):
        super().__init__(config)
        self.model = ArabiCharModel(config)

    def forward(self, tensor, labels=None):
        logits = self.model(tensor)
        if labels is not None:
            loss = torch.nn.cross_entropy(logits, labels)
            return {"loss": loss, "logits": logits}
        
        return {"logits": logits}

    def load_state_dict(self, model_name):
        self.model.load_state_dict(torch.load(model_name))