Spaces:
Running
on
Zero
Running
on
Zero
"""This file contains implementation for MaskGIT model. | |
Copyright (2024) Bytedance Ltd. and/or its affiliates | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
Reference: | |
https://github.com/huggingface/open-muse | |
https://github.com/baaivision/MUSE-Pytorch | |
https://github.com/sail-sg/MDT/blob/main/masked_diffusion/models.py | |
""" | |
import torch | |
from torch import nn | |
import numpy as np | |
import math | |
import torch.utils.checkpoint | |
from transformers import BertConfig, BertModel | |
from einops import rearrange | |
import json | |
from huggingface_hub import PyTorchModelHubMixin | |
from omegaconf import OmegaConf | |
from pathlib import Path | |
from modeling.modules.base_model import BaseModel | |
from modeling.modules.blocks import UViTBlock | |
class ImageBert(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-generation"], repo_url="https://github.com/bytedance/1d-tokenizer", license="apache-2.0"): | |
def __init__(self, config): | |
if isinstance(config, dict): | |
config = OmegaConf.create(config) | |
super().__init__() | |
self.config = config | |
self.target_codebook_size = config.model.vq_model.codebook_size | |
self.condition_num_classes = config.model.generator.condition_num_classes | |
self.image_seq_len = config.model.generator.image_seq_len | |
self.mask_token_id = self.target_codebook_size | |
self.hidden_size = config.model.generator.hidden_size | |
self.num_hidden_layers = config.model.generator.num_hidden_layers | |
self.num_attention_heads = config.model.generator.num_attention_heads | |
self.intermediate_size = config.model.generator.intermediate_size | |
self.model = BertModel(BertConfig( | |
vocab_size=self.target_codebook_size + self.condition_num_classes + 2, | |
hidden_size=self.hidden_size, | |
num_hidden_layers=self.num_hidden_layers, | |
num_attention_heads=self.num_attention_heads, | |
intermediate_size=self.intermediate_size, | |
hidden_act='gelu', | |
hidden_dropout_prob=config.model.generator.dropout, | |
attention_probs_dropout_prob=config.model.generator.attn_drop, | |
max_position_embeddings=config.model.generator.image_seq_len + 1, | |
initializer_range=0.02, | |
layer_norm_eps=1e-12, | |
pad_token_id=None, | |
position_embedding_type="absolute", | |
use_cache=True | |
), add_pooling_layer=False) | |
self.model.lm_head = nn.Linear(self.hidden_size, self.target_codebook_size, bias=True) | |
self.model.post_init() | |
def _save_pretrained(self, save_directory: Path) -> None: | |
"""Save weights and config to a local directory.""" | |
# Assume 'self.config' is your DictConfig object | |
# Convert to a regular dictionary | |
dict_config = OmegaConf.to_container(self.config) | |
# Save as JSON | |
file_path = Path(save_directory) / "config.json" | |
with open(file_path, 'w') as json_file: | |
json.dump(dict_config, json_file, indent=4) | |
super()._save_pretrained(save_directory) | |
def forward(self, input_ids=None, condition=None, cond_drop_prob=0.1): | |
# Token space: | |
# [0, codebook_size - 1] : those are the learned quantized image tokens | |
# codebook_size : the mask token used to mask image tokens | |
# [codebook_size + 1, codebook_size + nclass] : the imagenet class tokens | |
# codebook_size + 1 + nclass : the class drop label | |
drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob | |
# Shift the classes | |
condition = condition + self.target_codebook_size + 1 # [0, 999] -> [codebook_size + 1, codebook_size + 999] | |
condition[drop_label_mask] = self.condition_num_classes + self.target_codebook_size + 1 | |
# prepend condition token | |
if input_ids is not None: | |
input_ids = torch.cat([condition.view(condition.shape[0], -1), | |
input_ids.view(input_ids.shape[0], -1),], dim=1) | |
else: | |
# at least there should be masked token | |
raise NotImplementedError | |
model_output = self.model(input_ids=input_ids) | |
model_output = model_output[0] | |
return self.model.lm_head(model_output[:, 1:]) # remove cond | |
# ref: https://github.com/baaivision/MUSE-Pytorch/blob/master/libs/muse.py#L40 | |
def generate(self, | |
condition, | |
guidance_scale=3.0, | |
guidance_decay="constant", | |
guidance_scale_pow=3.0, | |
randomize_temperature=4.5, | |
softmax_temperature_annealing=False, | |
num_sample_steps=8): | |
if guidance_decay not in ["constant", "linear", "power-cosine"]: | |
# contstant: constant guidance scale | |
# linear: linear increasing the guidance scale as in MUSE | |
# power-cosine: the guidance schedule from MDT | |
raise ValueError(f"Unsupported guidance decay {guidance_decay}") | |
device = condition.device | |
ids = torch.full((condition.shape[0], self.image_seq_len), | |
self.mask_token_id, device=device) | |
cfg_scale = guidance_scale if guidance_decay == "constant" else 0. | |
for step in range(num_sample_steps): | |
ratio = 1. * (step + 1) / num_sample_steps | |
annealed_temp = randomize_temperature * (1.0 - ratio) | |
is_mask = (ids == self.mask_token_id) | |
if guidance_decay == "power-cosine": | |
# ref: https://github.com/sail-sg/MDT/blob/main/masked_diffusion/models.py#L501 | |
guidance_scale_pow = torch.ones((1), device=device) * guidance_scale_pow | |
scale_step = (1 - torch.cos(((step / num_sample_steps) ** guidance_scale_pow) * torch.pi)) * 1/2 | |
cfg_scale = (guidance_scale - 1) * scale_step + 1 | |
if cfg_scale != 0: | |
cond_logits = self.forward( | |
ids, condition, cond_drop_prob=0.0 | |
) | |
uncond_logits = self.forward( | |
ids, condition, cond_drop_prob=1.0 | |
) | |
if guidance_decay == "power-cosine": | |
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale | |
else: | |
logits = cond_logits + (cond_logits - uncond_logits) * cfg_scale | |
else: | |
logits = self.forward( | |
ids, condition, cond_drop_prob=0.0 | |
) | |
if softmax_temperature_annealing: | |
softmax_temperature = 0.5 + 0.8 * (1 - ratio) | |
logits = logits / softmax_temperature | |
# Add gumbel noise | |
def log(t, eps=1e-20): | |
return torch.log(t.clamp(min=eps)) | |
def gumbel_noise(t): | |
noise = torch.zeros_like(t).uniform_(0, 1) | |
return -log(-log(noise)) | |
def add_gumbel_noise(t, temperature): | |
return t + temperature * gumbel_noise(t) | |
sampled_ids = add_gumbel_noise(logits, annealed_temp).argmax(dim=-1) | |
sampled_logits = torch.squeeze( | |
torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) | |
sampled_ids = torch.where(is_mask, sampled_ids, ids) | |
sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float() | |
# masking | |
mask_ratio = np.arccos(ratio) / (math.pi * 0.5) | |
mask_len = torch.Tensor([np.floor(self.image_seq_len * mask_ratio)]).to(device) | |
mask_len = torch.maximum(torch.Tensor([1]).to(device), | |
torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1, | |
mask_len))[0].squeeze() | |
confidence = add_gumbel_noise(sampled_logits, annealed_temp) | |
sorted_confidence, _ = torch.sort(confidence, axis=-1) | |
cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()] | |
masking = (confidence <= cut_off) | |
if step == num_sample_steps - 1: | |
ids = sampled_ids | |
else: | |
ids = torch.where(masking, self.mask_token_id, sampled_ids) | |
if guidance_decay == "linear": | |
cfg_scale = ratio * guidance_scale | |
return ids | |
def masking_input_tokens(self, input_tokens): | |
batch_size, seq_len = input_tokens.shape | |
device = input_tokens.device | |
timesteps = torch.zeros((batch_size,), device=device).float().uniform_(0, 1.0) | |
mask_ratio = torch.acos(timesteps) / (math.pi * 0.5) # arccos schedule | |
mask_ratio = torch.clamp(mask_ratio, min=1e-6, max=1.) | |
num_token_masked = (seq_len * mask_ratio).round().clamp(min=1) | |
batch_randperm = torch.rand(batch_size, seq_len, device=device).argsort(dim=-1) | |
masks = batch_randperm < rearrange(num_token_masked, 'b -> b 1') | |
masked_tokens = torch.where(masks, self.mask_token_id, input_tokens) | |
return masked_tokens, masks | |
class UViTBert(ImageBert): | |
def __init__(self, config): | |
super().__init__(config=config) | |
del self.model | |
self.embeddings = nn.Embedding( | |
self.target_codebook_size + self.condition_num_classes + 2, | |
self.hidden_size) | |
self.pos_embed = nn.init.trunc_normal_( | |
nn.Parameter(torch.zeros(1, self.config.model.generator.image_seq_len + 1, self.hidden_size)), 0., 0.02) | |
self.in_blocks = nn.ModuleList([ | |
UViTBlock( | |
dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size), | |
qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, use_checkpoint=False) | |
for _ in range(self.num_hidden_layers // 2)]) | |
self.mid_block = UViTBlock( | |
dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size), | |
qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, use_checkpoint=False) | |
self.out_blocks = nn.ModuleList([ | |
UViTBlock( | |
dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size), | |
qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, skip=True, use_checkpoint=False) | |
for _ in range(self.num_hidden_layers // 2)]) | |
self.norm = nn.LayerNorm(self.hidden_size) | |
self.lm_head = nn.Linear(self.hidden_size, | |
self.target_codebook_size, bias=True) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
nn.init.trunc_normal_(m.weight, std=.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.Embedding): | |
m.weight.data = nn.init.trunc_normal_(m.weight.data, mean=0.0, std=0.02) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def forward(self, input_ids=None, condition=None, cond_drop_prob=0.1): | |
# Token space: | |
# [0, codebook_size - 1] : those are the learned quantized image tokens | |
# codebook_size : the mask token used to mask image tokens | |
# [codebook_size + 1, codebook_size + nclass] : the imagenet class tokens | |
# codebook_size + 1 + nclass : the class drop label | |
drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob | |
# Shift the classes | |
condition = condition + self.target_codebook_size + 1 # [0, 999] -> [codebook_size + 1, codebook_size + 999] | |
condition[drop_label_mask] = self.condition_num_classes + self.target_codebook_size + 1 | |
# prepend condition token | |
if input_ids is not None: | |
input_ids = torch.cat([condition.view(condition.shape[0], -1), | |
input_ids.view(input_ids.shape[0], -1),], dim=1) | |
else: | |
# at least there should be masked token | |
raise NotImplementedError | |
# UViT forward | |
embeddings = self.embeddings(input_ids) | |
x = embeddings + self.pos_embed[:, :embeddings.shape[1]] | |
skips = [] | |
for blk in self.in_blocks: | |
x = blk(x) | |
skips.append(x) | |
x = self.mid_block(x) | |
for blk in self.out_blocks: | |
x = blk(x, skips.pop()) | |
x = self.norm(x) | |
return self.lm_head(x[:, 1:]) # remove cond |