vjeronymo2's picture
Adding model and checkpoint
828992f
import torch
from contextlib import contextmanager
from colbert.utils.utils import NullContextManager
PyTorch_over_1_6 = float('.'.join(torch.__version__.split('.')[0:2])) >= 1.6
class MixedPrecisionManager():
def __init__(self, activated):
assert (not activated) or PyTorch_over_1_6, "Cannot use AMP for PyTorch version < 1.6"
self.activated = activated
if self.activated:
self.scaler = torch.cuda.amp.GradScaler()
def context(self):
return torch.cuda.amp.autocast() if self.activated else NullContextManager()
def backward(self, loss):
if self.activated:
self.scaler.scale(loss).backward()
else:
loss.backward()
def step(self, colbert, optimizer):
if self.activated:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)
self.scaler.step(optimizer)
self.scaler.update()
optimizer.zero_grad()
else:
torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)
optimizer.step()
optimizer.zero_grad()