File size: 9,228 Bytes
8741abe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import torch
import torch.nn.functional as F
from transformers import AutoConfig
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
from .sampling import cosine_schedule, mask_by_random_topk
from .phi import PhiForCausalLM
try:
import xformers.ops as xops
is_xformers_available = True
except ImportError:
is_xformers_available = False
class Showo(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
w_clip_vit,
vocab_size,
llm_vocab_size,
llm_model_path='',
codebook_size=8192,
num_vq_tokens=256,
**kwargs,
):
super().__init__()
self.vocab_size = vocab_size
self.register_to_config(mask_token_id=vocab_size - 1)
config = AutoConfig.from_pretrained(llm_model_path)
self.showo = PhiForCausalLM(config)
self.showo.resize_token_embeddings(self.vocab_size)
self.output_size = self.vocab_size
if self.w_clip_vit:
self.mm_projector = torch.nn.Sequential(
torch.nn.Linear(1024, 2048),
torch.nn.GELU(),
torch.nn.Linear(2048, 2048)
)
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = True
def forward(
self,
input_ids,
input_embeddings=None,
attention_mask=None,
labels=None,
label_smoothing=0.0,
config=None,
labels_mask_text=None,
labels_mask_image=None,
**kwargs,
):
if input_embeddings is None:
logits = self.showo(input_ids=input_ids, attention_mask=attention_mask)['logits']
else:
logits = self.showo(inputs_embeds=input_embeddings, attention_mask=attention_mask)['logits']
if labels is not None:
raise NotImplementedError
return logits
def t2i_generate(
self,
input_ids: torch.LongTensor = None,
uncond_input_ids: torch.LongTensor = None,
attention_mask=None,
temperature=1.0,
timesteps=18, # ideal number of steps is 18 in maskgit paper
guidance_scale=0,
noise_schedule=cosine_schedule,
generator: torch.Generator = None,
uni_prompting=None,
config=None,
**kwargs,
):
"""
Generate 1:1 similar to the original MaskGit repo
https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79
"""
# begin with all image token ids masked
mask_token_id = self.config.mask_token_id
seq_len = config.model.showo.num_vq_tokens
input_ids_minus_lm_vocab_size = input_ids[:, -(seq_len + 1):-1].clone()
input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id,
mask_token_id,
input_ids_minus_lm_vocab_size - config.model.showo.llm_vocab_size - 10)
# import ipdb
# ipdb.set_trace()
if uncond_input_ids is not None:
uncond_prefix = uncond_input_ids[:, :config.dataset.preprocessing.max_seq_length + 1]
for step in range(timesteps):
if uncond_input_ids is not None and guidance_scale > 0:
uncond_input_ids = torch.cat(
[uncond_prefix, input_ids[:, config.dataset.preprocessing.max_seq_length + 1:]], dim=1)
model_input = torch.cat([input_ids, uncond_input_ids])
cond_logits, uncond_logits = self(model_input, attention_mask=attention_mask).chunk(2)
# logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
# it seems that muse has different cfg setting
logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1]
else:
logits = self(input_ids, attention_mask=attention_mask)
logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1]
probs = logits.softmax(dim=-1)
sampled = probs.reshape(-1, logits.size(-1))
sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1])
unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size)
# Defines the mask ratio for the next round. The number to mask out is
# determined by mask_ratio * unknown_number_in_the_beginning.
ratio = 1.0 * (step + 1) / timesteps
mask_ratio = noise_schedule(torch.tensor(ratio))
# Computes the probabilities of each selected tokens.
selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
selected_probs = selected_probs.squeeze(-1)
# Ignores the tokens given in the input by overwriting their confidence.
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
# Gets mask lens for each sample in the batch according to the mask ratio.
mask_len = (seq_len * mask_ratio).floor().unsqueeze(0).to(logits.device)
# Keeps at least one of prediction in this round and also masks out at least
# one and for the next iteration
mask_len = torch.max(
torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
)
# Adds noise for randomness
temperature = temperature * (1.0 - ratio)
masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
# Masks tokens with lower confidence.
input_ids[:, -(seq_len + 1):-1] = torch.where(masking, mask_token_id,
sampled_ids + config.model.showo.llm_vocab_size + 10)
input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)
return sampled_ids
@torch.no_grad()
def mmu_generate(self, idx=None, input_embeddings=None, attention_mask=None, max_new_tokens=100, temperature=1.0, top_k=None, eot_token=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
try:
device = idx.device
except:
device = input_embeddings.device
result = []
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
# idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
# forward the model to get the logits for the index in the sequence
# logits, _ = self(idx_cond)
logits = self(idx, input_embeddings=input_embeddings, attention_mask=attention_mask)
L = attention_mask.shape[-1]
attention_mask = attention_mask.squeeze()
attention_mask_a = torch.hstack(
[
attention_mask, # L, L
torch.zeros((L, 1)).to(device) + torch.finfo(logits.dtype).min,
]
)
attention_mask_b = torch.vstack(
[
attention_mask_a, # L, L+1
torch.hstack([attention_mask[-1, :], torch.tensor([0]).to(device)]).unsqueeze(0),
]
)
attention_mask = attention_mask_b
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
result.append(idx_next[0][0])
# append sampled index to the running sequence and continue
if self.config.w_clip_vit:
idx_next_embeddings = self.showo.model.embed_tokens(idx_next)
input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1)
else:
idx = torch.cat((idx, idx_next), dim=1)
if eot_token is not None and idx_next.cpu() == eot_token:
break
return result
|