goodfellowliu commited on
Commit
8b8b792
1 Parent(s): 0c4b4b8

There is no need to download extra packages, official bring it with you

Browse files

I submitted it once in your yolov3 project, you seem to accept it? I'm not sure. I'll submit PR again.

Files changed (1) hide show
  1. utils/torch_utils.py +12 -7
utils/torch_utils.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  import torch.backends.cudnn as cudnn
8
  import torch.nn as nn
9
  import torch.nn.functional as F
 
10
 
11
 
12
  def init_seeds(seed=0):
@@ -120,18 +121,22 @@ def model_info(model, verbose=False):
120
 
121
  def load_classifier(name='resnet101', n=2):
122
  # Loads a pretrained model reshaped to n-class output
123
- import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision
124
- model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet')
125
 
126
  # Display model properties
127
- for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean', 'model.std']:
 
 
 
 
 
128
  print(x + ' =', eval(x))
129
 
130
  # Reshape output to n classes
131
- filters = model.last_linear.weight.shape[1]
132
- model.last_linear.bias = torch.nn.Parameter(torch.zeros(n))
133
- model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
134
- model.last_linear.out_features = n
135
  return model
136
 
137
 
 
7
  import torch.backends.cudnn as cudnn
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
+ import torchvision.models as models
11
 
12
 
13
  def init_seeds(seed=0):
 
121
 
122
  def load_classifier(name='resnet101', n=2):
123
  # Loads a pretrained model reshaped to n-class output
124
+ model = models.__dict__[name](pretrained=True)
 
125
 
126
  # Display model properties
127
+ input_size = [3, 224, 224]
128
+ input_space = 'RGB'
129
+ input_range = [0, 1]
130
+ mean = [0.485, 0.456, 0.406]
131
+ std = [0.229, 0.224, 0.225]
132
+ for x in [input_size, input_space, input_range, mean, std]:
133
  print(x + ' =', eval(x))
134
 
135
  # Reshape output to n classes
136
+ filters = model.fc.weight.shape[1]
137
+ model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True)
138
+ model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True)
139
+ model.fc.out_features = n
140
  return model
141
 
142