glenn-jocher commited on
Commit
e189fa1
1 Parent(s): fa2344c

`intersect_dicts()` in hubconf.py fix (#5542)

Browse files
Files changed (4) hide show
  1. hubconf.py +2 -3
  2. train.py +3 -4
  3. utils/general.py +5 -0
  4. utils/torch_utils.py +0 -5
hubconf.py CHANGED
@@ -30,7 +30,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
30
  from models.experimental import attempt_load
31
  from models.yolo import Model
32
  from utils.downloads import attempt_download
33
- from utils.general import check_requirements, set_logging
34
  from utils.torch_utils import select_device
35
 
36
  file = Path(__file__).resolve()
@@ -49,9 +49,8 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
49
  model = Model(cfg, channels, classes) # create model
50
  if pretrained:
51
  ckpt = torch.load(attempt_download(path), map_location=device) # load
52
- msd = model.state_dict() # model state_dict
53
  csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
54
- csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
55
  model.load_state_dict(csd, strict=False) # load
56
  if len(ckpt['model'].names) == classes:
57
  model.names = ckpt['model'].names # set class names attribute
 
30
  from models.experimental import attempt_load
31
  from models.yolo import Model
32
  from utils.downloads import attempt_download
33
+ from utils.general import check_requirements, intersect_dicts, set_logging
34
  from utils.torch_utils import select_device
35
 
36
  file = Path(__file__).resolve()
 
49
  model = Model(cfg, channels, classes) # create model
50
  if pretrained:
51
  ckpt = torch.load(attempt_download(path), map_location=device) # load
 
52
  csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
53
+ csd = intersect_dicts(csd, model.state_dict(), exclude=['anchors']) # intersect
54
  model.load_state_dict(csd, strict=False) # load
55
  if len(ckpt['model'].names) == classes:
56
  model.names = ckpt['model'].names # set class names attribute
train.py CHANGED
@@ -43,15 +43,14 @@ from utils.datasets import create_dataloader
43
  from utils.downloads import attempt_download
44
  from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
45
  check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
46
- labels_to_class_weights, labels_to_image_weights, methods, one_cycle, print_args,
47
- print_mutation, strip_optimizer)
48
  from utils.loggers import Loggers
49
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
50
  from utils.loss import ComputeLoss
51
  from utils.metrics import fitness
52
  from utils.plots import plot_evolve, plot_labels
53
- from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device,
54
- torch_distributed_zero_first)
55
 
56
  LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
57
  RANK = int(os.getenv('RANK', -1))
 
43
  from utils.downloads import attempt_download
44
  from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
45
  check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
46
+ intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
47
+ print_args, print_mutation, strip_optimizer)
48
  from utils.loggers import Loggers
49
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
50
  from utils.loss import ComputeLoss
51
  from utils.metrics import fitness
52
  from utils.plots import plot_evolve, plot_labels
53
+ from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
 
54
 
55
  LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
56
  RANK = int(os.getenv('RANK', -1))
utils/general.py CHANGED
@@ -125,6 +125,11 @@ def init_seeds(seed=0):
125
  cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
126
 
127
 
 
 
 
 
 
128
  def get_latest_run(search_dir='.'):
129
  # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
130
  last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
 
125
  cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
126
 
127
 
128
+ def intersect_dicts(da, db, exclude=()):
129
+ # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
130
+ return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
131
+
132
+
133
  def get_latest_run(search_dir='.'):
134
  # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
135
  last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
utils/torch_utils.py CHANGED
@@ -153,11 +153,6 @@ def de_parallel(model):
153
  return model.module if is_parallel(model) else model
154
 
155
 
156
- def intersect_dicts(da, db, exclude=()):
157
- # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
158
- return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
159
-
160
-
161
  def initialize_weights(model):
162
  for m in model.modules():
163
  t = type(m)
 
153
  return model.module if is_parallel(model) else model
154
 
155
 
 
 
 
 
 
156
  def initialize_weights(model):
157
  for m in model.modules():
158
  t = type(m)