Spaces:
Paused
Paused
from importlib.metadata import requires | |
import torch | |
import torch.nn as nn | |
from .registry import register_model | |
from .vlpencoder import LanguageEncoder | |
class FixLanguageEncoder(LanguageEncoder): | |
def __init__( | |
self, | |
*args, **kwargs): | |
super(FixLanguageEncoder, self).__init__(*args, **kwargs) | |
self.logit_scale = nn.Parameter(torch.ones([]), requires_grad=False) | |
def get_text_embeddings(self, *args, **kwargs): | |
return super().get_text_embeddings(*args, **kwargs) | |
def get_text_token_embeddings(self, *args, **kwargs): | |
return super().get_text_token_embeddings(*args, **kwargs) | |
def forward_language(self, *args, **kwargs): | |
return super().forward_language(*args, **kwargs) | |
def forward_language_token(self, *args, **kwargs): | |
return super().forward_language_token(*args, **kwargs) | |
def get_language_model(cfg, **kwargs): | |
return FixLanguageEncoder(cfg) |