Arulkumar03's picture
Upload 1389 files
fcd8cdd
raw
history blame contribute delete
No virus
4.25 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import os
import tempfile
import unittest
from collections import OrderedDict
import torch
from iopath.common.file_io import PathHandler, PathManager
from torch import nn
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.checkpoint.c2_model_loading import (
_longest_common_prefix_str,
align_and_update_state_dicts,
)
from detectron2.utils.logger import setup_logger
class TestCheckpointer(unittest.TestCase):
def setUp(self):
setup_logger()
def create_complex_model(self):
m = nn.Module()
m.block1 = nn.Module()
m.block1.layer1 = nn.Linear(2, 3)
m.layer2 = nn.Linear(3, 2)
m.res = nn.Module()
m.res.layer2 = nn.Linear(3, 2)
state_dict = OrderedDict()
state_dict["layer1.weight"] = torch.rand(3, 2)
state_dict["layer1.bias"] = torch.rand(3)
state_dict["layer2.weight"] = torch.rand(2, 3)
state_dict["layer2.bias"] = torch.rand(2)
state_dict["res.layer2.weight"] = torch.rand(2, 3)
state_dict["res.layer2.bias"] = torch.rand(2)
return m, state_dict
def test_complex_model_loaded(self):
for add_data_parallel in [False, True]:
model, state_dict = self.create_complex_model()
if add_data_parallel:
model = nn.DataParallel(model)
model_sd = model.state_dict()
sd_to_load = align_and_update_state_dicts(model_sd, state_dict)
model.load_state_dict(sd_to_load)
for loaded, stored in zip(model_sd.values(), state_dict.values()):
# different tensor references
self.assertFalse(id(loaded) == id(stored))
# same content
self.assertTrue(loaded.to(stored).equal(stored))
def test_load_with_matching_heuristics(self):
with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
model, state_dict = self.create_complex_model()
torch.save({"model": state_dict}, os.path.join(d, "checkpoint.pth"))
checkpointer = DetectionCheckpointer(model, save_dir=d)
with torch.no_grad():
# use a different weight from the `state_dict`, since torch.rand is less than 1
model.block1.layer1.weight.fill_(1)
# load checkpoint without matching_heuristics
checkpointer.load(os.path.join(d, "checkpoint.pth"))
self.assertTrue(model.block1.layer1.weight.equal(torch.ones(3, 2)))
# load checkpoint with matching_heuristics
checkpointer.load(os.path.join(d, "checkpoint.pth?matching_heuristics=True"))
self.assertFalse(model.block1.layer1.weight.equal(torch.ones(3, 2)))
def test_custom_path_manager_handler(self):
with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
class CustomPathManagerHandler(PathHandler):
PREFIX = "detectron2_test://"
def _get_supported_prefixes(self):
return [self.PREFIX]
def _get_local_path(self, path, **kwargs):
name = path[len(self.PREFIX) :]
return os.path.join(d, name)
def _open(self, path, mode="r", **kwargs):
return open(self._get_local_path(path), mode, **kwargs)
pathmgr = PathManager()
pathmgr.register_handler(CustomPathManagerHandler())
model, state_dict = self.create_complex_model()
torch.save({"model": state_dict}, os.path.join(d, "checkpoint.pth"))
checkpointer = DetectionCheckpointer(model, save_dir=d)
checkpointer.path_manager = pathmgr
checkpointer.load("detectron2_test://checkpoint.pth")
checkpointer.load("detectron2_test://checkpoint.pth?matching_heuristics=True")
def test_lcp(self):
self.assertEqual(_longest_common_prefix_str(["class", "dlaps_model"]), "")
self.assertEqual(_longest_common_prefix_str(["classA", "classB"]), "class")
self.assertEqual(_longest_common_prefix_str(["classA", "classB", "clab"]), "cla")
if __name__ == "__main__":
unittest.main()