glenn-jocher
commited on
Commit
•
e189fa1
1
Parent(s):
fa2344c
`intersect_dicts()` in hubconf.py fix (#5542)
Browse files- hubconf.py +2 -3
- train.py +3 -4
- utils/general.py +5 -0
- utils/torch_utils.py +0 -5
hubconf.py
CHANGED
@@ -30,7 +30,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
|
|
30 |
from models.experimental import attempt_load
|
31 |
from models.yolo import Model
|
32 |
from utils.downloads import attempt_download
|
33 |
-
from utils.general import check_requirements, set_logging
|
34 |
from utils.torch_utils import select_device
|
35 |
|
36 |
file = Path(__file__).resolve()
|
@@ -49,9 +49,8 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
|
|
49 |
model = Model(cfg, channels, classes) # create model
|
50 |
if pretrained:
|
51 |
ckpt = torch.load(attempt_download(path), map_location=device) # load
|
52 |
-
msd = model.state_dict() # model state_dict
|
53 |
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
54 |
-
csd =
|
55 |
model.load_state_dict(csd, strict=False) # load
|
56 |
if len(ckpt['model'].names) == classes:
|
57 |
model.names = ckpt['model'].names # set class names attribute
|
|
|
30 |
from models.experimental import attempt_load
|
31 |
from models.yolo import Model
|
32 |
from utils.downloads import attempt_download
|
33 |
+
from utils.general import check_requirements, intersect_dicts, set_logging
|
34 |
from utils.torch_utils import select_device
|
35 |
|
36 |
file = Path(__file__).resolve()
|
|
|
49 |
model = Model(cfg, channels, classes) # create model
|
50 |
if pretrained:
|
51 |
ckpt = torch.load(attempt_download(path), map_location=device) # load
|
|
|
52 |
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
53 |
+
csd = intersect_dicts(csd, model.state_dict(), exclude=['anchors']) # intersect
|
54 |
model.load_state_dict(csd, strict=False) # load
|
55 |
if len(ckpt['model'].names) == classes:
|
56 |
model.names = ckpt['model'].names # set class names attribute
|
train.py
CHANGED
@@ -43,15 +43,14 @@ from utils.datasets import create_dataloader
|
|
43 |
from utils.downloads import attempt_download
|
44 |
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
|
45 |
check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
|
46 |
-
labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
|
47 |
-
print_mutation, strip_optimizer)
|
48 |
from utils.loggers import Loggers
|
49 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
50 |
from utils.loss import ComputeLoss
|
51 |
from utils.metrics import fitness
|
52 |
from utils.plots import plot_evolve, plot_labels
|
53 |
-
from utils.torch_utils import
|
54 |
-
torch_distributed_zero_first)
|
55 |
|
56 |
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
57 |
RANK = int(os.getenv('RANK', -1))
|
|
|
43 |
from utils.downloads import attempt_download
|
44 |
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
|
45 |
check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
|
46 |
+
intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
|
47 |
+
print_args, print_mutation, strip_optimizer)
|
48 |
from utils.loggers import Loggers
|
49 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
50 |
from utils.loss import ComputeLoss
|
51 |
from utils.metrics import fitness
|
52 |
from utils.plots import plot_evolve, plot_labels
|
53 |
+
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
|
|
|
54 |
|
55 |
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
56 |
RANK = int(os.getenv('RANK', -1))
|
utils/general.py
CHANGED
@@ -125,6 +125,11 @@ def init_seeds(seed=0):
|
|
125 |
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
|
126 |
|
127 |
|
|
|
|
|
|
|
|
|
|
|
128 |
def get_latest_run(search_dir='.'):
|
129 |
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
|
130 |
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
|
|
125 |
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
|
126 |
|
127 |
|
128 |
+
def intersect_dicts(da, db, exclude=()):
|
129 |
+
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
130 |
+
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
|
131 |
+
|
132 |
+
|
133 |
def get_latest_run(search_dir='.'):
|
134 |
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
|
135 |
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
utils/torch_utils.py
CHANGED
@@ -153,11 +153,6 @@ def de_parallel(model):
|
|
153 |
return model.module if is_parallel(model) else model
|
154 |
|
155 |
|
156 |
-
def intersect_dicts(da, db, exclude=()):
|
157 |
-
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
158 |
-
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
|
159 |
-
|
160 |
-
|
161 |
def initialize_weights(model):
|
162 |
for m in model.modules():
|
163 |
t = type(m)
|
|
|
153 |
return model.module if is_parallel(model) else model
|
154 |
|
155 |
|
|
|
|
|
|
|
|
|
|
|
156 |
def initialize_weights(model):
|
157 |
for m in model.modules():
|
158 |
t = type(m)
|