Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from clip_transform import CLIPTransform | |
from PIL import Image | |
from torch.nn import functional as F | |
class Prototypes: | |
def __init__(self): | |
self._clip_transform = CLIPTransform() | |
self._load_prototypes() | |
def _prepare_prototypes(self): | |
image_embeddings = self.load_images_from_folder('prototypes') | |
assert image_embeddings is not None, "no image embeddings found" | |
assert len(image_embeddings) > 0, "no image embeddings found" | |
person_keys = [key for key in image_embeddings.keys() if key.startswith('person-')] | |
no_person_keys = [key for key in image_embeddings.keys() if key.startswith('no_person-')] | |
person_keys.sort() | |
no_person_keys.sort() | |
# create pytorch vector of person embeddings | |
person_embeddings = torch.cat([image_embeddings[key] for key in person_keys]) | |
# create pytorch vector of no_person embeddings | |
no_person_embeddings = torch.cat([image_embeddings[key] for key in no_person_keys]) | |
person_embedding = person_embeddings.mean(dim=0) | |
person_embedding /= person_embedding.norm(dim=-1, keepdim=True) | |
no_person_embedding = no_person_embeddings.mean(dim=0) | |
no_person_embedding /= no_person_embedding.norm(dim=-1, keepdim=True) | |
self.prototype_keys = ["person", "no_person"] | |
self.prototypes = torch.stack([person_embedding, no_person_embedding]) | |
# save prototypes to file | |
torch.save(self.prototypes, 'prototypes.pt') | |
def _load_prototypes(self): | |
# check if file exists | |
if not os.path.exists('prototypes.pt'): | |
self._prepare_prototypes() | |
self.prototypes = torch.load('prototypes.pt') | |
self.prototype_keys = ["person", "no_person"] | |
def load_images_from_folder(self, folder): | |
image_embeddings = {} | |
supported_filetypes = ['.jpg','.png','.jpeg'] | |
for filename in os.listdir(folder): | |
if not any([filename.endswith(ft) for ft in supported_filetypes]): | |
continue | |
image = Image.open(os.path.join(folder,filename)) | |
embeddings = self._clip_transform.pil_image_to_embeddings(image) | |
image_embeddings[filename] = embeddings | |
return image_embeddings | |
def get_distances(self, embeddings): | |
# case not normalized | |
# distances = F.cosine_similarity(embeddings, self.prototypes) | |
# case normalized | |
distances = embeddings @ self.prototypes.T | |
closest_item_idex = distances.argmax().item() | |
closest_item_key = self.prototype_keys[closest_item_idex] | |
debug_str = "" | |
for key, value in zip(self.prototype_keys, distances): | |
debug_str += f"{key}: {value.item():.2f}, " | |
return distances, closest_item_key, debug_str | |
if __name__ == "__main__": | |
prototypes = Prototypes() | |
print ("prototypes:") | |
for key, value in zip(prototypes.prototype_keys, prototypes.prototypes): | |
print (f"{key}: {len(value)}") | |
embeddings = prototypes.prototypes[0] | |
distances, closest_item_key, debug_str = prototypes.get_distances(embeddings) | |
print (f"closest_item_key: {closest_item_key}") | |
print (f"distances: {debug_str}") | |
print ("done") |