|
from transformers import PretrainedConfig |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
import torch |
|
from safetensors.torch import save_file |
|
import os |
|
from timm.models.vision_transformer import Block |
|
from .mar import MAR |
|
|
|
class MARConfig(PretrainedConfig): |
|
model_type = "mar" |
|
|
|
def __init__(self, |
|
img_size=256, |
|
vae_stride=16, |
|
patch_size=1, |
|
encoder_embed_dim=1024, |
|
encoder_depth=16, |
|
encoder_num_heads=16, |
|
decoder_embed_dim=1024, |
|
decoder_depth=16, |
|
decoder_num_heads=16, |
|
mlp_ratio=4., |
|
norm_layer="LayerNorm", |
|
vae_embed_dim=16, |
|
mask_ratio_min=0.7, |
|
label_drop_prob=0.1, |
|
class_num=1000, |
|
attn_dropout=0.1, |
|
proj_dropout=0.1, |
|
buffer_size=64, |
|
diffloss_d=3, |
|
diffloss_w=1024, |
|
num_sampling_steps='100', |
|
diffusion_batch_mul=4, |
|
grad_checkpointing=False, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
self.img_size = img_size |
|
self.vae_stride = vae_stride |
|
self.patch_size = patch_size |
|
self.encoder_embed_dim = encoder_embed_dim |
|
self.encoder_depth = encoder_depth |
|
self.encoder_num_heads = encoder_num_heads |
|
self.decoder_embed_dim = decoder_embed_dim |
|
self.decoder_depth = decoder_depth |
|
self.decoder_num_heads = decoder_num_heads |
|
self.mlp_ratio = mlp_ratio |
|
self.norm_layer = norm_layer |
|
self.vae_embed_dim = vae_embed_dim |
|
self.mask_ratio_min = mask_ratio_min |
|
self.label_drop_prob = label_drop_prob |
|
self.class_num = class_num |
|
self.attn_dropout = attn_dropout |
|
self.proj_dropout = proj_dropout |
|
self.buffer_size = buffer_size |
|
self.diffloss_d = diffloss_d |
|
self.diffloss_w = diffloss_w |
|
self.num_sampling_steps = num_sampling_steps |
|
self.diffusion_batch_mul = diffusion_batch_mul |
|
self.grad_checkpointing = grad_checkpointing |
|
|
|
|
|
|
|
class MARModel(PreTrainedModel): |
|
|
|
config_class = MARConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
norm_layer = getattr(nn, config.norm_layer) |
|
|
|
|
|
self.model = MAR( |
|
img_size=config.img_size, |
|
vae_stride=config.vae_stride, |
|
patch_size=config.patch_size, |
|
encoder_embed_dim=config.encoder_embed_dim, |
|
encoder_depth=config.encoder_depth, |
|
encoder_num_heads=config.encoder_num_heads, |
|
decoder_embed_dim=config.decoder_embed_dim, |
|
decoder_depth=config.decoder_depth, |
|
decoder_num_heads=config.decoder_num_heads, |
|
mlp_ratio=config.mlp_ratio, |
|
norm_layer=norm_layer, |
|
vae_embed_dim=config.vae_embed_dim, |
|
mask_ratio_min=config.mask_ratio_min, |
|
label_drop_prob=config.label_drop_prob, |
|
class_num=config.class_num, |
|
attn_dropout=config.attn_dropout, |
|
proj_dropout=config.proj_dropout, |
|
buffer_size=config.buffer_size, |
|
diffloss_d=config.diffloss_d, |
|
diffloss_w=config.diffloss_w, |
|
num_sampling_steps=config.num_sampling_steps, |
|
diffusion_batch_mul=config.diffusion_batch_mul, |
|
grad_checkpointing=config.grad_checkpointing, |
|
) |
|
|
|
def forward(self, imgs, labels): |
|
|
|
return self.model(imgs, labels) |
|
|
|
def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False): |
|
|
|
return self.model.sample_tokens(bsz, num_iter, cfg, cfg_schedule, labels, temperature, progress) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
model = cls(config) |
|
state_dict = torch.load('./checkpoint-last.safetensors') |
|
model.model.load_state_dict(state_dict) |
|
return model |
|
|
|
def save_pretrained(self, save_directory): |
|
|
|
os.makedirs(save_directory, exist_ok=True) |
|
state_dict = self.model.state_dict() |
|
safetensors_path = os.path.join(save_directory, "pytorch_model.safetensors") |
|
save_file(state_dict, safetensors_path) |
|
|
|
|
|
self.config.save_pretrained(save_directory) |
|
|