LongVILA-R1-7B / media_encoder.py
Yukang's picture
Upload 4 files
8f2be92 verified
from functools import partial
from typing import Any, Dict, List, Optional
import torch
from torch import nn
class BaseEncoder(nn.Module):
def __init__(self, parent: nn.Module) -> None:
super().__init__()
self._parent = [parent]
@property
def parent(self) -> nn.Module:
return self._parent[0]
class BasicImageEncoder(BaseEncoder):
def __init__(
self,
parent: torch.nn.Module,
start_tokens: Optional[str] = None,
end_tokens: Optional[str] = "\n",
) -> None:
super().__init__(parent)
self.start_tokens = start_tokens
self.end_tokens = end_tokens
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
if tokens is None:
return None
token_ids = self.parent.tokenizer(tokens).input_ids
token_ids = torch.tensor(token_ids, device=self.parent.device)
return self.parent.llm_model_embed_tokens(token_ids)
def _process_features(
self,
features: torch.Tensor,
start_token_embeds: Optional[torch.Tensor],
end_token_embeds: Optional[torch.Tensor],
) -> torch.Tensor:
if start_token_embeds is not None:
features = torch.cat([start_token_embeds, features], dim=0)
if end_token_embeds is not None:
features = torch.cat([features, end_token_embeds], dim=0)
return features
def forward(self, images: List[torch.Tensor], config: Dict[str, Any], device: torch.device) -> List[torch.Tensor]:
images = torch.stack(images, dim=0)
features = self.parent.encode_images(images, block_sizes=config.get("block_sizes"))
process_features = partial(
self._process_features,
start_token_embeds=self.embed_tokens(self.start_tokens),
end_token_embeds=self.embed_tokens(self.end_tokens),
)
return [process_features(f).to(device) for f in features]
class BasicVideoEncoder(BaseEncoder):
def __init__(
self,
parent: torch.nn.Module,
start_tokens: Optional[str] = None,
end_tokens: Optional[str] = "\n",
) -> None:
super().__init__(parent)
self.start_tokens = start_tokens
self.end_tokens = end_tokens
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
if tokens is None:
return None
token_ids = self.parent.tokenizer(tokens).input_ids
token_ids = torch.tensor(token_ids, device=self.parent.device)
return self.parent.llm_model_embed_tokens(token_ids)
def _process_features(
self,
features: torch.Tensor,
start_token_embeds: Optional[torch.Tensor],
end_token_embeds: Optional[torch.Tensor],
) -> torch.Tensor:
if start_token_embeds is not None:
start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0)
features = torch.cat([start_embeds, features], dim=1)
if end_token_embeds is not None:
end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0)
features = torch.cat([features, end_embeds], dim=1)
return features.flatten(0, 1)
def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
num_frames = [video.shape[0] for video in videos]
images = torch.cat(videos, dim=0)
features = self.parent.encode_images(images)
features = torch.split(features, num_frames)
process_features = partial(
self._process_features,
start_token_embeds=self.embed_tokens(self.start_tokens),
end_token_embeds=self.embed_tokens(self.end_tokens),
)
return [process_features(f) for f in features]
def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
class TSPVideoEncoder(BasicVideoEncoder):
def __init__(
self,
parent: torch.nn.Module,
start_tokens: Optional[str] = None,
end_tokens: Optional[str] = "\n",
sep_tokens: Optional[str] = None,
) -> None:
super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
self.pool_sizes = [[8, 1, 1]]
self.sep_tokens = sep_tokens
def _process_features(
self,
inputs: torch.Tensor,
start_token_embeds: Optional[torch.Tensor],
end_token_embeds: Optional[torch.Tensor],
sep_token_embeds: Optional[torch.Tensor],
) -> torch.Tensor:
nt, ns = inputs.shape[:2]
nl = int(ns**0.5)
outputs = []
for pool_size in self.pool_sizes:
features = inputs.view(nt, nl, nl, -1)
for dim, p in enumerate(pool_size):
features = pool(features, p, dim=dim)
features = features.flatten(1, 2)
features = super()._process_features(
features,
start_token_embeds=start_token_embeds,
end_token_embeds=end_token_embeds,
)
if sep_token_embeds is not None:
features = torch.cat([features, sep_token_embeds], dim=0)
outputs.append(features)
return torch.cat(outputs, dim=0)
def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
num_frames = [video.shape[0] for video in videos]
images = torch.cat(videos, dim=0)
features = self.parent.encode_images(images)
features = torch.split(features, num_frames)
process_features = partial(
self._process_features,
start_token_embeds=self.embed_tokens(self.start_tokens),
end_token_embeds=self.embed_tokens(self.end_tokens),
sep_token_embeds=self.embed_tokens(self.sep_tokens),
)
return [process_features(f) for f in features]