Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Portions Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import einops | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
class Normalize(nn.Module): | |
def __init__(self, dim: int) -> None: | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
return torch.nn.functional.normalize(x, dim=self.dim, p=2) | |
class LearnableLogitScaling(nn.Module): | |
def __init__( | |
self, | |
logit_scale_init: float = 1 / 0.07, | |
learnable: bool = True, | |
max_logit_scale: float = 100, | |
) -> None: | |
super().__init__() | |
self.max_logit_scale = max_logit_scale | |
self.logit_scale_init = logit_scale_init | |
self.learnable = learnable | |
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) | |
if learnable: | |
self.log_logit_scale = nn.Parameter(log_logit_scale) | |
else: | |
self.register_buffer("log_logit_scale", log_logit_scale) | |
def forward(self, x): | |
return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x | |
def extra_repr(self): | |
st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}" | |
return st | |
class EinOpsRearrange(nn.Module): | |
def __init__(self, rearrange_expr: str, **kwargs) -> None: | |
super().__init__() | |
self.rearrange_expr = rearrange_expr | |
self.kwargs = kwargs | |
def forward(self, x): | |
assert isinstance(x, torch.Tensor) | |
return einops.rearrange(x, self.rearrange_expr, **self.kwargs) | |
class VerboseNNModule(nn.Module): | |
""" | |
Wrapper around nn.Module that prints registered buffers and parameter names. | |
""" | |
def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: | |
st = ( | |
"(" | |
+ name | |
+ "): " | |
+ "tensor(" | |
+ str(tuple(tensor[1].shape)) | |
+ ", requires_grad=" | |
+ str(tensor[1].requires_grad) | |
+ ")\n" | |
) | |
return st | |
def extra_repr(self) -> str: | |
named_modules = set() | |
for p in self.named_modules(): | |
named_modules.update([p[0]]) | |
named_modules = list(named_modules) | |
string_repr = "" | |
for p in self.named_parameters(): | |
name = p[0].split(".")[0] | |
if name not in named_modules: | |
string_repr += self.get_readable_tensor_repr(name, p) | |
for p in self.named_buffers(): | |
name = p[0].split(".")[0] | |
string_repr += self.get_readable_tensor_repr(name, p) | |
return string_repr | |
def cast_if_src_dtype( | |
tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype | |
): | |
updated = False | |
if tensor.dtype == src_dtype: | |
tensor = tensor.to(dtype=tgt_dtype) | |
updated = True | |
return tensor, updated | |
class QuickGELU(nn.Module): | |
# From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166 | |
def forward(self, x: torch.Tensor): | |
return x * torch.sigmoid(1.702 * x) | |
class SelectElement(nn.Module): | |
def __init__(self, index) -> None: | |
super().__init__() | |
self.index = index | |
def forward(self, x): | |
assert x.ndim >= 3 | |
return x[:, self.index, ...] | |
class SelectEOSAndProject(nn.Module): | |
""" | |
Text Pooling used in OpenCLIP | |
""" | |
def __init__(self, proj: nn.Module) -> None: | |
super().__init__() | |
self.proj = proj | |
def forward(self, x, seq_len): | |
assert x.ndim == 3 | |
# x is of shape B x L x D | |
# take features from the eot embedding (eot_token is the highest number in each sequence) | |
x = x[torch.arange(x.shape[0]), seq_len] | |
x = self.proj(x) | |
return x | |