Spaces:
Running
on
Zero
Running
on
Zero
from typing import NamedTuple, List, Callable | |
import sys | |
import re | |
import numpy as np | |
import torch | |
from numpy.linalg import norm | |
from itertools import product, groupby | |
from PIL import Image | |
# Do two line segments intersect? Copied from | |
# https://stackoverflow.com/questions/3838329/how-can-i-check-if-two-segments-intersect | |
def ccw(A, B, C): | |
return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x) | |
def intersect(A, B, C, D): | |
"""Do line segments AB and CD intersect?""" | |
return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D) | |
class Box(NamedTuple): | |
x: int | |
y: int | |
w: int = 0 | |
h: int = 0 | |
def left(self): | |
return self.x | |
def right(self): | |
return self.x + self.w | |
def top(self): | |
return self.y | |
def bottom(self): | |
return self.y + self.h | |
def center(self): | |
return Box(self.x + self.w // 2, self.y + self.h // 2) | |
def corners(self): | |
yield Box(self.x, self.y) | |
yield Box(self.x + self.w, self.y) | |
yield Box(self.x + self.w, self.y + self.h) | |
yield Box(self.x, self.y + self.h) | |
def area(self): | |
return self.w * self.h | |
def intersect(self, other: "Box") -> "Box": | |
x1 = max(self.x, other.x) | |
x2 = max(x1, min(self.x+self.w, other.x+other.w)) | |
y1 = max(self.y, other.y) | |
y2 = max(y1, min(self.y+self.h, other.y+other.h)) | |
return Box(x=x1, y=y1, w=x2-x1, h=y2-y1) | |
def min_bounding(self, other: "Box") -> "Box": | |
corners = list(self.corners()) | |
corners.extend(other.corners()) | |
min_x = min_y = float("inf") | |
max_x = max_y = -float("inf") | |
for item in corners: | |
min_x = min(min_x, item.x) | |
min_y = min(min_y, item.y) | |
max_x = max(max_x, item.x) | |
max_y = max(max_y, item.y) | |
return Box(min_x, min_y, max_x - min_x, max_y - min_y) | |
def expand(self, growth: float = .1) -> "Box": | |
factor = 1 + growth | |
w = factor * self.w | |
h = factor * self.h | |
return Box(min_x - (w - self.w) / 2, min_y - (h - self.h) / 2, w, h) | |
def iou(box1, box2): | |
x1 = max(box1.x, box2.x) | |
x2 = max(x1, min(box1.x+box1.w, box2.x+box2.w)) | |
y1 = max(box1.y, box2.y) | |
y2 = max(y1, min(box1.y+box1.h, box2.y+box2.h)) | |
intersection = Box(x=x1, y=y1, w=x2-x1, h=y2-y1) | |
intersection_area = intersection.area | |
union_area = box1.area+box2.area-intersection_area | |
return intersection_area / union_area | |
def all_equal(iterable): | |
"""Are all elements the same?""" | |
g = groupby(iterable) | |
return next(g, True) and not next(g, False) | |
class spatial: | |
"""A decorator that converts a predicate over boxes to a function that returns a tensor over all boxes.""" | |
def __init__(self, arity: int = 2, enforce_antisymmetry: bool = False): | |
self.arity = arity | |
self.enforce_antisymmetry = enforce_antisymmetry # Zero out any entries where two boxes are the same. | |
def __call__(self, predicate: Callable[[Box], float]) -> Callable[["Environment"], np.ndarray]: | |
def _rel(env): | |
n_boxes = len(env.boxes) | |
tensor = np.empty([n_boxes for _ in range(self.arity)]) | |
enum_boxes = list(enumerate(env.boxes)) | |
for pairs in product(*[enum_boxes for _ in range(self.arity)]): | |
indices, boxes = zip(*pairs) | |
if self.enforce_antisymmetry and len(set(indices)) < len(indices): | |
tensor[indices] = 0. | |
else: | |
tensor[indices] = predicate(*boxes) | |
return tensor | |
return _rel | |
class Environment: | |
def __init__(self, image: Image, boxes: List[Box], executor: "Executor" = None, freeform_boxes: bool = False, image_name: str = None, image_pth: str=None): | |
self.image = image | |
self.boxes = boxes | |
self.executor = executor # An object or callback that can query CLIP with captions/images. | |
self.freeform_boxes = freeform_boxes | |
self.image_name = image_name | |
self.image_pth=image_pth | |
def uniform(self) -> np.ndarray: | |
n_boxes = len(self.boxes) | |
return 1 / n_boxes * np.ones(n_boxes) | |
def filter(self, | |
caption: str, | |
temperature: float = 1., | |
area_threshold: float = 0.0, | |
softmax: bool = False, | |
expand: float = None | |
) -> np.ndarray: | |
"""Return a new distribution reflecting the likelihood that `caption` describes the content of each box.""" | |
area_filtered_dist = torch.from_numpy(self.filter_area(area_threshold)).to(self.executor.device) | |
candidate_indices = [i for i in range(len(self.boxes)) if float(area_filtered_dist[i]) > 0.0] | |
boxes = [self.boxes[i] for i in candidate_indices] | |
if len(boxes) == 0: | |
boxes = self.boxes | |
candidate_indices = list(range(len(boxes))) | |
if expand is not None: | |
boxes = [box.expand(expand) for box in boxes] | |
result_partial = self.executor(caption, self.image, boxes, image_name=self.image_name, image_pth=self.image_pth) | |
if self.freeform_boxes: | |
result_partial, boxes = result_partial | |
self.boxes = [Box(x=boxes[i,0].item(), y=boxes[i,1].item(), w=boxes[i,2].item()-boxes[i,0].item(), h=boxes[i,3].item()-boxes[i,1].item()) for i in range(boxes.shape[0])] | |
candidate_indices = list(range(len(self.boxes))) | |
result_partial = result_partial.float() | |
if not softmax: | |
result_partial = (result_partial-result_partial.mean()) / (result_partial.std() + 1e-9) | |
result_partial = (temperature * result_partial).sigmoid() | |
result = torch.zeros((len(self.boxes))).to(result_partial.device) | |
result[candidate_indices] = result_partial | |
else: | |
result = torch.zeros((len(self.boxes))).to(result_partial.device) | |
result[candidate_indices] = result_partial.softmax(dim=-1) #softmax结果 | |
return result.cpu().numpy() | |
def filter_area(self, area_threshold: float) -> np.ndarray: | |
"""Return a new distribution in which all boxes whose area as a fraction of the image is less than the threshold.""" | |
image_area = self.image.width*self.image.height | |
return np.array([1 if self.boxes[i].area/image_area > area_threshold else 0 for i in range(len(self.boxes))]) | |
def left_of(b1, b2): | |
return (b1.right+b1.left) / 2 < (b2.right+b2.left) / 2 | |
def right_of(b1, b2): | |
return (b1.right+b1.left) / 2 > (b2.right+b2.left) / 2 | |
def above(b1, b2): | |
return (b1.bottom+b1.top) < (b2.bottom+b2.top) | |
def below(b1, b2): | |
return (b1.bottom+b1.top) > (b2.bottom+b2.top) | |
def bigger_than(b1, b2): | |
return b1.area > b2.area | |
def smaller_than(b1, b2): | |
return b1.area < b2.area | |
def within(box1, box2): | |
"""Return percent of box1 inside box2.""" | |
intersection = box1.intersect(box2) | |
return intersection.area / box1.area | |
def between(box1, box2, box3): | |
"""How much of box1 lies in min bounding box over box2 and box3?""" | |
min_bounding = box2.min_bounding(box3) | |
intersect = box1.intersect(min_bounding) | |
return intersect.area / box1.area | |