Spaces:
Runtime error
Runtime error
from typing import Callable, Dict, Iterable, List | |
from torch import nn | |
# these functions are taken from transformers repo | |
def grad_status(model: nn.Module) -> Iterable: | |
return (par.requires_grad for par in model.parameters()) | |
def freeze_params(model: nn.Module): | |
for par in model.parameters(): | |
par.requires_grad = False | |
def freeze_embeds(model: nn.Module): | |
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" | |
try: | |
freeze_params(model.model.shared) | |
for d in [model.model.encoder, model.model.decoder]: | |
freeze_params(d.embed_positions) | |
freeze_params(d.embed_tokens) | |
except AttributeError: | |
freeze_params(model.shared) | |
for d in [model.encoder, model.decoder]: | |
freeze_params(d.embed_tokens) | |
def assert_not_all_frozen(model): | |
model_grads: List[bool] = list(grad_status(model)) | |
npars = len(model_grads) | |
assert any(model_grads), f"none of {npars} weights require grad" | |
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): | |
"""From fairseq""" | |
if target.dim() == lprobs.dim() - 1: | |
target = target.unsqueeze(-1) | |
nll_loss = -lprobs.gather(dim=-1, index=target) | |
smooth_loss = -lprobs.sum(dim=-1, keepdim=True) | |
if ignore_index is not None: | |
pad_mask = target.eq(ignore_index) | |
nll_loss.masked_fill_(pad_mask, 0.0) | |
smooth_loss.masked_fill_(pad_mask, 0.0) | |
bs = pad_mask.long().sum() | |
else: | |
nll_loss = nll_loss.squeeze(-1) | |
smooth_loss = smooth_loss.squeeze(-1) | |
bs = lprobs.shape[0] | |
nll_loss = nll_loss.sum() # mean()? Scared to break other math. | |
smooth_loss = smooth_loss.sum() | |
eps_i = epsilon / lprobs.size(-1) | |
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss | |
return loss / bs, nll_loss / bs |