Spaces:
Sleeping
Sleeping
File size: 6,383 Bytes
749745d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import logging
import os
import torch
from maskrcnn_benchmark.utils.model_serialization import load_state_dict
from maskrcnn_benchmark.utils.c2_model_loading import load_c2_format
from maskrcnn_benchmark.utils.big_model_loading import load_big_format
from maskrcnn_benchmark.utils.pretrain_model_loading import load_pretrain_format
from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.model_zoo import cache_url
class Checkpointer(object):
def __init__(
self,
model,
optimizer=None,
scheduler=None,
save_dir="",
save_to_disk=None,
logger=None,
):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.save_dir = save_dir
self.save_to_disk = save_to_disk
if logger is None:
logger = logging.getLogger(__name__)
self.logger = logger
def save(self, name, **kwargs):
if not self.save_dir:
return
if not self.save_to_disk:
return
data = {}
data["model"] = self.model.state_dict()
if self.optimizer is not None:
data["optimizer"] = self.optimizer.state_dict()
if self.scheduler is not None:
if isinstance(self.scheduler, list):
data["scheduler"] = [scheduler.state_dict() for scheduler in self.scheduler]
else:
data["scheduler"] = self.scheduler.state_dict()
data.update(kwargs)
save_file = os.path.join(self.save_dir, "{}.pth".format(name))
self.logger.info("Saving checkpoint to {}".format(save_file))
torch.save(data, save_file)
# self.tag_last_checkpoint(save_file)
# use relative path name to save the checkpoint
self.tag_last_checkpoint("{}.pth".format(name))
def load(self, f=None, force=False, keyword="model", skip_optimizer=False, skip_scheduler=False):
resume = False
if self.has_checkpoint() and not force:
# override argument with existing checkpoint
f = self.get_checkpoint_file()
# get the absolute path
f = os.path.join(self.save_dir, f)
resume = True
if not f:
# no checkpoint could be found
self.logger.info("No checkpoint found. Initializing model from scratch")
return {}
self.logger.info("Loading checkpoint from {}".format(f))
checkpoint = self._load_file(f)
self._load_model(checkpoint, keyword=keyword)
# if resume training, load optimizer and scheduler,
# otherwise use the specified LR in config yaml for fine-tuning
if resume and not skip_optimizer:
if "optimizer" in checkpoint and self.optimizer:
self.logger.info("Loading optimizer from {}".format(f))
self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
if "scheduler" in checkpoint and self.scheduler and not skip_scheduler:
self.logger.info("Loading scheduler from {}".format(f))
if isinstance(self.scheduler, list):
for scheduler, state_dict in zip(self.scheduler, checkpoint.pop("scheduler")):
scheduler.load_state_dict(state_dict)
else:
self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
# print("Scheduler", {k:v for k,v in self.scheduler.state_dict() if k != "base_lrs"})
# return any further checkpoint data
return checkpoint
else:
return {}
def has_checkpoint(self):
save_file = os.path.join(self.save_dir, "last_checkpoint")
return os.path.exists(save_file)
def get_checkpoint_file(self):
save_file = os.path.join(self.save_dir, "last_checkpoint")
try:
with open(save_file, "r") as f:
last_saved = f.read()
last_saved = last_saved.strip()
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
last_saved = ""
return last_saved
def tag_last_checkpoint(self, last_filename):
save_file = os.path.join(self.save_dir, "last_checkpoint")
with open(save_file, "w") as f:
f.write(last_filename)
def _load_file(self, f):
return torch.load(f, map_location=torch.device("cpu"))
def _load_model(self, checkpoint, keyword="model"):
load_state_dict(self.model, checkpoint.pop(keyword))
class DetectronCheckpointer(Checkpointer):
def __init__(
self,
cfg,
model,
optimizer=None,
scheduler=None,
save_dir="",
save_to_disk=None,
logger=None,
):
super(DetectronCheckpointer, self).__init__(model, optimizer, scheduler, save_dir, save_to_disk, logger)
self.cfg = cfg.clone()
def _load_file(self, f):
# catalog lookup
if f.startswith("catalog://"):
paths_catalog = import_file("maskrcnn_benchmark.config.paths_catalog", self.cfg.PATHS_CATALOG, True)
catalog_f = paths_catalog.ModelCatalog.get(f[len("catalog://") :])
self.logger.info("{} points to {}".format(f, catalog_f))
f = catalog_f
# download url files
if f.startswith("http"):
# if the file is a url path, download it and cache it
cached_f = cache_url(f)
self.logger.info("url {} cached in {}".format(f, cached_f))
f = cached_f
# convert Caffe2 checkpoint from pkl
if f.endswith(".pkl"):
return load_c2_format(self.cfg, f)
if f.endswith(".big"):
return load_big_format(self.cfg, f)
if f.endswith(".pretrain"):
return load_pretrain_format(self.cfg, f)
# load native detectron.pytorch checkpoint
loaded = super(DetectronCheckpointer, self)._load_file(f)
if "model" not in loaded:
loaded = dict(model=loaded)
return loaded
|