Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from fairseq import utils | |
| from fairseq.criterions import LegacyFairseqCriterion, register_criterion | |
| from fairseq.data import encoders | |
| class WSCCriterion(LegacyFairseqCriterion): | |
| def __init__(self, args, task): | |
| super().__init__(args, task) | |
| if self.args.save_predictions is not None: | |
| self.prediction_h = open(self.args.save_predictions, "w") | |
| else: | |
| self.prediction_h = None | |
| self.bpe = encoders.build_bpe(args.bpe) | |
| self.tokenizer = encoders.build_tokenizer(args.tokenizer) | |
| def __del__(self): | |
| if self.prediction_h is not None: | |
| self.prediction_h.close() | |
| def add_args(parser): | |
| """Add criterion-specific arguments to the parser.""" | |
| parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0) | |
| parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0) | |
| parser.add_argument( | |
| "--wsc-cross-entropy", | |
| action="store_true", | |
| help="use cross entropy formulation instead of margin loss", | |
| ) | |
| parser.add_argument( | |
| "--save-predictions", metavar="FILE", help="file to save predictions to" | |
| ) | |
| def get_masked_input(self, tokens, mask): | |
| masked_tokens = tokens.clone() | |
| masked_tokens[mask] = self.task.mask | |
| return masked_tokens | |
| def get_lprobs(self, model, tokens, mask): | |
| logits, _ = model(src_tokens=self.get_masked_input(tokens, mask)) | |
| lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float) | |
| scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1) | |
| mask = mask.type_as(scores) | |
| scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1) | |
| return scores | |
| def get_loss(self, query_lprobs, cand_lprobs): | |
| if self.args.wsc_cross_entropy: | |
| return F.cross_entropy( | |
| torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0), | |
| query_lprobs.new([0]).long(), | |
| ) | |
| else: | |
| return ( | |
| -query_lprobs | |
| + self.args.wsc_margin_alpha | |
| * (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0) | |
| ).sum() | |
| def forward(self, model, sample, reduce=True): | |
| # compute loss and accuracy | |
| loss, nloss = 0.0, 0 | |
| ncorrect, nqueries = 0, 0 | |
| for i, label in enumerate(sample["labels"]): | |
| query_lprobs = self.get_lprobs( | |
| model, | |
| sample["query_tokens"][i].unsqueeze(0), | |
| sample["query_masks"][i].unsqueeze(0), | |
| ) | |
| cand_lprobs = self.get_lprobs( | |
| model, | |
| sample["candidate_tokens"][i], | |
| sample["candidate_masks"][i], | |
| ) | |
| pred = (query_lprobs >= cand_lprobs).all().item() | |
| if label is not None: | |
| label = 1 if label else 0 | |
| ncorrect += 1 if pred == label else 0 | |
| nqueries += 1 | |
| if label: | |
| # only compute a loss for positive instances | |
| nloss += 1 | |
| loss += self.get_loss(query_lprobs, cand_lprobs) | |
| id = sample["id"][i].item() | |
| if self.prediction_h is not None: | |
| print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h) | |
| if nloss == 0: | |
| loss = torch.tensor(0.0, requires_grad=True) | |
| sample_size = nqueries if nqueries > 0 else 1 | |
| logging_output = { | |
| "loss": utils.item(loss.data) if reduce else loss.data, | |
| "ntokens": sample["ntokens"], | |
| "nsentences": sample["nsentences"], | |
| "sample_size": sample_size, | |
| "ncorrect": ncorrect, | |
| "nqueries": nqueries, | |
| } | |
| return loss, sample_size, logging_output | |
| def aggregate_logging_outputs(logging_outputs): | |
| """Aggregate logging outputs from data parallel training.""" | |
| loss_sum = sum(log.get("loss", 0) for log in logging_outputs) | |
| ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) | |
| nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) | |
| sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) | |
| agg_output = { | |
| "loss": loss_sum / sample_size / math.log(2), | |
| "ntokens": ntokens, | |
| "nsentences": nsentences, | |
| "sample_size": sample_size, | |
| } | |
| ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) | |
| nqueries = sum(log.get("nqueries", 0) for log in logging_outputs) | |
| if nqueries > 0: | |
| agg_output["accuracy"] = ncorrect / float(nqueries) | |
| return agg_output | |
| class WinograndeCriterion(WSCCriterion): | |
| def forward(self, model, sample, reduce=True): | |
| # compute loss and accuracy | |
| query_lprobs = self.get_lprobs( | |
| model, | |
| sample["query_tokens"], | |
| sample["query_masks"], | |
| ) | |
| cand_lprobs = self.get_lprobs( | |
| model, | |
| sample["candidate_tokens"], | |
| sample["candidate_masks"], | |
| ) | |
| pred = query_lprobs >= cand_lprobs | |
| loss = self.get_loss(query_lprobs, cand_lprobs) | |
| sample_size = sample["query_tokens"].size(0) | |
| ncorrect = pred.sum().item() | |
| logging_output = { | |
| "loss": utils.item(loss.data) if reduce else loss.data, | |
| "ntokens": sample["ntokens"], | |
| "nsentences": sample["nsentences"], | |
| "sample_size": sample_size, | |
| "ncorrect": ncorrect, | |
| "nqueries": sample_size, | |
| } | |
| return loss, sample_size, logging_output | |