from datetime import datetime from typing import Dict, Union, Optional import deepspeed import torch import PIL.Image from torch.nn.functional import softmax, gumbel_softmax from torch import Tensor from transformers import PretrainedConfig, PreTrainedModel, AutoImageProcessor, AutoConfig, AutoModel from transformers import CLIPVisionModel, CLIPImageProcessor from transformers.integrations import is_deepspeed_zero3_enabled from .utils import BEGIN_LINE, END_LINE, rank0_print MODEL_TYPE = "clip_visual_tokenizer" class BaseVisualTokenizerConfig(PretrainedConfig): def __init__(self, vocab_size=16384, tokenize_function="softmax", tau=1.0, depths=None, use_indicators=False, drop_cls_token=False, backbone_config: Optional[Union[PretrainedConfig, dict]] = None, hidden_stride: int = 1, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.tokenize_function = tokenize_function self.tau = tau if isinstance(depths, str): depths = [int(x) for x in depths.split('|')] self.depths = depths self.backbone_kwargs = {} self.use_indicators = use_indicators self.drop_cls_token = drop_cls_token if backbone_config is not None: assert isinstance(backbone_config, (PretrainedConfig, dict)), \ f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type" if not isinstance(backbone_config, PretrainedConfig): model_type = backbone_config['model_type'] backbone_config.pop('model_type') backbone_config = AutoConfig.for_model(model_type, **backbone_config) self.backbone_config = backbone_config self.hidden_stride = hidden_stride class BaseVisualTokenizer(PreTrainedModel): base_model_prefix = "backbone" main_input_name = None _image_processor_class = None _image_processor_kwargs = {} _backbone_class = None _backbone_name_or_path = None def __init__(self, config: BaseVisualTokenizerConfig, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) if kwargs.get('train_from_scratch'): self.image_processor = self._image_processor_class.from_pretrained(self._backbone_name_or_path, **self._image_processor_kwargs) self.backbone = self._backbone_class.from_pretrained(self._backbone_name_or_path, **self.config.backbone_kwargs) self.config.backbone_config = self.backbone.config else: self.image_processor = AutoImageProcessor.from_pretrained(kwargs['image_processor_name_or_path']) self.backbone = AutoModel.from_config(self.config.backbone_config) self.head = None assert all((self.image_processor.do_resize, not getattr(self.image_processor, 'do_center_crop', False), self.image_processor.do_rescale, self.image_processor.do_normalize )), f"image_processor `{self.image_processor}` is not supported currently" def get_backbone(self): return self.backbone def get_monitor_tensors(self): raise NotImplementedError def get_image_processor(self): return self.image_processor def get_head(self): return self.head def get_image_size(self): raise NotImplementedError def preprocess_image(self, image: PIL.Image.Image, convert_to_rgb=True): if convert_to_rgb and image.mode != 'RGB': image = image.convert('RGB') # first resize and preprocess sides = self.get_image_size() if sides[0] != sides[1]: raise ValueError('get_image_size() returns non-square size') side = sides[0] width, height = image.size if width == height: new_width = new_height = side elif width > height: new_width = side new_height = int(height / width * new_width) else: new_height = side new_width = int(width / height * new_height) new_size = dict(height=new_height, width=new_width) pixel_values = self.image_processor.preprocess(image, size=new_size, return_tensors='pt')['pixel_values'] # then pad to square square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device) new_height, new_width = pixel_values.shape[2:] if new_height == new_width: square_values[:, :, :, :] = pixel_values elif new_height > new_width: from_index = (side - new_width) // 2 square_values[:, :, :, from_index:from_index + new_width] = pixel_values else: from_index = (side - new_height) // 2 square_values[:, :, from_index:from_index + new_height, :] = pixel_values return square_values def get_layer_norm(self): return self.layer_norm def tokenize(self, logits): def st_argmax(y_soft, dim): # straight-through softmax index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft return ret if self.config.tokenize_function == 'softmax': tokens = softmax(logits, dim=-1) elif self.config.tokenize_function == 'gumbel_argmax': tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) elif self.config.tokenize_function == 'st_argmax': tokens = st_argmax(logits, dim=-1) else: raise ValueError( f'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax, but got {self.config.tokenize_function}') return tokens class ClipVisualTokenizerConfig(BaseVisualTokenizerConfig): model_type = MODEL_TYPE def __init__(self, **kwargs): super().__init__(**kwargs) if self.depths: assert len(self.depths) == 1 self.backbone_kwargs['num_hidden_layers'] = self.depths[0] class ClipVisualTokenizer(BaseVisualTokenizer): config_class = ClipVisualTokenizerConfig supports_gradient_checkpointing = True _no_split_modules = ["CLIPEncoderLayer"] _image_processor_class = CLIPImageProcessor _image_processor_kwargs = dict(do_center_crop=False) _backbone_class = CLIPVisionModel _backbone_name_or_path = "openai/clip-vit-large-patch14-336" def __init__(self, config: ClipVisualTokenizerConfig = None, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) head_dim = self.config.vocab_size if self.config.use_indicators: head_dim -= 2 # reserved for two image indicator tokens self.head = torch.nn.Sequential( torch.nn.Linear(self.backbone.config.hidden_size, head_dim, bias=False), torch.nn.LayerNorm(head_dim) ) def re_init_layers(self, re_init_layer_begin): layer_dict = self.get_re_init_layer_dict(re_init_layer_begin) for name, layer in layer_dict.items(): rank0_print(BEGIN_LINE) rank0_print(f'[{datetime.now()}] Before layer re-initialization of {name}: ') for k, v in layer.named_parameters(): with deepspeed.zero.GatheredParameters([v]): rank0_print(f'{k}: {v}') with deepspeed.zero.GatheredParameters(list(layer.parameters(recurse=True)), modifier_rank=0): if not is_deepspeed_zero3_enabled() or deepspeed.comm.get_rank() == 0: layer.apply(self.backbone._init_weights) rank0_print(f'[{datetime.now()}] After layer re-initialization of {name}:') for k, v in layer.named_parameters(): with deepspeed.zero.GatheredParameters([v]): rank0_print(f'{k}: {v}') rank0_print(END_LINE) def get_re_init_layer_dict(self, re_init_layer_begin: int) -> Dict[str, torch.nn.Module]: assert re_init_layer_begin >= 0, "negative index is prohibited" layer_dict = dict() for i in range(re_init_layer_begin, self.backbone.config.num_hidden_layers): layer_dict[f'backbone.vision_model.encoder.layers.{i}'] = self.backbone.vision_model.encoder.layers[i] return layer_dict def get_monitor_tensors(self): return dict( backbone_bottom=self.backbone.vision_model.encoder.layers[0].self_attn.k_proj.weight, backbone_top=self.backbone.vision_model.encoder.layers[-1].self_attn.out_proj.weight, head=self.head[0].weight ) def get_image_size(self): height = self.image_processor.crop_size["height"] width = self.image_processor.crop_size["width"] return height, width def forward(self, pixel_values) -> Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize] output = self.backbone( pixel_values, output_hidden_states=True, return_dict=True) features = output.last_hidden_state if self.config.drop_cls_token: features = features[:, 1:, :] logits = self.head(features) tokens = self.tokenize(logits) if self.config.use_indicators: # tokens' shape is [BatchSize, #Token, VocabSize-2], so padding with [BatchSize, #Token, 2], after # which, tokens' shape should become [BatchSize, #Token, VocabSize] batch_size, token_len, _ = tokens.shape padding_tensor = torch.zeros(size=(batch_size, token_len, 2), dtype=tokens.dtype, device=tokens.device, layout=tokens.layout, requires_grad=False) tokens = torch.cat((tokens, padding_tensor), dim=2) # adding indicator tokens, after which tokens' shape should become [BatchSize, 1+#Token+1, VocabSize] begin_indicator = torch.zeros(size=(batch_size, 1), dtype=torch.long, device=tokens.device, requires_grad=False) + self.config.vocab_size - 2 begin_indicator_token = torch.nn.functional.one_hot(begin_indicator, num_classes=self.config.vocab_size).to( dtype=tokens.dtype) end_indicator = torch.zeros(size=(batch_size, 1), dtype=torch.long, device=tokens.device, requires_grad=False) + self.config.vocab_size - 1 end_indicator_token = torch.nn.functional.one_hot(end_indicator, num_classes=self.config.vocab_size).to(dtype=tokens.dtype) tokens = torch.cat((begin_indicator_token, tokens, end_indicator_token), dim=1) return tokens AutoConfig.register(MODEL_TYPE, ClipVisualTokenizerConfig) AutoModel.register(ClipVisualTokenizerConfig, ClipVisualTokenizer)