Vincentqyw
fix: roma
c74a070
raw
history blame contribute delete
No virus
2.46 kB
import torch
from PIL import Image
import torch.nn as nn
import torchvision.models as tvm
import torch.nn.functional as F
import numpy as np
class DeDoDeDescriptor(nn.Module):
def __init__(self, encoder, decoder, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.encoder = encoder
self.decoder = decoder
import torchvision.transforms as transforms
self.normalizer = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
def forward(
self,
batch,
):
if "im_A" in batch:
images = torch.cat((batch["im_A"], batch["im_B"]))
else:
images = batch["image"]
features, sizes = self.encoder(images)
descriptor = 0
context = None
scales = self.decoder.scales
for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
delta_descriptor, context = self.decoder(
feature_map, scale=scale, context=context
)
descriptor = descriptor + delta_descriptor
if idx < len(scales) - 1:
size = sizes[-(idx + 2)]
descriptor = F.interpolate(
descriptor, size=size, mode="bilinear", align_corners=False
)
context = F.interpolate(
context, size=size, mode="bilinear", align_corners=False
)
return {"description_grid": descriptor}
@torch.inference_mode()
def describe_keypoints(self, batch, keypoints):
self.train(False)
description_grid = self.forward(batch)["description_grid"]
described_keypoints = F.grid_sample(
description_grid.float(),
keypoints[:, None],
mode="bilinear",
align_corners=False,
)[:, :, 0].mT
return {"descriptions": described_keypoints}
def read_image(self, im_path, H=560, W=560):
return (
self.normalizer(
torch.from_numpy(
np.array(Image.open(im_path).resize((W, H))) / 255.0
).permute(2, 0, 1)
)
.cuda()
.float()[None]
)
def describe_keypoints_from_path(self, im_path, keypoints, H=768, W=768):
batch = {"image": self.read_image(im_path, H=H, W=W)}
return self.describe_keypoints(batch, keypoints)