japanese-clip-vit-b-32-roberta-base / modeling_japanese_clip.py
hidehisa-arai's picture
update
391228d verified
import collections.abc
import math
from collections import OrderedDict
from itertools import repeat
from typing import Callable, Optional, Sequence, Tuple
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
from transformers import AutoModel, PreTrainedModel
from .configuration_japanese_clip import JapaneseCLIPConfig
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(dtype=orig_dtype)
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(torch.ones(dim) * init_values)
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(self, prob, exclude_first_token=True):
super().__init__()
assert 0 <= prob < 1.0
self.prob = prob
self.exclude_first_token = exclude_first_token # exclude CLS token
def forward(self, x):
if not self.training or self.prob == 0.:
return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
return x
class AttentionalPooler(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
n_head: int = 8,
n_queries: int = 256,
norm_layer: Callable = LayerNorm
):
super().__init__()
self.query = nn.Parameter(torch.randn(n_queries, d_model))
self.attn = nn.MultiheadAttention(
d_model, n_head, kdim=context_dim, vdim=context_dim
)
self.ln_q = norm_layer(d_model)
self.ln_k = norm_layer(context_dim)
def forward(self, x: torch.Tensor):
x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False
)[0]
return out.permute(1, 0, 2) # LND -> NLD
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: Optional[float] = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
is_cross_attention: bool = False,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
if is_cross_attention:
self.ln_1_kv = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
return self.attn(
q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
)[0]
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
def _expand_token(token, batch_size: int):
return token.view(1, 1, -1).expand(batch_size, -1, -1)
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = False
self.resblocks = nn.ModuleList([
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
return self.resblocks[0].mlp.c_fc.int8_original_dtype
return self.resblocks[0].mlp.c_fc.weight.dtype
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
x = checkpoint(r, x, None, None, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
class JapaneseCLIPVisionTransformer(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
ls_init_value: float = None,
attentional_pool: bool = False,
attn_pooler_queries: int = 256,
attn_pooler_heads: int = 8,
output_dim: int = 512,
patch_dropout: float = 0.,
no_ln_pre: bool = False,
pool_type: str = 'tok',
final_ln_after_pool: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False,
**kwargs,
):
super().__init__()
assert pool_type in ('tok', 'avg', 'none')
self.output_tokens = output_tokens
image_height, image_width = self.image_size = to_2tuple(image_size)
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
# class embeddings and positional embeddings
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)
self.transformer = Transformer(
width,
layers,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
if attentional_pool:
if isinstance(attentional_pool, str):
self.attn_pool_type = attentional_pool
self.pool_type = 'none'
if attentional_pool in ('parallel', 'cascade'):
self.attn_pool = AttentionalPooler(
output_dim,
width,
n_head=attn_pooler_heads,
n_queries=attn_pooler_queries,
)
self.attn_pool_contrastive = AttentionalPooler(
output_dim,
width,
n_head=attn_pooler_heads,
n_queries=1,
)
else:
assert False
else:
self.attn_pool_type = ''
self.pool_type = pool_type
self.attn_pool = AttentionalPooler(
output_dim,
width,
n_head=attn_pooler_heads,
n_queries=attn_pooler_queries,
)
self.attn_pool_contrastive = None
pool_dim = output_dim
else:
self.attn_pool = None
pool_dim = width
self.pool_type = pool_type
self.ln_post = norm_layer(pool_dim)
self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))
self.init_parameters()
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
for param in self.parameters():
param.requires_grad = False
if unlocked_groups != 0:
groups = [
[
self.conv1,
self.class_embedding,
self.positional_embedding,
self.ln_pre,
],
*self.transformer.resblocks[:-1],
[
self.transformer.resblocks[-1],
self.ln_post,
],
self.proj,
]
def _unlock(x):
if isinstance(x, Sequence):
for g in x:
_unlock(g)
else:
if isinstance(x, torch.nn.Parameter):
x.requires_grad = True
else:
for p in x.parameters():
p.requires_grad = True
_unlock(groups[-unlocked_groups:])
def init_parameters(self):
# FIXME OpenAI CLIP did not define an init for the VisualTransformer
# TODO experiment if default PyTorch init, below, or alternate init is best.
# nn.init.normal_(self.class_embedding, std=self.scale)
# nn.init.normal_(self.positional_embedding, std=self.scale)
#
# proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
# attn_std = self.transformer.width ** -0.5
# fc_std = (2 * self.transformer.width) ** -0.5
# for block in self.transformer.resblocks:
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
#
# if self.text_projection is not None:
# nn.init.normal_(self.text_projection, std=self.scale)
pass
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.pool_type == 'avg':
pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
elif self.pool_type == 'tok':
pooled, tokens = x[:, 0], x[:, 1:]
else:
pooled = tokens = x
return pooled, tokens
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
# shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.patch_dropout(x)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
if self.attn_pool is not None:
if self.attn_pool_contrastive is not None:
# This is untested, WIP pooling that should match paper
x = self.ln_post(x) # TBD LN first or separate one after each pool?
tokens = self.attn_pool(x)
if self.attn_pool_type == 'parallel':
pooled = self.attn_pool_contrastive(x)
else:
assert self.attn_pool_type == 'cascade'
pooled = self.attn_pool_contrastive(tokens)
else:
# this is the original OpenCLIP CoCa setup, does not match paper
x = self.attn_pool(x)
x = self.ln_post(x)
pooled, tokens = self._global_pool(x)
elif self.final_ln_after_pool:
pooled, tokens = self._global_pool(x)
pooled = self.ln_post(pooled)
else:
x = self.ln_post(x)
pooled, tokens = self._global_pool(x)
if self.proj is not None:
pooled = pooled @ self.proj
if self.output_tokens:
return pooled, tokens
return pooled
class JapaneseCLIPModel(PreTrainedModel):
config_class = JapaneseCLIPConfig
def __init__(self, config: JapaneseCLIPConfig):
super().__init__(config)
text_config = config.text_config
vision_config = config.vision_config
self.image_encoder = JapaneseCLIPVisionTransformer(
**vision_config.to_dict()
)
self.text_encoder = AutoModel.from_config(text_config, add_pooling_layer=False)
hidden_size = text_config.hidden_size
self.projection_dim = self.image_encoder.output_dim
self.text_projection = nn.Linear(hidden_size, self.projection_dim, bias=False)
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07))
self.max_length = config.max_length
self.position_ids = list(range(0, self.max_length))
def _create_position_id_tensor(self, batch_size: int) -> torch.LongTensor:
# rinna/japanese-roberta-base requires providing custom position ids
# see: https://huggingface.co/rinna/japanese-roberta-base#note-3-provide-position_ids-as-an-argument-explicitly
return torch.LongTensor([self.position_ids for _ in range(batch_size)])
def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
return self.image_encoder(pixel_values) # (batch_size, hidden_dim)
def get_text_features(
self, input_ids: torch.Tensor, position_ids: torch.Tensor = None
) -> torch.FloatTensor:
if position_ids is None:
position_ids = self._create_position_id_tensor(input_ids.size(0)).to(
input_ids.device
)
last_hidden_state = self.text_encoder(
input_ids=input_ids,
position_ids=position_ids,
output_hidden_states=True,
return_dict=True,
).hidden_states[
-1
] # (batch_size, tokens, embed_dim)
pooled_output = last_hidden_state[:, 0, :] # (batch_size, embed_dim)
return self.text_projection(pooled_output) # (batch_size, hidden_dim)
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.Tensor,
position_ids: torch.Tensor = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""
DDPを使うときはこのメソッドを経由しなければならない
他のメソッドで得られた勾配はGPU間で同期されない
"""
image_features = self.get_image_features(pixel_values)
text_features = self.get_text_features(input_ids, position_ids)
return image_features, text_features, self.logit_scale