Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import logging | |
import unittest | |
import torch | |
from fairseq.optim.adam import FairseqAdam | |
from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer | |
from omegaconf import OmegaConf | |
class TestMemoryEfficientFP16(unittest.TestCase): | |
def setUp(self): | |
logging.disable(logging.CRITICAL) | |
def tearDown(self): | |
logging.disable(logging.NOTSET) | |
def test_load_state_dict(self): | |
# define simple FP16 model | |
model = torch.nn.Linear(5, 5).cuda().half() | |
params = list(model.parameters()) | |
# initialize memory efficient FP16 optimizer | |
# with pseudo DictConfigs | |
optimizer = FairseqAdam( | |
cfg=OmegaConf.create( | |
vars( | |
argparse.Namespace( | |
adam_betas="(0.9, 0.999)", | |
adam_eps=1e-8, | |
weight_decay=0.0, | |
lr=[0.00001], | |
) | |
) | |
), | |
params=params, | |
) | |
me_optimizer = MemoryEfficientFP16Optimizer( | |
cfg=OmegaConf.create( | |
{ | |
"common": vars( | |
argparse.Namespace( | |
fp16_init_scale=1, | |
fp16_scale_window=1, | |
fp16_scale_tolerance=1, | |
threshold_loss_scale=1, | |
min_loss_scale=1e-4, | |
) | |
) | |
} | |
), | |
params=params, | |
optimizer=optimizer, | |
) | |
# optimizer state is created in the first step | |
loss = model(torch.rand(5).cuda().half()).sum() | |
me_optimizer.backward(loss) | |
me_optimizer.step() | |
# reload state | |
state = me_optimizer.state_dict() | |
me_optimizer.load_state_dict(state) | |
for k, v in me_optimizer.optimizer.state.items(): | |
self.assertTrue(k.dtype == torch.float16) | |
for v_i in v.values(): | |
if torch.is_tensor(v_i): | |
self.assertTrue(v_i.dtype == torch.float32) | |
if __name__ == "__main__": | |
unittest.main() | |