Spaces:
Runtime error
Runtime error
File size: 2,608 Bytes
bb5cd12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
import torch.nn as nn
from typing import Callable, Union
from torchtyping import patch_typeguard
from einops import rearrange
import timm
import clip
from functools import partial
# ----------------------------- Utils --------------------------------------
clip.model.LayerNorm = (
nn.LayerNorm
) # we need to patch this for clip to work with deepspeed
patch_typeguard() # needed for torchtyping typechecks to work
class Lambda(torch.nn.Module):
def __init__(self, fn: Callable):
super().__init__()
assert hasattr(fn, "__call__")
self.fn = fn
def forward(self, x):
return self.fn(x)
# ------------------------- Image encoders ----------------------------------
def nfresnet50(
device: Union[torch.device, str] = None, pretrained: bool = True
) -> nn.Module:
"""
Loads nfresnet50 model, removing the pooling layer and replacing it with
an adaptive pooling layer.
"""
encoder = torch.nn.Sequential(
*list(timm.create_model("nf_resnet50", pretrained=pretrained).children())[:-1]
)
pooling = torch.nn.AdaptiveAvgPool2d((1, 1))
encoder = torch.nn.Sequential(encoder, pooling)
if device is not None:
encoder = encoder.to(device)
return encoder
def clip_encoder(
device: Union[torch.device, str] = None, name: str = "clip",
) -> nn.Module:
"""
Loads clip's image encoder module, discarding the lm component.
If the variant is a resnet model, we also remove the attention pooling.
"""
if name in ["clip", "ViT-B/32"]:
name = "ViT-B/32"
elif name in ["clip_resnet", "RN50x4"]:
name = "RN50x4"
elif name in ["clip_resnet_large", "RN50x16"]:
name = "RN50x16"
else:
raise ValueError(f"encoder {name} not recognized")
encoder = clip.load(name, device=device)[0].visual
if device is not None:
encoder = encoder.to(device)
if "RN" in name:
# remove attention pooling
encoder.attnpool = Lambda(
partial(rearrange, pattern="b d h w -> b (h w) d")
) # remove attn pooling, just use reshaped features
return encoder
def get_image_encoder(
name: str, device: Union[torch.device, str] = None, pretrained: bool = False
) -> torch.nn.Module:
"""
Loads image encoder module
"""
if name == "nfresnet50":
encoder = nfresnet50(device=device, pretrained=pretrained)
elif "clip" in name:
encoder = clip_encoder(device=device, name=name)
else:
raise ValueError(f"image encoder {name} not recognized")
return encoder
|