File size: 3,971 Bytes
92f0e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from __future__ import print_function
# based on https://github.com/jiecaoyu/pytorch_imagenet

import os
import torch
import sys
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import torchvision.transforms
import numpy

def load_places_alexnet(weight_file):
    model = AlexNet()
    state_dict = torch.load(weight_file)
    model.load_state_dict(state_dict)
    return model


class AlexNet(nn.Sequential):

    def __init__(self, num_classes=None,
            include_lrn=True, split_groups=True,
            include_dropout=True):
        w = [3, 96, 256, 384, 384, 256, 4096, 4096, 365]
        if num_classes is not None:
            w[-1] = num_classes
        if split_groups is True:
            groups = [1, 2, 1, 2, 2]
        else:
            groups = [1, 1, 1, 1, 1]
        sequence = OrderedDict()
        for name, module in [
                ('conv1', nn.Conv2d(w[0], w[1], kernel_size=11,
                    stride=4,
                    groups=groups[0], bias=True)),
                ('relu1', nn.ReLU(inplace=True)),
                ('pool1', nn.MaxPool2d(kernel_size=3, stride=2)),
                ('lrn1', LRN(local_size=5, alpha=0.0001, beta=0.75)),
                ('conv2', nn.Conv2d(w[1], w[2], kernel_size=5, padding=2,
                    groups=groups[1], bias=True)),
                ('relu2', nn.ReLU(inplace=True)),
                ('pool2', nn.MaxPool2d(kernel_size=3, stride=2)),
                ('lrn2', LRN(local_size=5, alpha=0.0001, beta=0.75)),
                ('conv3', nn.Conv2d(w[2], w[3], kernel_size=3, padding=1,
                    groups=groups[2], bias=True)),
                ('relu3', nn.ReLU(inplace=True)),
                ('conv4', nn.Conv2d(w[3], w[4], kernel_size=3, padding=1,
                    groups=groups[3], bias=True)),
                ('relu4', nn.ReLU(inplace=True)),
                ('conv5', nn.Conv2d(w[4], w[5], kernel_size=3, padding=1,
                    groups=groups[4], bias=True)),
                ('relu5', nn.ReLU(inplace=True)),
                ('pool5', nn.MaxPool2d(kernel_size=3, stride=2)),
                ('flatten', Vectorize()),
                ('fc6', nn.Linear(w[5] * 6 * 6, w[6], bias=True)),
                ('relu6', nn.ReLU(inplace=True)),
                ('dropout6', nn.Dropout()),
                ('fc7', nn.Linear(w[6], w[7], bias=True)),
                ('relu7', nn.ReLU(inplace=True)),
                ('dropout7', nn.Dropout()),
                ('fc8', nn.Linear(w[7], w[8])) ]:
            if not include_lrn and name.startswith('lrn'):
                continue
            if not include_dropout and name.startswith('drop'):
                continue
            sequence[name] = module
        super(AlexNet, self).__init__(sequence)

class LRN(nn.Module):
    def __init__(self, local_size=1, alpha=1.0, beta=0.75,
            ACROSS_CHANNELS=True):
        super(LRN, self).__init__()
        self.ACROSS_CHANNELS = ACROSS_CHANNELS
        if ACROSS_CHANNELS:
            self.average=nn.AvgPool3d(kernel_size=(local_size, 1, 1),
                    stride=1,
                    padding=(int((local_size-1.0)/2), 0, 0))
        else:
            self.average=nn.AvgPool2d(kernel_size=local_size,
                    stride=1,
                    padding=int((local_size-1.0)/2))
        self.alpha = alpha
        self.beta = beta

    def forward(self, x):
        if self.ACROSS_CHANNELS:
            div = x.pow(2).unsqueeze(1)
            div = self.average(div).squeeze(1)
            div = div.mul(self.alpha).add(1.0).pow(self.beta)
        else:
            div = x.pow(2)
            div = self.average(div)
            div = div.mul(self.alpha).add(1.0).pow(self.beta)
        x = x.div(div)
        return x

class Vectorize(nn.Module):
    def __init__(self):
        super(Vectorize, self).__init__()

    def forward(self, x):
        x = x.view(x.size(0), int(numpy.prod(x.size()[1:])))
        return x