|
import torch
|
|
|
|
from contextlib import contextmanager
|
|
from colbert.utils.utils import NullContextManager
|
|
|
|
PyTorch_over_1_6 = float('.'.join(torch.__version__.split('.')[0:2])) >= 1.6
|
|
|
|
|
|
class MixedPrecisionManager():
|
|
def __init__(self, activated):
|
|
assert (not activated) or PyTorch_over_1_6, "Cannot use AMP for PyTorch version < 1.6"
|
|
|
|
self.activated = activated
|
|
|
|
if self.activated:
|
|
self.scaler = torch.cuda.amp.GradScaler()
|
|
|
|
def context(self):
|
|
return torch.cuda.amp.autocast() if self.activated else NullContextManager()
|
|
|
|
def backward(self, loss):
|
|
if self.activated:
|
|
self.scaler.scale(loss).backward()
|
|
else:
|
|
loss.backward()
|
|
|
|
def step(self, colbert, optimizer):
|
|
if self.activated:
|
|
self.scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)
|
|
|
|
self.scaler.step(optimizer)
|
|
self.scaler.update()
|
|
optimizer.zero_grad()
|
|
else:
|
|
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|