Spaces:
Configuration error
Configuration error
| """ | |
| Contains all the types of interaction related to the GUI | |
| Not related to automatic evaluation in the DAVIS dataset | |
| You can inherit the Interaction class to create new interaction types | |
| undo is (sometimes partially) supported | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import cv2 | |
| import time | |
| from .interactive_utils import color_map, index_numpy_to_one_hot_torch | |
| def aggregate_sbg(prob, keep_bg=False, hard=False): | |
| device = prob.device | |
| k, h, w = prob.shape | |
| ex_prob = torch.zeros((k+1, h, w), device=device) | |
| ex_prob[0] = 0.5 | |
| ex_prob[1:] = prob | |
| ex_prob = torch.clamp(ex_prob, 1e-7, 1-1e-7) | |
| logits = torch.log((ex_prob /(1-ex_prob))) | |
| if hard: | |
| # Very low temperature o((⊙﹏⊙))o 🥶 | |
| logits *= 1000 | |
| if keep_bg: | |
| return F.softmax(logits, dim=0) | |
| else: | |
| return F.softmax(logits, dim=0)[1:] | |
| def aggregate_wbg(prob, keep_bg=False, hard=False): | |
| k, h, w = prob.shape | |
| new_prob = torch.cat([ | |
| torch.prod(1-prob, dim=0, keepdim=True), | |
| prob | |
| ], 0).clamp(1e-7, 1-1e-7) | |
| logits = torch.log((new_prob /(1-new_prob))) | |
| if hard: | |
| # Very low temperature o((⊙﹏⊙))o 🥶 | |
| logits *= 1000 | |
| if keep_bg: | |
| return F.softmax(logits, dim=0) | |
| else: | |
| return F.softmax(logits, dim=0)[1:] | |
| class Interaction: | |
| def __init__(self, image, prev_mask, true_size, controller): | |
| self.image = image | |
| self.prev_mask = prev_mask | |
| self.controller = controller | |
| self.start_time = time.time() | |
| self.h, self.w = true_size | |
| self.out_prob = None | |
| self.out_mask = None | |
| def predict(self): | |
| pass | |
| class FreeInteraction(Interaction): | |
| def __init__(self, image, prev_mask, true_size, num_objects): | |
| """ | |
| prev_mask should be index format numpy array | |
| """ | |
| super().__init__(image, prev_mask, true_size, None) | |
| self.K = num_objects | |
| self.drawn_map = self.prev_mask.copy() | |
| self.curr_path = [[] for _ in range(self.K + 1)] | |
| self.size = None | |
| def set_size(self, size): | |
| self.size = size | |
| """ | |
| k - object id | |
| vis - a tuple (visualization map, pass through alpha). None if not needed. | |
| """ | |
| def push_point(self, x, y, k, vis=None): | |
| if vis is not None: | |
| vis_map, vis_alpha = vis | |
| selected = self.curr_path[k] | |
| selected.append((x, y)) | |
| if len(selected) >= 2: | |
| cv2.line(self.drawn_map, | |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), | |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), | |
| k, thickness=self.size) | |
| # Plot visualization | |
| if vis is not None: | |
| # Visualization for drawing | |
| if k == 0: | |
| vis_map = cv2.line(vis_map, | |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), | |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), | |
| color_map[k], thickness=self.size) | |
| else: | |
| vis_map = cv2.line(vis_map, | |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), | |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), | |
| color_map[k], thickness=self.size) | |
| # Visualization on/off boolean filter | |
| vis_alpha = cv2.line(vis_alpha, | |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), | |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), | |
| 0.75, thickness=self.size) | |
| if vis is not None: | |
| return vis_map, vis_alpha | |
| def end_path(self): | |
| # Complete the drawing | |
| self.curr_path = [[] for _ in range(self.K + 1)] | |
| def predict(self): | |
| self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1).cuda() | |
| # self.out_prob = torch.from_numpy(self.drawn_map).float().cuda() | |
| # self.out_prob, _ = pad_divide_by(self.out_prob, 16, self.out_prob.shape[-2:]) | |
| # self.out_prob = aggregate_sbg(self.out_prob, keep_bg=True) | |
| return self.out_prob | |
| class ScribbleInteraction(Interaction): | |
| def __init__(self, image, prev_mask, true_size, controller, num_objects): | |
| """ | |
| prev_mask should be in an indexed form | |
| """ | |
| super().__init__(image, prev_mask, true_size, controller) | |
| self.K = num_objects | |
| self.drawn_map = np.empty((self.h, self.w), dtype=np.uint8) | |
| self.drawn_map.fill(255) | |
| # background + k | |
| self.curr_path = [[] for _ in range(self.K + 1)] | |
| self.size = 3 | |
| """ | |
| k - object id | |
| vis - a tuple (visualization map, pass through alpha). None if not needed. | |
| """ | |
| def push_point(self, x, y, k, vis=None): | |
| if vis is not None: | |
| vis_map, vis_alpha = vis | |
| selected = self.curr_path[k] | |
| selected.append((x, y)) | |
| if len(selected) >= 2: | |
| self.drawn_map = cv2.line(self.drawn_map, | |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), | |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), | |
| k, thickness=self.size) | |
| # Plot visualization | |
| if vis is not None: | |
| # Visualization for drawing | |
| if k == 0: | |
| vis_map = cv2.line(vis_map, | |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), | |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), | |
| color_map[k], thickness=self.size) | |
| else: | |
| vis_map = cv2.line(vis_map, | |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), | |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), | |
| color_map[k], thickness=self.size) | |
| # Visualization on/off boolean filter | |
| vis_alpha = cv2.line(vis_alpha, | |
| (int(round(selected[-2][0])), int(round(selected[-2][1]))), | |
| (int(round(selected[-1][0])), int(round(selected[-1][1]))), | |
| 0.75, thickness=self.size) | |
| # Optional vis return | |
| if vis is not None: | |
| return vis_map, vis_alpha | |
| def end_path(self): | |
| # Complete the drawing | |
| self.curr_path = [[] for _ in range(self.K + 1)] | |
| def predict(self): | |
| self.out_prob = self.controller.interact(self.image.unsqueeze(0), self.prev_mask, self.drawn_map) | |
| self.out_prob = aggregate_wbg(self.out_prob, keep_bg=True, hard=True) | |
| return self.out_prob | |
| class ClickInteraction(Interaction): | |
| def __init__(self, image, prev_mask, true_size, controller, tar_obj): | |
| """ | |
| prev_mask in a prob. form | |
| """ | |
| super().__init__(image, prev_mask, true_size, controller) | |
| self.tar_obj = tar_obj | |
| # negative/positive for each object | |
| self.pos_clicks = [] | |
| self.neg_clicks = [] | |
| self.out_prob = self.prev_mask.clone() | |
| """ | |
| neg - Negative interaction or not | |
| vis - a tuple (visualization map, pass through alpha). None if not needed. | |
| """ | |
| def push_point(self, x, y, neg, vis=None): | |
| # Clicks | |
| if neg: | |
| self.neg_clicks.append((x, y)) | |
| else: | |
| self.pos_clicks.append((x, y)) | |
| # Do the prediction | |
| self.obj_mask = self.controller.interact(self.image.unsqueeze(0), x, y, not neg) | |
| # Plot visualization | |
| if vis is not None: | |
| vis_map, vis_alpha = vis | |
| # Visualization for clicks | |
| if neg: | |
| vis_map = cv2.circle(vis_map, | |
| (int(round(x)), int(round(y))), | |
| 2, color_map[0], thickness=-1) | |
| else: | |
| vis_map = cv2.circle(vis_map, | |
| (int(round(x)), int(round(y))), | |
| 2, color_map[self.tar_obj], thickness=-1) | |
| vis_alpha = cv2.circle(vis_alpha, | |
| (int(round(x)), int(round(y))), | |
| 2, 1, thickness=-1) | |
| # Optional vis return | |
| return vis_map, vis_alpha | |
| def predict(self): | |
| self.out_prob = self.prev_mask.clone() | |
| # a small hack to allow the interacting object to overwrite existing masks | |
| # without remembering all the object probabilities | |
| self.out_prob = torch.clamp(self.out_prob, max=0.9) | |
| self.out_prob[self.tar_obj] = self.obj_mask | |
| self.out_prob = aggregate_wbg(self.out_prob[1:], keep_bg=True, hard=True) | |
| return self.out_prob | |