Show-o / models /modeling_showo.py
JosephPai
init
8741abe
raw
history blame
9.23 kB
import torch
import torch.nn.functional as F
from transformers import AutoConfig
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
from .sampling import cosine_schedule, mask_by_random_topk
from .phi import PhiForCausalLM
try:
import xformers.ops as xops
is_xformers_available = True
except ImportError:
is_xformers_available = False
class Showo(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
w_clip_vit,
vocab_size,
llm_vocab_size,
llm_model_path='',
codebook_size=8192,
num_vq_tokens=256,
**kwargs,
):
super().__init__()
self.vocab_size = vocab_size
self.register_to_config(mask_token_id=vocab_size - 1)
config = AutoConfig.from_pretrained(llm_model_path)
self.showo = PhiForCausalLM(config)
self.showo.resize_token_embeddings(self.vocab_size)
self.output_size = self.vocab_size
if self.w_clip_vit:
self.mm_projector = torch.nn.Sequential(
torch.nn.Linear(1024, 2048),
torch.nn.GELU(),
torch.nn.Linear(2048, 2048)
)
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = True
def forward(
self,
input_ids,
input_embeddings=None,
attention_mask=None,
labels=None,
label_smoothing=0.0,
config=None,
labels_mask_text=None,
labels_mask_image=None,
**kwargs,
):
if input_embeddings is None:
logits = self.showo(input_ids=input_ids, attention_mask=attention_mask)['logits']
else:
logits = self.showo(inputs_embeds=input_embeddings, attention_mask=attention_mask)['logits']
if labels is not None:
raise NotImplementedError
return logits
def t2i_generate(
self,
input_ids: torch.LongTensor = None,
uncond_input_ids: torch.LongTensor = None,
attention_mask=None,
temperature=1.0,
timesteps=18, # ideal number of steps is 18 in maskgit paper
guidance_scale=0,
noise_schedule=cosine_schedule,
generator: torch.Generator = None,
uni_prompting=None,
config=None,
**kwargs,
):
"""
Generate 1:1 similar to the original MaskGit repo
https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79
"""
# begin with all image token ids masked
mask_token_id = self.config.mask_token_id
seq_len = config.model.showo.num_vq_tokens
input_ids_minus_lm_vocab_size = input_ids[:, -(seq_len + 1):-1].clone()
input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id,
mask_token_id,
input_ids_minus_lm_vocab_size - config.model.showo.llm_vocab_size - 10)
# import ipdb
# ipdb.set_trace()
if uncond_input_ids is not None:
uncond_prefix = uncond_input_ids[:, :config.dataset.preprocessing.max_seq_length + 1]
for step in range(timesteps):
if uncond_input_ids is not None and guidance_scale > 0:
uncond_input_ids = torch.cat(
[uncond_prefix, input_ids[:, config.dataset.preprocessing.max_seq_length + 1:]], dim=1)
model_input = torch.cat([input_ids, uncond_input_ids])
cond_logits, uncond_logits = self(model_input, attention_mask=attention_mask).chunk(2)
# logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
# it seems that muse has different cfg setting
logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1]
else:
logits = self(input_ids, attention_mask=attention_mask)
logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1]
probs = logits.softmax(dim=-1)
sampled = probs.reshape(-1, logits.size(-1))
sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1])
unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size)
# Defines the mask ratio for the next round. The number to mask out is
# determined by mask_ratio * unknown_number_in_the_beginning.
ratio = 1.0 * (step + 1) / timesteps
mask_ratio = noise_schedule(torch.tensor(ratio))
# Computes the probabilities of each selected tokens.
selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
selected_probs = selected_probs.squeeze(-1)
# Ignores the tokens given in the input by overwriting their confidence.
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
# Gets mask lens for each sample in the batch according to the mask ratio.
mask_len = (seq_len * mask_ratio).floor().unsqueeze(0).to(logits.device)
# Keeps at least one of prediction in this round and also masks out at least
# one and for the next iteration
mask_len = torch.max(
torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
)
# Adds noise for randomness
temperature = temperature * (1.0 - ratio)
masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
# Masks tokens with lower confidence.
input_ids[:, -(seq_len + 1):-1] = torch.where(masking, mask_token_id,
sampled_ids + config.model.showo.llm_vocab_size + 10)
input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)
return sampled_ids
@torch.no_grad()
def mmu_generate(self, idx=None, input_embeddings=None, attention_mask=None, max_new_tokens=100, temperature=1.0, top_k=None, eot_token=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
try:
device = idx.device
except:
device = input_embeddings.device
result = []
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
# idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
# forward the model to get the logits for the index in the sequence
# logits, _ = self(idx_cond)
logits = self(idx, input_embeddings=input_embeddings, attention_mask=attention_mask)
L = attention_mask.shape[-1]
attention_mask = attention_mask.squeeze()
attention_mask_a = torch.hstack(
[
attention_mask, # L, L
torch.zeros((L, 1)).to(device) + torch.finfo(logits.dtype).min,
]
)
attention_mask_b = torch.vstack(
[
attention_mask_a, # L, L+1
torch.hstack([attention_mask[-1, :], torch.tensor([0]).to(device)]).unsqueeze(0),
]
)
attention_mask = attention_mask_b
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
result.append(idx_next[0][0])
# append sampled index to the running sequence and continue
if self.config.w_clip_vit:
idx_next_embeddings = self.showo.model.embed_tokens(idx_next)
input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1)
else:
idx = torch.cat((idx, idx_next), dim=1)
if eot_token is not None and idx_next.cpu() == eot_token:
break
return result