File size: 7,118 Bytes
8741abe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import math
import random
import torch
import torch.nn.functional as F
from omegaconf import DictConfig, ListConfig, OmegaConf
from typing import Any, List, Tuple, Union
# config utils
def get_config():
cli_conf = OmegaConf.from_cli()
yaml_conf = OmegaConf.load(cli_conf.config)
conf = OmegaConf.merge(yaml_conf, cli_conf)
return conf
def flatten_omega_conf(cfg: Any, resolve: bool = False) -> List[Tuple[str, Any]]:
ret = []
def handle_dict(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]:
return [(f"{key}.{k1}", v1) for k1, v1 in flatten_omega_conf(value, resolve=resolve)]
def handle_list(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]:
return [(f"{key}.{idx}", v1) for idx, v1 in flatten_omega_conf(value, resolve=resolve)]
if isinstance(cfg, DictConfig):
for k, v in cfg.items_ex(resolve=resolve):
if isinstance(v, DictConfig):
ret.extend(handle_dict(k, v, resolve=resolve))
elif isinstance(v, ListConfig):
ret.extend(handle_list(k, v, resolve=resolve))
ret.append((str(k), v))
elif isinstance(cfg, ListConfig):
for idx, v in enumerate(cfg._iter_ex(resolve=resolve)):
if isinstance(v, DictConfig):
ret.extend(handle_dict(idx, v, resolve=resolve))
elif isinstance(v, ListConfig):
ret.extend(handle_list(idx, v, resolve=resolve))
ret.append((str(idx), v))
assert False
return ret
# training utils
def soft_target_cross_entropy(logits, targets, soft_targets):
# ignore the first token from logits and targets (class id token)
logits = logits[:, 1:]
targets = targets[:, 1:]
logits = logits[..., : soft_targets.shape[-1]]
log_probs = F.log_softmax(logits, dim=-1)
padding_mask = targets.eq(-100)
loss = torch.sum(-soft_targets * log_probs, dim=-1)
loss.masked_fill_(padding_mask, 0.0)
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
loss = loss.sum() / num_active_elements
return loss
def get_loss_weight(t, mask, min_val=0.3):
return 1 - (1 - mask) * ((1 - t) * (1 - min_val))[:, None]
def mask_or_random_replace_tokens(image_tokens, mask_id, config, mask_schedule, is_train=True):
batch_size, seq_len = image_tokens.shape
if not is_train and"eval_mask_ratios", None):
mask_prob = random.choices(, k=batch_size)
mask_prob = torch.tensor(mask_prob, device=image_tokens.device)
# Sample a random timestep for each image
timesteps = torch.rand(batch_size, device=image_tokens.device)
# Sample a random mask probability for each image using timestep and cosine schedule
mask_prob = mask_schedule(timesteps)
mask_prob = mask_prob.clip(
# creat a random mask for each image
num_token_masked = (seq_len * mask_prob).round().clamp(min=1)
mask_contiguous_region_prob ="mask_contiguous_region_prob", None)
if mask_contiguous_region_prob is None:
mask_contiguous_region = False
mask_contiguous_region = random.random() < mask_contiguous_region_prob
if not mask_contiguous_region:
batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1)
mask = batch_randperm < num_token_masked.unsqueeze(-1)
resolution = int(seq_len ** 0.5)
mask = torch.zeros((batch_size, resolution, resolution), device=image_tokens.device)
# TODO - would be nice to vectorize
for batch_idx, num_token_masked_ in enumerate(num_token_masked):
num_token_masked_ = int(num_token_masked_.item())
# NOTE: a bit handwavy with the bounds but gets a rectangle of ~num_token_masked_
num_token_masked_height = random.randint(
math.ceil(num_token_masked_ / resolution), min(resolution, num_token_masked_)
num_token_masked_height = min(num_token_masked_height, resolution)
num_token_masked_width = math.ceil(num_token_masked_ / num_token_masked_height)
num_token_masked_width = min(num_token_masked_width, resolution)
start_idx_height = random.randint(0, resolution - num_token_masked_height)
start_idx_width = random.randint(0, resolution - num_token_masked_width)
start_idx_height: start_idx_height + num_token_masked_height,
start_idx_width: start_idx_width + num_token_masked_width,
] = 1
mask = mask.reshape(batch_size, seq_len)
mask =
# mask images and create input and labels
if"noise_type", "mask"):
input_ids = torch.where(mask, mask_id, image_tokens)
elif"noise_type", "random_replace"):
# sample random tokens from the vocabulary
random_tokens = torch.randint_like(
image_tokens, low=0, high=config.model.codebook_size, device=image_tokens.device
input_ids = torch.where(mask, random_tokens, image_tokens)
raise ValueError(f"noise_type {} not supported")
if ("predict_all_tokens", False)
or"noise_type", "mask") == "random_replace"
labels = image_tokens
loss_weight = get_loss_weight(mask_prob, mask.long())
labels = torch.where(mask, image_tokens, -100)
loss_weight = None
return input_ids, labels, loss_weight, mask_prob
# misc
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
from torchvision import transforms
def image_transform(image, resolution=256, normalize=True):
image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)(image)
image = transforms.CenterCrop((resolution, resolution))(image)
image = transforms.ToTensor()(image)
if normalize:
image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
return image |