|
|
from functools import partial |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
class BaseEncoder(nn.Module): |
|
|
def __init__(self, parent: nn.Module) -> None: |
|
|
super().__init__() |
|
|
self._parent = [parent] |
|
|
|
|
|
@property |
|
|
def parent(self) -> nn.Module: |
|
|
return self._parent[0] |
|
|
|
|
|
|
|
|
class BasicImageEncoder(BaseEncoder): |
|
|
def __init__( |
|
|
self, |
|
|
parent: torch.nn.Module, |
|
|
start_tokens: Optional[str] = None, |
|
|
end_tokens: Optional[str] = "\n", |
|
|
) -> None: |
|
|
super().__init__(parent) |
|
|
self.start_tokens = start_tokens |
|
|
self.end_tokens = end_tokens |
|
|
|
|
|
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: |
|
|
if tokens is None: |
|
|
return None |
|
|
token_ids = self.parent.tokenizer(tokens).input_ids |
|
|
token_ids = torch.tensor(token_ids, device=self.parent.device) |
|
|
return self.parent.llm_model_embed_tokens(token_ids) |
|
|
|
|
|
def _process_features( |
|
|
self, |
|
|
features: torch.Tensor, |
|
|
start_token_embeds: Optional[torch.Tensor], |
|
|
end_token_embeds: Optional[torch.Tensor], |
|
|
) -> torch.Tensor: |
|
|
if start_token_embeds is not None: |
|
|
features = torch.cat([start_token_embeds, features], dim=0) |
|
|
if end_token_embeds is not None: |
|
|
features = torch.cat([features, end_token_embeds], dim=0) |
|
|
return features |
|
|
|
|
|
def forward(self, images: List[torch.Tensor], config: Dict[str, Any], device: torch.device) -> List[torch.Tensor]: |
|
|
images = torch.stack(images, dim=0) |
|
|
features = self.parent.encode_images(images, block_sizes=config.get("block_sizes")) |
|
|
process_features = partial( |
|
|
self._process_features, |
|
|
start_token_embeds=self.embed_tokens(self.start_tokens), |
|
|
end_token_embeds=self.embed_tokens(self.end_tokens), |
|
|
) |
|
|
return [process_features(f).to(device) for f in features] |
|
|
|
|
|
|
|
|
class BasicVideoEncoder(BaseEncoder): |
|
|
def __init__( |
|
|
self, |
|
|
parent: torch.nn.Module, |
|
|
start_tokens: Optional[str] = None, |
|
|
end_tokens: Optional[str] = "\n", |
|
|
) -> None: |
|
|
super().__init__(parent) |
|
|
self.start_tokens = start_tokens |
|
|
self.end_tokens = end_tokens |
|
|
|
|
|
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: |
|
|
if tokens is None: |
|
|
return None |
|
|
token_ids = self.parent.tokenizer(tokens).input_ids |
|
|
token_ids = torch.tensor(token_ids, device=self.parent.device) |
|
|
return self.parent.llm_model_embed_tokens(token_ids) |
|
|
|
|
|
def _process_features( |
|
|
self, |
|
|
features: torch.Tensor, |
|
|
start_token_embeds: Optional[torch.Tensor], |
|
|
end_token_embeds: Optional[torch.Tensor], |
|
|
) -> torch.Tensor: |
|
|
if start_token_embeds is not None: |
|
|
start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0) |
|
|
features = torch.cat([start_embeds, features], dim=1) |
|
|
if end_token_embeds is not None: |
|
|
end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0) |
|
|
features = torch.cat([features, end_embeds], dim=1) |
|
|
return features.flatten(0, 1) |
|
|
|
|
|
def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: |
|
|
num_frames = [video.shape[0] for video in videos] |
|
|
images = torch.cat(videos, dim=0) |
|
|
features = self.parent.encode_images(images) |
|
|
features = torch.split(features, num_frames) |
|
|
process_features = partial( |
|
|
self._process_features, |
|
|
start_token_embeds=self.embed_tokens(self.start_tokens), |
|
|
end_token_embeds=self.embed_tokens(self.end_tokens), |
|
|
) |
|
|
return [process_features(f) for f in features] |
|
|
|
|
|
def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor: |
|
|
return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1) |
|
|
|
|
|
class TSPVideoEncoder(BasicVideoEncoder): |
|
|
def __init__( |
|
|
self, |
|
|
parent: torch.nn.Module, |
|
|
start_tokens: Optional[str] = None, |
|
|
end_tokens: Optional[str] = "\n", |
|
|
sep_tokens: Optional[str] = None, |
|
|
) -> None: |
|
|
super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens) |
|
|
self.pool_sizes = [[8, 1, 1]] |
|
|
self.sep_tokens = sep_tokens |
|
|
|
|
|
def _process_features( |
|
|
self, |
|
|
inputs: torch.Tensor, |
|
|
start_token_embeds: Optional[torch.Tensor], |
|
|
end_token_embeds: Optional[torch.Tensor], |
|
|
sep_token_embeds: Optional[torch.Tensor], |
|
|
) -> torch.Tensor: |
|
|
nt, ns = inputs.shape[:2] |
|
|
nl = int(ns**0.5) |
|
|
outputs = [] |
|
|
for pool_size in self.pool_sizes: |
|
|
features = inputs.view(nt, nl, nl, -1) |
|
|
for dim, p in enumerate(pool_size): |
|
|
features = pool(features, p, dim=dim) |
|
|
features = features.flatten(1, 2) |
|
|
features = super()._process_features( |
|
|
features, |
|
|
start_token_embeds=start_token_embeds, |
|
|
end_token_embeds=end_token_embeds, |
|
|
) |
|
|
if sep_token_embeds is not None: |
|
|
features = torch.cat([features, sep_token_embeds], dim=0) |
|
|
outputs.append(features) |
|
|
return torch.cat(outputs, dim=0) |
|
|
|
|
|
def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: |
|
|
num_frames = [video.shape[0] for video in videos] |
|
|
images = torch.cat(videos, dim=0) |
|
|
features = self.parent.encode_images(images) |
|
|
features = torch.split(features, num_frames) |
|
|
process_features = partial( |
|
|
self._process_features, |
|
|
start_token_embeds=self.embed_tokens(self.start_tokens), |
|
|
end_token_embeds=self.embed_tokens(self.end_tokens), |
|
|
sep_token_embeds=self.embed_tokens(self.sep_tokens), |
|
|
) |
|
|
return [process_features(f) for f in features] |
|
|
|