File size: 3,065 Bytes
4d4dd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import torch
from joblib import Parallel, delayed
from pytlsd import lsd

from ..base_model import BaseModel


class LSD(BaseModel):
    default_conf = {
        "min_length": 15,
        "max_num_lines": None,
        "force_num_lines": False,
        "n_jobs": 4,
    }
    required_data_keys = ["image"]

    def _init(self, conf):
        if self.conf.force_num_lines:
            assert (
                self.conf.max_num_lines is not None
            ), "Missing max_num_lines parameter"

    def detect_lines(self, img):
        # Run LSD
        segs = lsd(img)

        # Filter out keylines that do not meet the minimum length criteria
        lengths = np.linalg.norm(segs[:, 2:4] - segs[:, 0:2], axis=1)
        to_keep = lengths >= self.conf.min_length
        segs, lengths = segs[to_keep], lengths[to_keep]

        # Keep the best lines
        scores = segs[:, -1] * np.sqrt(lengths)
        segs = segs[:, :4].reshape(-1, 2, 2)
        indices = np.argsort(-scores)
        if self.conf.max_num_lines is not None:
            indices = indices[: self.conf.max_num_lines]
            segs = segs[indices]
            scores = scores[indices]

        # Pad if necessary
        n = len(segs)
        valid_mask = np.ones(n, dtype=bool)
        if self.conf.force_num_lines:
            pad = self.conf.max_num_lines - n
            segs = np.concatenate(
                [segs, np.zeros((pad, 2, 2), dtype=np.float32)], axis=0
            )
            scores = np.concatenate([scores, np.zeros(pad, dtype=np.float32)], axis=0)
            valid_mask = np.concatenate([valid_mask, np.zeros(pad, dtype=bool)], axis=0)

        return segs, scores, valid_mask

    def _forward(self, data):
        # Convert to the right data format
        image = data["image"]
        if image.shape[1] == 3:
            # Convert to grayscale
            scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
            image = (image * scale).sum(1, keepdim=True)
        device = image.device
        b_size = len(image)
        image = np.uint8(image.squeeze(1).cpu().numpy() * 255)

        # LSD detection in parallel
        if b_size == 1:
            lines, line_scores, valid_lines = self.detect_lines(image[0])
            lines = [lines]
            line_scores = [line_scores]
            valid_lines = [valid_lines]
        else:
            lines, line_scores, valid_lines = zip(
                *Parallel(n_jobs=self.conf.n_jobs)(
                    delayed(self.detect_lines)(img) for img in image
                )
            )

        # Batch if possible
        if b_size == 1 or self.conf.force_num_lines:
            lines = torch.tensor(lines, dtype=torch.float, device=device)
            line_scores = torch.tensor(line_scores, dtype=torch.float, device=device)
            valid_lines = torch.tensor(valid_lines, dtype=torch.bool, device=device)

        return {"lines": lines, "line_scores": line_scores, "valid_lines": valid_lines}

    def loss(self, pred, data):
        raise NotImplementedError