|
from __future__ import print_function |
|
|
|
|
|
import os |
|
import torch |
|
import sys |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from collections import OrderedDict |
|
import torchvision.transforms |
|
import numpy |
|
|
|
def load_places_alexnet(weight_file): |
|
model = AlexNet() |
|
state_dict = torch.load(weight_file) |
|
model.load_state_dict(state_dict) |
|
return model |
|
|
|
|
|
class AlexNet(nn.Sequential): |
|
|
|
def __init__(self, num_classes=None, |
|
include_lrn=True, split_groups=True, |
|
include_dropout=True): |
|
w = [3, 96, 256, 384, 384, 256, 4096, 4096, 365] |
|
if num_classes is not None: |
|
w[-1] = num_classes |
|
if split_groups is True: |
|
groups = [1, 2, 1, 2, 2] |
|
else: |
|
groups = [1, 1, 1, 1, 1] |
|
sequence = OrderedDict() |
|
for name, module in [ |
|
('conv1', nn.Conv2d(w[0], w[1], kernel_size=11, |
|
stride=4, |
|
groups=groups[0], bias=True)), |
|
('relu1', nn.ReLU(inplace=True)), |
|
('pool1', nn.MaxPool2d(kernel_size=3, stride=2)), |
|
('lrn1', LRN(local_size=5, alpha=0.0001, beta=0.75)), |
|
('conv2', nn.Conv2d(w[1], w[2], kernel_size=5, padding=2, |
|
groups=groups[1], bias=True)), |
|
('relu2', nn.ReLU(inplace=True)), |
|
('pool2', nn.MaxPool2d(kernel_size=3, stride=2)), |
|
('lrn2', LRN(local_size=5, alpha=0.0001, beta=0.75)), |
|
('conv3', nn.Conv2d(w[2], w[3], kernel_size=3, padding=1, |
|
groups=groups[2], bias=True)), |
|
('relu3', nn.ReLU(inplace=True)), |
|
('conv4', nn.Conv2d(w[3], w[4], kernel_size=3, padding=1, |
|
groups=groups[3], bias=True)), |
|
('relu4', nn.ReLU(inplace=True)), |
|
('conv5', nn.Conv2d(w[4], w[5], kernel_size=3, padding=1, |
|
groups=groups[4], bias=True)), |
|
('relu5', nn.ReLU(inplace=True)), |
|
('pool5', nn.MaxPool2d(kernel_size=3, stride=2)), |
|
('flatten', Vectorize()), |
|
('fc6', nn.Linear(w[5] * 6 * 6, w[6], bias=True)), |
|
('relu6', nn.ReLU(inplace=True)), |
|
('dropout6', nn.Dropout()), |
|
('fc7', nn.Linear(w[6], w[7], bias=True)), |
|
('relu7', nn.ReLU(inplace=True)), |
|
('dropout7', nn.Dropout()), |
|
('fc8', nn.Linear(w[7], w[8])) ]: |
|
if not include_lrn and name.startswith('lrn'): |
|
continue |
|
if not include_dropout and name.startswith('drop'): |
|
continue |
|
sequence[name] = module |
|
super(AlexNet, self).__init__(sequence) |
|
|
|
class LRN(nn.Module): |
|
def __init__(self, local_size=1, alpha=1.0, beta=0.75, |
|
ACROSS_CHANNELS=True): |
|
super(LRN, self).__init__() |
|
self.ACROSS_CHANNELS = ACROSS_CHANNELS |
|
if ACROSS_CHANNELS: |
|
self.average=nn.AvgPool3d(kernel_size=(local_size, 1, 1), |
|
stride=1, |
|
padding=(int((local_size-1.0)/2), 0, 0)) |
|
else: |
|
self.average=nn.AvgPool2d(kernel_size=local_size, |
|
stride=1, |
|
padding=int((local_size-1.0)/2)) |
|
self.alpha = alpha |
|
self.beta = beta |
|
|
|
def forward(self, x): |
|
if self.ACROSS_CHANNELS: |
|
div = x.pow(2).unsqueeze(1) |
|
div = self.average(div).squeeze(1) |
|
div = div.mul(self.alpha).add(1.0).pow(self.beta) |
|
else: |
|
div = x.pow(2) |
|
div = self.average(div) |
|
div = div.mul(self.alpha).add(1.0).pow(self.beta) |
|
x = x.div(div) |
|
return x |
|
|
|
class Vectorize(nn.Module): |
|
def __init__(self): |
|
super(Vectorize, self).__init__() |
|
|
|
def forward(self, x): |
|
x = x.view(x.size(0), int(numpy.prod(x.size()[1:]))) |
|
return x |
|
|
|
|