Spaces:
Runtime error
Runtime error
import unittest | |
from detectron2.solver.build import _expand_param_groups, reduce_param_groups | |
class TestOptimizer(unittest.TestCase): | |
def testExpandParamsGroups(self): | |
params = [ | |
{ | |
"params": ["p1", "p2", "p3", "p4"], | |
"lr": 1.0, | |
"weight_decay": 3.0, | |
}, | |
{ | |
"params": ["p2", "p3", "p5"], | |
"lr": 2.0, | |
"momentum": 2.0, | |
}, | |
{ | |
"params": ["p1"], | |
"weight_decay": 4.0, | |
}, | |
] | |
out = _expand_param_groups(params) | |
gt = [ | |
dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa | |
dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa | |
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa | |
dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa | |
dict(params=["p5"], lr=2.0, momentum=2.0), # noqa | |
] | |
self.assertEqual(out, gt) | |
def testReduceParamGroups(self): | |
params = [ | |
dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa | |
dict(params=["p2", "p6"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa | |
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa | |
dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa | |
dict(params=["p5"], lr=2.0, momentum=2.0), # noqa | |
] | |
gt_groups = [ | |
{ | |
"lr": 1.0, | |
"weight_decay": 4.0, | |
"params": ["p1"], | |
}, | |
{ | |
"lr": 2.0, | |
"weight_decay": 3.0, | |
"momentum": 2.0, | |
"params": ["p2", "p6", "p3"], | |
}, | |
{ | |
"lr": 1.0, | |
"weight_decay": 3.0, | |
"params": ["p4"], | |
}, | |
{ | |
"lr": 2.0, | |
"momentum": 2.0, | |
"params": ["p5"], | |
}, | |
] | |
out = reduce_param_groups(params) | |
self.assertEqual(out, gt_groups) | |