helblazer811's picture
"Orphan branch commit with a readme"
55866f4
import torch
from torchvision import transforms
import torch.nn as nn
import numpy as np
from concept_attention.segmentation import SegmentationAbstractClass
import concept_attention.binary_segmentation_baselines.dino_src.vision_transformer as vits
class DINOSegmentationModel(SegmentationAbstractClass):
def __init__(self, arch="vit_small", patch_size=8, image_size=480, image_path=None, device="cuda"):
self.device = device
# build model
self.image_size = image_size
self.patch_size = patch_size
self.model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)
for p in self.model.parameters():
p.requires_grad = False
self.model.eval()
self.model.to(device)
# Load up the model
if arch == "vit_small" and patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif arch == "vit_small" and patch_size == 8:
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
elif arch == "vit_base" and patch_size == 16:
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
elif arch == "vit_base" and patch_size == 8:
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
if url is not None:
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
self.model.load_state_dict(state_dict, strict=True)
else:
print("There is no reference weights available for this model => We use random weights.")
# Transforms
self.transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
def segment_individual_image(self, image, concepts, caption, **kwargs):
# NOTE: Do nothing with concepts or caption, as this is not a text conditioned approach.
if isinstance(image, torch.Tensor):
image = transforms.Resize(self.image_size)(image)
else:
image = self.transform(image)
# Predict the raw scores.
# make the image divisible by the patch size
w, h = image.shape[1] - image.shape[1] % self.patch_size, image.shape[2] - image.shape[2] % self.patch_size
image = image[:, :w, :h].unsqueeze(0)
w_featmap = image.shape[-2] // self.patch_size
h_featmap = image.shape[-1] // self.patch_size
attentions = self.model.get_last_selfattention(image.to(self.device))
nh = attentions.shape[1] # number of head
# we keep only the output patch attention
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=self.patch_size, mode="nearest")[0]
attentions = torch.mean(attentions, dim=0, keepdim=True)
attentions = attentions.repeat(len(concepts), 1, 1)
return attentions, None