|
import torch.nn as nn |
|
import torch |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from transformers.modeling_utils import PreTrainedModel |
|
from alexnet_model.configuration_alexnet import AlexNetConfig |
|
|
|
class AlexNetPneumoniaClassification(PreTrainedModel): |
|
config_class = AlexNetConfig |
|
|
|
def __init__(self, config): |
|
super(AlexNetPneumoniaClassification, self).__init__(config) |
|
self.num_labels = config.num_labels |
|
self.conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0) |
|
self.conv2 = nn.Conv2d(96, 256, kernel_size=5, stride=1,padding=2) |
|
self.conv3 = nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1) |
|
self.conv4 = nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1) |
|
self.conv5 = nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1) |
|
self.fc1 = nn.Linear(256*6*6, 4096) |
|
self.fc2 = nn.Linear(4096, 4096) |
|
self.fc3 = nn.Linear(4096, config.num_labels) |
|
|
|
def forward(self, pixel_values, labels=None): |
|
x = torch.relu(self.conv1(pixel_values)) |
|
x = torch.max_pool2d(x, kernel_size=3, stride=2, padding=0) |
|
x = torch.relu(self.conv2(x)) |
|
x = torch.max_pool2d(x, kernel_size=3, stride=2, padding=0) |
|
x = torch.relu(self.conv3(x)) |
|
x = torch.relu(self.conv4(x)) |
|
x = torch.relu(self.conv5(x)) |
|
x = torch.max_pool2d(x, kernel_size=3, stride=2, padding=0) |
|
x = x.view(-1, 256*6*6) |
|
x = torch.relu(self.fc1(x)) |
|
x = torch.relu(self.fc2(x)) |
|
logits = self.fc3(x) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels) |
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
) |
|
|
|
return SequenceClassifierOutput( |
|
logits=torch.softmax(logits, dim=1), |
|
) |