glenn-jocher commited on
Commit
d08575e
1 Parent(s): 9b91db6

PyTorch Hub load directly when possible (#2986)

Browse files
Files changed (1) hide show
  1. hubconf.py +23 -19
hubconf.py CHANGED
@@ -9,7 +9,7 @@ from pathlib import Path
9
 
10
  import torch
11
 
12
- from models.yolo import Model
13
  from utils.general import check_requirements, set_logging
14
  from utils.google_utils import attempt_download
15
  from utils.torch_utils import select_device
@@ -26,33 +26,37 @@ def create(name, pretrained, channels, classes, autoshape, verbose):
26
  pretrained (bool): load pretrained weights into the model
27
  channels (int): number of input channels
28
  classes (int): number of model classes
 
 
29
 
30
  Returns:
31
- pytorch model
32
  """
 
 
33
  try:
34
- set_logging(verbose=verbose)
35
-
36
- cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
37
- model = Model(cfg, channels, classes)
38
- if pretrained:
39
- fname = f'{name}.pt' # checkpoint filename
40
- attempt_download(fname) # download if not found locally
41
- ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
42
- msd = model.state_dict() # model state_dict
43
- csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
44
- csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
45
- model.load_state_dict(csd, strict=False) # load
46
- if len(ckpt['model'].names) == classes:
47
- model.names = ckpt['model'].names # set class names attribute
48
- if autoshape:
49
- model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
50
  device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
51
  return model.to(device)
52
 
53
  except Exception as e:
54
  help_url = 'https://github.com/ultralytics/yolov5/issues/36'
55
- s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url
56
  raise Exception(s) from e
57
 
58
 
 
9
 
10
  import torch
11
 
12
+ from models.yolo import Model, attempt_load
13
  from utils.general import check_requirements, set_logging
14
  from utils.google_utils import attempt_download
15
  from utils.torch_utils import select_device
 
26
  pretrained (bool): load pretrained weights into the model
27
  channels (int): number of input channels
28
  classes (int): number of model classes
29
+ autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
30
+ verbose (bool): print all information to screen
31
 
32
  Returns:
33
+ YOLOv5 pytorch model
34
  """
35
+ set_logging(verbose=verbose)
36
+ fname = f'{name}.pt' # checkpoint filename
37
  try:
38
+ if pretrained and channels == 3 and classes == 80:
39
+ model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
40
+ else:
41
+ cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
42
+ model = Model(cfg, channels, classes) # create model
43
+ if pretrained:
44
+ attempt_download(fname) # download if not found locally
45
+ ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
46
+ msd = model.state_dict() # model state_dict
47
+ csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
48
+ csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
49
+ model.load_state_dict(csd, strict=False) # load
50
+ if len(ckpt['model'].names) == classes:
51
+ model.names = ckpt['model'].names # set class names attribute
52
+ if autoshape:
53
+ model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
54
  device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
55
  return model.to(device)
56
 
57
  except Exception as e:
58
  help_url = 'https://github.com/ultralytics/yolov5/issues/36'
59
+ s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
60
  raise Exception(s) from e
61
 
62