import os from pprint import pprint from tqdm import tqdm import torch import torch.distributed as dist import timm.models.hub as timm_hub def drop_sequence_mask(N, S, device, p=0.1, training=True): if training: mask = torch.rand((N, S), device=device) mask = mask > p mask[torch.arange(N), torch.randint(S, (N, ))] = True # keep at least one token mask = mask.long() assert torch.all(torch.sum(mask, dim=1) > 0) else: mask = torch.ones((N, S), dtype=torch.long).to(device) return mask def cat_pad(x, cat_dim, pad_dim, pad_val=0): l_max = max([xi.shape[pad_dim] for xi in x]) for i, xi in enumerate(x): l_diff = l_max - xi.shape[pad_dim] if l_diff > 0: shape = list(xi.shape) shape[pad_dim] = l_diff p = torch.full(shape, pad_val, dtype=xi.dtype, device=xi.device) xi = torch.cat([xi, p], dim=pad_dim) x[i] = xi x = torch.cat(x, dim=cat_dim) return x def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def download_cached_file(url, check_hash=True, progress=False): """ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. """ def get_cached_file_path(): # a hack to sync the file path across processes parts = torch.hub.urlparse(url) filename = os.path.basename(parts.path) cached_file = os.path.join(timm_hub.get_cache_dir(), filename) return cached_file if is_main_process(): timm_hub.download_cached_file(url, check_hash, progress) if is_dist_avail_and_initialized(): dist.barrier() return get_cached_file_path() def trim_ckpt(ckpt_input, ckpt_output, extra_keys=()): kept_keys = ('llm_proj', 'knwl', 'qformer', 'ln_vision', 'query_tokens') + extra_keys ckpt = torch.load(ckpt_input, map_location="cpu") ckpt = { ".".join(n.split(".")[2:]): v for n, v in tqdm(ckpt["module"].items(), dynamic_ncols=True) if any([k in n for k in kept_keys]) } print("Kept params:") pprint(list(ckpt.keys())) torch.save(ckpt, ckpt_output)