|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import itertools |
|
import json |
|
import linecache |
|
import math |
|
import os |
|
import pickle |
|
import socket |
|
from logging import getLogger |
|
from pathlib import Path |
|
from typing import Callable, Dict, Iterable, List, Tuple, Union |
|
|
|
import git |
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
from rouge_score import rouge_scorer, scoring |
|
from sacrebleu import corpus_bleu |
|
from torch import nn |
|
from torch.utils.data import Dataset, Sampler |
|
|
|
from sentence_splitter import add_newline_to_end_of_each_sentence |
|
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer |
|
from transformers.file_utils import cached_property |
|
from transformers.models.bart.modeling_bart import shift_tokens_right |
|
|
|
|
|
try: |
|
from fairseq.data.data_utils import batch_by_size |
|
|
|
FAIRSEQ_AVAILABLE = True |
|
except (ImportError, ModuleNotFoundError): |
|
FAIRSEQ_AVAILABLE = False |
|
|
|
|
|
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): |
|
"""From fairseq""" |
|
if target.dim() == lprobs.dim() - 1: |
|
target = target.unsqueeze(-1) |
|
nll_loss = -lprobs.gather(dim=-1, index=target) |
|
smooth_loss = -lprobs.sum(dim=-1, keepdim=True) |
|
if ignore_index is not None: |
|
pad_mask = target.eq(ignore_index) |
|
nll_loss.masked_fill_(pad_mask, 0.0) |
|
smooth_loss.masked_fill_(pad_mask, 0.0) |
|
else: |
|
nll_loss = nll_loss.squeeze(-1) |
|
smooth_loss = smooth_loss.squeeze(-1) |
|
|
|
nll_loss = nll_loss.sum() |
|
smooth_loss = smooth_loss.sum() |
|
eps_i = epsilon / lprobs.size(-1) |
|
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss |
|
return loss, nll_loss |
|
|
|
|
|
def lmap(f: Callable, x: Iterable) -> List: |
|
"""list(map(f, x))""" |
|
return list(map(f, x)) |
|
|
|
|
|
def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: |
|
"""Uses sacrebleu's corpus_bleu implementation.""" |
|
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)} |
|
|
|
|
|
def build_compute_metrics_fn( |
|
task_name: str, tokenizer: PreTrainedTokenizer |
|
) -> Callable[[EvalPrediction], Dict]: |
|
def non_pad_len(tokens: np.ndarray) -> int: |
|
return np.count_nonzero(tokens != tokenizer.pad_token_id) |
|
|
|
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]: |
|
pred_ids = pred.predictions |
|
label_ids = pred.label_ids |
|
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
|
label_ids[label_ids == -100] = tokenizer.pad_token_id |
|
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) |
|
pred_str = lmap(str.strip, pred_str) |
|
label_str = lmap(str.strip, label_str) |
|
return pred_str, label_str |
|
|
|
def summarization_metrics(pred: EvalPrediction) -> Dict: |
|
pred_str, label_str = decode_pred(pred) |
|
rouge: Dict = calculate_rouge(pred_str, label_str) |
|
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) |
|
rouge.update({"gen_len": summ_len}) |
|
return rouge |
|
|
|
def translation_metrics(pred: EvalPrediction) -> Dict: |
|
pred_str, label_str = decode_pred(pred) |
|
bleu: Dict = calculate_bleu(pred_str, label_str) |
|
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) |
|
bleu.update({"gen_len": gen_len}) |
|
return bleu |
|
|
|
compute_metrics_fn = ( |
|
summarization_metrics if "summarization" in task_name else translation_metrics |
|
) |
|
return compute_metrics_fn |
|
|
|
|
|
def trim_batch( |
|
input_ids, |
|
pad_token_id, |
|
attention_mask=None, |
|
): |
|
"""Remove columns that are populated exclusively by pad_token_id""" |
|
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) |
|
if attention_mask is None: |
|
return input_ids[:, keep_column_mask] |
|
else: |
|
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) |
|
|
|
|
|
class AbstractSeq2SeqDataset(Dataset): |
|
def __init__( |
|
self, |
|
tokenizer, |
|
data_dir, |
|
max_source_length, |
|
max_target_length, |
|
type_path="train", |
|
n_obs=None, |
|
prefix="", |
|
**dataset_kwargs, |
|
): |
|
super().__init__() |
|
self.src_file = Path(data_dir).joinpath(type_path + ".source") |
|
self.tgt_file = Path(data_dir).joinpath(type_path + ".target") |
|
self.len_file = Path(data_dir).joinpath(type_path + ".len") |
|
if os.path.exists(self.len_file): |
|
self.src_lens = pickle_load(self.len_file) |
|
self.used_char_len = False |
|
else: |
|
self.src_lens = self.get_char_lens(self.src_file) |
|
self.used_char_len = True |
|
self.max_source_length = max_source_length |
|
self.max_target_length = max_target_length |
|
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" |
|
self.tokenizer = tokenizer |
|
self.prefix = prefix if prefix is not None else "" |
|
|
|
if n_obs is not None: |
|
self.src_lens = self.src_lens[:n_obs] |
|
self.pad_token_id = self.tokenizer.pad_token_id |
|
self.dataset_kwargs = dataset_kwargs |
|
dataset_kwargs.update( |
|
{"add_prefix_space": True} |
|
if isinstance(self.tokenizer, BartTokenizer) |
|
else {} |
|
) |
|
|
|
def __len__(self): |
|
return len(self.src_lens) |
|
|
|
@staticmethod |
|
def get_char_lens(data_file): |
|
return [len(x) for x in Path(data_file).open().readlines()] |
|
|
|
@cached_property |
|
def tgt_lens(self): |
|
"""Length in characters of target documents""" |
|
return self.get_char_lens(self.tgt_file) |
|
|
|
def make_sortish_sampler( |
|
self, batch_size, distributed=False, shuffle=True, **kwargs |
|
): |
|
if distributed: |
|
return DistributedSortishSampler( |
|
self, batch_size, shuffle=shuffle, **kwargs |
|
) |
|
else: |
|
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle) |
|
|
|
def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs): |
|
assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`" |
|
assert ( |
|
not self.used_char_len |
|
), "You must call python make_len_file.py before calling make_dynamic_sampler" |
|
sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False)) |
|
|
|
def num_tokens_in_example(i): |
|
return min(self.src_lens[i], self.max_target_length) |
|
|
|
|
|
batch_sampler: List[List[int]] = batch_by_size( |
|
sorted_indices, |
|
num_tokens_fn=num_tokens_in_example, |
|
max_tokens=max_tokens_per_batch, |
|
required_batch_size_multiple=64, |
|
) |
|
shuffled_batches = [ |
|
batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler))) |
|
] |
|
|
|
approximate_toks_per_batch = [ |
|
max(self.src_lens[i] for i in batch) * len(batch) |
|
for batch in shuffled_batches |
|
] |
|
largest_batch_idx = np.argmax(approximate_toks_per_batch) |
|
shuffled_batches[0], shuffled_batches[largest_batch_idx] = ( |
|
shuffled_batches[largest_batch_idx], |
|
shuffled_batches[0], |
|
) |
|
return shuffled_batches |
|
|
|
def __getitem__(self, item): |
|
raise NotImplementedError("You must implement this") |
|
|
|
def collate_fn(self, batch): |
|
raise NotImplementedError("You must implement this") |
|
|
|
|
|
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): |
|
def __getitem__(self, index) -> Dict[str, torch.Tensor]: |
|
"""Call tokenizer on src and tgt_lines""" |
|
index = index + 1 |
|
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( |
|
"\n" |
|
) |
|
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") |
|
assert source_line, f"empty source line for index {index}" |
|
assert tgt_line, f"empty tgt line for index {index}" |
|
source_inputs = self.encode_line( |
|
self.tokenizer, source_line, self.max_source_length |
|
) |
|
target_inputs = self.encode_line( |
|
self.tokenizer, tgt_line, self.max_target_length |
|
) |
|
|
|
source_ids = source_inputs["input_ids"].squeeze() |
|
target_ids = target_inputs["input_ids"].squeeze() |
|
src_mask = source_inputs["attention_mask"].squeeze() |
|
return { |
|
"input_ids": source_ids, |
|
"attention_mask": src_mask, |
|
"labels": target_ids, |
|
} |
|
|
|
def encode_line( |
|
self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt" |
|
): |
|
"""Only used by LegacyDataset""" |
|
return tokenizer( |
|
[line], |
|
max_length=max_length, |
|
padding="max_length" if pad_to_max_length else None, |
|
truncation=True, |
|
return_tensors=return_tensors, |
|
**self.dataset_kwargs, |
|
) |
|
|
|
def collate_fn(self, batch) -> Dict[str, torch.Tensor]: |
|
input_ids = torch.stack([x["input_ids"] for x in batch]) |
|
masks = torch.stack([x["attention_mask"] for x in batch]) |
|
target_ids = torch.stack([x["labels"] for x in batch]) |
|
pad_token_id = self.pad_token_id |
|
y = trim_batch(target_ids, pad_token_id) |
|
source_ids, source_mask = trim_batch( |
|
input_ids, pad_token_id, attention_mask=masks |
|
) |
|
batch = { |
|
"input_ids": source_ids, |
|
"attention_mask": source_mask, |
|
"labels": y, |
|
} |
|
return batch |
|
|
|
|
|
class Seq2SeqDataset(AbstractSeq2SeqDataset): |
|
"""A dataset that calls prepare_seq2seq_batch.""" |
|
|
|
def __getitem__(self, index) -> Dict[str, str]: |
|
index = index + 1 |
|
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( |
|
"\n" |
|
) |
|
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") |
|
assert source_line, f"empty source line for index {index}" |
|
assert tgt_line, f"empty tgt line for index {index}" |
|
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1} |
|
|
|
def collate_fn(self, batch) -> Dict[str, torch.Tensor]: |
|
"""Call prepare_seq2seq_batch.""" |
|
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch( |
|
[x["src_texts"] for x in batch], |
|
tgt_texts=[x["tgt_texts"] for x in batch], |
|
max_length=self.max_source_length, |
|
max_target_length=self.max_target_length, |
|
return_tensors="pt", |
|
**self.dataset_kwargs, |
|
).data |
|
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch]) |
|
return batch_encoding |
|
|
|
|
|
class Seq2SeqDataCollator: |
|
def __init__( |
|
self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None |
|
): |
|
self.tokenizer = tokenizer |
|
self.pad_token_id = tokenizer.pad_token_id |
|
self.decoder_start_token_id = decoder_start_token_id |
|
assert ( |
|
self.pad_token_id is not None |
|
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined." |
|
self.data_args = data_args |
|
self.tpu_num_cores = tpu_num_cores |
|
self.dataset_kwargs = ( |
|
{"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} |
|
) |
|
if data_args.src_lang is not None: |
|
self.dataset_kwargs["src_lang"] = data_args.src_lang |
|
if data_args.tgt_lang is not None: |
|
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang |
|
|
|
def __call__(self, batch) -> Dict[str, torch.Tensor]: |
|
if hasattr(self.tokenizer, "prepare_seq2seq_batch"): |
|
batch = self._encode(batch) |
|
input_ids, attention_mask, labels = ( |
|
batch["input_ids"], |
|
batch["attention_mask"], |
|
batch["labels"], |
|
) |
|
else: |
|
input_ids = torch.stack([x["input_ids"] for x in batch]) |
|
attention_mask = torch.stack([x["attention_mask"] for x in batch]) |
|
labels = torch.stack([x["labels"] for x in batch]) |
|
|
|
labels = trim_batch(labels, self.pad_token_id) |
|
input_ids, attention_mask = trim_batch( |
|
input_ids, self.pad_token_id, attention_mask=attention_mask |
|
) |
|
|
|
if isinstance(self.tokenizer, T5Tokenizer): |
|
decoder_input_ids = self._shift_right_t5(labels) |
|
else: |
|
decoder_input_ids = shift_tokens_right( |
|
labels, self.pad_token_id, self.decoder_start_token_id |
|
) |
|
|
|
batch = { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"decoder_input_ids": decoder_input_ids, |
|
"labels": labels, |
|
} |
|
return batch |
|
|
|
def _shift_right_t5(self, input_ids): |
|
|
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
|
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() |
|
shifted_input_ids[..., 0] = self.pad_token_id |
|
return shifted_input_ids |
|
|
|
def _encode(self, batch) -> Dict[str, torch.Tensor]: |
|
batch_encoding = self.tokenizer.prepare_seq2seq_batch( |
|
[x["src_texts"] for x in batch], |
|
tgt_texts=[x["tgt_texts"] for x in batch], |
|
max_length=self.data_args.max_source_length, |
|
max_target_length=self.data_args.max_target_length, |
|
padding="max_length" |
|
if self.tpu_num_cores is not None |
|
else "longest", |
|
return_tensors="pt", |
|
**self.dataset_kwargs, |
|
) |
|
return batch_encoding.data |
|
|
|
|
|
class SortishSampler(Sampler): |
|
"Go through the text data by order of src length with a bit of randomness. From fastai repo." |
|
|
|
def __init__(self, data, batch_size, shuffle=True): |
|
self.data, self.bs, self.shuffle = data, batch_size, shuffle |
|
|
|
def __len__(self) -> int: |
|
return len(self.data) |
|
|
|
def __iter__(self): |
|
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle)) |
|
|
|
|
|
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array: |
|
"Go through the text data by order of src length with a bit of randomness. From fastai repo." |
|
if not shuffle: |
|
return np.argsort(np.array(data) * -1) |
|
|
|
def key_fn(i): |
|
return data[i] |
|
|
|
idxs = np.random.permutation(len(data)) |
|
sz = bs * 50 |
|
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] |
|
sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx]) |
|
sz = bs |
|
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] |
|
max_ck = np.argmax( |
|
[key_fn(ck[0]) for ck in ck_idx] |
|
) |
|
ck_idx[0], ck_idx[max_ck] = ( |
|
ck_idx[max_ck], |
|
ck_idx[0], |
|
) |
|
sort_idx = ( |
|
np.concatenate(np.random.permutation(ck_idx[1:])) |
|
if len(ck_idx) > 1 |
|
else np.array([], dtype=np.int) |
|
) |
|
sort_idx = np.concatenate((ck_idx[0], sort_idx)) |
|
return sort_idx |
|
|
|
|
|
class DistributedSortishSampler(Sampler): |
|
"""Copied from torch DistributedSampler""" |
|
|
|
def __init__( |
|
self, |
|
dataset, |
|
batch_size, |
|
num_replicas=None, |
|
rank=None, |
|
add_extra_examples=True, |
|
shuffle=True, |
|
): |
|
if num_replicas is None: |
|
if not dist.is_available(): |
|
raise RuntimeError("Requires distributed package to be available") |
|
num_replicas = dist.get_world_size() |
|
if rank is None: |
|
if not dist.is_available(): |
|
raise RuntimeError("Requires distributed package to be available") |
|
rank = dist.get_rank() |
|
self.dataset = dataset |
|
self.num_replicas = num_replicas |
|
self.rank = rank |
|
self.epoch = 0 |
|
if add_extra_examples: |
|
self.num_samples = int( |
|
math.ceil(len(self.dataset) * 1.0 / self.num_replicas) |
|
) |
|
self.total_size = self.num_samples * self.num_replicas |
|
else: |
|
self.total_size = len(dataset) |
|
self.num_samples = len(self.available_indices) |
|
self.batch_size = batch_size |
|
self.add_extra_examples = add_extra_examples |
|
self.shuffle = shuffle |
|
|
|
def __iter__(self) -> Iterable: |
|
g = torch.Generator() |
|
g.manual_seed(self.epoch) |
|
|
|
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices] |
|
sortish_indices = sortish_sampler_indices( |
|
sortish_data, self.batch_size, shuffle=self.shuffle |
|
) |
|
indices = [self.available_indices[i] for i in sortish_indices] |
|
assert len(indices) == self.num_samples |
|
return iter(indices) |
|
|
|
@cached_property |
|
def available_indices(self) -> np.array: |
|
indices = list(range(len(self.dataset))) |
|
|
|
indices += indices[: (self.total_size - len(indices))] |
|
assert len(indices) == self.total_size |
|
|
|
available_indices = indices[self.rank : self.total_size : self.num_replicas] |
|
return available_indices |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
def set_epoch(self, epoch): |
|
self.epoch = epoch |
|
|
|
|
|
logger = getLogger(__name__) |
|
|
|
|
|
def use_task_specific_params(model, task): |
|
"""Update config with summarization specific params.""" |
|
task_specific_params = model.config.task_specific_params |
|
|
|
if task_specific_params is not None: |
|
pars = task_specific_params.get(task, {}) |
|
logger.info( |
|
f"setting model.config to task specific params for {task}:\n {pars}" |
|
) |
|
logger.info("note: command line args may override some of these") |
|
model.config.update(pars) |
|
|
|
|
|
def pickle_load(path): |
|
"""pickle.load(path)""" |
|
with open(path, "rb") as f: |
|
return pickle.load(f) |
|
|
|
|
|
def pickle_save(obj, path): |
|
"""pickle.dump(obj, path)""" |
|
with open(path, "wb") as f: |
|
return pickle.dump(obj, f) |
|
|
|
|
|
def flatten_list(summary_ids: List[List]): |
|
return [x for x in itertools.chain.from_iterable(summary_ids)] |
|
|
|
|
|
def save_git_info(folder_path: str) -> None: |
|
"""Save git information to output_dir/git_log.json""" |
|
repo_infos = get_git_info() |
|
save_json(repo_infos, os.path.join(folder_path, "git_log.json")) |
|
|
|
|
|
def save_json(content, path, indent=4, **json_dump_kwargs): |
|
with open(path, "w") as f: |
|
json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs) |
|
|
|
|
|
def load_json(path): |
|
with open(path) as f: |
|
return json.load(f) |
|
|
|
|
|
def get_git_info(): |
|
try: |
|
repo = git.Repo(search_parent_directories=True) |
|
repo_infos = { |
|
"repo_id": str(repo), |
|
"repo_sha": str(repo.head.object.hexsha), |
|
"repo_branch": str(repo.active_branch), |
|
"hostname": str(socket.gethostname()), |
|
} |
|
return repo_infos |
|
except TypeError: |
|
return { |
|
"repo_id": None, |
|
"repo_sha": None, |
|
"repo_branch": None, |
|
"hostname": None, |
|
} |
|
|
|
|
|
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"] |
|
|
|
|
|
def extract_rouge_mid_statistics(dct): |
|
new_dict = {} |
|
for k1, v1 in dct.items(): |
|
mid = v1.mid |
|
new_dict[k1] = { |
|
stat: round(getattr(mid, stat), 4) |
|
for stat in ["precision", "recall", "fmeasure"] |
|
} |
|
return new_dict |
|
|
|
|
|
def calculate_rouge( |
|
pred_lns: List[str], |
|
tgt_lns: List[str], |
|
use_stemmer=True, |
|
rouge_keys=ROUGE_KEYS, |
|
return_precision_and_recall=False, |
|
bootstrap_aggregation=True, |
|
newline_sep=True, |
|
) -> Dict: |
|
"""Calculate rouge using rouge_scorer package. |
|
|
|
Args: |
|
pred_lns: list of summaries generated by model |
|
tgt_lns: list of groundtruth summaries (e.g. contents of val.target) |
|
use_stemmer: Bool indicating whether Porter stemmer should be used to |
|
strip word suffixes to improve matching. |
|
rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum |
|
return_precision_and_recall: (False) whether to also return precision and recall. |
|
bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False |
|
this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]`` |
|
newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL |
|
on multi sentence summaries (CNN/DM dataset). |
|
|
|
Returns: |
|
Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys |
|
|
|
""" |
|
scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer) |
|
aggregator = scoring.BootstrapAggregator() |
|
for pred, tgt in zip(tgt_lns, pred_lns): |
|
|
|
if newline_sep: |
|
pred = add_newline_to_end_of_each_sentence(pred) |
|
tgt = add_newline_to_end_of_each_sentence(tgt) |
|
scores = scorer.score(pred, tgt) |
|
aggregator.add_scores(scores) |
|
|
|
if bootstrap_aggregation: |
|
result = aggregator.aggregate() |
|
if return_precision_and_recall: |
|
return extract_rouge_mid_statistics(result) |
|
else: |
|
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} |
|
|
|
else: |
|
return aggregator._scores |
|
|
|
|
|
|
|
|
|
|
|
def freeze_params(model: nn.Module): |
|
"""Set requires_grad=False for each of model.parameters()""" |
|
for par in model.parameters(): |
|
par.requires_grad = False |
|
|
|
|
|
def freeze_embeds(model): |
|
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" |
|
model_type = model.config.model_type |
|
|
|
if model_type in ["t5", "mt5"]: |
|
freeze_params(model.shared) |
|
for d in [model.encoder, model.decoder]: |
|
freeze_params(d.embed_tokens) |
|
elif model_type == "fsmt": |
|
for d in [model.model.encoder, model.model.decoder]: |
|
freeze_params(d.embed_positions) |
|
freeze_params(d.embed_tokens) |
|
else: |
|
freeze_params(model.model.shared) |
|
for d in [model.model.encoder, model.model.decoder]: |
|
freeze_params(d.embed_positions) |
|
freeze_params(d.embed_tokens) |
|
|
|
|
|
def grad_status(model: nn.Module) -> Iterable: |
|
return (par.requires_grad for par in model.parameters()) |
|
|
|
|
|
def any_requires_grad(model: nn.Module) -> bool: |
|
return any(grad_status(model)) |
|
|
|
|
|
def assert_all_frozen(model): |
|
model_grads: List[bool] = list(grad_status(model)) |
|
n_require_grad = sum(lmap(int, model_grads)) |
|
npars = len(model_grads) |
|
assert not any( |
|
model_grads |
|
), f"{n_require_grad/npars:.1%} of {npars} weights require grad" |
|
|
|
|
|
def assert_not_all_frozen(model): |
|
model_grads: List[bool] = list(grad_status(model)) |
|
npars = len(model_grads) |
|
assert any(model_grads), f"none of {npars} weights require grad" |
|
|
|
|
|
def parse_numeric_n_bool_cl_kwargs( |
|
unparsed_args: List[str], |
|
) -> Dict[str, Union[int, float, bool]]: |
|
""" |
|
Parse an argv list of unspecified command line args to a dict. |
|
Assumes all values are either numeric or boolean in the form of true/false. |
|
""" |
|
result = {} |
|
assert ( |
|
len(unparsed_args) % 2 == 0 |
|
), f"got odd number of unparsed args: {unparsed_args}" |
|
num_pairs = len(unparsed_args) // 2 |
|
for pair_num in range(num_pairs): |
|
i = 2 * pair_num |
|
assert unparsed_args[i].startswith("--") |
|
if unparsed_args[i + 1].lower() == "true": |
|
value = True |
|
elif unparsed_args[i + 1].lower() == "false": |
|
value = False |
|
else: |
|
try: |
|
value = int(unparsed_args[i + 1]) |
|
except ValueError: |
|
value = float( |
|
unparsed_args[i + 1] |
|
) |
|
|
|
result[unparsed_args[i][2:]] = value |
|
return result |
|
|
|
|
|
def write_txt_file(ordered_tgt, path): |
|
f = Path(path).open("w") |
|
for ln in ordered_tgt: |
|
f.write(ln + "\n") |
|
f.flush() |
|
|
|
|
|
def chunks(lst, n): |
|
"""Yield successive n-sized chunks from lst.""" |
|
for i in range(0, len(lst), n): |
|
yield lst[i : i + n] |
|
|
|
|
|
def check_output_dir(args, expected_items=0): |
|
""" |
|
Checks whether to bail out if output_dir already exists and has more than expected_items in it |
|
|
|
`args`: needs to have the following attributes of `args`: |
|
- output_dir |
|
- do_train |
|
- overwrite_output_dir |
|
|
|
`expected_items`: normally 0 (default) - i.e. empty dir, but in some cases a few files are expected (e.g. recovery from OOM) |
|
""" |
|
if ( |
|
os.path.exists(args.output_dir) |
|
and len(os.listdir(args.output_dir)) > expected_items |
|
and args.do_train |
|
and not args.overwrite_output_dir |
|
): |
|
raise ValueError( |
|
f"Output directory ({args.output_dir}) already exists and " |
|
f"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). " |
|
"Use --overwrite_output_dir to overcome." |
|
) |
|
|