Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import json | |
import os | |
import numpy as np | |
import torch | |
from fairseq.data import ( | |
Dictionary, | |
IdDataset, | |
ListDataset, | |
NestedDictionaryDataset, | |
NumelDataset, | |
NumSamplesDataset, | |
RawLabelDataset, | |
RightPadDataset, | |
SortDataset, | |
data_utils, | |
encoders, | |
) | |
from fairseq.tasks import LegacyFairseqTask, register_task | |
class CommonsenseQATask(LegacyFairseqTask): | |
"""Task to finetune RoBERTa for Commonsense QA.""" | |
def add_args(parser): | |
"""Add task-specific arguments to the parser.""" | |
parser.add_argument( | |
"data", metavar="DIR", help="path to data directory; we load <split>.jsonl" | |
) | |
parser.add_argument( | |
"--init-token", | |
type=int, | |
default=None, | |
help="add token at the beginning of each batch item", | |
) | |
parser.add_argument("--num-classes", type=int, default=5) | |
def __init__(self, args, vocab): | |
super().__init__(args) | |
self.vocab = vocab | |
self.mask = vocab.add_symbol("<mask>") | |
self.bpe = encoders.build_bpe(args) | |
def load_dictionary(cls, filename): | |
"""Load the dictionary from the filename | |
Args: | |
filename (str): the filename | |
""" | |
dictionary = Dictionary.load(filename) | |
dictionary.add_symbol("<mask>") | |
return dictionary | |
def setup_task(cls, args, **kwargs): | |
assert ( | |
args.criterion == "sentence_ranking" | |
), "Must set --criterion=sentence_ranking" | |
# load data and label dictionaries | |
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt")) | |
print("| dictionary: {} types".format(len(vocab))) | |
return cls(args, vocab) | |
def load_dataset( | |
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs | |
): | |
"""Load a given dataset split. | |
Args: | |
split (str): name of the split (e.g., train, valid, test) | |
""" | |
def binarize(s, append_bos=False): | |
if self.bpe is not None: | |
s = self.bpe.encode(s) | |
tokens = self.vocab.encode_line( | |
s, | |
append_eos=True, | |
add_if_not_exist=False, | |
).long() | |
if append_bos and self.args.init_token is not None: | |
tokens = torch.cat([tokens.new([self.args.init_token]), tokens]) | |
return tokens | |
if data_path is None: | |
data_path = os.path.join(self.args.data, split + ".jsonl") | |
if not os.path.exists(data_path): | |
raise FileNotFoundError("Cannot find data: {}".format(data_path)) | |
src_tokens = [[] for i in range(self.args.num_classes)] | |
src_lengths = [[] for i in range(self.args.num_classes)] | |
labels = [] | |
with open(data_path) as h: | |
for line in h: | |
example = json.loads(line.strip()) | |
if "answerKey" in example: | |
label = ord(example["answerKey"]) - ord("A") | |
labels.append(label) | |
question = example["question"]["stem"] | |
assert len(example["question"]["choices"]) == self.args.num_classes | |
# format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>` | |
question = "Q: " + question | |
question_toks = binarize(question, append_bos=True) | |
for i, choice in enumerate(example["question"]["choices"]): | |
src = "A: " + choice["text"] | |
src_bin = torch.cat([question_toks, binarize(src)]) | |
src_tokens[i].append(src_bin) | |
src_lengths[i].append(len(src_bin)) | |
assert all( | |
len(src_tokens[0]) == len(src_tokens[i]) | |
for i in range(self.args.num_classes) | |
) | |
assert len(src_tokens[0]) == len(src_lengths[0]) | |
assert len(labels) == 0 or len(labels) == len(src_tokens[0]) | |
for i in range(self.args.num_classes): | |
src_lengths[i] = np.array(src_lengths[i]) | |
src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i]) | |
src_lengths[i] = ListDataset(src_lengths[i]) | |
dataset = { | |
"id": IdDataset(), | |
"nsentences": NumSamplesDataset(), | |
"ntokens": NumelDataset(src_tokens[0], reduce=True), | |
} | |
for i in range(self.args.num_classes): | |
dataset.update( | |
{ | |
"net_input{}".format(i + 1): { | |
"src_tokens": RightPadDataset( | |
src_tokens[i], | |
pad_idx=self.source_dictionary.pad(), | |
), | |
"src_lengths": src_lengths[i], | |
} | |
} | |
) | |
if len(labels) > 0: | |
dataset.update({"target": RawLabelDataset(labels)}) | |
dataset = NestedDictionaryDataset( | |
dataset, | |
sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])], | |
) | |
with data_utils.numpy_seed(self.args.seed): | |
dataset = SortDataset( | |
dataset, | |
# shuffle | |
sort_order=[np.random.permutation(len(dataset))], | |
) | |
print("| Loaded {} with {} samples".format(split, len(dataset))) | |
self.datasets[split] = dataset | |
return self.datasets[split] | |
def build_model(self, args): | |
from fairseq import models | |
model = models.build_model(args, self) | |
model.register_classification_head( | |
"sentence_classification_head", | |
num_classes=1, | |
) | |
return model | |
def source_dictionary(self): | |
return self.vocab | |
def target_dictionary(self): | |
return self.vocab | |