mar / modeling.py
jadechoghari's picture
Create modeling.py
e8aa2e4 verified
raw
history blame
5 kB
from transformers import PretrainedConfig
import torch.nn as nn
from transformers import PreTrainedModel
import torch
from safetensors.torch import save_file
import os
from timm.models.vision_transformer import Block
from .mar import MAR
class MARConfig(PretrainedConfig):
model_type = "mar"
def __init__(self,
img_size=256,
vae_stride=16,
patch_size=1,
encoder_embed_dim=1024,
encoder_depth=16,
encoder_num_heads=16,
decoder_embed_dim=1024,
decoder_depth=16,
decoder_num_heads=16,
mlp_ratio=4.,
norm_layer="LayerNorm",
vae_embed_dim=16,
mask_ratio_min=0.7,
label_drop_prob=0.1,
class_num=1000,
attn_dropout=0.1,
proj_dropout=0.1,
buffer_size=64,
diffloss_d=3,
diffloss_w=1024,
num_sampling_steps='100',
diffusion_batch_mul=4,
grad_checkpointing=False,
**kwargs):
super().__init__(**kwargs)
# store parameters in the config
self.img_size = img_size
self.vae_stride = vae_stride
self.patch_size = patch_size
self.encoder_embed_dim = encoder_embed_dim
self.encoder_depth = encoder_depth
self.encoder_num_heads = encoder_num_heads
self.decoder_embed_dim = decoder_embed_dim
self.decoder_depth = decoder_depth
self.decoder_num_heads = decoder_num_heads
self.mlp_ratio = mlp_ratio
self.norm_layer = norm_layer
self.vae_embed_dim = vae_embed_dim
self.mask_ratio_min = mask_ratio_min
self.label_drop_prob = label_drop_prob
self.class_num = class_num
self.attn_dropout = attn_dropout
self.proj_dropout = proj_dropout
self.buffer_size = buffer_size
self.diffloss_d = diffloss_d
self.diffloss_w = diffloss_w
self.num_sampling_steps = num_sampling_steps
self.diffusion_batch_mul = diffusion_batch_mul
self.grad_checkpointing = grad_checkpointing
class MARModel(PreTrainedModel):
# links to MARConfig class
config_class = MARConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# convert norm_layer from string to class
norm_layer = getattr(nn, config.norm_layer)
# init the mar model using the parameters from config
self.model = MAR(
img_size=config.img_size,
vae_stride=config.vae_stride,
patch_size=config.patch_size,
encoder_embed_dim=config.encoder_embed_dim,
encoder_depth=config.encoder_depth,
encoder_num_heads=config.encoder_num_heads,
decoder_embed_dim=config.decoder_embed_dim,
decoder_depth=config.decoder_depth,
decoder_num_heads=config.decoder_num_heads,
mlp_ratio=config.mlp_ratio,
norm_layer=norm_layer, # use the actual class for the layer
vae_embed_dim=config.vae_embed_dim,
mask_ratio_min=config.mask_ratio_min,
label_drop_prob=config.label_drop_prob,
class_num=config.class_num,
attn_dropout=config.attn_dropout,
proj_dropout=config.proj_dropout,
buffer_size=config.buffer_size,
diffloss_d=config.diffloss_d,
diffloss_w=config.diffloss_w,
num_sampling_steps=config.num_sampling_steps,
diffusion_batch_mul=config.diffusion_batch_mul,
grad_checkpointing=config.grad_checkpointing,
)
def forward(self, imgs, labels):
# calls the forward method from the mar class - passing imgs & labels
return self.model(imgs, labels)
def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
# call the sample_tokens method from the MAR class
return self.model.sample_tokens(bsz, num_iter, cfg, cfg_schedule, labels, temperature, progress)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model = cls(config)
state_dict = torch.load('./checkpoint-last.safetensors')
model.model.load_state_dict(state_dict)
return model
def save_pretrained(self, save_directory):
# we will save to safetensors
os.makedirs(save_directory, exist_ok=True)
state_dict = self.model.state_dict()
safetensors_path = os.path.join(save_directory, "pytorch_model.safetensors")
save_file(state_dict, safetensors_path)
# save the configuration as usual
self.config.save_pretrained(save_directory)