File size: 12,215 Bytes
3e0c00e cc1340e e352a62 3e0c00e e352a62 cc1340e 3e0c00e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate.hooks import add_hook_to_module
from einops import rearrange
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, SiglipImageProcessor
from transformers.image_processing_utils import BaseImageProcessor
from transformers.models.siglip import SiglipVisionModel
from s2wrapper import forward as multiscale_forward
# from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
def is_deepspeed_zero3_enabled():
return False
class VisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = getattr(args, "mm_vision_select_layer", -2)
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
self.cfg_only = None
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == "patch":
image_features = image_features[:, 1:]
elif self.select_feature == "cls_patch":
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features
def _maybe_resize_pos_embeds(
self,
model: PreTrainedModel,
image_processor: BaseImageProcessor,
resolution: int = -1,
interpolate_mode: str = "linear",
):
if resolution in [model.config.image_size, -1]:
return
print(
f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..."
)
embeddings = model.vision_model.embeddings
patch_size = embeddings.patch_size
num_new_tokens = int((resolution // patch_size) ** 2)
old_embeddings = embeddings.position_embedding
match interpolate_mode:
case "linear":
## Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M
## Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)]
import torch
import torch.nn as nn
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
else:
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
new_embeddings = nn.Embedding(
num_new_tokens,
old_embedding_dim,
dtype=old_embeddings.weight.dtype,
device=old_embeddings.weight.device,
)
mapped_indices = (
torch.arange(num_new_tokens).to(old_embeddings.weight.device)
/ (num_new_tokens - 1)
* (old_num_tokens - 1)
)
floor_indices = torch.clamp(mapped_indices.floor().long(), min=0, max=old_num_tokens - 1)
ceil_indices = torch.clamp(mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1)
if is_deepspeed_zero3_enabled():
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
ceil_indices, :
] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
else:
interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
ceil_indices, :
] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
new_embeddings.weight.data = interpolated_embeds
case _:
raise NotImplementedError
if hasattr(old_embeddings, "_hf_hook"):
hook = old_embeddings._hf_hook
add_hook_to_module(new_embeddings, hook)
new_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
## update vision encoder's configurations
model.config.image_size = resolution
if hasattr(image_processor, "crop_size"):
# CLIP vision tower
image_processor.crop_size = resolution
else:
# SIGLIP vision tower
assert hasattr(image_processor, "size")
image_processor.size = {"height": resolution, "width": resolution}
## TODO define a '_reinitialize' method for VisionTower
embeddings.position_embedding = new_embeddings
embeddings.image_size = resolution
embeddings.num_patches = embeddings.num_positions = num_new_tokens
embeddings.position_ids = (
torch.arange(embeddings.num_positions).expand((1, -1)).to(old_embeddings.weight.device)
)
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True,
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
class VisionTowerS2(VisionTower):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__(vision_tower, args, delay_load)
self.scales = list(map(int, args.s2_scales.split(",")))
self.scales.sort()
self.max_split_size = args.s2_max_split_size
self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
def forward_feature(self, images):
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
def forward(self, images):
if type(images) is list:
image_feature = []
for image in images:
image_feature = multiscale_forward(
self.forward_feature,
image.unsqueeze(0),
img_sizes=self.scales,
max_split_size=self.max_split_size,
resize_output_to_idx=self.resize_output_to_scale_idx,
)
image_features.append(image_feature)
else:
image_features = multiscale_forward(
self.forward_feature,
images,
img_sizes=self.scales,
max_split_size=self.max_split_size,
resize_output_to_idx=self.resize_output_to_scale_idx,
)
return image_features
@property
def hidden_size(self):
return self.config.hidden_size * len(self.scales)
class VisionTowerDynamicS2(VisionTower):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__(vision_tower, args, delay_load)
self.scales = list(map(int, args.s2_scales.split(",")))
self.scales.sort()
self.max_split_size = args.s2_max_split_size
self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
def forward_feature(self, images):
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
def forward(self, images):
assert type(images) is not list
image_features = self.forward_feature(images)
return image_features
@property
def hidden_size(self):
return self.config.hidden_size * len(self.scales)
class SiglipVisionTower(VisionTower):
def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
super().__init__(model_name_or_path, config)
# TODO(ligengl): why pass config here leading to errors?
self.vision_tower = SiglipVisionModel.from_pretrained(
model_name_or_path,
attn_implementation=config._attn_implementation,
torch_dtype=eval(config.model_dtype),
)
self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
self.is_loaded = True
class SiglipVisionTowerS2(VisionTowerS2):
def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
super().__init__(model_name_or_path, config)
self.vision_tower = SiglipVisionModel.from_pretrained(
model_name_or_path,
attn_implementation=config._attn_implementation,
torch_dtype=eval(config.model_dtype),
)
self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
# Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[-1]
self.is_loaded = True
class SiglipVisionTowerDynamicS2(VisionTowerDynamicS2):
def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
super().__init__(model_name_or_path, config)
self.vision_tower = SiglipVisionModel.from_pretrained(
model_name_or_path,
attn_implementation="flash_attention_2",
torch_dtype=eval(config.model_dtype),
)
self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
# Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[0]
self.is_loaded = True
|