| |
| import torch |
| import torchvision |
| import open_clip |
|
|
|
|
| class OpenCLIPNetwork: |
| def __init__(self, device): |
| self.process = torchvision.transforms.Compose( |
| [ |
| torchvision.transforms.Resize((224, 224)), |
| torchvision.transforms.Normalize( |
| mean=[0.48145466, 0.4578275, 0.40821073], |
| std=[0.26862954, 0.26130258, 0.27577711], |
| ), |
| ] |
| ) |
| self.clip_model_type = "ViT-B-16" |
| self.clip_model_pretrained = 'laion2b_s34b_b88k' |
| self.clip_n_dims = 512 |
| model, _, _ = open_clip.create_model_and_transforms( |
| self.clip_model_type, |
| pretrained=self.clip_model_pretrained, |
| precision="fp16", |
| ) |
| model.eval() |
| |
| self.tokenizer = open_clip.get_tokenizer(self.clip_model_type) |
| self.model = model.to(device) |
|
|
| self.negatives = ("object", "things", "stuff", "texture") |
| self.positives = (" ",) |
| with torch.no_grad(): |
| tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.positives]).to(device) |
| self.pos_embeds = model.encode_text(tok_phrases) |
| tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.negatives]).to(device) |
| self.neg_embeds = model.encode_text(tok_phrases) |
| self.pos_embeds /= self.pos_embeds.norm(dim=-1, keepdim=True) |
| self.neg_embeds /= self.neg_embeds.norm(dim=-1, keepdim=True) |
|
|
| @torch.no_grad() |
| def get_relevancy(self, embed: torch.Tensor, positive_id: int) -> torch.Tensor: |
| |
| phrases_embeds = torch.cat([self.pos_embeds, self.neg_embeds], dim=0) |
| p = phrases_embeds.to(embed.dtype) |
| output = torch.mm(embed, p.T) |
| positive_vals = output[..., positive_id : positive_id + 1] |
| negative_vals = output[..., len(self.positives) :] |
| repeated_pos = positive_vals.repeat(1, len(self.negatives)) |
|
|
| sims = torch.stack((repeated_pos, negative_vals), dim=-1) |
| softmax = torch.softmax(10 * sims, dim=-1) |
| best_id = softmax[..., 0].argmin(dim=1) |
| return torch.gather(softmax, 1, best_id[..., None, None].expand(best_id.shape[0], len(self.negatives), 2))[ |
| :, 0, : |
| ] |
|
|
| def encode_image(self, input, mask=None): |
| processed_input = self.process(input).half() |
| return self.model.encode_image(processed_input, mask=mask) |
|
|
| def encode_text(self, text_list, device): |
| text = self.tokenizer(text_list).to(device) |
| return self.model.encode_text(text) |
| |
| def set_positives(self, text_list): |
| self.positives = text_list |
| with torch.no_grad(): |
| tok_phrases = torch.cat( |
| [self.tokenizer(phrase) for phrase in self.positives] |
| ).to(self.neg_embeds.device) |
| self.pos_embeds = self.model.encode_text(tok_phrases) |
| self.pos_embeds /= self.pos_embeds.norm(dim=-1, keepdim=True) |
| |
| def set_semantics(self, text_list): |
| self.semantic_labels = text_list |
| with torch.no_grad(): |
| tok_phrases = torch.cat([self.tokenizer(phrase) for phrase in self.semantic_labels]).to("cuda") |
| self.semantic_embeds = self.model.encode_text(tok_phrases) |
| self.semantic_embeds /= self.semantic_embeds.norm(dim=-1, keepdim=True) |
| |
| def get_semantic_map(self, sem_map: torch.Tensor) -> torch.Tensor: |
| |
| n_levels, h, w, c = sem_map.shape |
| pos_num = self.semantic_embeds.shape[0] |
| phrases_embeds = torch.cat([self.semantic_embeds, self.neg_embeds], dim=0) |
| p = phrases_embeds.to(sem_map.dtype) |
| sem_pred = torch.zeros(n_levels, h, w) |
| for i in range(n_levels): |
| output = torch.mm(sem_map[i].view(-1, c), p.T) |
| softmax = torch.softmax(10 * output, dim=-1) |
| sem_pred[i] = torch.argmax(softmax, dim=-1).view(h, w) |
| sem_pred[i][sem_pred[i] >= pos_num] = -1 |
| return sem_pred.long() |
|
|
| def get_max_across(self, sem_map): |
| n_phrases = len(self.positives) |
| n_phrases_sims = [None for _ in range(n_phrases)] |
| |
| n_levels, h, w, _ = sem_map.shape |
| clip_output = sem_map.permute(1, 2, 0, 3).flatten(0, 1) |
|
|
| n_levels_sims = [None for _ in range(n_levels)] |
| for i in range(n_levels): |
| for j in range(n_phrases): |
| probs = self.get_relevancy(clip_output[..., i, :], j) |
| pos_prob = probs[..., 0:1] |
| n_phrases_sims[j] = pos_prob |
| n_levels_sims[i] = torch.stack(n_phrases_sims) |
| |
| relev_map = torch.stack(n_levels_sims).view(n_levels, n_phrases, h, w) |
| return relev_map |