Spaces:
Running
Running
import numpy as np | |
import torch | |
from joblib import Parallel, delayed | |
from pytlsd import lsd | |
from ..base_model import BaseModel | |
class LSD(BaseModel): | |
default_conf = { | |
"min_length": 15, | |
"max_num_lines": None, | |
"force_num_lines": False, | |
"n_jobs": 4, | |
} | |
required_data_keys = ["image"] | |
def _init(self, conf): | |
if self.conf.force_num_lines: | |
assert ( | |
self.conf.max_num_lines is not None | |
), "Missing max_num_lines parameter" | |
def detect_lines(self, img): | |
# Run LSD | |
segs = lsd(img) | |
# Filter out keylines that do not meet the minimum length criteria | |
lengths = np.linalg.norm(segs[:, 2:4] - segs[:, 0:2], axis=1) | |
to_keep = lengths >= self.conf.min_length | |
segs, lengths = segs[to_keep], lengths[to_keep] | |
# Keep the best lines | |
scores = segs[:, -1] * np.sqrt(lengths) | |
segs = segs[:, :4].reshape(-1, 2, 2) | |
indices = np.argsort(-scores) | |
if self.conf.max_num_lines is not None: | |
indices = indices[: self.conf.max_num_lines] | |
segs = segs[indices] | |
scores = scores[indices] | |
# Pad if necessary | |
n = len(segs) | |
valid_mask = np.ones(n, dtype=bool) | |
if self.conf.force_num_lines: | |
pad = self.conf.max_num_lines - n | |
segs = np.concatenate( | |
[segs, np.zeros((pad, 2, 2), dtype=np.float32)], axis=0 | |
) | |
scores = np.concatenate([scores, np.zeros(pad, dtype=np.float32)], axis=0) | |
valid_mask = np.concatenate([valid_mask, np.zeros(pad, dtype=bool)], axis=0) | |
return segs, scores, valid_mask | |
def _forward(self, data): | |
# Convert to the right data format | |
image = data["image"] | |
if image.shape[1] == 3: | |
# Convert to grayscale | |
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) | |
image = (image * scale).sum(1, keepdim=True) | |
device = image.device | |
b_size = len(image) | |
image = np.uint8(image.squeeze(1).cpu().numpy() * 255) | |
# LSD detection in parallel | |
if b_size == 1: | |
lines, line_scores, valid_lines = self.detect_lines(image[0]) | |
lines = [lines] | |
line_scores = [line_scores] | |
valid_lines = [valid_lines] | |
else: | |
lines, line_scores, valid_lines = zip( | |
*Parallel(n_jobs=self.conf.n_jobs)( | |
delayed(self.detect_lines)(img) for img in image | |
) | |
) | |
# Batch if possible | |
if b_size == 1 or self.conf.force_num_lines: | |
lines = torch.tensor(lines, dtype=torch.float, device=device) | |
line_scores = torch.tensor(line_scores, dtype=torch.float, device=device) | |
valid_lines = torch.tensor(valid_lines, dtype=torch.bool, device=device) | |
return {"lines": lines, "line_scores": line_scores, "valid_lines": valid_lines} | |
def loss(self, pred, data): | |
raise NotImplementedError | |