x-decoder-video / xdecoder /language /fixvlpencoder.py
fffiloni's picture
Upload 6 files
3bbe5bd
raw
history blame
1.05 kB
from importlib.metadata import requires
import torch
import torch.nn as nn
from .registry import register_model
from .vlpencoder import LanguageEncoder
class FixLanguageEncoder(LanguageEncoder):
def __init__(
self,
*args, **kwargs):
super(FixLanguageEncoder, self).__init__(*args, **kwargs)
self.logit_scale = nn.Parameter(torch.ones([]), requires_grad=False)
@torch.no_grad()
def get_text_embeddings(self, *args, **kwargs):
return super().get_text_embeddings(*args, **kwargs)
@torch.no_grad()
def get_text_token_embeddings(self, *args, **kwargs):
return super().get_text_token_embeddings(*args, **kwargs)
@torch.no_grad()
def forward_language(self, *args, **kwargs):
return super().forward_language(*args, **kwargs)
@torch.no_grad()
def forward_language_token(self, *args, **kwargs):
return super().forward_language_token(*args, **kwargs)
@register_model
def get_language_model(cfg, **kwargs):
return FixLanguageEncoder(cfg)