File size: 5,125 Bytes
94f372a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import sys, os
import json
from PIL import Image
from tqdm import tqdm
from os.path import dirname, join

sys.path.append(dirname(dirname(__file__)))

import torch
from transformers import AutoImageProcessor, AutoModel
from transformers import CLIPProcessor, CLIPModel
from transformers import pipeline

from data.data import osv5m
from json_stream import streamable_list

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model_clip():
    model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
    processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
    return processor, model.to(DEVICE)


def load_model_dino():
    model = AutoModel.from_pretrained("facebook/dinov2-base")
    processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
    return processor, model.to(DEVICE)


def compute_dino(processor, model, x):
    inputs = processor(images=x[0], return_tensors="pt", device=DEVICE).to(DEVICE)
    outputs = model(**inputs)
    last_hidden_states = outputs.last_hidden_state.cpu().numpy()
    for i in range(len(x[0])):
        yield [last_hidden_states[i].tolist(), x[1][i], x[2][i], x[3][i]]


def compute_clip(processor, model, x):
    inputs = processor(images=x[0], return_tensors="pt", device=DEVICE).to(DEVICE)
    features = model.get_image_features(**inputs)
    features /= features.norm(dim=-1, keepdim=True)
    features = features.cpu().numpy()
    for i in range(len(x[0])):
        yield [features[i].tolist(), x[1][i], x[2][i], x[3][i]]


def get_batch(dataset, batch_size):
    data, lats, lons, ids = [], [], [], []
    for i in range(len(dataset)):
        id, lat, lon = dataset.df.iloc[i]
        data.append(Image.open(join(dataset.image_folder, f"{int(id)}.jpg")))
        lats.append(lat)
        lons.append(lon)
        ids.append(id)
        if len(data) == batch_size:
            yield data, lats, lons, ids
            data, lats, lons, ids = [], [], [], []

    if len(data) > 0:
        yield data, lats, lons, ids
        data, lats, lons, ids = [], [], [], []


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--compute_features", action="store_true")
    parser.add_argument("--compute_nearest", action="store_true")
    parser.add_argument("--json_path", default="features")
    parser.add_argument("--which", type=str, default="clip", choices=["clip", "dino"])
    args = parser.parse_args()
    json_path = join(args.json_path, args.which)

    os.makedirs(json_path, exist_ok=True)
    if args.compute_features:
        processor, model = (
            load_model_clip() if args.which == "clip" else load_model_dino()
        )
        compute_fn = compute_clip if args.which == "clip" else compute_dino

        for split in ["test"]:  #'train',
            # open existing json and read as dictionary
            json_path_ = join(json_path, f"{split}.json")

            dataset = osv5m(
                "datasets/osv5m", transforms=None, split=split, dont_split=True
            )

            @torch.no_grad()
            def compute(batch_size):
                for data in tqdm(
                    get_batch(dataset, batch_size),
                    total=len(dataset) // batch_size,
                    desc=f"Computing {split} on {args.which}",
                ):
                    features = compute_fn(processor, model, data)
                    for feature, lat, lon, id in features:
                        yield feature, lat, lon, id

            data = streamable_list(compute(args.batch_size))
            json.dump(data, open(json_path_, "w"), indent=4)

    if args.compute_nearest:
        from sklearn.metrics.pairwise import cosine_similarity
        import numpy as np

        train, test = [
            json.load(open(join(json_path, f"{split}.json"), "r"))
            for split in ["train", "test"]
        ]

        def get_neighbors(k=10):
            for i, test_data in enumerate(tqdm(test)):
                feature, lat, lon, id = test_data
                features_train = np.stack(
                    [np.array(train_data[0]) for train_data in train]
                )
                cs = np.squeeze(
                    cosine_similarity(np.expand_dims(feature, axis=0), features_train),
                    axis=0,
                )
                i = np.argsort(cs)[-k:][::-1].tolist()
                yield [
                    {n: x}
                    for idx in i
                    for n, x in zip(
                        ["feature", "lat", "lon", "id", "distance"],
                        train[idx]
                        + [
                            cs[idx],
                        ],
                    )
                ]

        data = streamable_list(get_neighbors())
        json.dump(data, open(join(json_path, "nearest.json"), "w"), indent=4)