Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,391 Bytes
c3f3b0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# ------------------------------------------------------------------------
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
# Modified from LLaVA (https://github.com/haotian-liu/LLaVA) and S^2(https://github.com/bfshi/scaling_on_scales)
# Copyright 2024 Jiachen Li
# ------------------------------------------------------------------------
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
import torch.nn.functional as F
from transformers.activations import ACT2FN
import math
from einops import rearrange
from .clip import CLIPVisionTransformer
from .clip_smoe import CLIPSMoEVisionTransformer
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.vision_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.clip_smoe = args.clip_smoe
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.scales = args.scales
if args.clip_smoe:
self.vision_model = CLIPSMoEVisionTransformer(self.cfg_only, num_experts=args.num_experts, num_selected=args.num_selected)
else:
self.vision_model = CLIPVisionTransformer(self.cfg_only)
self.is_loaded = True
def feature_select(self, image_features):
#image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
def split_chessboard(self, x, num_split):
"""
x: b * c * h * w
Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
"""
B, C, H, W = x.shape
assert H % num_split == 0 and W % num_split == 0
h, w = H // num_split, W // num_split
x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)], dim=0)
return x_split
def merge_chessboard(self, x, num_split):
"""
x: b * c * h * w
Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
(inverse of split_chessboard)
"""
B, C, H, W = x.shape
assert B % (num_split**2) == 0
b = B // (num_split**2)
x_merge = torch.cat([torch.cat([x[(i*num_split + j)*b:(i*num_split + j + 1)*b] for j in range(num_split)], dim=-1) for i in range(num_split)], dim=-2)
return x_merge
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_model(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
input_size = images.shape[3]
img_sizes = [int(input_size * scale) for scale in self.scales]
num_splits = [math.ceil(size / input_size) for size in img_sizes]
image_pyramids = [images]
for i, (size, num_split) in enumerate(zip(img_sizes, num_splits)):
if i > 0:
x = F.interpolate(images.to(torch.float32), size=size, mode='bicubic').to(images.dtype)
x = self.split_chessboard(x, num_split=num_split)
image_pyramids.append(x)
if self.clip_smoe:
image_features = []
balance_losses = []
router_z_losses = []
for i, (x, num_split) in enumerate(zip(image_pyramids, num_splits)):
out_x, balance_loss, router_z_loss = self.vision_model(x)
out_x = self.feature_select(out_x)
if i > 0:
out_x = rearrange(out_x, 'b (h w) c -> b c h w', h=int(out_x.shape[1] ** 0.5), w=int(out_x.shape[1] ** 0.5))
out_x = self.merge_chessboard(out_x, num_split=num_split)
out_x = F.interpolate(out_x.to(torch.float32), size=int(image_features[0].shape[1] ** 0.5), mode='area').to(x.dtype)
out_x = rearrange(out_x, 'b c h w -> b (h w) c')
image_features.append(out_x)
balance_losses.append(balance_loss)
router_z_losses.append(router_z_loss)
image_features = torch.cat(image_features, dim=-1)
return image_features, torch.stack(balance_losses).mean(), torch.stack(router_z_losses).mean()
else:
image_features = []
for i, (x, num_split) in enumerate(zip(image_pyramids, num_splits)):
out_x = self.vision_model(x)
out_x = self.feature_select(out_x)
if i > 0:
out_x = rearrange(out_x, 'b (h w) c -> b c h w', h=int(out_x.shape[1] ** 0.5), w=int(out_x.shape[1] ** 0.5))
out_x = self.merge_chessboard(out_x, num_split=num_split)
out_x = F.interpolate(out_x.to(torch.float32), size=int(image_features[0].shape[1] ** 0.5), mode='area').to(x.dtype)
out_x = rearrange(out_x, 'b c h w -> b (h w) c')
image_features.append(out_x)
image_features = torch.cat(image_features, dim=-1)
return image_features, None, None
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_model.dtype
@property
def device(self):
return self.vision_model.device
@property
def config(self):
if self.is_loaded:
return self.vision_model.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
|