alexnet-pneumonia / modeling_alexnet.py
dyaminda's picture
End of training
bc1bebe
raw
history blame
1.99 kB
import torch.nn as nn
import torch
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.modeling_utils import PreTrainedModel
from alexnet_model.configuration_alexnet import AlexNetConfig
class AlexNetPneumoniaClassification(PreTrainedModel):
config_class = AlexNetConfig
def __init__(self, config):
super(AlexNetPneumoniaClassification, self).__init__(config)
self.num_labels = config.num_labels
self.conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0)
self.conv2 = nn.Conv2d(96, 256, kernel_size=5, stride=1,padding=2)
self.conv3 = nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1)
self.conv5 = nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(256*6*6, 4096)
self.fc2 = nn.Linear(4096, 4096)
self.fc3 = nn.Linear(4096, config.num_labels)
def forward(self, pixel_values, labels=None):
x = torch.relu(self.conv1(pixel_values))
x = torch.max_pool2d(x, kernel_size=3, stride=2, padding=0)
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, kernel_size=3, stride=2, padding=0)
x = torch.relu(self.conv3(x))
x = torch.relu(self.conv4(x))
x = torch.relu(self.conv5(x))
x = torch.max_pool2d(x, kernel_size=3, stride=2, padding=0)
x = x.view(-1, 256*6*6)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
logits = self.fc3(x)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels)
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
return SequenceClassifierOutput(
logits=torch.softmax(logits, dim=1),
)