Spaces:
Runtime error
Runtime error
import datetime | |
import typing | |
import numpy as np | |
import struct | |
import os | |
import getpass | |
import logging | |
import torch | |
import torch.nn as nn | |
from collections import defaultdict | |
import math | |
LOG = logging.getLogger(__name__) | |
def masked_mean(values, mask): | |
assert mask.dtype == torch.bool | |
assert values.shape == mask.shape | |
return (values * mask.float()).sum() / mask.sum().float() | |
def mask_hf_labels(labels, null_token=0): | |
valid_mask = labels != -100 | |
valid_labels = labels.masked_fill(~valid_mask, null_token) | |
return valid_mask, valid_labels | |
def gather_log_probs(logits, labels): | |
assert labels.dim() == logits.dim() - 1 | |
assert labels.shape == logits.shape[:-1] | |
return logits.log_softmax(-1).gather(-1, labels.unsqueeze(-1)).squeeze(-1) | |
def off_diagonal(mat): | |
assert mat.dim() == 2 | |
# assert mat.shape[0] == mat.shape[1] | |
mask = ~torch.eye(max(mat.shape), dtype=torch.bool) | |
mask = mask[:mat.shape[0], :mat.shape[1]] | |
off_d = mat[mask] | |
assert off_d.numel() == mat.shape[0] * mat.shape[1] - min(mat.shape) | |
return off_d | |
def set_dropout(model, p): | |
if p is not None: | |
n_reset = 0 | |
for m in model.modules(): | |
if isinstance(m, nn.Dropout): | |
m.p = p | |
n_reset += 1 | |
if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout | |
if isinstance(m.dropout, float): | |
m.dropout = p | |
n_reset += 1 | |
if hasattr(m, "activation_dropout"): # Requires for BART, which uses F.dropout | |
if isinstance(m.activation_dropout, float): | |
m.activation_dropout = p | |
n_reset += 1 | |
LOG.info(f"Set {n_reset} dropout modules to p={p}") | |
def _inner_params(named_parameters, inner_names): | |
param_dict = dict(named_parameters) | |
return [(n, param_dict[n]) for n in inner_names] | |
def shift_targets(config): | |
return "t5" not in config.model.name.lower() and "blender" not in config.model.name.lower() | |
# https://stackoverflow.com/questions/32871539/integer-factorization-in-python | |
def factorization(n): | |
return [(i, n // i) for i in range(1, int(n**0.5) + 1) if n % i == 0] | |
def scr(): | |
if os.path.exists("/scr-ssd"): | |
scr_dir = "/scr-ssd/" + getpass.getuser() | |
else: | |
scr_dir = "/scr/" + getpass.getuser() | |
if not os.path.exists(scr_dir): | |
os.makedirs(scr_dir) | |
return scr_dir | |
def uuid(digits=4): | |
if not hasattr(uuid, "uuid_value"): | |
uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10**digits) | |
return uuid.uuid_value | |
def formatted_timestamp(time=None): | |
if time is None: | |
time = datetime.datetime.now() | |
return time.strftime("%d/%m/%Y-%H:%M:%S/%f") | |
def time_delta_seconds(start, finish=None): | |
assert type(start) == str | |
t1 = datetime.datetime.strptime(start, "%d/%m/%Y-%H:%M:%S/%f") | |
if finish is not None: | |
assert type(finish) == str | |
t2 = datetime.datetime.strptime(finish, "%d/%m/%Y-%H:%M:%S/%f") | |
else: | |
t2 = datetime.datetime.now() | |
return (t2 - t1).total_seconds() | |
def dict_to(d, device): | |
new_dict = {} | |
for k, v in d.items(): | |
if isinstance(v, torch.Tensor): | |
new_dict[k] = v.to(device) | |
elif isinstance(v, dict): | |
new_dict[k] = dict_to(v, device) | |
else: | |
new_dict[k] = v | |
return new_dict | |
def safe_backward(loss, parameters, accumulate=1, allow_unused=False, backward=False): | |
if backward: | |
(loss / accumulate).backward() | |
else: | |
parameters = list(parameters) # Capture the generator output | |
grads = torch.autograd.grad(loss, parameters, allow_unused=allow_unused) | |
nan, inf = False, False | |
for g in grads: | |
if g is not None: | |
nan |= g.isnan().any().item() | |
inf |= g.isinf().any().item() | |
if not (nan or inf): | |
for p, g in zip(parameters, grads): | |
if g is None: | |
continue | |
if p.grad is None: | |
p.grad = g / accumulate | |
else: | |
p.grad += g / accumulate | |
else: | |
LOG.info(f"Skipping grad accumulation because inf: {inf} nan: {nan}") | |
def _logits(x): | |
if hasattr(x, "logits"): | |
return x.logits | |
elif hasattr(x, "scores"): | |
return torch.cat(x.scores).unsqueeze(0) | |
return x | |
def _last_encoder_state(x): | |
if hasattr(x, "encoder_last_hidden_state"): | |
return x.encoder_last_hidden_state | |
elif hasattr(x, "encoder_hidden_states"): | |
return x.encoder_hidden_states[-1] | |
else: | |
return x.hidden_states[-1] | |
def load_archive(path): | |
import torch | |
if not os.path.exists(path): | |
# We've not passed an explicit path, but a part of the filename | |
wd = '/iris/u/clin/code/efk/' | |
directories = ["outputs", "multirun"] | |
matches = [] | |
for d in directories: | |
search = os.path.join(wd, d) | |
for run_dir in os.listdir(search): | |
if path in run_dir: | |
matches.append(os.path.join(search, run_dir)) | |
assert len(matches) == 1, f">1 matches for search {path}; specify exact path" | |
full_run_dir = matches[0] | |
if "0" in os.listdir(full_run_dir): | |
full_run_dir = os.path.join(full_run_dir, "0") | |
models_dir = os.path.join(full_run_dir, "models") | |
models = os.listdir(models_dir) | |
non_bk = [m for m in models if not m.endswith(".bk")] | |
assert ( | |
len(non_bk) == 1 | |
), f"Expected a single model in {models_dir}, got {len(non_bk)}" | |
path = os.path.join(models_dir, non_bk[0]) | |
LOG.info(f"Loading checkpoint from {path}") | |
archive = torch.load(path, map_location="cpu") | |
LOG.info("Load complete.") | |
return archive, path | |
def flatten_dict(d): | |
to_process = list(d.items()) | |
output = {} | |
while len(to_process): | |
k, v = to_process.pop() | |
if isinstance(v, typing.MutableMapping): | |
to_process.extend([(f"{k}.{k_}", v_) for (k_, v_) in v.items()]) | |
else: | |
assert k not in output.keys(), "Somehow ended up with duplicate keys" | |
output[k] = v | |
return output | |
def add_padding(tokenizer, model): | |
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
model.resize_token_embeddings(len(tokenizer)) | |
model.transformer.wte.weight.data[-1] = model.transformer.wte.weight.data.mean(0) | |
def add_sep(tokenizer, model): | |
tokenizer.add_special_tokens({'sep_token': '[SEP]'}) | |
# model.resize_token_embeddings(len(tokenizer)) | |
# model.lm_head.weight.data[-1, :] = model.lm_head.weight.data.mean(0) | |
class EarlyStopper: | |
def __init__(self, patience: int, key: str, minimize: bool = False): | |
self.best_value = 1e9 if minimize else -1e9 | |
self.best_iter = 0 | |
self.current_iter = 0 | |
self.key = key | |
self.patience = patience | |
self.minimize = minimize | |
self._stop = False | |
def update(self, idx, stats): | |
assert self.key in stats, f"'{self.key}' not in stats dict" | |
value = stats[self.key] | |
new_best = value < self.best_value if self.minimize else value > self.best_value | |
if new_best: | |
self.best_value = value | |
self.best_iter = idx | |
self.current_iter = idx | |
return new_best | |
def should_stop(self): | |
self._stop |= self.current_iter - self.best_iter >= self.patience | |
return self._stop | |
class RunningStatAverager: | |
def __init__(self, suffix="", exclude=["grad/"], compute_ppl: bool = True): | |
self.underlying = None | |
self.suffix = suffix | |
self.exclude = exclude | |
self.compute_ppl = compute_ppl | |
self.reset() | |
def add(self, d: dict): | |
for k, v in d.items(): | |
if not any([k.startswith(prefix) for prefix in self.exclude]): | |
if len(self.suffix): | |
self.underlying[f"{k}_{self.suffix}"].append(v) | |
else: | |
self.underlying[k].append(v) | |
def average(self): | |
average = {} | |
for k, v in self.underlying.items(): | |
if not k.startswith("nll/"): | |
average[k] = sum(v) / len(v) | |
else: | |
assert len(k.split("/")) == 2, f"Invalid key {k}" | |
name = k.split("/")[1] | |
token_counts = self.underlying[f"n_tokens/{name}"] | |
total_nll = sum([nll * c for nll, c in zip(v, token_counts)]) | |
average[k] = total_nll / sum(token_counts) | |
if self.compute_ppl: | |
average[f"perplexity/{name}"] = math.e ** average[k] | |
return {k: v if not isinstance(v, torch.Tensor) else v.item() for k, v in average.items()} | |
def reset(self): | |
self.underlying = defaultdict(list) | |
class EditBatchSampler: | |
def __init__( | |
self, | |
n, | |
memorize_mode=False, | |
loc_disjoint=True, | |
seed=0, | |
hard_neg=False, | |
hard_neg_prob=1.0, | |
loc_distr_matrix=None, | |
loc_idx_matrix=None, | |
keep_probs=None, | |
mutex=None | |
): | |
self.memorize_mode = memorize_mode | |
self.n = n | |
self.loc_disjoint = loc_disjoint | |
self.rng = np.random.default_rng(seed) | |
self.hard_neg = hard_neg | |
self.hard_neg_prob = hard_neg_prob | |
self.loc_probs = loc_distr_matrix | |
self.loc_idxs = loc_idx_matrix | |
self.keep_probs = np.array(keep_probs)[:self.n] if keep_probs is not None else None | |
self.mutex = mutex[:self.n] if mutex is not None else None | |
self._init() | |
def _init(self): | |
idxs = np.arange(self.n) | |
if self.keep_probs is not None: | |
sample = self.rng.binomial(1, self.keep_probs).astype(np.bool) | |
idxs = idxs[sample] | |
self.perm = self.rng.permutation(idxs) | |
self.edit_position = 0 | |
def get_edit_idxs(self, batch_size): | |
if self.mutex is None: | |
idxs = set([int(idx) for idx in self.perm[self.edit_position: self.edit_position + batch_size]]) | |
self.edit_position += batch_size | |
else: | |
mutexes = [] | |
idxs = [] | |
def notin(x, mutexes): | |
for m in mutexes: | |
if x in m or m in x: | |
return False | |
return True | |
while len(idxs) < batch_size: | |
new_idx = self.perm[self.edit_position] | |
if notin(self.mutex[new_idx], mutexes): | |
mutexes.append(self.mutex[new_idx]) | |
idxs.append(int(new_idx)) | |
self.edit_position += 1 | |
if self.edit_position == self.perm.shape[0]: | |
return None | |
idxs = set(idxs) | |
return idxs | |
def sample(self, batch_size, return_hard_flag=False): | |
if self.memorize_mode: | |
return list(range(batch_size)), list(range(batch_size, batch_size * 2)) | |
if self.edit_position + batch_size >= self.perm.shape[0]: | |
self._init() # Re-start if we end with a partially-sized batch | |
edit_idxs = self.get_edit_idxs(batch_size) | |
if edit_idxs is None: | |
self._init() | |
edit_idxs = self.get_edit_idxs(batch_size) | |
if edit_idxs is None: | |
raise RuntimeError(f"No valid batches of size {batch_size} exist!") | |
if self.hard_neg: | |
assert self.loc_probs is not None, "hard_neg is on, but don't have distance matrix!" | |
def get_loc_idxs(): | |
if self.hard_neg and self.rng.uniform() < self.hard_neg_prob: | |
return [int(self.rng.choice(self.loc_idxs[idx], p=self.loc_probs[idx])) for idx in edit_idxs], True | |
else: | |
# Use deterministic implementation in case edit batches are large | |
non_edit_idxs = list(set(range(self.n)) - set(edit_idxs)) | |
return [int(idx) for idx in self.rng.choice(non_edit_idxs, batch_size)], False | |
loc_idxs, hard = get_loc_idxs() | |
if self.loc_disjoint: | |
steps = 0 | |
while len(edit_idxs.intersection(set(loc_idxs))) > 0: | |
loc_idxs, hard = get_loc_idxs() | |
steps += 1 | |
if steps > 100: | |
raise RuntimeError("Can't find disjoint loc_idxs and edit_idxs!") | |
if return_hard_flag: | |
return list(edit_idxs), loc_idxs, hard | |
else: | |
return list(edit_idxs), loc_idxs | |
def parent_module(model, pname): | |
comps = pname.split('.') | |
parent = model | |
for comp in comps[:-1]: | |
if hasattr(parent, comp): | |
parent = getattr(parent, comp) | |
elif comp.isdigit(): | |
parent = parent[int(comp)] | |
else: | |
raise RuntimeError(f"Couldn't find child module {comp}") | |
assert hasattr(parent, comps[-1]) | |
return parent | |
def build_distr_matrix(edit_qs, config, loc_qs=None, slice_size=1000): | |
n = len(edit_qs) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
num_neighbors = config.data.hard_neg_neighbors | |
num_exclude = config.data.hard_neg_exclude | |
temp = config.data.hard_neg_temp | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.util import pytorch_cos_sim | |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2', cache_folder=scr()).to(device) | |
ind_matrix = torch.zeros((n, num_neighbors - num_exclude), dtype=torch.long) | |
distr_matrix = torch.full((n, num_neighbors - num_exclude), float('nan')) | |
edit_encodings = torch.FloatTensor(embedding_model.encode(edit_qs, batch_size=256)).to(device) | |
# If loc_qs is None then build the similarity matrix between edit_qs and itself | |
loc_encodings = edit_encodings if loc_qs is None else embedding_model.encode(loc_qs, batch_size=256) | |
if isinstance(loc_encodings, np.ndarray): | |
loc_encodings = torch.FloatTensor(loc_encodings).to(device) | |
for idx in range(0, n, slice_size): | |
end_idx = idx + slice_size if idx + slice_size <= n else n | |
slice_encodings = edit_encodings[idx:end_idx] | |
sim_rows = pytorch_cos_sim(slice_encodings, loc_encodings) | |
indices = sim_rows.topk(num_neighbors, -1).indices[:, num_exclude:] | |
ind_matrix[idx:end_idx] = indices.cpu() | |
distr_matrix[idx:end_idx] = sim_rows.gather(-1, indices).mul(temp).exp().cpu() | |
assert not torch.isnan(distr_matrix).any() | |
LOG.info(f"Built hard negative distribution matrix of size {distr_matrix.shape}") | |
distr_matrix = distr_matrix.numpy() | |
distr_matrix = distr_matrix / distr_matrix.sum(-1, keepdims=True) | |
return distr_matrix, ind_matrix.numpy() | |