File size: 2,404 Bytes
c4c7cee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os, sys

# Ajouter le répertoire racine au chemin
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
sys.path.append(root_dir)

import torch
from utils.image_processing import CenterCrop
from data.extract_embeddings.dataset_with_path import ImageWithPathDataset
import torch
from torchvision import transforms
from pathlib import Path


from tqdm import tqdm
import numpy as np
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
    "--number_of_splits",
    type=int,
    help="Number of splits to process",
    default=1,
)
parser.add_argument(
    "--split_index",
    type=int,
    help="Index of the split to process",
    default=0,
)
parser.add_argument(
    "--input_path",
    type=str,
    help="Path to the input dataset",
)
parser.add_argument(
    "--output_path",
    type=str,
    help="Path to the output dataset",
)

args = parser.parse_args()

device = "cuda" if torch.cuda.is_available() else "cpu"

model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg")
model = torch.compile(model, mode="max-autotune")
model.eval()
model.to(device)

input_path = Path(args.input_path)
output_path = Path(args.output_path)

output_path.mkdir(exist_ok=True, parents=True)
augmentation = transforms.Compose(
    [
        CenterCrop(ratio="1:1"),
        transforms.Resize(336, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ]
)
dataset = ImageWithPathDataset(input_path, output_path, transform=augmentation)
dataset = torch.utils.data.Subset(
    dataset,
    range(
        args.split_index * len(dataset) // args.number_of_splits,
        (
            (args.split_index + 1) * len(dataset) // args.number_of_splits
            if args.split_index != args.number_of_splits - 1
            else len(dataset)
        ),
    ),
)

batch_size = 128
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, num_workers=16, collate_fn=lambda x: zip(*x)
)

for images, output_emb_paths in tqdm(dataloader):
    images = torch.stack(images, dim=0).to(device)
    with torch.no_grad():
        embeddings = model(images)
    numpy_embeddings = embeddings.cpu().numpy()
    for emb, output_emb_path in zip(numpy_embeddings, output_emb_paths):
        np.save(f"{output_emb_path}.npy", emb)