Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torchvision import transforms | |
import torch.nn as nn | |
import numpy as np | |
from concept_attention.segmentation import SegmentationAbstractClass | |
import concept_attention.binary_segmentation_baselines.dino_src.vision_transformer as vits | |
class DINOSegmentationModel(SegmentationAbstractClass): | |
def __init__(self, arch="vit_small", patch_size=8, image_size=480, image_path=None, device="cuda"): | |
self.device = device | |
# build model | |
self.image_size = image_size | |
self.patch_size = patch_size | |
self.model = vits.__dict__[arch](patch_size=patch_size, num_classes=0) | |
for p in self.model.parameters(): | |
p.requires_grad = False | |
self.model.eval() | |
self.model.to(device) | |
# Load up the model | |
if arch == "vit_small" and patch_size == 16: | |
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" | |
elif arch == "vit_small" and patch_size == 8: | |
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper | |
elif arch == "vit_base" and patch_size == 16: | |
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" | |
elif arch == "vit_base" and patch_size == 8: | |
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" | |
if url is not None: | |
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") | |
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) | |
self.model.load_state_dict(state_dict, strict=True) | |
else: | |
print("There is no reference weights available for this model => We use random weights.") | |
# Transforms | |
self.transform = transforms.Compose([ | |
transforms.Resize(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
]) | |
def segment_individual_image(self, image, concepts, caption, **kwargs): | |
# NOTE: Do nothing with concepts or caption, as this is not a text conditioned approach. | |
if isinstance(image, torch.Tensor): | |
image = transforms.Resize(self.image_size)(image) | |
else: | |
image = self.transform(image) | |
# Predict the raw scores. | |
# make the image divisible by the patch size | |
w, h = image.shape[1] - image.shape[1] % self.patch_size, image.shape[2] - image.shape[2] % self.patch_size | |
image = image[:, :w, :h].unsqueeze(0) | |
w_featmap = image.shape[-2] // self.patch_size | |
h_featmap = image.shape[-1] // self.patch_size | |
attentions = self.model.get_last_selfattention(image.to(self.device)) | |
nh = attentions.shape[1] # number of head | |
# we keep only the output patch attention | |
attentions = attentions[0, :, 0, 1:].reshape(nh, -1) | |
attentions = attentions.reshape(nh, w_featmap, h_featmap) | |
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=self.patch_size, mode="nearest")[0] | |
attentions = torch.mean(attentions, dim=0, keepdim=True) | |
attentions = attentions.repeat(len(concepts), 1, 1) | |
return attentions, None |