File size: 5,898 Bytes
a2db297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from datetime import datetime
from typing import Dict

import deepspeed
import torch
from torch import Tensor
from transformers import AutoConfig, AutoModel
from transformers import CLIPVisionModel, CLIPImageProcessor
from transformers.integrations import is_deepspeed_zero3_enabled

from .utils import BEGIN_LINE, END_LINE, rank0_print
from .base_visual_tokenizer import BaseVisualTokenizerConfig, BaseVisualTokenizer

MODEL_TYPE = "clip_visual_tokenizer"


class ClipVisualTokenizerConfig(BaseVisualTokenizerConfig):
    model_type = MODEL_TYPE

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        if self.depths:
            assert len(self.depths) == 1
            self.backbone_kwargs['num_hidden_layers'] = self.depths[0]


class ClipVisualTokenizer(BaseVisualTokenizer):
    config_class = ClipVisualTokenizerConfig
    supports_gradient_checkpointing = True
    _no_split_modules = ["CLIPEncoderLayer"]
    _image_processor_class = CLIPImageProcessor
    _image_processor_kwargs = dict(do_center_crop=False)
    _backbone_class = CLIPVisionModel
    _backbone_name_or_path = "openai/clip-vit-large-patch14-336"

    def __init__(self, config: ClipVisualTokenizerConfig = None, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        head_dim = self.config.vocab_size
        if self.config.use_indicators:
            head_dim -= 2  # reserved for two image indicator tokens
        self.head = torch.nn.Sequential(
            torch.nn.Linear(self.backbone.config.hidden_size, head_dim, bias=False),
            torch.nn.LayerNorm(head_dim)
        )

    def re_init_layers(self, re_init_layer_begin):
        layer_dict = self.get_re_init_layer_dict(re_init_layer_begin)
        for name, layer in layer_dict.items():
            rank0_print(BEGIN_LINE)
            rank0_print(f'[{datetime.now()}] Before layer re-initialization of {name}: ')
            for k, v in layer.named_parameters():
                with deepspeed.zero.GatheredParameters([v]):
                    rank0_print(f'{k}: {v}')
            with deepspeed.zero.GatheredParameters(list(layer.parameters(recurse=True)), modifier_rank=0):
                if not is_deepspeed_zero3_enabled() or deepspeed.comm.get_rank() == 0:
                    layer.apply(self.backbone._init_weights)
            rank0_print(f'[{datetime.now()}] After layer re-initialization of {name}:')
            for k, v in layer.named_parameters():
                with deepspeed.zero.GatheredParameters([v]):
                    rank0_print(f'{k}: {v}')
            rank0_print(END_LINE)

    def get_re_init_layer_dict(self, re_init_layer_begin: int) -> Dict[str, torch.nn.Module]:
        assert re_init_layer_begin >= 0, "negative index is prohibited"
        layer_dict = dict()
        for i in range(re_init_layer_begin, self.backbone.config.num_hidden_layers):
            layer_dict[f'backbone.vision_model.encoder.layers.{i}'] = self.backbone.vision_model.encoder.layers[i]
        return layer_dict

    def get_monitor_tensors(self):
        return dict(
            backbone_bottom=self.backbone.vision_model.encoder.layers[0].self_attn.k_proj.weight,
            backbone_top=self.backbone.vision_model.encoder.layers[-1].self_attn.out_proj.weight,
            head=self.head[0].weight
        )

    def get_image_size(self):
        height = self.image_processor.crop_size["height"]
        width = self.image_processor.crop_size["width"]
        return height, width

    def forward(self, pixel_values) -> Tensor:  # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize]
        output = self.backbone(
            pixel_values, output_hidden_states=True, return_dict=True)
        features = output.last_hidden_state
        if self.config.drop_cls_token:
            features = features[:, 1:, :]
        logits = self.head(features)
        tokens = self.tokenize(logits)
        if self.config.use_indicators:
            # tokens' shape is [BatchSize, #Token, VocabSize-2], so padding with [BatchSize, #Token, 2], after
            # which, tokens' shape should become [BatchSize, #Token, VocabSize]
            batch_size, token_len, _ = tokens.shape
            padding_tensor = torch.zeros(size=(batch_size, token_len, 2),
                                         dtype=tokens.dtype,
                                         device=tokens.device,
                                         layout=tokens.layout,
                                         requires_grad=False)
            tokens = torch.cat((tokens, padding_tensor), dim=2)

            # adding indicator tokens, after which tokens' shape should become [BatchSize, 1+#Token+1, VocabSize]
            begin_indicator = torch.zeros(size=(batch_size, 1),
                                          dtype=torch.long,
                                          device=tokens.device,
                                          requires_grad=False) + self.config.vocab_size - 2
            begin_indicator_token = torch.nn.functional.one_hot(begin_indicator,
                                                                num_classes=self.config.vocab_size).to(
                dtype=tokens.dtype)
            end_indicator = torch.zeros(size=(batch_size, 1),
                                        dtype=torch.long,
                                        device=tokens.device,
                                        requires_grad=False) + self.config.vocab_size - 1
            end_indicator_token = torch.nn.functional.one_hot(end_indicator,
                                                              num_classes=self.config.vocab_size).to(dtype=tokens.dtype)
            tokens = torch.cat((begin_indicator_token, tokens, end_indicator_token), dim=1)
        return tokens


AutoConfig.register(MODEL_TYPE, ClipVisualTokenizerConfig)
AutoModel.register(ClipVisualTokenizerConfig, ClipVisualTokenizer)