import os import torch from clip_transform import CLIPTransform from PIL import Image from torch.nn import functional as F class Prototypes: def __init__(self): self._clip_transform = CLIPTransform() self._load_prototypes() def _prepare_prototypes(self): image_embeddings = self.load_images_from_folder('prototypes') assert image_embeddings is not None, "no image embeddings found" assert len(image_embeddings) > 0, "no image embeddings found" person_keys = [key for key in image_embeddings.keys() if key.startswith('person-')] no_person_keys = [key for key in image_embeddings.keys() if key.startswith('no_person-')] person_keys.sort() no_person_keys.sort() # create pytorch vector of person embeddings person_embeddings = torch.cat([image_embeddings[key] for key in person_keys]) # create pytorch vector of no_person embeddings no_person_embeddings = torch.cat([image_embeddings[key] for key in no_person_keys]) person_embedding = person_embeddings.mean(dim=0) person_embedding /= person_embedding.norm(dim=-1, keepdim=True) no_person_embedding = no_person_embeddings.mean(dim=0) no_person_embedding /= no_person_embedding.norm(dim=-1, keepdim=True) self.prototype_keys = ["person", "no_person"] self.prototypes = torch.stack([person_embedding, no_person_embedding]) # save prototypes to file torch.save(self.prototypes, 'prototypes.pt') def _load_prototypes(self): # check if file exists if not os.path.exists('prototypes.pt'): self._prepare_prototypes() self.prototypes = torch.load('prototypes.pt') self.prototype_keys = ["person", "no_person"] def load_images_from_folder(self, folder): image_embeddings = {} supported_filetypes = ['.jpg','.png','.jpeg'] for filename in os.listdir(folder): if not any([filename.endswith(ft) for ft in supported_filetypes]): continue image = Image.open(os.path.join(folder,filename)) embeddings = self._clip_transform.pil_image_to_embeddings(image) image_embeddings[filename] = embeddings return image_embeddings def get_distances(self, embeddings): # case not normalized # distances = F.cosine_similarity(embeddings, self.prototypes) # case normalized distances = embeddings @ self.prototypes.T closest_item_idex = distances.argmax().item() closest_item_key = self.prototype_keys[closest_item_idex] debug_str = "" for key, value in zip(self.prototype_keys, distances): debug_str += f"{key}: {value.item():.2f}, " return distances, closest_item_key, debug_str if __name__ == "__main__": prototypes = Prototypes() print ("prototypes:") for key, value in zip(prototypes.prototype_keys, prototypes.prototypes): print (f"{key}: {len(value)}") embeddings = prototypes.prototypes[0] distances, closest_item_key, debug_str = prototypes.get_distances(embeddings) print (f"closest_item_key: {closest_item_key}") print (f"distances: {debug_str}") print ("done")