import datetime import typing import numpy as np import struct import os import getpass import logging import torch import torch.nn as nn from collections import defaultdict import math LOG = logging.getLogger(__name__) def masked_mean(values, mask): assert mask.dtype == torch.bool assert values.shape == mask.shape return (values * mask.float()).sum() / mask.sum().float() def mask_hf_labels(labels, null_token=0): valid_mask = labels != -100 valid_labels = labels.masked_fill(~valid_mask, null_token) return valid_mask, valid_labels def gather_log_probs(logits, labels): assert labels.dim() == logits.dim() - 1 assert labels.shape == logits.shape[:-1] return logits.log_softmax(-1).gather(-1, labels.unsqueeze(-1)).squeeze(-1) def off_diagonal(mat): assert mat.dim() == 2 # assert mat.shape[0] == mat.shape[1] mask = ~torch.eye(max(mat.shape), dtype=torch.bool) mask = mask[:mat.shape[0], :mat.shape[1]] off_d = mat[mask] assert off_d.numel() == mat.shape[0] * mat.shape[1] - min(mat.shape) return off_d def set_dropout(model, p): if p is not None: n_reset = 0 for m in model.modules(): if isinstance(m, nn.Dropout): m.p = p n_reset += 1 if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout if isinstance(m.dropout, float): m.dropout = p n_reset += 1 if hasattr(m, "activation_dropout"): # Requires for BART, which uses F.dropout if isinstance(m.activation_dropout, float): m.activation_dropout = p n_reset += 1 LOG.info(f"Set {n_reset} dropout modules to p={p}") def _inner_params(named_parameters, inner_names): param_dict = dict(named_parameters) return [(n, param_dict[n]) for n in inner_names] def shift_targets(config): return "t5" not in config.model.name.lower() and "blender" not in config.model.name.lower() # https://stackoverflow.com/questions/32871539/integer-factorization-in-python def factorization(n): return [(i, n // i) for i in range(1, int(n**0.5) + 1) if n % i == 0] def scr(): if os.path.exists("/scr-ssd"): scr_dir = "/scr-ssd/" + getpass.getuser() else: scr_dir = "/scr/" + getpass.getuser() if not os.path.exists(scr_dir): os.makedirs(scr_dir) return scr_dir def uuid(digits=4): if not hasattr(uuid, "uuid_value"): uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10**digits) return uuid.uuid_value def formatted_timestamp(time=None): if time is None: time = datetime.datetime.now() return time.strftime("%d/%m/%Y-%H:%M:%S/%f") def time_delta_seconds(start, finish=None): assert type(start) == str t1 = datetime.datetime.strptime(start, "%d/%m/%Y-%H:%M:%S/%f") if finish is not None: assert type(finish) == str t2 = datetime.datetime.strptime(finish, "%d/%m/%Y-%H:%M:%S/%f") else: t2 = datetime.datetime.now() return (t2 - t1).total_seconds() def dict_to(d, device): new_dict = {} for k, v in d.items(): if isinstance(v, torch.Tensor): new_dict[k] = v.to(device) elif isinstance(v, dict): new_dict[k] = dict_to(v, device) else: new_dict[k] = v return new_dict def safe_backward(loss, parameters, accumulate=1, allow_unused=False, backward=False): if backward: (loss / accumulate).backward() else: parameters = list(parameters) # Capture the generator output grads = torch.autograd.grad(loss, parameters, allow_unused=allow_unused) nan, inf = False, False for g in grads: if g is not None: nan |= g.isnan().any().item() inf |= g.isinf().any().item() if not (nan or inf): for p, g in zip(parameters, grads): if g is None: continue if p.grad is None: p.grad = g / accumulate else: p.grad += g / accumulate else: LOG.info(f"Skipping grad accumulation because inf: {inf} nan: {nan}") def _logits(x): if hasattr(x, "logits"): return x.logits elif hasattr(x, "scores"): return torch.cat(x.scores).unsqueeze(0) return x def _last_encoder_state(x): if hasattr(x, "encoder_last_hidden_state"): return x.encoder_last_hidden_state elif hasattr(x, "encoder_hidden_states"): return x.encoder_hidden_states[-1] else: return x.hidden_states[-1] def load_archive(path): import torch if not os.path.exists(path): # We've not passed an explicit path, but a part of the filename wd = '/iris/u/clin/code/efk/' directories = ["outputs", "multirun"] matches = [] for d in directories: search = os.path.join(wd, d) for run_dir in os.listdir(search): if path in run_dir: matches.append(os.path.join(search, run_dir)) assert len(matches) == 1, f">1 matches for search {path}; specify exact path" full_run_dir = matches[0] if "0" in os.listdir(full_run_dir): full_run_dir = os.path.join(full_run_dir, "0") models_dir = os.path.join(full_run_dir, "models") models = os.listdir(models_dir) non_bk = [m for m in models if not m.endswith(".bk")] assert ( len(non_bk) == 1 ), f"Expected a single model in {models_dir}, got {len(non_bk)}" path = os.path.join(models_dir, non_bk[0]) LOG.info(f"Loading checkpoint from {path}") archive = torch.load(path, map_location="cpu") LOG.info("Load complete.") return archive, path def flatten_dict(d): to_process = list(d.items()) output = {} while len(to_process): k, v = to_process.pop() if isinstance(v, typing.MutableMapping): to_process.extend([(f"{k}.{k_}", v_) for (k_, v_) in v.items()]) else: assert k not in output.keys(), "Somehow ended up with duplicate keys" output[k] = v return output def add_padding(tokenizer, model): tokenizer.add_special_tokens({'pad_token': '[PAD]'}) model.resize_token_embeddings(len(tokenizer)) model.transformer.wte.weight.data[-1] = model.transformer.wte.weight.data.mean(0) def add_sep(tokenizer, model): tokenizer.add_special_tokens({'sep_token': '[SEP]'}) # model.resize_token_embeddings(len(tokenizer)) # model.lm_head.weight.data[-1, :] = model.lm_head.weight.data.mean(0) class EarlyStopper: def __init__(self, patience: int, key: str, minimize: bool = False): self.best_value = 1e9 if minimize else -1e9 self.best_iter = 0 self.current_iter = 0 self.key = key self.patience = patience self.minimize = minimize self._stop = False def update(self, idx, stats): assert self.key in stats, f"'{self.key}' not in stats dict" value = stats[self.key] new_best = value < self.best_value if self.minimize else value > self.best_value if new_best: self.best_value = value self.best_iter = idx self.current_iter = idx return new_best def should_stop(self): self._stop |= self.current_iter - self.best_iter >= self.patience return self._stop class RunningStatAverager: def __init__(self, suffix="", exclude=["grad/"], compute_ppl: bool = True): self.underlying = None self.suffix = suffix self.exclude = exclude self.compute_ppl = compute_ppl self.reset() def add(self, d: dict): for k, v in d.items(): if not any([k.startswith(prefix) for prefix in self.exclude]): if len(self.suffix): self.underlying[f"{k}_{self.suffix}"].append(v) else: self.underlying[k].append(v) def average(self): average = {} for k, v in self.underlying.items(): if not k.startswith("nll/"): average[k] = sum(v) / len(v) else: assert len(k.split("/")) == 2, f"Invalid key {k}" name = k.split("/")[1] token_counts = self.underlying[f"n_tokens/{name}"] total_nll = sum([nll * c for nll, c in zip(v, token_counts)]) average[k] = total_nll / sum(token_counts) if self.compute_ppl: average[f"perplexity/{name}"] = math.e ** average[k] return {k: v if not isinstance(v, torch.Tensor) else v.item() for k, v in average.items()} def reset(self): self.underlying = defaultdict(list) class EditBatchSampler: def __init__( self, n, memorize_mode=False, loc_disjoint=True, seed=0, hard_neg=False, hard_neg_prob=1.0, loc_distr_matrix=None, loc_idx_matrix=None, keep_probs=None, mutex=None ): self.memorize_mode = memorize_mode self.n = n self.loc_disjoint = loc_disjoint self.rng = np.random.default_rng(seed) self.hard_neg = hard_neg self.hard_neg_prob = hard_neg_prob self.loc_probs = loc_distr_matrix self.loc_idxs = loc_idx_matrix self.keep_probs = np.array(keep_probs)[:self.n] if keep_probs is not None else None self.mutex = mutex[:self.n] if mutex is not None else None self._init() def _init(self): idxs = np.arange(self.n) if self.keep_probs is not None: sample = self.rng.binomial(1, self.keep_probs).astype(np.bool) idxs = idxs[sample] self.perm = self.rng.permutation(idxs) self.edit_position = 0 def get_edit_idxs(self, batch_size): if self.mutex is None: idxs = set([int(idx) for idx in self.perm[self.edit_position: self.edit_position + batch_size]]) self.edit_position += batch_size else: mutexes = [] idxs = [] def notin(x, mutexes): for m in mutexes: if x in m or m in x: return False return True while len(idxs) < batch_size: new_idx = self.perm[self.edit_position] if notin(self.mutex[new_idx], mutexes): mutexes.append(self.mutex[new_idx]) idxs.append(int(new_idx)) self.edit_position += 1 if self.edit_position == self.perm.shape[0]: return None idxs = set(idxs) return idxs def sample(self, batch_size, return_hard_flag=False): if self.memorize_mode: return list(range(batch_size)), list(range(batch_size, batch_size * 2)) if self.edit_position + batch_size >= self.perm.shape[0]: self._init() # Re-start if we end with a partially-sized batch edit_idxs = self.get_edit_idxs(batch_size) if edit_idxs is None: self._init() edit_idxs = self.get_edit_idxs(batch_size) if edit_idxs is None: raise RuntimeError(f"No valid batches of size {batch_size} exist!") if self.hard_neg: assert self.loc_probs is not None, "hard_neg is on, but don't have distance matrix!" def get_loc_idxs(): if self.hard_neg and self.rng.uniform() < self.hard_neg_prob: return [int(self.rng.choice(self.loc_idxs[idx], p=self.loc_probs[idx])) for idx in edit_idxs], True else: # Use deterministic implementation in case edit batches are large non_edit_idxs = list(set(range(self.n)) - set(edit_idxs)) return [int(idx) for idx in self.rng.choice(non_edit_idxs, batch_size)], False loc_idxs, hard = get_loc_idxs() if self.loc_disjoint: steps = 0 while len(edit_idxs.intersection(set(loc_idxs))) > 0: loc_idxs, hard = get_loc_idxs() steps += 1 if steps > 100: raise RuntimeError("Can't find disjoint loc_idxs and edit_idxs!") if return_hard_flag: return list(edit_idxs), loc_idxs, hard else: return list(edit_idxs), loc_idxs def parent_module(model, pname): comps = pname.split('.') parent = model for comp in comps[:-1]: if hasattr(parent, comp): parent = getattr(parent, comp) elif comp.isdigit(): parent = parent[int(comp)] else: raise RuntimeError(f"Couldn't find child module {comp}") assert hasattr(parent, comps[-1]) return parent def build_distr_matrix(edit_qs, config, loc_qs=None, slice_size=1000): n = len(edit_qs) device = "cuda" if torch.cuda.is_available() else "cpu" num_neighbors = config.data.hard_neg_neighbors num_exclude = config.data.hard_neg_exclude temp = config.data.hard_neg_temp from sentence_transformers import SentenceTransformer from sentence_transformers.util import pytorch_cos_sim embedding_model = SentenceTransformer('all-MiniLM-L6-v2', cache_folder=scr()).to(device) ind_matrix = torch.zeros((n, num_neighbors - num_exclude), dtype=torch.long) distr_matrix = torch.full((n, num_neighbors - num_exclude), float('nan')) edit_encodings = torch.FloatTensor(embedding_model.encode(edit_qs, batch_size=256)).to(device) # If loc_qs is None then build the similarity matrix between edit_qs and itself loc_encodings = edit_encodings if loc_qs is None else embedding_model.encode(loc_qs, batch_size=256) if isinstance(loc_encodings, np.ndarray): loc_encodings = torch.FloatTensor(loc_encodings).to(device) for idx in range(0, n, slice_size): end_idx = idx + slice_size if idx + slice_size <= n else n slice_encodings = edit_encodings[idx:end_idx] sim_rows = pytorch_cos_sim(slice_encodings, loc_encodings) indices = sim_rows.topk(num_neighbors, -1).indices[:, num_exclude:] ind_matrix[idx:end_idx] = indices.cpu() distr_matrix[idx:end_idx] = sim_rows.gather(-1, indices).mul(temp).exp().cpu() assert not torch.isnan(distr_matrix).any() LOG.info(f"Built hard negative distribution matrix of size {distr_matrix.shape}") distr_matrix = distr_matrix.numpy() distr_matrix = distr_matrix / distr_matrix.sum(-1, keepdims=True) return distr_matrix, ind_matrix.numpy()