|
import torch |
|
import torch.nn.functional as F |
|
from transformers import AutoConfig |
|
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config |
|
from .sampling import cosine_schedule, mask_by_random_topk |
|
from .phi import PhiForCausalLM |
|
|
|
try: |
|
import xformers.ops as xops |
|
|
|
is_xformers_available = True |
|
except ImportError: |
|
is_xformers_available = False |
|
|
|
|
|
class Showo(ModelMixin, ConfigMixin): |
|
_supports_gradient_checkpointing = True |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
w_clip_vit, |
|
vocab_size, |
|
llm_vocab_size, |
|
llm_model_path='', |
|
codebook_size=8192, |
|
num_vq_tokens=256, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.vocab_size = vocab_size |
|
self.register_to_config(mask_token_id=vocab_size - 1) |
|
config = AutoConfig.from_pretrained(llm_model_path) |
|
self.showo = PhiForCausalLM(config) |
|
self.showo.resize_token_embeddings(self.vocab_size) |
|
self.output_size = self.vocab_size |
|
|
|
if self.w_clip_vit: |
|
self.mm_projector = torch.nn.Sequential( |
|
torch.nn.Linear(1024, 2048), |
|
torch.nn.GELU(), |
|
torch.nn.Linear(2048, 2048) |
|
) |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
self.gradient_checkpointing = True |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
input_embeddings=None, |
|
attention_mask=None, |
|
labels=None, |
|
label_smoothing=0.0, |
|
config=None, |
|
labels_mask_text=None, |
|
labels_mask_image=None, |
|
**kwargs, |
|
): |
|
|
|
if input_embeddings is None: |
|
logits = self.showo(input_ids=input_ids, attention_mask=attention_mask)['logits'] |
|
else: |
|
logits = self.showo(inputs_embeds=input_embeddings, attention_mask=attention_mask)['logits'] |
|
|
|
if labels is not None: |
|
raise NotImplementedError |
|
|
|
return logits |
|
|
|
def t2i_generate( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
uncond_input_ids: torch.LongTensor = None, |
|
attention_mask=None, |
|
temperature=1.0, |
|
timesteps=18, |
|
guidance_scale=0, |
|
noise_schedule=cosine_schedule, |
|
generator: torch.Generator = None, |
|
uni_prompting=None, |
|
config=None, |
|
**kwargs, |
|
): |
|
""" |
|
Generate 1:1 similar to the original MaskGit repo |
|
https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 |
|
""" |
|
|
|
mask_token_id = self.config.mask_token_id |
|
seq_len = config.model.showo.num_vq_tokens |
|
|
|
input_ids_minus_lm_vocab_size = input_ids[:, -(seq_len + 1):-1].clone() |
|
input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, |
|
mask_token_id, |
|
input_ids_minus_lm_vocab_size - config.model.showo.llm_vocab_size - 10) |
|
|
|
|
|
if uncond_input_ids is not None: |
|
uncond_prefix = uncond_input_ids[:, :config.dataset.preprocessing.max_seq_length + 1] |
|
|
|
for step in range(timesteps): |
|
if uncond_input_ids is not None and guidance_scale > 0: |
|
uncond_input_ids = torch.cat( |
|
[uncond_prefix, input_ids[:, config.dataset.preprocessing.max_seq_length + 1:]], dim=1) |
|
model_input = torch.cat([input_ids, uncond_input_ids]) |
|
cond_logits, uncond_logits = self(model_input, attention_mask=attention_mask).chunk(2) |
|
|
|
|
|
logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits |
|
logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1] |
|
else: |
|
logits = self(input_ids, attention_mask=attention_mask) |
|
logits = logits[:, -(seq_len + 1):-1, config.model.showo.llm_vocab_size + 10:-1] |
|
|
|
probs = logits.softmax(dim=-1) |
|
sampled = probs.reshape(-1, logits.size(-1)) |
|
sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) |
|
|
|
unknown_map = input_ids_minus_lm_vocab_size == mask_token_id |
|
sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) |
|
|
|
|
|
ratio = 1.0 * (step + 1) / timesteps |
|
mask_ratio = noise_schedule(torch.tensor(ratio)) |
|
|
|
selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) |
|
selected_probs = selected_probs.squeeze(-1) |
|
|
|
|
|
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) |
|
|
|
mask_len = (seq_len * mask_ratio).floor().unsqueeze(0).to(logits.device) |
|
|
|
|
|
mask_len = torch.max( |
|
torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) |
|
) |
|
|
|
temperature = temperature * (1.0 - ratio) |
|
masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) |
|
|
|
input_ids[:, -(seq_len + 1):-1] = torch.where(masking, mask_token_id, |
|
sampled_ids + config.model.showo.llm_vocab_size + 10) |
|
input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) |
|
|
|
return sampled_ids |
|
|
|
@torch.no_grad() |
|
def mmu_generate(self, idx=None, input_embeddings=None, attention_mask=None, max_new_tokens=100, temperature=1.0, top_k=None, eot_token=None): |
|
""" |
|
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete |
|
the sequence max_new_tokens times, feeding the predictions back into the model each time. |
|
Most likely you'll want to make sure to be in model.eval() mode of operation for this. |
|
""" |
|
try: |
|
device = idx.device |
|
except: |
|
device = input_embeddings.device |
|
|
|
result = [] |
|
for _ in range(max_new_tokens): |
|
|
|
|
|
|
|
|
|
logits = self(idx, input_embeddings=input_embeddings, attention_mask=attention_mask) |
|
|
|
L = attention_mask.shape[-1] |
|
attention_mask = attention_mask.squeeze() |
|
attention_mask_a = torch.hstack( |
|
[ |
|
attention_mask, |
|
torch.zeros((L, 1)).to(device) + torch.finfo(logits.dtype).min, |
|
] |
|
) |
|
attention_mask_b = torch.vstack( |
|
[ |
|
attention_mask_a, |
|
torch.hstack([attention_mask[-1, :], torch.tensor([0]).to(device)]).unsqueeze(0), |
|
] |
|
) |
|
attention_mask = attention_mask_b |
|
|
|
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
if top_k is not None: |
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
result.append(idx_next[0][0]) |
|
|
|
if self.config.w_clip_vit: |
|
idx_next_embeddings = self.showo.model.embed_tokens(idx_next) |
|
input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1) |
|
else: |
|
idx = torch.cat((idx, idx_next), dim=1) |
|
|
|
if eot_token is not None and idx_next.cpu() == eot_token: |
|
break |
|
|
|
return result |
|
|