lucabarsellotti
commited on
Commit
•
081a09c
1
Parent(s):
821c060
First commit
Browse files- freeda/__init__.py +1 -0
- freeda/configs/dinov2_vitb_clip_vitb.yaml +34 -0
- freeda/configs/dinov2_vitl_clip_vitl.yaml +34 -0
- freeda/configs/dinov2_vitl_clip_vitl_approx.yaml +34 -0
- freeda/models/freeda_model.py +519 -0
- freeda/models/mask_proposer/superpixel.py +66 -0
- freeda/models/vision_backbone.py +54 -0
- freeda/utils/factory.py +111 -0
- main.py +15 -0
- requirements.txt +9 -0
freeda/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .utils.factory import load
|
freeda/configs/dinov2_vitb_clip_vitb.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
collection_url: "https://drive.google.com/uc?id=10v1ZbbVjZQhA43F9JWju5chYnPjR2Ael"
|
3 |
+
compression: "zip"
|
4 |
+
index_url: "https://drive.google.com/uc?id=1AHE6YpY7sGQGp_wPetcr5-z8mk91kWtF"
|
5 |
+
collection_length: 2166945
|
6 |
+
clip:
|
7 |
+
model: ViT-B-16
|
8 |
+
weights: openai
|
9 |
+
templates:
|
10 |
+
- "itap of a {}."
|
11 |
+
- "a bad photo of a {}."
|
12 |
+
- "a origami {}."
|
13 |
+
- "a photo of the large {}."
|
14 |
+
- "a {} in a video game."
|
15 |
+
- "art of the {}."
|
16 |
+
- "a photo of the small {}."
|
17 |
+
|
18 |
+
backbone:
|
19 |
+
model: "vit_base_patch14_dinov2.lvd142m"
|
20 |
+
img_size: 518
|
21 |
+
mask_proposer:
|
22 |
+
use_mask_proposals: true
|
23 |
+
method: "superpixel"
|
24 |
+
args:
|
25 |
+
algorithm: "felzenszwalb"
|
26 |
+
scale: 100
|
27 |
+
sigma: 1.0
|
28 |
+
min_size: 100
|
29 |
+
freeda:
|
30 |
+
global_local_ensemble: 0.8
|
31 |
+
k_search: 350
|
32 |
+
sliding_window: true
|
33 |
+
with_background: false
|
34 |
+
background_threshold: 0.48
|
freeda/configs/dinov2_vitl_clip_vitl.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
collection_url: "https://drive.google.com/uc?id=1U4d0exJuq29b0rLR6iOT20ErW3DAmgw0"
|
3 |
+
compression: "tar"
|
4 |
+
index_url: "https://drive.google.com/uc?id=1FHjpM0aqPf9OjiuG_341EMlEuq6hsh6L"
|
5 |
+
collection_length: 2166946
|
6 |
+
clip:
|
7 |
+
model: ViT-L-14
|
8 |
+
weights: openai
|
9 |
+
templates:
|
10 |
+
- "itap of a {}."
|
11 |
+
- "a bad photo of a {}."
|
12 |
+
- "a origami {}."
|
13 |
+
- "a photo of the large {}."
|
14 |
+
- "a {} in a video game."
|
15 |
+
- "art of the {}."
|
16 |
+
- "a photo of the small {}."
|
17 |
+
|
18 |
+
backbone:
|
19 |
+
model: "vit_large_patch14_dinov2.lvd142m"
|
20 |
+
img_size: 518
|
21 |
+
mask_proposer:
|
22 |
+
use_mask_proposals: true
|
23 |
+
method: "superpixel"
|
24 |
+
args:
|
25 |
+
algorithm: "felzenszwalb"
|
26 |
+
scale: 100
|
27 |
+
sigma: 1.0
|
28 |
+
min_size: 100
|
29 |
+
freeda:
|
30 |
+
global_local_ensemble: 0.8
|
31 |
+
k_search: 350
|
32 |
+
sliding_window: true
|
33 |
+
with_background: false
|
34 |
+
background_threshold: 0.48
|
freeda/configs/dinov2_vitl_clip_vitl_approx.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
collection_url: "https://drive.google.com/uc?id=1U4d0exJuq29b0rLR6iOT20ErW3DAmgw0"
|
3 |
+
compression: "tar"
|
4 |
+
index_url: "https://drive.google.com/uc?id=1LGBnwu8g2PDzIlgyd7gLqkD02wS9r8Hp"
|
5 |
+
collection_length: 2166946
|
6 |
+
clip:
|
7 |
+
model: ViT-L-14
|
8 |
+
weights: openai
|
9 |
+
templates:
|
10 |
+
- "itap of a {}."
|
11 |
+
- "a bad photo of a {}."
|
12 |
+
- "a origami {}."
|
13 |
+
- "a photo of the large {}."
|
14 |
+
- "a {} in a video game."
|
15 |
+
- "art of the {}."
|
16 |
+
- "a photo of the small {}."
|
17 |
+
backbone:
|
18 |
+
model: "vit_large_patch14_dinov2.lvd142m"
|
19 |
+
img_size: 518
|
20 |
+
mask_proposer:
|
21 |
+
use_mask_proposals: true
|
22 |
+
method: "superpixel"
|
23 |
+
args:
|
24 |
+
algorithm: "felzenszwalb"
|
25 |
+
scale: 100
|
26 |
+
sigma: 1.0
|
27 |
+
min_size: 100
|
28 |
+
freeda:
|
29 |
+
global_local_ensemble: 0.8
|
30 |
+
k_search: 350
|
31 |
+
ef_search: 4096
|
32 |
+
sliding_window: true
|
33 |
+
with_background: false
|
34 |
+
background_threshold: 0.5
|
freeda/models/freeda_model.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import open_clip
|
3 |
+
import faiss
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
from math import sqrt, ceil
|
7 |
+
from PIL import Image
|
8 |
+
from torchvision.transforms import Compose, Resize
|
9 |
+
from torchvision.transforms.functional import pil_to_tensor
|
10 |
+
from torch.nn.functional import interpolate
|
11 |
+
from freeda.models.vision_backbone import VisionBackbone
|
12 |
+
|
13 |
+
class FreeDA(torch.nn.Module):
|
14 |
+
def __init__(self, config: dict,
|
15 |
+
lazy_init: bool = True,
|
16 |
+
collection_in_gpu: bool = False,
|
17 |
+
collection_path: str = None,
|
18 |
+
index_path: str = None,
|
19 |
+
use_cached_embeddings: bool = True,
|
20 |
+
cache_embeddings: bool = True,
|
21 |
+
embeddings_cache_path: str = None,
|
22 |
+
device: str = 'cuda',
|
23 |
+
verbose: bool = False,
|
24 |
+
max_masks_batch: int = 128):
|
25 |
+
"""
|
26 |
+
Initialize the model.
|
27 |
+
Args:
|
28 |
+
config (dict): The configuration of the model.
|
29 |
+
lazy_init (bool): Whether to lazily load the collection.
|
30 |
+
collection_in_gpu (bool): Whether to load the faiss retrieval collection in GPU.
|
31 |
+
collection_path (str): The path to the collection.
|
32 |
+
index_path (str): The path to the index.
|
33 |
+
use_cached_embeddings (bool): Whether to use the cached embeddings for the required model.
|
34 |
+
cache_embeddings (bool): Whether to cache the embeddings.
|
35 |
+
embeddings_cache_path (str): The path to the embeddings cache.
|
36 |
+
device (str): The device to use.
|
37 |
+
verbose (bool): Whether to print the progress.
|
38 |
+
max_masks_batch (int): The maximum number of masks to process at once.
|
39 |
+
"""
|
40 |
+
super(FreeDA, self).__init__()
|
41 |
+
|
42 |
+
self.clip_model_name = config['clip']['model']
|
43 |
+
self.clip_weights = config['clip']['weights']
|
44 |
+
self.backbone = VisionBackbone(config['backbone'], device)
|
45 |
+
self.collection_path = collection_path
|
46 |
+
self.index_path = index_path
|
47 |
+
self.collection_in_gpu = collection_in_gpu
|
48 |
+
self.lazy_init = lazy_init
|
49 |
+
self.use_cached_embeddings = use_cached_embeddings
|
50 |
+
self.cache_embeddings = cache_embeddings
|
51 |
+
self.embeddings_cache_path = embeddings_cache_path
|
52 |
+
self.device = device
|
53 |
+
self.verbose = verbose
|
54 |
+
self.config = config
|
55 |
+
self.max_masks_batch = max_masks_batch
|
56 |
+
|
57 |
+
self.collection_initialized = False
|
58 |
+
|
59 |
+
self.ens = self.config['freeda']['global_local_ensemble']
|
60 |
+
if "ef_search" in self.config['freeda']:
|
61 |
+
self.ef_search = self.config['freeda']['ef_search']
|
62 |
+
|
63 |
+
self.use_mask_proposer = config['mask_proposer']['use_mask_proposals']
|
64 |
+
if self.use_mask_proposer:
|
65 |
+
if config['mask_proposer']['method'] == 'superpixel':
|
66 |
+
from freeda.models.mask_proposer.superpixel import SuperpixelMaskProposer
|
67 |
+
self.mask_proposer = SuperpixelMaskProposer(config['mask_proposer']['args'])
|
68 |
+
|
69 |
+
self.templates = config['clip']['templates']
|
70 |
+
self.sliding_window = config['freeda']['sliding_window']
|
71 |
+
self.with_background = config['freeda']['with_background']
|
72 |
+
if self.with_background:
|
73 |
+
self.background_threshold = config['freeda']['background_threshold']
|
74 |
+
|
75 |
+
self.initialized = False
|
76 |
+
if not lazy_init:
|
77 |
+
self.init_collection()
|
78 |
+
self.init_clip()
|
79 |
+
self.initialized = True
|
80 |
+
|
81 |
+
def init_collection(self):
|
82 |
+
"""
|
83 |
+
Initialize the collection reading the faiss index and the list of embeddings. Moves the index to GPU if collection_in_gpu is True.
|
84 |
+
"""
|
85 |
+
self.collection_index = faiss.read_index(self.index_path + "faiss_index/knn.index")
|
86 |
+
if str(type(self.collection_index)) == "<class 'faiss.swigfaiss.IndexHNSWFlat'>" \
|
87 |
+
or str(type(self.collection_index)) == "<class 'faiss.swigfaiss_avx2.IndexHNSWFlat'>":
|
88 |
+
print(f"Setting faiss efSearch to {self.ef_search}")
|
89 |
+
faiss.ParameterSpace().set_index_parameter(self.collection_index, 'efSearch', self.ef_search)
|
90 |
+
self.collection_embeddings = sorted(os.listdir(self.collection_path))
|
91 |
+
if self.collection_in_gpu:
|
92 |
+
resources = [faiss.StandardGpuResources()]
|
93 |
+
self.collection_index = faiss.index_cpu_to_gpu_multiple_py(resources, self.collection_index)
|
94 |
+
|
95 |
+
@torch.no_grad()
|
96 |
+
def init_clip(self):
|
97 |
+
"""
|
98 |
+
Initialize the CLIP model and tokenizer.
|
99 |
+
"""
|
100 |
+
self.clip_model, _, self.clip_image_preprocess = open_clip.create_model_and_transforms(self.clip_model_name,
|
101 |
+
pretrained=self.clip_weights,
|
102 |
+
device=self.device)
|
103 |
+
self.clip_resize_dim = self.clip_image_preprocess.transforms[0].size
|
104 |
+
self.clip_image_preprocess = Compose([
|
105 |
+
Resize(
|
106 |
+
(self.clip_image_preprocess.transforms[0].size, self.clip_image_preprocess.transforms[0].size),
|
107 |
+
interpolation=self.clip_image_preprocess.transforms[0].interpolation, antialias=None),
|
108 |
+
lambda x: x / 255,
|
109 |
+
self.clip_image_preprocess.transforms[4]
|
110 |
+
])
|
111 |
+
self.clip_model.eval()
|
112 |
+
self.clip_tokenizer = open_clip.get_tokenizer(self.clip_model_name)
|
113 |
+
|
114 |
+
@torch.no_grad()
|
115 |
+
def set_categories(self, categories):
|
116 |
+
"""
|
117 |
+
Set the categories to be used by the model. If use_cached_embeddings is True, it will try to load the embeddings from the cache.
|
118 |
+
Otherwise, it will compute the embeddings and cache them if cache_embeddings is True.
|
119 |
+
Args:
|
120 |
+
categories (list): The list of textual arbitrary categories.
|
121 |
+
Return:
|
122 |
+
output_prototypes (torch.Tensor): The prototype embeddings. [num_categories, vis_emb_dim]
|
123 |
+
output_text_embeddings (torch.Tensor): The text embeddings. [num_categories, txt_emb_dim]
|
124 |
+
"""
|
125 |
+
self.categories = categories
|
126 |
+
self.num_categories = len(self.categories) + 1 if self.with_background else len(self.categories)
|
127 |
+
num_categories = len(categories)
|
128 |
+
already_cached_categories = {}
|
129 |
+
not_cached_categories = {}
|
130 |
+
if self.use_cached_embeddings:
|
131 |
+
for i, category in enumerate(self.categories):
|
132 |
+
if os.path.exists(f"{self.embeddings_cache_path}/visual/{category}.npy"):
|
133 |
+
already_cached_categories[i] = category
|
134 |
+
else:
|
135 |
+
not_cached_categories[i] = category
|
136 |
+
else:
|
137 |
+
not_cached_categories = {i: category for i, category in enumerate(self.categories)}
|
138 |
+
if len(not_cached_categories.keys()) != 0:
|
139 |
+
if self.lazy_init and not self.initialized:
|
140 |
+
self.init_clip()
|
141 |
+
self.init_collection()
|
142 |
+
self.initialized = True
|
143 |
+
num_templates = len(self.templates)
|
144 |
+
text = [template.format(category) for category in [v for k, v in not_cached_categories.items()] for template in self.templates]
|
145 |
+
tokens = self.clip_tokenizer(text).to(self.device)
|
146 |
+
context_length = tokens.shape[-1]
|
147 |
+
tokens = tokens.reshape(-1, context_length)
|
148 |
+
text_embeddings = self.clip_model.encode_text(tokens)
|
149 |
+
text_emb_dim = text_embeddings.shape[-1]
|
150 |
+
text_embeddings = text_embeddings.reshape(-1, num_templates, text_emb_dim)
|
151 |
+
text_embeddings = text_embeddings.mean(dim=1)
|
152 |
+
text_embeddings = text_embeddings / (text_embeddings.norm(dim=-1, keepdim=True) + 1e-6)
|
153 |
+
_, indices = self.collection_index.search(text_embeddings.cpu().numpy(), self.config['freeda']['k_search'])
|
154 |
+
prototypes = []
|
155 |
+
for c in indices:
|
156 |
+
category_retrieved_embeddings = []
|
157 |
+
for k in c:
|
158 |
+
retrieved_embedding = torch.from_numpy(np.load(f"{self.collection_path}/{self.collection_embeddings[k]}"))
|
159 |
+
if len(retrieved_embedding.shape) == 1:
|
160 |
+
retrieved_embedding = retrieved_embedding.unsqueeze(0)
|
161 |
+
category_retrieved_embeddings.append(retrieved_embedding)
|
162 |
+
category_retrieved_embeddings = torch.cat(category_retrieved_embeddings, dim=0).to(self.device)
|
163 |
+
prototypes.append(category_retrieved_embeddings)
|
164 |
+
prototypes = torch.stack(prototypes, dim=0).mean(dim=1)
|
165 |
+
prototypes = prototypes / (prototypes.norm(dim=-1, keepdim=True) + 1e-6)
|
166 |
+
prototypes_emb_dim = prototypes.shape[-1]
|
167 |
+
|
168 |
+
if self.cache_embeddings:
|
169 |
+
os.makedirs(f"{self.embeddings_cache_path}/visual", exist_ok=True)
|
170 |
+
os.makedirs(f"{self.embeddings_cache_path}/textual", exist_ok=True)
|
171 |
+
for j, (i, category) in enumerate(not_cached_categories.items()):
|
172 |
+
print(f"Caching embeddings for {category}")
|
173 |
+
np.save(f"{self.embeddings_cache_path}/visual/{category}.npy", prototypes[j].cpu().numpy())
|
174 |
+
np.save(f"{self.embeddings_cache_path}/textual/{category}.npy", text_embeddings[j].cpu().numpy())
|
175 |
+
output_prototypes = torch.zeros(num_categories, prototypes_emb_dim, device=self.device)
|
176 |
+
output_text_embeddings = torch.zeros(num_categories, text_emb_dim, device=self.device)
|
177 |
+
for i, category in already_cached_categories.items():
|
178 |
+
if self.verbose:
|
179 |
+
print(f"Loading cached embeddings for {category}")
|
180 |
+
output_prototypes[i] = torch.tensor(np.load(f"{self.embeddings_cache_path}/visual/{category}.npy"), device=self.device)
|
181 |
+
output_text_embeddings[i] = torch.tensor(np.load(f"{self.embeddings_cache_path}/textual/{category}.npy"), device=self.device)
|
182 |
+
for i, category in not_cached_categories.items():
|
183 |
+
output_prototypes[i] = prototypes[j]
|
184 |
+
output_text_embeddings[i] = text_embeddings[j]
|
185 |
+
else:
|
186 |
+
loaded_prototypes_embeddings = [None for _ in range(num_categories)]
|
187 |
+
loaded_text_embeddings = [None for _ in range(num_categories)]
|
188 |
+
for i, category in already_cached_categories.items():
|
189 |
+
if self.verbose:
|
190 |
+
print(f"Loading cached embeddings for {category}")
|
191 |
+
loaded_prototypes_embeddings[i] = torch.tensor(np.load(f"{self.embeddings_cache_path}/visual/{category}.npy"), device=self.device)
|
192 |
+
loaded_text_embeddings[i] = torch.tensor(np.load(f"{self.embeddings_cache_path}/textual/{category}.npy"), device=self.device)
|
193 |
+
output_prototypes = torch.stack(loaded_prototypes_embeddings, dim=0).to(self.device)
|
194 |
+
output_text_embeddings = torch.stack(loaded_text_embeddings, dim=0).to(self.device)
|
195 |
+
|
196 |
+
self.prototype_embeddings = output_prototypes.clone()
|
197 |
+
self.text_embeddings = output_text_embeddings.clone()
|
198 |
+
return output_prototypes, output_text_embeddings
|
199 |
+
|
200 |
+
@torch.no_grad()
|
201 |
+
def set_images(self, images):
|
202 |
+
"""
|
203 |
+
Set the images to be used by the model and extracts their backbone features, CLIP features and mask proposals.
|
204 |
+
If sliding_window is True, it will split the images into windows.
|
205 |
+
It also returns the features and the masks if the caller wants to use them for other purposes.
|
206 |
+
Args:
|
207 |
+
images (list): The list of PIL images.
|
208 |
+
Return:
|
209 |
+
backbone_features (list): The list of backbone features.[num_images, num_patches_h, num_patches_w, vis_emb_dim]
|
210 |
+
clip_features (torch.Tensor): The visual features from the CLIP model. [num_images, txt_emb_dim]
|
211 |
+
masks (list): The list of mask proposals. (num_images) x (pred_masks, n_pred_masks, covered_pixels, assigned_masks)
|
212 |
+
"""
|
213 |
+
self.images = images
|
214 |
+
self.original_sizes = [image.size for image in images]
|
215 |
+
if self.sliding_window:
|
216 |
+
new_images = []
|
217 |
+
self.window_boxes = []
|
218 |
+
for image in images:
|
219 |
+
image_windows = self.get_window_boxes(image)
|
220 |
+
self.window_boxes.append(image_windows)
|
221 |
+
new_images.extend([image.crop(window_box) for window_box in image_windows])
|
222 |
+
self.images_pre_sliding = self.images
|
223 |
+
self.images = new_images
|
224 |
+
backbone_features, clip_features = self.set_backbone_features(self.images)
|
225 |
+
masks = self.set_mask_proposals(self.images)
|
226 |
+
return backbone_features, clip_features, masks
|
227 |
+
|
228 |
+
def get_window_boxes(self, image):
|
229 |
+
"""
|
230 |
+
Get the window boxes for the sliding window approach.
|
231 |
+
Args:
|
232 |
+
image (PIL.Image): The image to split into windows.
|
233 |
+
Return:
|
234 |
+
window_boxes (list): The list of window boxes coordinates. (x1, y1, x2, y2)
|
235 |
+
"""
|
236 |
+
short_side = min(image.size)
|
237 |
+
long_side = max(image.size)
|
238 |
+
aspect_ratio = long_side / short_side
|
239 |
+
num_windows = ceil(aspect_ratio)
|
240 |
+
window_shift = (long_side - short_side) / (num_windows -1) if num_windows > 1 else 0
|
241 |
+
window_boxes = []
|
242 |
+
current_shift = 0
|
243 |
+
for j in range(num_windows):
|
244 |
+
if short_side == image.size[0]:
|
245 |
+
window_boxes.append((0, round(current_shift), short_side, round(current_shift + short_side)))
|
246 |
+
else:
|
247 |
+
window_boxes.append((round(current_shift), 0, round(current_shift + short_side), short_side))
|
248 |
+
current_shift += window_shift
|
249 |
+
return window_boxes
|
250 |
+
|
251 |
+
@torch.no_grad()
|
252 |
+
def get_clip_visual_features(self, images):
|
253 |
+
"""
|
254 |
+
Get the normalized visual features from the CLIP model.
|
255 |
+
Args:
|
256 |
+
images (list or PIL.Image): The list of PIL images or a single PIL image.
|
257 |
+
Return:
|
258 |
+
clip_features (torch.Tensor): The visual features from the CLIP model. [num_images, vis_emb_dim]
|
259 |
+
"""
|
260 |
+
if type(images) == list:
|
261 |
+
if len(images) == 0:
|
262 |
+
raise ValueError("Images list is empty")
|
263 |
+
images = [pil_to_tensor(image) for image in images]
|
264 |
+
images = [self.clip_image_preprocess(image.unsqueeze(0)).squeeze(0) for image in images]
|
265 |
+
images = torch.stack(images, dim=0).to(self.device)
|
266 |
+
else:
|
267 |
+
images = self.clip_image_preprocess(images)
|
268 |
+
clip_features = self.clip_model.encode_image(images)
|
269 |
+
clip_features = clip_features / (clip_features.norm(dim=-1, keepdim=True) + 1e-6)
|
270 |
+
return clip_features
|
271 |
+
|
272 |
+
@torch.no_grad()
|
273 |
+
def set_backbone_features(self, images):
|
274 |
+
"""
|
275 |
+
Set the backbone features and the CLIP features of the images.
|
276 |
+
If the model is lazy or global similarities are required, it will initialize the CLIP model.
|
277 |
+
Args:
|
278 |
+
images (list): The list of PIL images.
|
279 |
+
Return:
|
280 |
+
backbone_features (list): The list of backbone features. [num_images, num_patches_h, num_patches_w, vis_emb_dim]
|
281 |
+
clip_features (torch.Tensor): The visual features from the CLIP model. [num_images, vis_emb_dim]
|
282 |
+
"""
|
283 |
+
self.backbone_features = self.backbone(images)
|
284 |
+
if self.ens != 1.0:
|
285 |
+
if self.lazy_init and not self.initialized:
|
286 |
+
self.init_clip()
|
287 |
+
|
288 |
+
self.clip_features = self.get_clip_visual_features(images)
|
289 |
+
return self.backbone_features, self.clip_features
|
290 |
+
|
291 |
+
@torch.no_grad()
|
292 |
+
def set_mask_proposals(self, images):
|
293 |
+
"""
|
294 |
+
Set the mask proposals of the images.
|
295 |
+
Args:
|
296 |
+
images (list): The list of PIL images.
|
297 |
+
Return:
|
298 |
+
masks (list): The list of mask proposals. (num_images) x (pred_masks, n_pred_masks, covered_pixels, assigned_masks)
|
299 |
+
Each mask proposal is composed of the predicted binary masks, the number of predicted masks, the pixels covered by masks and the assigned masks for each pixel.
|
300 |
+
pred_masks: [n_pred_masks, h, w] (bool)
|
301 |
+
n_pred_masks: int
|
302 |
+
covered_pixels: [h, w] (bool)
|
303 |
+
assigned_masks: [h, w] (int)
|
304 |
+
"""
|
305 |
+
self.masks = [self.mask_proposer(image, self.device) for image in images] # List of tuples (pred_masks, n_pred_masks, covered_pixels, assigned_masks)
|
306 |
+
return self.masks
|
307 |
+
|
308 |
+
@torch.no_grad()
|
309 |
+
def forward(self):
|
310 |
+
"""
|
311 |
+
Forward pass of the model.
|
312 |
+
Return:
|
313 |
+
masks (list): The list of output segmentation masks. [num_images, h, w] (int)
|
314 |
+
"""
|
315 |
+
patch_similarities = self.compute_patch_similarities(self.backbone_features, self.prototype_embeddings, self.masks)
|
316 |
+
region_embeddings_batch = self.region_pooling(self.backbone_features, self.masks)
|
317 |
+
region_similarities = self.compute_region_similarities(region_embeddings_batch, self.prototype_embeddings, self.masks)
|
318 |
+
global_similarities = self.compute_global_similarities(self.clip_features, self.text_embeddings)
|
319 |
+
similarities = self.compute_final_similarities(region_similarities, patch_similarities, global_similarities, self.masks)
|
320 |
+
if self.sliding_window:
|
321 |
+
similarities = self.merge_sliding_windows(similarities)
|
322 |
+
self.images = self.images_pre_sliding
|
323 |
+
masks = [similarity.argmax(dim=0) for similarity in similarities]
|
324 |
+
return masks
|
325 |
+
|
326 |
+
@torch.no_grad()
|
327 |
+
def merge_sliding_windows(self, similarities):
|
328 |
+
"""
|
329 |
+
Merge the similarities of the sliding windows.
|
330 |
+
Args:
|
331 |
+
similarities (list): The list of similarities of the sliding windows. [num_windows, num_categories, h, w]
|
332 |
+
Return:
|
333 |
+
new_similarities (list): The list of merged similarities. [num_images, num_categories, h, w]
|
334 |
+
"""
|
335 |
+
counter = 0
|
336 |
+
new_similarities = []
|
337 |
+
for original_size, window_boxes in zip(self.original_sizes, self.window_boxes):
|
338 |
+
new_similarity = torch.zeros(self.num_categories, original_size[1], original_size[0], device=self.device)
|
339 |
+
window_overlaps = torch.zeros(original_size[1], original_size[0], device=self.device)
|
340 |
+
for window_box in window_boxes:
|
341 |
+
new_similarity[:, window_box[1]:window_box[3], window_box[0]:window_box[2]] += similarities[counter]
|
342 |
+
window_overlaps[window_box[1]:window_box[3], window_box[0]:window_box[2]] += 1
|
343 |
+
counter += 1
|
344 |
+
new_similarity = new_similarity / window_overlaps
|
345 |
+
new_similarities.append(new_similarity)
|
346 |
+
return new_similarities
|
347 |
+
|
348 |
+
@torch.no_grad()
|
349 |
+
def region_pooling(self, features, masks):
|
350 |
+
"""
|
351 |
+
Perform region pooling to get the region embeddings from the patch-level embeddings and the mask proposals.
|
352 |
+
Args:
|
353 |
+
features (torch.Tensor): The backbone features. [num_images, num_patches_h, num_patches_w, vis_emb_dim]
|
354 |
+
masks (list): The list of mask proposals. (num_images) x (pred_masks, n_pred_masks, covered_pixels, assigned_masks)
|
355 |
+
Return:
|
356 |
+
region_embeddings_batch (list): The list of normalized region embeddings. [num_images, num_regions, vis_emb_dim]
|
357 |
+
"""
|
358 |
+
region_embeddings_batch = []
|
359 |
+
for i in range(len(masks)):
|
360 |
+
pred_masks = interpolate(masks[i][0].unsqueeze(1).float(), size=(features.shape[1], features.shape[2]), mode='bilinear', align_corners=True).type(torch.bool).squeeze(1)
|
361 |
+
region_embeddings = torch.zeros(pred_masks.shape[0], features.shape[-1], device=self.device)
|
362 |
+
for j in range(0, pred_masks.shape[0], self.max_masks_batch):
|
363 |
+
r = min(j + self.max_masks_batch, masks[i][1])
|
364 |
+
current_region_embeddings = pred_masks[j:r].unsqueeze(-1) * features[i].unsqueeze(0)
|
365 |
+
region_embeddings[j:r] = current_region_embeddings.sum(dim=(1, 2)) / pred_masks[j:r].sum(dim=(1, 2)).unsqueeze(-1)
|
366 |
+
region_embeddings = region_embeddings / (region_embeddings.norm(dim=-1, keepdim=True) + 1e-6)
|
367 |
+
self.masks[i][2][self.masks[i][0][region_embeddings.isnan().sum(-1).bool()].sum(0).bool()] = False # Replace pixels covered by too small regions
|
368 |
+
region_embeddings_batch.append(region_embeddings)
|
369 |
+
return region_embeddings_batch
|
370 |
+
|
371 |
+
@torch.no_grad()
|
372 |
+
def compute_region_similarities(self, region_embeddings_batch, prototype_embeddings, masks):
|
373 |
+
"""
|
374 |
+
Compute the region local similarities between the region embeddings and the visual prototypes.
|
375 |
+
Args:
|
376 |
+
region_embeddings_batch (list): The list of region embeddings. [num_images, num_regions, vis_emb_dim]
|
377 |
+
prototype_embeddings (torch.Tensor): The prototype embeddings. [num_categories, vis_emb_dim]
|
378 |
+
masks (list): The list of mask proposals. (num_images) x (pred_masks, n_pred_masks, covered_pixels, assigned_masks)
|
379 |
+
Return:
|
380 |
+
similarities (list): The list of region similarities. [num_images, num_categories, h, w]
|
381 |
+
"""
|
382 |
+
similarities = []
|
383 |
+
for i in range(len(region_embeddings_batch)):
|
384 |
+
output_similarities = torch.zeros(len(self.categories), masks[i][0].shape[1], masks[i][0].shape[2], device=self.device)
|
385 |
+
region_embeddings = region_embeddings_batch[i]
|
386 |
+
region_similarities = torch.matmul(region_embeddings, prototype_embeddings.T)
|
387 |
+
output_similarities[:, masks[i][2]] = region_similarities[masks[i][3]][masks[i][2]].permute(1, 0)
|
388 |
+
similarities.append(torch.sigmoid(output_similarities))
|
389 |
+
return similarities
|
390 |
+
|
391 |
+
@torch.no_grad()
|
392 |
+
def compute_patch_similarities(self, patch_embeddings, prototype_embeddings, masks):
|
393 |
+
"""
|
394 |
+
Compute the per-patch local similarities between the patch-level embeddings and the visual prototypes.
|
395 |
+
Args:
|
396 |
+
patch_embeddings (torch.Tensor): The patch-level embeddings. [num_images, num_patches_h, num_patches_w, vis_emb_dim]
|
397 |
+
prototype_embeddings (torch.Tensor): The prototype embeddings. [num_categories, vis_emb_dim]
|
398 |
+
masks (list): The list of mask proposals.
|
399 |
+
Return:
|
400 |
+
output_similarities (list): The list of patch similarities. [num_images, num_categories, h, w]
|
401 |
+
"""
|
402 |
+
patch_embeddings = patch_embeddings / (patch_embeddings.norm(dim=-1, keepdim=True) + 1e-6)
|
403 |
+
similarities = torch.matmul(patch_embeddings, prototype_embeddings.T)
|
404 |
+
similarities = torch.sigmoid(similarities)
|
405 |
+
image_sizes = [mask[0].shape[1:] for mask in masks]
|
406 |
+
output_similarities = []
|
407 |
+
for i, image_size in enumerate(image_sizes):
|
408 |
+
output_similarities.append(interpolate(similarities[i].permute(2,0,1).unsqueeze(1), size=image_size, mode='bilinear', align_corners=True).squeeze(1))
|
409 |
+
return output_similarities
|
410 |
+
|
411 |
+
@torch.no_grad()
|
412 |
+
def compute_global_similarities(self, clip_features, text_embeddings):
|
413 |
+
"""
|
414 |
+
Compute the global similarities between the CLIP visual features and the text embeddings.
|
415 |
+
Args:
|
416 |
+
clip_features (torch.Tensor): The visual features from the CLIP model. [num_images, vis_emb_dim]
|
417 |
+
text_embeddings (torch.Tensor): The text embeddings. [num_categories, txt_emb_dim]
|
418 |
+
Return:
|
419 |
+
similarities (list): The list of global similarities. [num_images, num_categories]
|
420 |
+
"""
|
421 |
+
similarities = torch.matmul(clip_features, text_embeddings.T)
|
422 |
+
return similarities
|
423 |
+
|
424 |
+
@torch.no_grad()
|
425 |
+
def compute_final_similarities(self, region_similarities, patch_similarities, global_similarities, masks):
|
426 |
+
"""
|
427 |
+
Compute the final similarities by combining the region and patch local similarities with the textual global similarities.
|
428 |
+
Args:
|
429 |
+
region_similarities (list): The list of region local similarities. [num_images, num_categories, h, w]
|
430 |
+
patch_similarities (list): The list of per-patch interpolated local similarities. [num_images, num_categories, h, w]
|
431 |
+
global_similarities (list): The list of textual global similarities. [num_images, num_categories]
|
432 |
+
masks (list): The list of mask proposals. (num_images) x (pred_masks, n_pred_masks, covered_pixels, assigned_masks)
|
433 |
+
Return:
|
434 |
+
new_similarities (list): The list of final similarities. [num_images, num_categories, h, w]
|
435 |
+
"""
|
436 |
+
new_similarities = self.replace_covered_pixel_similarities(masks, patch_similarities, region_similarities)
|
437 |
+
new_similarities = [self.ens * new_similarities[i] + (1 - self.ens) * global_similarities[i].reshape(-1, 1, 1) for i in range(len(new_similarities))]
|
438 |
+
if self.with_background:
|
439 |
+
new_similarities = self.add_backgrounds(new_similarities)
|
440 |
+
return new_similarities
|
441 |
+
|
442 |
+
@torch.no_grad()
|
443 |
+
def add_backgrounds(self, similarities):
|
444 |
+
"""
|
445 |
+
Add the background class to the similarities by thresholding the maximum similarity of each pixel.
|
446 |
+
Args:
|
447 |
+
similarities (list): The list of final similarities. [num_images, num_categories, h, w]
|
448 |
+
Return:
|
449 |
+
similarities (list): The list of final similarities with the background class. [num_images, num_categories + 1, h, w]
|
450 |
+
"""
|
451 |
+
for i in range(len(similarities)):
|
452 |
+
background = (similarities[i].max(dim=0).values < self.background_threshold).float().unsqueeze(0)
|
453 |
+
similarities[i] = torch.cat([background, similarities[i]], dim=0)
|
454 |
+
return similarities
|
455 |
+
|
456 |
+
@torch.no_grad()
|
457 |
+
def replace_covered_pixel_similarities(self, masks, patch_similarities, region_similarities):
|
458 |
+
"""
|
459 |
+
Replace the similarities of the covered pixels by the region similarities.
|
460 |
+
Args:
|
461 |
+
masks (list): The list of mask proposals. (num_images) x (pred_masks, n_pred_masks, covered_pixels, assigned_masks)
|
462 |
+
patch_similarities (list): The list of per-patch interpolated local similarities. [num_images, num_categories, h, w]
|
463 |
+
region_similarities (list): The list of region local similarities. [num_images, num_categories, h, w]
|
464 |
+
Return:
|
465 |
+
output_similarities (list): The list of final similarities. [num_images, num_categories, h, w]
|
466 |
+
"""
|
467 |
+
output_similarities = []
|
468 |
+
for i in range(len(masks)):
|
469 |
+
tmp_patch_similarities = patch_similarities[i].permute(1, 2, 0)
|
470 |
+
tmp_new_similarities = region_similarities[i].permute(1, 2, 0)
|
471 |
+
tmp_patch_similarities[masks[i][2]] = tmp_new_similarities[masks[i][2]]
|
472 |
+
output_similarities.append(tmp_patch_similarities.permute(2, 0, 1))
|
473 |
+
return output_similarities
|
474 |
+
|
475 |
+
@torch.no_grad()
|
476 |
+
def visualize(self, segmentation_masks, output_paths, legend=True):
|
477 |
+
"""
|
478 |
+
Visualize the segmentation masks saving the plots to the provided output paths.
|
479 |
+
Args:
|
480 |
+
segmentation_masks (list): The list of segmentation masks. [num_images, h, w] (int)
|
481 |
+
output_paths (list): The list of output paths. [num_images]
|
482 |
+
legend (bool): Whether to add a legend to the plot.
|
483 |
+
"""
|
484 |
+
if len(segmentation_masks) != len(self.images):
|
485 |
+
raise ValueError("Number of segmentation masks and images must be the same")
|
486 |
+
if len(segmentation_masks) != len(output_paths):
|
487 |
+
raise ValueError("Number of segmentation masks and output paths must be the same")
|
488 |
+
from skimage.segmentation import find_boundaries
|
489 |
+
import random
|
490 |
+
import matplotlib.pyplot as plt
|
491 |
+
from matplotlib.lines import Line2D
|
492 |
+
for i in range(len(segmentation_masks)):
|
493 |
+
mask = segmentation_masks[i]
|
494 |
+
image = self.images[i]
|
495 |
+
h, w = mask.shape
|
496 |
+
colored_image = torch.zeros((h, w, 3), dtype=torch.int)
|
497 |
+
random_colors = []
|
498 |
+
new_categories = ["background"] + self.categories if self.with_background else self.categories
|
499 |
+
for index in range(len(new_categories)):
|
500 |
+
rand_mask = torch.ones(h, w, 3, dtype=torch.int)
|
501 |
+
rand_mask[:,:,-1] = (mask == index) * 255
|
502 |
+
random_rgb = [0, 0, 0] if self.with_background and index == 0 else [random.randint(64, 255), random.randint(64, 255), random.randint(64, 255)]
|
503 |
+
for j in range(3):
|
504 |
+
rand_mask[:,:,j] = torch.ones(h, w, dtype=torch.int) * random_rgb[j] * (mask == index).cpu().int()
|
505 |
+
colored_image[(mask == index).cpu()] = rand_mask[(mask == index).cpu()]
|
506 |
+
random_colors.append([channel / 255 for channel in random_rgb])
|
507 |
+
plt.imshow(image)
|
508 |
+
boundaries = find_boundaries(mask.cpu().numpy().astype(np.uint8))
|
509 |
+
boundaries_image = np.zeros((h, w, 4), dtype=np.uint8)
|
510 |
+
boundaries_image[:,:,3] = boundaries * 255
|
511 |
+
plt.imshow(colored_image.numpy(), alpha=0.5)
|
512 |
+
plt.imshow(boundaries_image)
|
513 |
+
plt.axis('off')
|
514 |
+
plt.tight_layout()
|
515 |
+
if legend:
|
516 |
+
legend_elements = [Line2D([0], [0], color=random_colors[i], lw=4, label=new_categories[i]) for i in range(len(random_colors))]
|
517 |
+
plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1, 1))
|
518 |
+
plt.savefig(output_paths[i], bbox_inches='tight')
|
519 |
+
plt.close()
|
freeda/models/mask_proposer/superpixel.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from torchvision.transforms.functional import pil_to_tensor
|
5 |
+
from skimage.color import rgb2gray
|
6 |
+
from skimage.filters import sobel
|
7 |
+
from skimage.measure import regionprops
|
8 |
+
from skimage.segmentation import felzenszwalb, slic, quickshift, watershed
|
9 |
+
|
10 |
+
class SuperpixelMaskProposer:
|
11 |
+
def __init__(self, config):
|
12 |
+
self.config = config
|
13 |
+
self.superpixel_method = config['algorithm']
|
14 |
+
self.superpixel_conf = config
|
15 |
+
self.superpixel_conf.pop('algorithm')
|
16 |
+
if self.superpixel_method == 'seeds':
|
17 |
+
self.num_iterations = self.superpixel_conf.pop("num_iterations")
|
18 |
+
|
19 |
+
def __call__(self, image, device):
|
20 |
+
if type(image) != torch.Tensor:
|
21 |
+
image = pil_to_tensor(image)
|
22 |
+
pred_masks = []
|
23 |
+
n_pred_masks = []
|
24 |
+
assigned_masks = []
|
25 |
+
covered_pixels = torch.ones(image.shape[1], image.shape[2]).type(torch.bool).to(device)
|
26 |
+
if self.superpixel_method == "seeds":
|
27 |
+
image = image.permute(1, 2, 0).cpu().numpy()
|
28 |
+
image = np.ascontiguousarray(image.astype(np.uint8))
|
29 |
+
else:
|
30 |
+
image = image.permute(1, 2, 0).cpu().numpy() / 255
|
31 |
+
|
32 |
+
if self.superpixel_method == "felzenszwalb":
|
33 |
+
superpixel_mask = felzenszwalb(image, **self.superpixel_conf)
|
34 |
+
|
35 |
+
elif self.superpixel_method == "slic":
|
36 |
+
superpixel_mask = slic(image, **self.superpixel_conf)
|
37 |
+
|
38 |
+
elif self.superpixel_method == "quickshift":
|
39 |
+
superpixel_mask = quickshift(image, **self.superpixel_conf)
|
40 |
+
|
41 |
+
elif self.superpixel_method == "watershed":
|
42 |
+
gradient = sobel(rgb2gray(image))
|
43 |
+
superpixel_mask = watershed(gradient, **self.superpixel_conf)
|
44 |
+
|
45 |
+
elif self.superpixel_method == "seeds":
|
46 |
+
superpix_seeds = cv2.ximageproc.createSuperpixelSEEDS(**self.superpixel_conf)
|
47 |
+
superpix_seeds.iterate(image, self.num_iterations)
|
48 |
+
superpixel_mask = superpix_seeds.getLabels()
|
49 |
+
num_superpixels = superpix_seeds.getNumberOfSuperpixels()
|
50 |
+
else:
|
51 |
+
raise NotImplementedError(f"Superpixel algorithm {self.superpixel_method} not implemented.")
|
52 |
+
|
53 |
+
if self.superpixel_method == "seeds":
|
54 |
+
superpixel_mask_binary = np.array([superpixel_mask == i for i in np.arange(num_superpixels)])
|
55 |
+
else:
|
56 |
+
superpixel_mask_binary = np.array([superpixel_mask == i for i in np.unique(superpixel_mask)])
|
57 |
+
num_superpixel = superpixel_mask_binary.shape[0]
|
58 |
+
|
59 |
+
pred_masks = torch.from_numpy(superpixel_mask_binary).type(torch.bool).to(device)
|
60 |
+
n_pred_masks = num_superpixel
|
61 |
+
assigned_masks = torch.from_numpy(superpixel_mask).type(torch.long).to(device)
|
62 |
+
|
63 |
+
if self.superpixel_method == "watershed":
|
64 |
+
assigned_masks = assigned_masks - 1
|
65 |
+
|
66 |
+
return pred_masks, n_pred_masks, covered_pixels, assigned_masks
|
freeda/models/vision_backbone.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import timm
|
3 |
+
from math import sqrt
|
4 |
+
from torchvision.transforms import Compose, Resize
|
5 |
+
from torchvision.transforms.functional import pil_to_tensor
|
6 |
+
|
7 |
+
class VisionBackbone(torch.nn.Module):
|
8 |
+
def __init__(self,
|
9 |
+
config: dict,
|
10 |
+
device: str = 'cuda',
|
11 |
+
max_batch_size: int = 16):
|
12 |
+
super(VisionBackbone, self).__init__()
|
13 |
+
self.backbone_name = config['model']
|
14 |
+
self.device = device
|
15 |
+
self.model = timm.create_model(
|
16 |
+
config['model'],
|
17 |
+
pretrained=True,
|
18 |
+
img_size=config['img_size'],
|
19 |
+
).to(self.device).eval()
|
20 |
+
|
21 |
+
data_config = timm.data.resolve_model_data_config(config['model'])
|
22 |
+
self.transform = timm.data.create_transform(**data_config, is_training=False)
|
23 |
+
|
24 |
+
self.transform = Compose([
|
25 |
+
Resize((config['img_size'], config['img_size']), antialias=None),
|
26 |
+
lambda x: x / 255,
|
27 |
+
self.transform.transforms[-1]
|
28 |
+
])
|
29 |
+
|
30 |
+
self.max_batch_size = max_batch_size
|
31 |
+
|
32 |
+
@torch.no_grad()
|
33 |
+
def forward(self, images):
|
34 |
+
if type(images) == list:
|
35 |
+
if len(images) == 0:
|
36 |
+
raise ValueError("Images list is empty")
|
37 |
+
images = [pil_to_tensor(image) for image in images]
|
38 |
+
images = [self.transform(image.unsqueeze(0)).squeeze(0) for image in images]
|
39 |
+
images = torch.stack(images, dim=0)
|
40 |
+
else:
|
41 |
+
if images.shape[1] != 3:
|
42 |
+
images = images.permute(0, 3, 1, 2)
|
43 |
+
images = self.transform(images)
|
44 |
+
batch_size = images.shape[0]
|
45 |
+
if batch_size < self.max_batch_size:
|
46 |
+
features = self.model.forward_features(images.to(self.device))
|
47 |
+
else:
|
48 |
+
features = []
|
49 |
+
for i in range(0, batch_size, self.max_batch_size):
|
50 |
+
r = min(i + self.max_batch_size, batch_size)
|
51 |
+
features.append(self.model.forward_features(images[i:r].to(self.device)))
|
52 |
+
features = torch.cat(features, dim=0)
|
53 |
+
num_tokens_side = int(sqrt(features.shape[1] - 1))
|
54 |
+
return features[:, 1::, :].reshape(batch_size, num_tokens_side, num_tokens_side, features.shape[-1])
|
freeda/utils/factory.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import gdown
|
4 |
+
import subprocess
|
5 |
+
import tarfile
|
6 |
+
import zipfile
|
7 |
+
from freeda.models.freeda_model import FreeDA
|
8 |
+
|
9 |
+
def load(model_name: str,
|
10 |
+
lazy_init: bool = False,
|
11 |
+
collection_in_gpu: bool = False,
|
12 |
+
use_cache: bool = True,
|
13 |
+
force_cache_download: bool = False,
|
14 |
+
collection_path: str = None,
|
15 |
+
index_path: str = None,
|
16 |
+
use_cached_embeddings: bool = True,
|
17 |
+
cache_embeddings: bool = True,
|
18 |
+
embeddings_cache_path: str = None,
|
19 |
+
verbose: bool = True,
|
20 |
+
device: str = 'cuda',
|
21 |
+
custom_configs_path: str = None):
|
22 |
+
"""
|
23 |
+
Load the model and its configuration.
|
24 |
+
Args:
|
25 |
+
model_name (str): The name of the model to load.
|
26 |
+
lazy_init (bool): Whether to lazily load the collection.
|
27 |
+
collection_in_gpu (bool): Whether to load the collection in GPU.
|
28 |
+
use_cache (bool): Whether to use the cache. If False, required data will be downloaded.
|
29 |
+
force_cache_dowload (bool): Whether to force dowloading and storing in cache.
|
30 |
+
collection_path (str): The path to the collection.
|
31 |
+
index_path (str): The path to the index.
|
32 |
+
use_cached_embeddings (bool): Whether to use the cached embeddings for the required model.
|
33 |
+
cache_embeddings (bool): Whether to cache the embeddings.
|
34 |
+
embeddings_cache_path (str): The path to the embeddings cache.
|
35 |
+
verbose (bool): Whether to print the progress.
|
36 |
+
device (str): The device to use.
|
37 |
+
custom_configs_path (str): The path to a directory with custom configurations. If None, uses the default configs.
|
38 |
+
Return:
|
39 |
+
FreeDA: The FreeDA model.
|
40 |
+
"""
|
41 |
+
if custom_configs_path is None:
|
42 |
+
current_path = os.path.dirname(os.path.abspath(__file__))
|
43 |
+
configs_path = os.path.join(current_path, '../configs')
|
44 |
+
else:
|
45 |
+
configs_path = custom_configs_path
|
46 |
+
if f"{model_name}.yaml" not in os.listdir(configs_path):
|
47 |
+
raise ValueError(f"Model {model_name} not available")
|
48 |
+
with open(f"{configs_path}/{model_name}.yaml", 'r') as file:
|
49 |
+
config = yaml.load(file, Loader=yaml.FullLoader)
|
50 |
+
|
51 |
+
cache_root = os.path.expanduser(f"~/.cache/freeda/{model_name}/")
|
52 |
+
|
53 |
+
new_collection_path = f"{cache_root}prototype_embeddings/" if collection_path is None else collection_path
|
54 |
+
os.makedirs(new_collection_path, exist_ok=True)
|
55 |
+
|
56 |
+
new_index_path = f"{cache_root}index/" if index_path is None else index_path
|
57 |
+
os.makedirs(new_index_path, exist_ok=True)
|
58 |
+
|
59 |
+
if use_cache:
|
60 |
+
use_cached_collection = True if len(os.listdir(new_collection_path)) == config['data']['collection_length'] and not force_cache_download else False
|
61 |
+
use_cached_index = True if len(os.listdir(new_index_path)) != 0 and not force_cache_download else False
|
62 |
+
else:
|
63 |
+
use_cached_collection = False
|
64 |
+
use_cached_index = False
|
65 |
+
|
66 |
+
if not use_cached_collection:
|
67 |
+
if verbose:
|
68 |
+
print("Downloading collection...")
|
69 |
+
output_collection_tar = f"{cache_root}prototype_embeddings.tar"
|
70 |
+
gdown.download(config['data']['collection_url'], output_collection_tar, quiet=verbose)
|
71 |
+
if verbose:
|
72 |
+
print("Extracting compressed collection... (it may take a while)")
|
73 |
+
if config['data']['compression'] == 'zip':
|
74 |
+
with tarfile.open(output_collection_tar, 'r:gz') as tar:
|
75 |
+
tar.extractall(new_collection_path)
|
76 |
+
elif config['data']['compression'] == 'tar':
|
77 |
+
with tarfile.open(output_collection_tar, 'r') as tar:
|
78 |
+
tar.extractall(new_collection_path)
|
79 |
+
else:
|
80 |
+
raise ValueError("Invalid compression format")
|
81 |
+
else:
|
82 |
+
if verbose:
|
83 |
+
print("Using cached collection...")
|
84 |
+
|
85 |
+
if not use_cached_index:
|
86 |
+
if verbose:
|
87 |
+
print("Downloading index...")
|
88 |
+
output_index_zip = f"{cache_root}faiss_index.zip"
|
89 |
+
gdown.download(config['data']['index_url'], output_index_zip, quiet=verbose)
|
90 |
+
with zipfile.ZipFile(output_index_zip, 'r') as zip_ref:
|
91 |
+
zip_ref.extractall(new_index_path)
|
92 |
+
else:
|
93 |
+
if verbose:
|
94 |
+
print("Using cached index...")
|
95 |
+
|
96 |
+
if embeddings_cache_path is None and use_cached_embeddings:
|
97 |
+
embeddings_cache_path = f"{cache_root}embeddings/"
|
98 |
+
elif embeddings_cache_path is not None and use_cached_embeddings:
|
99 |
+
embeddings_cache_path = os.path.expanduser(embeddings_cache_path)
|
100 |
+
|
101 |
+
return FreeDA(config,
|
102 |
+
lazy_init=lazy_init,
|
103 |
+
collection_in_gpu=collection_in_gpu,
|
104 |
+
collection_path=new_collection_path,
|
105 |
+
index_path=new_index_path,
|
106 |
+
device=device,
|
107 |
+
use_cached_embeddings=use_cached_embeddings,
|
108 |
+
cache_embeddings=cache_embeddings,
|
109 |
+
embeddings_cache_path=embeddings_cache_path,
|
110 |
+
verbose=verbose)
|
111 |
+
|
main.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import freeda
|
2 |
+
from PIL import Image
|
3 |
+
import requests
|
4 |
+
from io import BytesIO
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
fr = freeda.load("dinov2_vitb_clip_vitb")
|
8 |
+
response1 = requests.get("https://farm9.staticflickr.com/8306/7926031760_b313dca06a_z.jpg")
|
9 |
+
img1 = Image.open(BytesIO(response1.content))
|
10 |
+
response2 = requests.get("https://farm3.staticflickr.com/2207/2157810040_4883738d2d_z.jpg")
|
11 |
+
img2 = Image.open(BytesIO(response2.content))
|
12 |
+
fr.set_categories(["cat", "table", "pen", "keyboard", "toilet", "wall"])
|
13 |
+
fr.set_images([img1, img2])
|
14 |
+
segmentation = fr()
|
15 |
+
fr.visualize(segmentation, ["plot.png", "plot1.png"])
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
faiss==1.8.0
|
2 |
+
gdown==5.2.0
|
3 |
+
matplotlib==3.9.2
|
4 |
+
opencv_python==4.10.0.84
|
5 |
+
Pillow==10.4.0
|
6 |
+
PyYAML==6.0.2
|
7 |
+
Requests==2.32.3
|
8 |
+
skimage==0.0
|
9 |
+
timm==1.0.9
|