OFA-OCR / fairseq /examples /roberta /wsc /wsc_criterion.py
JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
6.04 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
from fairseq.data import encoders
@register_criterion("wsc")
class WSCCriterion(LegacyFairseqCriterion):
def __init__(self, args, task):
super().__init__(args, task)
if self.args.save_predictions is not None:
self.prediction_h = open(self.args.save_predictions, "w")
else:
self.prediction_h = None
self.bpe = encoders.build_bpe(args.bpe)
self.tokenizer = encoders.build_tokenizer(args.tokenizer)
def __del__(self):
if self.prediction_h is not None:
self.prediction_h.close()
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0)
parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0)
parser.add_argument(
"--wsc-cross-entropy",
action="store_true",
help="use cross entropy formulation instead of margin loss",
)
parser.add_argument(
"--save-predictions", metavar="FILE", help="file to save predictions to"
)
def get_masked_input(self, tokens, mask):
masked_tokens = tokens.clone()
masked_tokens[mask] = self.task.mask
return masked_tokens
def get_lprobs(self, model, tokens, mask):
logits, _ = model(src_tokens=self.get_masked_input(tokens, mask))
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
mask = mask.type_as(scores)
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
return scores
def get_loss(self, query_lprobs, cand_lprobs):
if self.args.wsc_cross_entropy:
return F.cross_entropy(
torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0),
query_lprobs.new([0]).long(),
)
else:
return (
-query_lprobs
+ self.args.wsc_margin_alpha
* (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0)
).sum()
def forward(self, model, sample, reduce=True):
# compute loss and accuracy
loss, nloss = 0.0, 0
ncorrect, nqueries = 0, 0
for i, label in enumerate(sample["labels"]):
query_lprobs = self.get_lprobs(
model,
sample["query_tokens"][i].unsqueeze(0),
sample["query_masks"][i].unsqueeze(0),
)
cand_lprobs = self.get_lprobs(
model,
sample["candidate_tokens"][i],
sample["candidate_masks"][i],
)
pred = (query_lprobs >= cand_lprobs).all().item()
if label is not None:
label = 1 if label else 0
ncorrect += 1 if pred == label else 0
nqueries += 1
if label:
# only compute a loss for positive instances
nloss += 1
loss += self.get_loss(query_lprobs, cand_lprobs)
id = sample["id"][i].item()
if self.prediction_h is not None:
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
if nloss == 0:
loss = torch.tensor(0.0, requires_grad=True)
sample_size = nqueries if nqueries > 0 else 1
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"ncorrect": ncorrect,
"nqueries": nqueries,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / sample_size / math.log(2),
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
nqueries = sum(log.get("nqueries", 0) for log in logging_outputs)
if nqueries > 0:
agg_output["accuracy"] = ncorrect / float(nqueries)
return agg_output
@register_criterion("winogrande")
class WinograndeCriterion(WSCCriterion):
def forward(self, model, sample, reduce=True):
# compute loss and accuracy
query_lprobs = self.get_lprobs(
model,
sample["query_tokens"],
sample["query_masks"],
)
cand_lprobs = self.get_lprobs(
model,
sample["candidate_tokens"],
sample["candidate_masks"],
)
pred = query_lprobs >= cand_lprobs
loss = self.get_loss(query_lprobs, cand_lprobs)
sample_size = sample["query_tokens"].size(0)
ncorrect = pred.sum().item()
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"ncorrect": ncorrect,
"nqueries": sample_size,
}
return loss, sample_size, logging_output