import torch from contextlib import contextmanager from colbert.utils.utils import NullContextManager PyTorch_over_1_6 = float('.'.join(torch.__version__.split('.')[0:2])) >= 1.6 class MixedPrecisionManager(): def __init__(self, activated): assert (not activated) or PyTorch_over_1_6, "Cannot use AMP for PyTorch version < 1.6" self.activated = activated if self.activated: self.scaler = torch.cuda.amp.GradScaler() def context(self): return torch.cuda.amp.autocast() if self.activated else NullContextManager() def backward(self, loss): if self.activated: self.scaler.scale(loss).backward() else: loss.backward() def step(self, colbert, optimizer): if self.activated: self.scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0) self.scaler.step(optimizer) self.scaler.update() optimizer.zero_grad() else: torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0) optimizer.step() optimizer.zero_grad()