Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from dataclasses import dataclass, field | |
from typing import Optional | |
from transformers import TrainingArguments | |
from logger_config import logger | |
class Arguments(TrainingArguments): | |
model_name_or_path: str = field( | |
default='bert-base-uncased', | |
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} | |
) | |
data_dir: str = field( | |
default=None, metadata={"help": "Path to train directory"} | |
) | |
task_type: str = field( | |
default='ir', metadata={"help": "task type: ir / qa"} | |
) | |
train_file: Optional[str] = field( | |
default=None, metadata={"help": "The input training data file (a jsonlines file)."} | |
) | |
validation_file: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": "An optional input evaluation data file to evaluate the metrics on (a jsonlines file)." | |
}, | |
) | |
train_n_passages: int = field( | |
default=8, | |
metadata={"help": "number of passages for each example (including both positive and negative passages)"} | |
) | |
share_encoder: bool = field( | |
default=True, | |
metadata={"help": "no weight sharing between qry passage encoders"} | |
) | |
use_first_positive: bool = field( | |
default=False, | |
metadata={"help": "Always use the first positive passage"} | |
) | |
use_scaled_loss: bool = field( | |
default=True, | |
metadata={"help": "Use scaled loss or not"} | |
) | |
loss_scale: float = field( | |
default=-1., | |
metadata={"help": "loss scale, -1 will use world_size"} | |
) | |
add_pooler: bool = field(default=False) | |
out_dimension: int = field( | |
default=768, | |
metadata={"help": "output dimension for pooler"} | |
) | |
t: float = field(default=0.05, metadata={"help": "temperature of biencoder training"}) | |
l2_normalize: bool = field(default=True, metadata={"help": "L2 normalize embeddings or not"}) | |
t_warmup: bool = field(default=False, metadata={"help": "warmup temperature"}) | |
full_contrastive_loss: bool = field(default=True, metadata={"help": "use full contrastive loss or not"}) | |
# following arguments are used for encoding documents | |
do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"}) | |
encode_in_path: str = field(default=None, metadata={"help": "Path to data to encode"}) | |
encode_save_dir: str = field(default=None, metadata={"help": "where to save the encode"}) | |
encode_shard_size: int = field(default=int(2 * 10**6)) | |
encode_batch_size: int = field(default=256) | |
# used for index search | |
do_search: bool = field(default=False, metadata={"help": "run the index search loop"}) | |
search_split: str = field(default='dev', metadata={"help": "which split to search"}) | |
search_batch_size: int = field(default=128, metadata={"help": "query batch size for index search"}) | |
search_topk: int = field(default=200, metadata={"help": "return topk search results"}) | |
search_out_dir: str = field(default='', metadata={"help": "output directory for writing search results"}) | |
# used for reranking | |
do_rerank: bool = field(default=False, metadata={"help": "run the reranking loop"}) | |
rerank_max_length: int = field(default=256, metadata={"help": "max length for rerank inputs"}) | |
rerank_in_path: str = field(default='', metadata={"help": "Path to predictions for rerank"}) | |
rerank_out_path: str = field(default='', metadata={"help": "Path to write rerank results"}) | |
rerank_split: str = field(default='dev', metadata={"help": "which split to rerank"}) | |
rerank_batch_size: int = field(default=128, metadata={"help": "rerank batch size"}) | |
rerank_depth: int = field(default=1000, metadata={"help": "rerank depth, useful for debugging purpose"}) | |
rerank_forward_factor: int = field( | |
default=1, | |
metadata={"help": "forward n passages, then select top n/factor passages for backward"} | |
) | |
rerank_use_rdrop: bool = field(default=False, metadata={"help": "use R-Drop regularization for re-ranker"}) | |
# used for knowledge distillation | |
do_kd_gen_score: bool = field(default=False, metadata={"help": "run the score generation for distillation"}) | |
kd_gen_score_split: str = field(default='dev', metadata={ | |
"help": "Which split to use for generation of teacher score" | |
}) | |
kd_gen_score_batch_size: int = field(default=128, metadata={"help": "batch size for teacher score generation"}) | |
kd_gen_score_n_neg: int = field(default=30, metadata={"help": "number of negatives to compute teacher scores"}) | |
do_kd_biencoder: bool = field(default=False, metadata={"help": "knowledge distillation to biencoder"}) | |
kd_mask_hn: bool = field(default=True, metadata={"help": "mask out hard negatives for distillation"}) | |
kd_cont_loss_weight: float = field(default=1.0, metadata={"help": "weight for contrastive loss"}) | |
rlm_generator_model_name: Optional[str] = field( | |
default='google/electra-base-generator', | |
metadata={"help": "generator for replace LM pre-training"} | |
) | |
rlm_freeze_generator: Optional[bool] = field( | |
default=True, | |
metadata={'help': 'freeze generator params or not'} | |
) | |
rlm_generator_mlm_weight: Optional[float] = field( | |
default=0.2, | |
metadata={'help': 'weight for generator MLM loss'} | |
) | |
all_use_mask_token: Optional[bool] = field( | |
default=False, | |
metadata={'help': 'Do not use 80:10:10 mask, use [MASK] for all places'} | |
) | |
rlm_num_eval_samples: Optional[int] = field( | |
default=4096, | |
metadata={"help": "number of evaluation samples pre-training"} | |
) | |
rlm_max_length: Optional[int] = field( | |
default=144, | |
metadata={"help": "max length for MatchLM pre-training"} | |
) | |
rlm_decoder_layers: Optional[int] = field( | |
default=2, | |
metadata={"help": "number of transformer layers for MatchLM decoder part"} | |
) | |
rlm_encoder_mask_prob: Optional[float] = field( | |
default=0.3, | |
metadata={'help': 'mask rate for encoder'} | |
) | |
rlm_decoder_mask_prob: Optional[float] = field( | |
default=0.5, | |
metadata={'help': 'mask rate for decoder'} | |
) | |
q_max_len: int = field( | |
default=32, | |
metadata={ | |
"help": "The maximum total input sequence length after tokenization for query." | |
}, | |
) | |
p_max_len: int = field( | |
default=144, | |
metadata={ | |
"help": "The maximum total input sequence length after tokenization for passage." | |
}, | |
) | |
max_train_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": "For debugging purposes or quicker training, truncate the number of training examples to this " | |
"value if set." | |
}, | |
) | |
dry_run: Optional[bool] = field( | |
default=False, | |
metadata={'help': 'Set dry_run to True for debugging purpose'} | |
) | |
def __post_init__(self): | |
assert os.path.exists(self.data_dir) | |
assert torch.cuda.is_available(), 'Only support running on GPUs' | |
assert self.task_type in ['ir', 'qa'] | |
if self.dry_run: | |
self.logging_steps = 1 | |
self.max_train_samples = self.max_train_samples or 128 | |
self.num_train_epochs = 1 | |
self.per_device_train_batch_size = min(2, self.per_device_train_batch_size) | |
self.train_n_passages = min(4, self.train_n_passages) | |
self.rerank_forward_factor = 1 | |
self.gradient_accumulation_steps = 1 | |
self.rlm_num_eval_samples = min(256, self.rlm_num_eval_samples) | |
self.max_steps = 30 | |
self.save_steps = self.eval_steps = 30 | |
logger.warning('Dry run: set logging_steps=1') | |
if self.do_encode: | |
assert self.encode_save_dir | |
os.makedirs(self.encode_save_dir, exist_ok=True) | |
assert os.path.exists(self.encode_in_path) | |
if self.do_search: | |
assert os.path.exists(self.encode_save_dir) | |
assert self.search_out_dir | |
os.makedirs(self.search_out_dir, exist_ok=True) | |
if self.do_rerank: | |
assert os.path.exists(self.rerank_in_path) | |
logger.info('Rerank result will be written to {}'.format(self.rerank_out_path)) | |
assert self.train_n_passages > 1, 'Having positive passages only does not make sense for training re-ranker' | |
assert self.train_n_passages % self.rerank_forward_factor == 0 | |
if self.do_kd_gen_score: | |
assert os.path.exists('{}/{}.jsonl'.format(self.data_dir, self.kd_gen_score_split)) | |
if self.do_kd_biencoder: | |
if self.use_scaled_loss: | |
assert not self.kd_mask_hn, 'Use scaled loss only works with not masking out hard negatives' | |
if torch.cuda.device_count() <= 1: | |
self.logging_steps = min(10, self.logging_steps) | |
super(Arguments, self).__post_init__() | |
if self.output_dir: | |
os.makedirs(self.output_dir, exist_ok=True) | |
self.label_names = ['labels'] | |