JustinLin610
update
10b0761
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import tempfile
import numpy as np
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import (
Dictionary,
IdDataset,
ListDataset,
NestedDictionaryDataset,
NumelDataset,
NumSamplesDataset,
PadDataset,
SortDataset,
data_utils,
encoders,
)
from fairseq.tasks import LegacyFairseqTask, register_task
from . import wsc_utils
@register_task("wsc")
class WSCTask(LegacyFairseqTask):
"""Task to finetune RoBERTa for Winograd Schemas."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument(
"data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
)
parser.add_argument(
"--init-token",
type=int,
default=None,
help="add token at the beginning of each batch item",
)
def __init__(self, args, vocab):
super().__init__(args)
self.vocab = vocab
self.mask = vocab.add_symbol("<mask>")
self.bpe = encoders.build_bpe(args)
self.tokenizer = encoders.build_tokenizer(args)
# hack to handle GPT-2 BPE, which includes leading spaces
if args.bpe == "gpt2":
self.leading_space = True
self.trailing_space = False
else:
self.leading_space = False
self.trailing_space = True
@classmethod
def load_dictionary(cls, filename):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol("<mask>")
return dictionary
@classmethod
def setup_task(cls, args, **kwargs):
assert args.criterion == "wsc", "Must set --criterion=wsc"
# load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
print("| dictionary: {} types".format(len(vocab)))
return cls(args, vocab)
def binarize(self, s: str, append_eos: bool = False):
if self.tokenizer is not None:
s = self.tokenizer.encode(s)
if self.bpe is not None:
s = self.bpe.encode(s)
tokens = self.vocab.encode_line(
s,
append_eos=append_eos,
add_if_not_exist=False,
).long()
if self.args.init_token is not None:
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
return tokens
def binarize_with_mask(self, txt, prefix, suffix, leading_space, trailing_space):
toks = self.binarize(
prefix + leading_space + txt + trailing_space + suffix,
append_eos=True,
)
mask = torch.zeros_like(toks, dtype=torch.bool)
mask_start = len(self.binarize(prefix))
mask_size = len(self.binarize(leading_space + txt))
mask[mask_start : mask_start + mask_size] = 1
return toks, mask
def load_dataset(
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if data_path is None:
data_path = os.path.join(self.args.data, split + ".jsonl")
if not os.path.exists(data_path):
raise FileNotFoundError("Cannot find data: {}".format(data_path))
query_tokens = []
query_masks = []
query_lengths = []
candidate_tokens = []
candidate_masks = []
candidate_lengths = []
labels = []
for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
prefix = sentence[: pronoun_span.start].text
suffix = sentence[pronoun_span.end :].text_with_ws
# spaCy spans include trailing spaces, but we need to know about
# leading spaces for the GPT-2 BPE
leading_space = (
" " if sentence[: pronoun_span.start].text_with_ws.endswith(" ") else ""
)
trailing_space = " " if pronoun_span.text_with_ws.endswith(" ") else ""
# get noun phrases, excluding pronouns and anything overlapping with the query
cand_spans = wsc_utils.filter_noun_chunks(
wsc_utils.extended_noun_chunks(sentence),
exclude_pronouns=True,
exclude_query=query,
exact_match=False,
)
if query is not None:
query_toks, query_mask = self.binarize_with_mask(
query, prefix, suffix, leading_space, trailing_space
)
query_len = len(query_toks)
else:
query_toks, query_mask, query_len = None, None, 0
query_tokens.append(query_toks)
query_masks.append(query_mask)
query_lengths.append(query_len)
cand_toks, cand_masks = [], []
for cand_span in cand_spans:
toks, mask = self.binarize_with_mask(
cand_span.text,
prefix,
suffix,
leading_space,
trailing_space,
)
cand_toks.append(toks)
cand_masks.append(mask)
# collate candidates
cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad())
cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0)
assert cand_toks.size() == cand_masks.size()
candidate_tokens.append(cand_toks)
candidate_masks.append(cand_masks)
candidate_lengths.append(cand_toks.size(1))
labels.append(label)
query_lengths = np.array(query_lengths)
query_tokens = ListDataset(query_tokens, query_lengths)
query_masks = ListDataset(query_masks, query_lengths)
candidate_lengths = np.array(candidate_lengths)
candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
candidate_masks = ListDataset(candidate_masks, candidate_lengths)
labels = ListDataset(labels, [1] * len(labels))
dataset = {
"id": IdDataset(),
"query_tokens": query_tokens,
"query_masks": query_masks,
"candidate_tokens": candidate_tokens,
"candidate_masks": candidate_masks,
"labels": labels,
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(query_tokens, reduce=True),
}
nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[query_lengths],
)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(query_tokens))
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)
if return_only:
return dataset
self.datasets[split] = dataset
return self.datasets[split]
def build_dataset_for_inference(self, sample_json):
with tempfile.NamedTemporaryFile(buffering=0) as h:
h.write((json.dumps(sample_json) + "\n").encode("utf-8"))
dataset = self.load_dataset(
"disambiguate_pronoun",
data_path=h.name,
return_only=True,
)
return dataset
def disambiguate_pronoun(self, model, sentence, use_cuda=False):
sample_json = wsc_utils.convert_sentence_to_json(sentence)
dataset = self.build_dataset_for_inference(sample_json)
sample = dataset.collater([dataset[0]])
if use_cuda:
sample = utils.move_to_cuda(sample)
def get_masked_input(tokens, mask):
masked_tokens = tokens.clone()
masked_tokens[mask.bool()] = self.mask
return masked_tokens
def get_lprobs(tokens, mask):
logits, _ = model(src_tokens=get_masked_input(tokens, mask))
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
mask = mask.type_as(scores)
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
return scores
cand_lprobs = get_lprobs(
sample["candidate_tokens"][0],
sample["candidate_masks"][0],
)
if sample["query_tokens"][0] is not None:
query_lprobs = get_lprobs(
sample["query_tokens"][0].unsqueeze(0),
sample["query_masks"][0].unsqueeze(0),
)
return (query_lprobs >= cand_lprobs).all().item() == 1
else:
best_idx = cand_lprobs.argmax().item()
full_cand = sample["candidate_tokens"][0][best_idx]
mask = sample["candidate_masks"][0][best_idx]
toks = full_cand[mask.bool()]
return self.bpe.decode(self.source_dictionary.string(toks)).strip()
@property
def source_dictionary(self):
return self.vocab
@property
def target_dictionary(self):
return self.vocab
@register_task("winogrande")
class WinograndeTask(WSCTask):
"""
Task for WinoGrande dataset. Efficient implementation for Winograd schema
tasks with exactly two candidates, one of which is correct.
"""
@classmethod
def setup_task(cls, args, **kwargs):
assert args.criterion == "winogrande", "Must set --criterion=winogrande"
# load data and label dictionaries
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
print("| dictionary: {} types".format(len(vocab)))
return cls(args, vocab)
def load_dataset(
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if data_path is None:
data_path = os.path.join(self.args.data, split + ".jsonl")
if not os.path.exists(data_path):
raise FileNotFoundError("Cannot find data: {}".format(data_path))
query_tokens = []
query_masks = []
query_lengths = []
candidate_tokens = []
candidate_masks = []
candidate_lengths = []
itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == "test"))
for sample in itr:
sentence, pronoun_span, query, cand_text = sample
prefix = sentence[: pronoun_span[0]].rstrip()
suffix = sentence[pronoun_span[1] :]
leading_space = " " if sentence[: pronoun_span[0]].endswith(" ") else ""
trailing_space = ""
if query is not None:
query_toks, query_mask = self.binarize_with_mask(
query,
prefix,
suffix,
leading_space,
trailing_space,
)
query_len = len(query_toks)
else:
query_toks, query_mask, query_len = None, None, 0
query_tokens.append(query_toks)
query_masks.append(query_mask)
query_lengths.append(query_len)
cand_toks, cand_mask = self.binarize_with_mask(
cand_text,
prefix,
suffix,
leading_space,
trailing_space,
)
candidate_tokens.append(cand_toks)
candidate_masks.append(cand_mask)
candidate_lengths.append(cand_toks.size(0))
query_lengths = np.array(query_lengths)
def get_pad_dataset_fn(tokens, length, pad_idx):
return PadDataset(
ListDataset(tokens, length),
pad_idx=pad_idx,
left_pad=False,
)
query_tokens = get_pad_dataset_fn(query_tokens, query_lengths, self.vocab.pad())
query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0)
candidate_lengths = np.array(candidate_lengths)
candidate_tokens = get_pad_dataset_fn(
candidate_tokens, candidate_lengths, self.vocab.pad()
)
candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0)
dataset = {
"id": IdDataset(),
"query_tokens": query_tokens,
"query_masks": query_masks,
"candidate_tokens": candidate_tokens,
"candidate_masks": candidate_masks,
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(query_tokens, reduce=True),
}
nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[query_lengths],
)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(query_tokens))
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)
if return_only:
return dataset
self.datasets[split] = dataset
return self.datasets[split]