Spaces:
Running
Running
import math, torch | |
from functools import partial | |
from torch import nn, Tensor | |
from torchvision.transforms.functional import normalize | |
from transformers import AutoModel | |
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD | |
from .configuration_live import LiveConfigMixin | |
def _siglip_vision_encode(vision_model: nn.Module, frames: Tensor, frame_token_cls: bool, frame_token_pooled: tuple, | |
mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], rescale_factor=0.00392156862745098, **kwargs): | |
frames = normalize(frames * rescale_factor, mean=mean, std=std) | |
with torch.cuda.amp.autocast(): | |
vision_outputs = vision_model(frames) | |
last_hidden_state = vision_outputs.last_hidden_state | |
if frame_token_pooled: | |
s = int(math.sqrt(last_hidden_state.shape[1])) | |
spatial_tokens = torch.nn.functional.adaptive_avg_pool2d( | |
last_hidden_state.reshape( | |
last_hidden_state.shape[0], s, s, last_hidden_state.shape[-1] | |
).permute(0, 3, 1, 2), | |
frame_token_pooled | |
).flatten(2, 3).permute(0, 2, 1) | |
if not frame_token_cls: | |
return spatial_tokens | |
if frame_token_cls: | |
cls_token = vision_outputs.pooler_output[:, None] | |
if not frame_token_pooled: | |
return cls_token | |
return torch.cat([cls_token, spatial_tokens], dim=1) | |
def _clip_vision_encode(vision_model: nn.Module, frames: Tensor, frame_token_cls: bool, frame_token_pooled: tuple, | |
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, rescale_factor=0.00392156862745098, **kwargs): | |
frames = normalize(frames * rescale_factor, mean=mean, std=std) | |
with torch.cuda.amp.autocast(): | |
vision_outputs = vision_model(frames) | |
last_hidden_state = vision_outputs.last_hidden_state | |
if frame_token_pooled: | |
s = int(math.sqrt(last_hidden_state.shape[1])) | |
spatial_tokens = torch.nn.functional.adaptive_avg_pool2d( | |
last_hidden_state[:,1:].reshape( | |
last_hidden_state.shape[0], s, s, last_hidden_state.shape[-1] | |
).permute(0, 3, 1, 2), | |
frame_token_pooled | |
).flatten(2, 3).permute(0, 2, 1) | |
if not frame_token_cls: | |
return spatial_tokens | |
if frame_token_cls: | |
cls_token = last_hidden_state[:,0] | |
if not frame_token_pooled: | |
return cls_token | |
return torch.cat([cls_token, spatial_tokens], dim=1) | |
def build_live_vision(config: LiveConfigMixin): | |
model = AutoModel.from_pretrained(config.vision_pretrained).vision_model | |
if 'google/siglip-large-patch16-384' == config.vision_pretrained: | |
return model, partial(_siglip_vision_encode, frame_token_cls=config.frame_token_cls, frame_token_pooled=config.frame_token_pooled) | |
elif 'laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90k' == config.vision_pretrained or 'openai/clip-vit-large-patch14-336' == config.vision_pretrained: | |
return model, partial(_clip_vision_encode, config) | |
else: | |
raise ValueError(f'Unverified vision_pretrained: {config.vision_pretrained}') |