Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from typing import Callable, Union | |
from torchtyping import patch_typeguard | |
from einops import rearrange | |
import timm | |
import clip | |
from functools import partial | |
# ----------------------------- Utils -------------------------------------- | |
clip.model.LayerNorm = ( | |
nn.LayerNorm | |
) # we need to patch this for clip to work with deepspeed | |
patch_typeguard() # needed for torchtyping typechecks to work | |
class Lambda(torch.nn.Module): | |
def __init__(self, fn: Callable): | |
super().__init__() | |
assert hasattr(fn, "__call__") | |
self.fn = fn | |
def forward(self, x): | |
return self.fn(x) | |
# ------------------------- Image encoders ---------------------------------- | |
def nfresnet50( | |
device: Union[torch.device, str] = None, pretrained: bool = True | |
) -> nn.Module: | |
""" | |
Loads nfresnet50 model, removing the pooling layer and replacing it with | |
an adaptive pooling layer. | |
""" | |
encoder = torch.nn.Sequential( | |
*list(timm.create_model("nf_resnet50", pretrained=pretrained).children())[:-1] | |
) | |
pooling = torch.nn.AdaptiveAvgPool2d((1, 1)) | |
encoder = torch.nn.Sequential(encoder, pooling) | |
if device is not None: | |
encoder = encoder.to(device) | |
return encoder | |
def clip_encoder( | |
device: Union[torch.device, str] = None, name: str = "clip", | |
) -> nn.Module: | |
""" | |
Loads clip's image encoder module, discarding the lm component. | |
If the variant is a resnet model, we also remove the attention pooling. | |
""" | |
if name in ["clip", "ViT-B/32"]: | |
name = "ViT-B/32" | |
elif name in ["clip_resnet", "RN50x4"]: | |
name = "RN50x4" | |
elif name in ["clip_resnet_large", "RN50x16"]: | |
name = "RN50x16" | |
else: | |
raise ValueError(f"encoder {name} not recognized") | |
encoder = clip.load(name, device=device)[0].visual | |
if device is not None: | |
encoder = encoder.to(device) | |
if "RN" in name: | |
# remove attention pooling | |
encoder.attnpool = Lambda( | |
partial(rearrange, pattern="b d h w -> b (h w) d") | |
) # remove attn pooling, just use reshaped features | |
return encoder | |
def get_image_encoder( | |
name: str, device: Union[torch.device, str] = None, pretrained: bool = False | |
) -> torch.nn.Module: | |
""" | |
Loads image encoder module | |
""" | |
if name == "nfresnet50": | |
encoder = nfresnet50(device=device, pretrained=pretrained) | |
elif "clip" in name: | |
encoder = clip_encoder(device=device, name=name) | |
else: | |
raise ValueError(f"image encoder {name} not recognized") | |
return encoder | |