Spaces:
Running
on
Zero
Running
on
Zero
| """This file contains implementation for MaskGIT model. | |
| Copyright (2024) Bytedance Ltd. and/or its affiliates | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| Reference: | |
| https://github.com/huggingface/open-muse | |
| https://github.com/baaivision/MUSE-Pytorch | |
| https://github.com/sail-sg/MDT/blob/main/masked_diffusion/models.py | |
| """ | |
| import torch | |
| from torch import nn | |
| import numpy as np | |
| import math | |
| import torch.utils.checkpoint | |
| from transformers import BertConfig, BertModel | |
| from einops import rearrange | |
| import json | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from omegaconf import OmegaConf | |
| from pathlib import Path | |
| from modeling.modules.base_model import BaseModel | |
| from modeling.modules.blocks import UViTBlock | |
| class ImageBert(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-generation"], repo_url="https://github.com/bytedance/1d-tokenizer", license="apache-2.0"): | |
| def __init__(self, config): | |
| if isinstance(config, dict): | |
| config = OmegaConf.create(config) | |
| super().__init__() | |
| self.config = config | |
| self.target_codebook_size = config.model.vq_model.codebook_size | |
| self.condition_num_classes = config.model.generator.condition_num_classes | |
| self.image_seq_len = config.model.generator.image_seq_len | |
| self.mask_token_id = self.target_codebook_size | |
| self.hidden_size = config.model.generator.hidden_size | |
| self.num_hidden_layers = config.model.generator.num_hidden_layers | |
| self.num_attention_heads = config.model.generator.num_attention_heads | |
| self.intermediate_size = config.model.generator.intermediate_size | |
| self.model = BertModel(BertConfig( | |
| vocab_size=self.target_codebook_size + self.condition_num_classes + 2, | |
| hidden_size=self.hidden_size, | |
| num_hidden_layers=self.num_hidden_layers, | |
| num_attention_heads=self.num_attention_heads, | |
| intermediate_size=self.intermediate_size, | |
| hidden_act='gelu', | |
| hidden_dropout_prob=config.model.generator.dropout, | |
| attention_probs_dropout_prob=config.model.generator.attn_drop, | |
| max_position_embeddings=config.model.generator.image_seq_len + 1, | |
| initializer_range=0.02, | |
| layer_norm_eps=1e-12, | |
| pad_token_id=None, | |
| position_embedding_type="absolute", | |
| use_cache=True | |
| ), add_pooling_layer=False) | |
| self.model.lm_head = nn.Linear(self.hidden_size, self.target_codebook_size, bias=True) | |
| self.model.post_init() | |
| def _save_pretrained(self, save_directory: Path) -> None: | |
| """Save weights and config to a local directory.""" | |
| # Assume 'self.config' is your DictConfig object | |
| # Convert to a regular dictionary | |
| dict_config = OmegaConf.to_container(self.config) | |
| # Save as JSON | |
| file_path = Path(save_directory) / "config.json" | |
| with open(file_path, 'w') as json_file: | |
| json.dump(dict_config, json_file, indent=4) | |
| super()._save_pretrained(save_directory) | |
| def forward(self, input_ids=None, condition=None, cond_drop_prob=0.1): | |
| # Token space: | |
| # [0, codebook_size - 1] : those are the learned quantized image tokens | |
| # codebook_size : the mask token used to mask image tokens | |
| # [codebook_size + 1, codebook_size + nclass] : the imagenet class tokens | |
| # codebook_size + 1 + nclass : the class drop label | |
| drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob | |
| # Shift the classes | |
| condition = condition + self.target_codebook_size + 1 # [0, 999] -> [codebook_size + 1, codebook_size + 999] | |
| condition[drop_label_mask] = self.condition_num_classes + self.target_codebook_size + 1 | |
| # prepend condition token | |
| if input_ids is not None: | |
| input_ids = torch.cat([condition.view(condition.shape[0], -1), | |
| input_ids.view(input_ids.shape[0], -1),], dim=1) | |
| else: | |
| # at least there should be masked token | |
| raise NotImplementedError | |
| model_output = self.model(input_ids=input_ids) | |
| model_output = model_output[0] | |
| return self.model.lm_head(model_output[:, 1:]) # remove cond | |
| # ref: https://github.com/baaivision/MUSE-Pytorch/blob/master/libs/muse.py#L40 | |
| def generate(self, | |
| condition, | |
| guidance_scale=3.0, | |
| guidance_decay="constant", | |
| guidance_scale_pow=3.0, | |
| randomize_temperature=4.5, | |
| softmax_temperature_annealing=False, | |
| num_sample_steps=8): | |
| if guidance_decay not in ["constant", "linear", "power-cosine"]: | |
| # contstant: constant guidance scale | |
| # linear: linear increasing the guidance scale as in MUSE | |
| # power-cosine: the guidance schedule from MDT | |
| raise ValueError(f"Unsupported guidance decay {guidance_decay}") | |
| device = condition.device | |
| ids = torch.full((condition.shape[0], self.image_seq_len), | |
| self.mask_token_id, device=device) | |
| cfg_scale = guidance_scale if guidance_decay == "constant" else 0. | |
| for step in range(num_sample_steps): | |
| ratio = 1. * (step + 1) / num_sample_steps | |
| annealed_temp = randomize_temperature * (1.0 - ratio) | |
| is_mask = (ids == self.mask_token_id) | |
| if guidance_decay == "power-cosine": | |
| # ref: https://github.com/sail-sg/MDT/blob/main/masked_diffusion/models.py#L501 | |
| guidance_scale_pow = torch.ones((1), device=device) * guidance_scale_pow | |
| scale_step = (1 - torch.cos(((step / num_sample_steps) ** guidance_scale_pow) * torch.pi)) * 1/2 | |
| cfg_scale = (guidance_scale - 1) * scale_step + 1 | |
| if cfg_scale != 0: | |
| cond_logits = self.forward( | |
| ids, condition, cond_drop_prob=0.0 | |
| ) | |
| uncond_logits = self.forward( | |
| ids, condition, cond_drop_prob=1.0 | |
| ) | |
| if guidance_decay == "power-cosine": | |
| logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale | |
| else: | |
| logits = cond_logits + (cond_logits - uncond_logits) * cfg_scale | |
| else: | |
| logits = self.forward( | |
| ids, condition, cond_drop_prob=0.0 | |
| ) | |
| if softmax_temperature_annealing: | |
| softmax_temperature = 0.5 + 0.8 * (1 - ratio) | |
| logits = logits / softmax_temperature | |
| # Add gumbel noise | |
| def log(t, eps=1e-20): | |
| return torch.log(t.clamp(min=eps)) | |
| def gumbel_noise(t): | |
| noise = torch.zeros_like(t).uniform_(0, 1) | |
| return -log(-log(noise)) | |
| def add_gumbel_noise(t, temperature): | |
| return t + temperature * gumbel_noise(t) | |
| sampled_ids = add_gumbel_noise(logits, annealed_temp).argmax(dim=-1) | |
| sampled_logits = torch.squeeze( | |
| torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) | |
| sampled_ids = torch.where(is_mask, sampled_ids, ids) | |
| sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float() | |
| # masking | |
| mask_ratio = np.arccos(ratio) / (math.pi * 0.5) | |
| mask_len = torch.Tensor([np.floor(self.image_seq_len * mask_ratio)]).to(device) | |
| mask_len = torch.maximum(torch.Tensor([1]).to(device), | |
| torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1, | |
| mask_len))[0].squeeze() | |
| confidence = add_gumbel_noise(sampled_logits, annealed_temp) | |
| sorted_confidence, _ = torch.sort(confidence, axis=-1) | |
| cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()] | |
| masking = (confidence <= cut_off) | |
| if step == num_sample_steps - 1: | |
| ids = sampled_ids | |
| else: | |
| ids = torch.where(masking, self.mask_token_id, sampled_ids) | |
| if guidance_decay == "linear": | |
| cfg_scale = ratio * guidance_scale | |
| return ids | |
| def masking_input_tokens(self, input_tokens): | |
| batch_size, seq_len = input_tokens.shape | |
| device = input_tokens.device | |
| timesteps = torch.zeros((batch_size,), device=device).float().uniform_(0, 1.0) | |
| mask_ratio = torch.acos(timesteps) / (math.pi * 0.5) # arccos schedule | |
| mask_ratio = torch.clamp(mask_ratio, min=1e-6, max=1.) | |
| num_token_masked = (seq_len * mask_ratio).round().clamp(min=1) | |
| batch_randperm = torch.rand(batch_size, seq_len, device=device).argsort(dim=-1) | |
| masks = batch_randperm < rearrange(num_token_masked, 'b -> b 1') | |
| masked_tokens = torch.where(masks, self.mask_token_id, input_tokens) | |
| return masked_tokens, masks | |
| class UViTBert(ImageBert): | |
| def __init__(self, config): | |
| super().__init__(config=config) | |
| del self.model | |
| self.embeddings = nn.Embedding( | |
| self.target_codebook_size + self.condition_num_classes + 2, | |
| self.hidden_size) | |
| self.pos_embed = nn.init.trunc_normal_( | |
| nn.Parameter(torch.zeros(1, self.config.model.generator.image_seq_len + 1, self.hidden_size)), 0., 0.02) | |
| self.in_blocks = nn.ModuleList([ | |
| UViTBlock( | |
| dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size), | |
| qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, use_checkpoint=False) | |
| for _ in range(self.num_hidden_layers // 2)]) | |
| self.mid_block = UViTBlock( | |
| dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size), | |
| qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, use_checkpoint=False) | |
| self.out_blocks = nn.ModuleList([ | |
| UViTBlock( | |
| dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size), | |
| qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, skip=True, use_checkpoint=False) | |
| for _ in range(self.num_hidden_layers // 2)]) | |
| self.norm = nn.LayerNorm(self.hidden_size) | |
| self.lm_head = nn.Linear(self.hidden_size, | |
| self.target_codebook_size, bias=True) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| nn.init.trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Embedding): | |
| m.weight.data = nn.init.trunc_normal_(m.weight.data, mean=0.0, std=0.02) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def forward(self, input_ids=None, condition=None, cond_drop_prob=0.1): | |
| # Token space: | |
| # [0, codebook_size - 1] : those are the learned quantized image tokens | |
| # codebook_size : the mask token used to mask image tokens | |
| # [codebook_size + 1, codebook_size + nclass] : the imagenet class tokens | |
| # codebook_size + 1 + nclass : the class drop label | |
| drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob | |
| # Shift the classes | |
| condition = condition + self.target_codebook_size + 1 # [0, 999] -> [codebook_size + 1, codebook_size + 999] | |
| condition[drop_label_mask] = self.condition_num_classes + self.target_codebook_size + 1 | |
| # prepend condition token | |
| if input_ids is not None: | |
| input_ids = torch.cat([condition.view(condition.shape[0], -1), | |
| input_ids.view(input_ids.shape[0], -1),], dim=1) | |
| else: | |
| # at least there should be masked token | |
| raise NotImplementedError | |
| # UViT forward | |
| embeddings = self.embeddings(input_ids) | |
| x = embeddings + self.pos_embed[:, :embeddings.shape[1]] | |
| skips = [] | |
| for blk in self.in_blocks: | |
| x = blk(x) | |
| skips.append(x) | |
| x = self.mid_block(x) | |
| for blk in self.out_blocks: | |
| x = blk(x, skips.pop()) | |
| x = self.norm(x) | |
| return self.lm_head(x[:, 1:]) # remove cond |