Spaces:
Runtime error
Runtime error
import os | |
import logging | |
from dataclasses import dataclass, field | |
from typing import Dict, List, Optional | |
import torch | |
import nlp | |
from transformers import T5Tokenizer, BartTokenizer, HfArgumentParser | |
logger = logging.getLogger(__name__) | |
class DataTrainingArguments: | |
""" | |
Arguments pertaining to what data we are going to input our model for training and eval. | |
""" | |
task: str = field( | |
metadata={"help": "Which task 'qa', 'qg', 'e2e_qg', 'ans_ext', 'multi'. 'multi' means 'qa', 'qg', 'ans_ext' tasks"}, | |
) | |
model_type: str = field(metadata={"help": "One of 't5', 'bart'"}) | |
dataset_path: Optional[str] = field( | |
default="data/squad_multitask", | |
metadata={"help": "Path for dataset directory"}, | |
) | |
train_file_name: Optional[str] = field( | |
default=None, | |
metadata={"help": "name for cached train dataset"}, | |
) | |
valid_file_name: Optional[str] = field( | |
default=None, | |
metadata={"help": "name for cached valid dataset"}, | |
) | |
valid_for_qg_only: bool = field( | |
default=False, | |
metadata={"help": "For multitask dataset valid split should contain only qg task or all tasks."} | |
) | |
qg_format: Optional[str] = field( | |
default='highlight_qg_format', | |
metadata={"help": "How to format inputs for que generation, 'highlight_qg_format' or 'prepend_qg_format'"}, | |
) | |
max_source_length: Optional[int] = field( | |
default=512, | |
metadata={"help": "Max input length for the source text"}, | |
) | |
max_target_length: Optional[int] = field( | |
default=32, | |
metadata={"help": "Max input length for the target text"}, | |
) | |
class DataProcessor: | |
def __init__(self, tokenizer, model_type="t5", max_source_length=512, max_target_length=32): | |
self.tokenizer = tokenizer | |
self.max_source_length = max_source_length | |
self.max_target_length = max_target_length | |
self.model_type = model_type | |
self.hl_token = "<hl>" | |
if model_type == "t5": | |
self.sep_token = "<sep>" | |
elif model_type == "bart": | |
self.sep_token = "<sep>" | |
else: | |
self.sep_token = "[SEP]" | |
def process(self, dataset): | |
if self.model_type == "t5": | |
dataset = dataset.map(self._add_eos_examples) | |
dataset = dataset.map(self._add_special_tokens) | |
dataset = dataset.map(self._convert_to_features, batched=True) | |
return dataset | |
def _add_eos_examples(self, example): | |
example['source_text'] = example['source_text'] + " </s>" | |
example['target_text'] = example['target_text'] + " </s>" | |
return example | |
def _add_special_tokens(self, example): | |
example['source_text'] = example['source_text'].replace("{hl_token}", self.hl_token) | |
example['target_text'] = example['target_text'].replace("{sep_token}", self.sep_token) | |
return example | |
# tokenize the examples | |
def _convert_to_features(self, example_batch): | |
source_encoding = self.tokenizer.batch_encode_plus( | |
example_batch['source_text'], | |
max_length=self.max_source_length, | |
padding='max_length', | |
pad_to_max_length=True, | |
truncation=True, | |
) | |
target_encoding = self.tokenizer.batch_encode_plus( | |
example_batch['target_text'], | |
max_length=self.max_target_length, | |
padding='max_length', | |
pad_to_max_length=True, | |
truncation=True, | |
) | |
encodings = { | |
'source_ids': source_encoding['input_ids'], | |
'target_ids': target_encoding['input_ids'], | |
'attention_mask': source_encoding['attention_mask'], | |
} | |
return encodings | |
def filter_qa(example): | |
return example['task'] == 'qa' | |
def filter_qg(example): | |
return example['task'] == 'qg' | |
def filter_e2e_qg(example): | |
return example['task'] == 'e2e_qg' | |
def filter_ans_ext(example): | |
return example['task'] == 'ans_ext' | |
def filter_multi(example): | |
return example['task'] != 'e2e_qg' | |
TASK_TO_FILTER_FN = { | |
'qa': filter_qa, | |
'qg': filter_qg, | |
'e2e_qg': filter_e2e_qg, | |
'ans_ext': filter_ans_ext, | |
'multi': filter_multi | |
} | |
def main(): | |
parser = HfArgumentParser((DataTrainingArguments,)) | |
data_args = parser.parse_args_into_dataclasses()[0] | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO | |
) | |
if data_args.model_type == 't5': | |
tokenizer = T5Tokenizer.from_pretrained("t5-base") | |
else: | |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") | |
tokenizer.add_tokens(['<sep>', '<hl>']) | |
train_dataset = nlp.load_dataset(data_args.dataset_path, name=data_args.qg_format, split=nlp.Split.TRAIN) | |
valid_dataset = nlp.load_dataset(data_args.dataset_path, name=data_args.qg_format, split=nlp.Split.VALIDATION) | |
processor = DataProcessor( | |
tokenizer, | |
model_type=data_args.model_type, | |
max_source_length=data_args.max_source_length, | |
max_target_length=data_args.max_target_length | |
) | |
train_dataset = train_dataset.filter(TASK_TO_FILTER_FN[data_args.task]) | |
if data_args.task == 'multi' and data_args.valid_for_qg_only: | |
logger.info("processing valid data only for qg task") | |
valid_dataset = valid_dataset.filter(filter_qg) | |
else: | |
valid_dataset = valid_dataset.filter(TASK_TO_FILTER_FN[data_args.task]) | |
train_dataset = processor.process(train_dataset) | |
valid_dataset = processor.process(valid_dataset) | |
columns = ["source_ids", "target_ids", "attention_mask"] | |
train_dataset.set_format(type='torch', columns=columns) | |
valid_dataset.set_format(type='torch', columns=columns) | |
if data_args.train_file_name is None: | |
train_file_name = f"train_data_{data_args.task}_{data_args.qg_format}_{data_args.model_type}.pt" | |
train_path = os.path.join("data", train_file_name) | |
valid_file_name = f"valid_data_{data_args.task}_{data_args.qg_format}_{data_args.model_type}.pt" | |
valid_path = os.path.join("data", valid_file_name) | |
else: | |
train_path = os.path.join("data", data_args.train_file_name) | |
valid_path = os.path.join("data", data_args.valid_file_name) | |
torch.save(train_dataset, train_path) | |
logger.info(f"saved train dataset at {train_path}") | |
torch.save(valid_dataset, valid_path) | |
logger.info(f"saved validation dataset at {valid_path}") | |
tokenizer_path = f"{data_args.model_type}_qg_tokenizer" | |
if not os.path.exists(tokenizer_path): | |
os.mkdir(tokenizer_path) | |
tokenizer.save_pretrained(tokenizer_path) | |
logger.info(f"saved tokenizer at {tokenizer_path}") | |
if __name__ == "__main__": | |
main() | |