File size: 2,181 Bytes
9ae1b1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
#!/usr/bin/env python
import argparse
import os
import sys
from pathlib import Path
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from sinder import (
get_tokens,
load_model,
load_visual_data,
pca_array,
)
os.environ['XFORMERS_DISABLED'] = '1'
def parse_args():
parser = argparse.ArgumentParser(description='Visualize')
parser.add_argument(
'imgs', nargs='+', type=str, help='path to image/images'
)
parser.add_argument(
'--model', type=str, default='dinov2_vitg14', help='model name'
)
parser.add_argument('--workdir', type=str, default='visualize')
parser.add_argument(
'--checkpoint',
type=str,
default=None,
help='Path to checkpoint. Default is None, which loads the official pretrained weights',
)
parser.add_argument(
'--visual_size',
type=int,
default=518,
help='short side size of input image',
)
args = parser.parse_args()
return args
def visualize(args, model, visual_dataset):
model.eval()
for d in tqdm(range(len(visual_dataset))):
visual_image = visual_dataset[d]
visual_tokens_all = get_tokens(model, visual_image)
visual_tokens, visual_tokens_cls = zip(*visual_tokens_all)
filename = Path(visual_dataset.files[d]).stem
t = visual_tokens[-1].detach().cpu()
h, w, c = t.shape
norm = ((t.norm(dim=-1) / t.norm(dim=-1).max()) * 255).byte().numpy()
norm_img = Image.fromarray(norm).resize((w * 14, h * 14), 0)
norm = cv2.applyColorMap(np.array(norm_img), cv2.COLORMAP_JET)
cv2.imwrite(args.folder / f'{filename}_norm.png', norm)
pca_img = pca_array(visual_tokens[-1])
pca_img.save(args.folder / f'{filename}_pca.png')
def main():
args = parse_args()
args.folder = Path(args.workdir).expanduser()
os.makedirs(args.folder, exist_ok=True)
print(args)
print(' '.join(sys.argv))
model = load_model(args.model, args.checkpoint)
visual_dataset = load_visual_data(args, model)
visualize(args, model, visual_dataset)
if __name__ == '__main__':
main()
|