import torch import torch.nn.functional as F from transformers import AutoConfig from .modeling_utils import ConfigMixin, ModelMixin, register_to_config from .sampling import cosine_schedule, mask_by_random_topk from .phi import PhiForCausalLM try: import xformers.ops as xops is_xformers_available = True except ImportError: is_xformers_available = False class Showo(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, w_clip_vit, vocab_size, llm_vocab_size, llm_model_path='', codebook_size=8192, num_vq_tokens=256, **kwargs, ): super().__init__() self.vocab_size = vocab_size self.register_to_config(mask_token_id=vocab_size - 1) config = AutoConfig.from_pretrained(llm_model_path) self.showo = PhiForCausalLM(config) self.showo.resize_token_embeddings(self.vocab_size) self.output_size = self.vocab_size if self.w_clip_vit: self.mm_projector = torch.nn.Sequential( torch.nn.Linear(1024, 2048), torch.nn.GELU(), torch.nn.Linear(2048, 2048) ) def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = True def forward( self, input_ids, input_embeddings=None, attention_mask=None, labels=None, label_smoothing=0.0, config=None, labels_mask_text=None, labels_mask_image=None, **kwargs, ): if input_embeddings is None: logits = self.showo(input_ids=input_ids, attention_mask=attention_mask)['logits'] else: logits = self.showo(inputs_embeds=input_embeddings, attention_mask=attention_mask)['logits'] if labels is not None: raise NotImplementedError return logits def t2i_generate( self, input_ids: torch.LongTensor = None, uncond_input_ids: torch.LongTensor = None, attention_mask=None, temperature=1.0, timesteps=18, # ideal number of steps is 18 in maskgit paper guidance_scale=0, noise_schedule=cosine_schedule, generator: torch.Generator = None, uni_prompting=None, config=None, **kwargs, ): """ Generate 1:1 similar to the original MaskGit repo https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 """ # begin with all image token ids masked mask_token_id = self.config.mask_token_id seq_len = config.model.showo.num_vq_tokens input_ids_minus_lm_vocab_size = input_ids[:, -(seq_len + 1):-1].clone() input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - config.model.showo.llm_vocab_size - 10) # import ipdb # ipdb.set_trace() if uncond_input_ids is not None: uncond_prefix = uncond_input_ids[:, :config.dataset.preprocessing.max_seq_length + 1] for step in range(timesteps): if uncond_input_ids is not None and guidance_scale > 0: uncond_input_ids = torch.cat( [uncond_prefix, input_ids[:, config.dataset.preprocessing.max_seq_length + 1:]], dim=1) model_input = torch.cat([input_ids, uncond_input_ids]) cond_logits, uncond_logits = self(model_input, attention_mask=attention_mask).chunk(2) # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) # it seems that muse has different cfg setting logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1] else: logits = self(input_ids, attention_mask=attention_mask) logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1] probs = logits.softmax(dim=-1) sampled = probs.reshape(-1, logits.size(-1)) sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) unknown_map = input_ids_minus_lm_vocab_size == mask_token_id sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) # Defines the mask ratio for the next round. The number to mask out is # determined by mask_ratio * unknown_number_in_the_beginning. ratio = 1.0 * (step + 1) / timesteps mask_ratio = noise_schedule(torch.tensor(ratio)) # Computes the probabilities of each selected tokens. selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) selected_probs = selected_probs.squeeze(-1) # Ignores the tokens given in the input by overwriting their confidence. selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) # Gets mask lens for each sample in the batch according to the mask ratio. mask_len = (seq_len * mask_ratio).floor().unsqueeze(0).to(logits.device) # Keeps at least one of prediction in this round and also masks out at least # one and for the next iteration mask_len = torch.max( torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) ) # Adds noise for randomness temperature = temperature * (1.0 - ratio) masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) # Masks tokens with lower confidence. input_ids[:, -(seq_len + 1):-1] = torch.where(masking, mask_token_id, sampled_ids + config.model.showo.llm_vocab_size + 10) input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) return sampled_ids @torch.no_grad() def mmu_generate(self, idx=None, input_embeddings=None, attention_mask=None, max_new_tokens=100, temperature=1.0, top_k=None, eot_token=None): """ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. Most likely you'll want to make sure to be in model.eval() mode of operation for this. """ try: device = idx.device except: device = input_embeddings.device result = [] for _ in range(max_new_tokens): # if the sequence context is growing too long we must crop it at block_size # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] # forward the model to get the logits for the index in the sequence # logits, _ = self(idx_cond) logits = self(idx, input_embeddings=input_embeddings, attention_mask=attention_mask) L = attention_mask.shape[-1] attention_mask = attention_mask.squeeze() attention_mask_a = torch.hstack( [ attention_mask, # L, L torch.zeros((L, 1)).to(device) + torch.finfo(logits.dtype).min, ] ) attention_mask_b = torch.vstack( [ attention_mask_a, # L, L+1 torch.hstack([attention_mask[-1, :], torch.tensor([0]).to(device)]).unsqueeze(0), ] ) attention_mask = attention_mask_b # pluck the logits at the final step and scale by desired temperature logits = logits[:, -1, :] / temperature # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) # sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) result.append(idx_next[0][0]) # append sampled index to the running sequence and continue if self.config.w_clip_vit: idx_next_embeddings = self.showo.model.embed_tokens(idx_next) input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1) else: idx = torch.cat((idx, idx_next), dim=1) if eot_token is not None and idx_next.cpu() == eot_token: break return result