File size: 4,945 Bytes
c4c7cee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from os.path import join
import PIL
import numpy as np
import pandas as pd
import reverse_geocoder
from torch.utils.data import Dataset


class GeoDataset(Dataset):
    def __init__(self, image_folder, annotation_file, transformation, tag="image_id"):
        self.image_folder = image_folder
        gt = pd.read_csv(annotation_file, dtype={tag: str})
        files = set([f.replace(".jpg", "") for f in os.listdir(image_folder)])
        gt = gt[gt[tag].isin(files)]
        self.processor = transformation
        self.gt = [
            (g[1][tag], g[1]["latitude"], g[1]["longitude"]) for g in gt.iterrows()
        ]
        self.tag = tag

    def fid(self, i):
        return self.gt[i][0]

    def latlon(self, i):
        return self.gt[i][1]

    def __len__(self):
        return len(self.gt)

    def __getitem__(self, idx):
        fp = join(self.image_folder, self.gt[idx][0] + ".jpg")
        return self.processor(self, idx, fp)


def load_plonk(path):
    import hydra
    from hydra import initialize, compose
    from models.module import DiffGeolocalizer
    from omegaconf import OmegaConf, open_dict
    from os.path import join
    from hydra.utils import instantiate

    # load config from path
    # make path relative to current_dir
    with initialize(version_base=None, config_path="osv5m__best_model"):
        cfg = compose(config_name="config", overrides=[])

    checkpoint = torch.load(join(path, "last.ckpt"))
    del checkpoint["state_dict"][
        "model.backbone.clip.vision_model.embeddings.position_ids"
    ]
    torch.save(checkpoint, join(path, "last2.ckpt"))

    with open_dict(cfg):
        cfg.checkpoint = join(path, "last2.ckpt")

    cfg.num_classes = 11399
    cfg.model.network.mid.instance.final_dim = cfg.num_classes * 3
    cfg.model.network.head.final_dim = cfg.num_classes * 3
    cfg.model.network.head.instance.quadtree_path = join(path, "quadtree_10_1000.csv")

    cfg.dataset.train_dataset.path = ""
    cfg.dataset.val_dataset.path = ""
    cfg.dataset.test_dataset.path = ""
    cfg.logger.save_dir = ""
    cfg.data_dir = ""
    cfg.root_dir = ""
    cfg.mode = "test"
    cfg.model.network.backbone.instance.path = (
        "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
    )
    transform = instantiate(cfg.dataset.test_transform)
    model = DiffGeolocalizer.load_from_checkpoint(
        join(path, "last2.ckpt"), cfg=cfg.model
    )
    os.remove(join(path, "last2.ckpt"))

    @torch.no_grad()
    def inference(model, x):
        return x[0], model.model.backbone({"img": x[1].to(model.device)})[:, 0, :].cpu()

    def collate_fn(batch):
        return [b[0] for b in batch], torch.stack([b[1] for b in batch], dim=0)

    def operate(self, idx, fp):
        proc = self.processor(PIL.Image.open(fp))
        return self.gt[idx][0], proc

    return model, operate, inference, collate_fn


def load_clip(which):
    # We evaluate on:
    # - "openai/clip-vit-base-patch32"
    # - "openai/clip-vit-large-patch14-336"
    # - "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
    # - "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
    # - "geolocal/StreetCLIP"
    from transformers import CLIPProcessor, CLIPModel

    @torch.no_grad()
    def inference(model, img):
        image_ids = img.data.pop("image_id")
        image_input = img.to(model.device)
        image_input["pixel_values"] = image_input["pixel_values"].squeeze(1)
        features = model.get_image_features(**image_input)
        features /= features.norm(dim=-1, keepdim=True)
        return image_ids, features.cpu()

    processor = CLIPProcessor.from_pretrained(which)

    def operate(self, idx, fp):
        pil = PIL.Image.open(fp)
        proc = processor(images=pil, return_tensors="pt")
        proc["image_id"] = self.gt[idx][0]
        return proc

    return CLIPModel.from_pretrained(which), operate, inference, None


def load_dino(which):
    # We evaluate on:
    # - 'facebook/dinov2-large'
    from transformers import AutoImageProcessor, AutoModel

    @torch.no_grad()
    def inference(model, img):
        image_ids = img.data.pop("image_id")
        image_input = img.to(model.device)
        image_input["pixel_values"] = image_input["pixel_values"].squeeze(1)
        features = model(**image_input).last_hidden_state[:, 0]
        features /= features.norm(dim=-1, keepdim=True)
        return image_ids, features.cpu()

    processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")

    def operate(self, idx, fp):
        pil = PIL.Image.open(fp)
        proc = processor(images=pil, return_tensors="pt")
        proc["image_id"] = self.gt[idx][0]
        return proc

    return AutoModel.from_pretrained("facebook/dinov2-large"), operate, inference, None


def get_backbone(name):
    if os.path.isdir(name):
        return load_plonk(name)
    elif "clip" in name.lower():
        return load_clip(name)
    elif "dino" in name.lower():
        return load_dino(name)