franchesoni's picture
v0
e1b51e5
raw
history blame contribute delete
No virus
4.72 kB
import torch
from torch import nn
import numpy as np
from cv2 import resize
import cv2
from pathlib import Path
from network import EfficientViT_l1_r224
from losses import IISLoss, activate
from utils import minmaxnorm, load_from_ckpt
class Busam:
def __init__(self, checkpoint, device, side=224):
out_channels = 16
use_norm_params = False
net = EfficientViT_l1_r224(
out_channels=out_channels, use_norm_params=use_norm_params, pretrained=False
)
net = load_from_ckpt(net, checkpoint)
net = net.to(device)
net.eval()
self.net = net
self.device = device
self.side = side
def prepare_img(self, img):
"""
assume H, W, 3 image
"""
assert len(img.shape) == 3, "should be H, W, 3 but is " + str(img.shape)
assert img.shape[2] == 3, "should be H, W, 3 but is " + str(img.shape)
assert img.min() >= 0, "min should be more than 0 but is " + str(img.min())
assert img.max() <= 255, "max should be less than 255 but is " + str(img.max())
assert img.dtype == np.uint8, "dtype should be np.uint8 but is " + str(
img.dtype
)
nimg = resize(img, (self.side, self.side))
tensorimg = (
(torch.from_numpy(nimg / 255).permute(2, 0, 1) - 0.5)
.float()[None]
.to(self.device)
)
return tensorimg
def process_image(self, img, do_activate=False):
with torch.no_grad():
x = self.prepare_img(img)
pred = self.net(x)
H, W = img.shape[:2]
if do_activate:
B, F, pH, pW = pred.shape
features, _, _, _ = activate(
pred.view(F, pH * pW), None, "symlog", False, False, False
)
pred = features.view(B, F, pH, pW)
return pred, (H, W)
def get_mask(self, aux, click):
"""assume click is (row, col)"""
pred = aux[0][0] # remove batch dim
oH, oW = aux[1]
F, H, W = pred.shape
features = pred.view(F, H * W)
rclick = click[0] * H // oH, click[1] * W // oW
sindex = rclick[0] * W + rclick[1]
mask = IISLoss.get_mask_from_query(features, sindex)
mask = mask.reshape(H, W)
mask = (
resize((mask.cpu().numpy() * 255).astype(np.uint8), (oW, oH)) > 100
).astype(bool)
return mask
def get_gradients(self, pred, size):
F, H, W = pred[0].shape
sobel_x = (
torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().to(pred.device)
)
sobel_y = sobel_x.T
sobel_x = sobel_x.repeat(F, 1, 1, 1)
sobel_y = sobel_y.repeat(F, 1, 1, 1)
edge_x = torch.nn.functional.conv2d(pred, sobel_x, padding=1, groups=F).view(
F, H, W
) # 1, F, H, W
edge_y = torch.nn.functional.conv2d(pred, sobel_y, padding=1, groups=F).view(
F, H, W
)
edge_x = torch.norm(edge_x, dim=0, p=2) # will take sqrt
edge_y = torch.norm(edge_y, dim=0, p=2) # H, W
return edge_x, edge_y
def sobel_from_pred(self, pred, size):
edge_x, edge_y = self.get_gradients(pred, size)
edge = torch.sqrt(edge_x**2 + edge_y**2)
return edge
def canny_from_pred(self, pred, size, th_low=10000, th_high=20000):
th_low = th_low or th_high
th_high = th_high or th_low
edge_x, edge_y = self.get_gradients(pred, size)
amin = min(edge_x.min(), edge_y.min())
amax = max(edge_x.max(), edge_y.max())
edge_x, edge_y = (edge_x - amin) / (amax - amin), (edge_y - amin) / (
amax - amin
)
canny = cv2.Canny(cast_to_int16(edge_x), cast_to_int16(edge_y), th_low, th_high)
return canny
def cast_to_int16(x):
if isinstance(x, torch.Tensor):
x = x.cpu().numpy()
return (x * 32767).astype(np.int16)
# from segment_anything import sam_model_registry, SamPredictor
# class SAM:
# sam_checkpoint = "sam_vit_b_01ec64.pth"
# model_type = "vit_b"
# def __init__(self, device):
# sam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint)
# sam.to(device=device)
# self.predictor = SamPredictor(sam)
# def process_image(self, img):
# self.predictor.set_image(img)
# return None
# def get_mask(self, aux, click):
# input_point = np.array([[click[1], click[0]]])
# input_label = np.array([1])
# masks, scores, logits = self.predictor.predict(
# point_coords=input_point, point_labels=input_label, multimask_output=False
# )
# return masks[0]