File size: 3,019 Bytes
2571f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from enum import Enum
from typing import Optional
from .jit_utils import floor_div
Tensor = torch.Tensor


class BoundType(Enum):
    zero = zeros = 0
    replicate = nearest = 1
    dct1 = mirror = 2
    dct2 = reflect = 3
    dst1 = antimirror = 4
    dst2 = antireflect = 5
    dft = wrap = 6


class ExtrapolateType(Enum):
    no = 0     # threshold: (0, n-1)
    yes = 1
    hist = 2   # threshold: (-0.5, n-0.5)


@torch.jit.script
class Bound:

    def __init__(self, bound_type: int = 3):
        self.type = bound_type

    def index(self, i, n: int):
        if self.type in (0, 1):  # zero / replicate
            return i.clamp(min=0, max=n-1)
        elif self.type in (3, 5):  # dct2 / dst2
            n2 = n * 2
            i = torch.where(i < 0, (-i-1).remainder(n2).neg().add(n2 - 1),
                            i.remainder(n2))
            i = torch.where(i >= n, -i + (n2 - 1), i)
            return i
        elif self.type == 2:  # dct1
            if n == 1:
                return torch.zeros(i.shape, dtype=i.dtype, device=i.device)
            else:
                n2 = (n - 1) * 2
                i = i.abs().remainder(n2)
                i = torch.where(i >= n, -i + n2, i)
                return i
        elif self.type == 4:  # dst1
            n2 = 2 * (n + 1)
            first = torch.zeros([1], dtype=i.dtype, device=i.device)
            last = torch.full([1], n - 1, dtype=i.dtype, device=i.device)
            i = torch.where(i < 0, -i - 2, i)
            i = i.remainder(n2)
            i = torch.where(i > n, -i + (n2 - 2), i)
            i = torch.where(i == -1, first, i)
            i = torch.where(i == n, last, i)
            return i
        elif self.type == 6:  # dft
            return i.remainder(n)
        else:
            return i

    def transform(self, i, n: int) -> Optional[Tensor]:
        if self.type == 4:  # dst1
            if n == 1:
                return None
            one = torch.ones([1], dtype=torch.int8, device=i.device)
            zero = torch.zeros([1], dtype=torch.int8, device=i.device)
            n2 = 2 * (n + 1)
            i = torch.where(i < 0, -i + (n-1), i)
            i = i.remainder(n2)
            x = torch.where(i == 0, zero, one)
            x = torch.where(i.remainder(n + 1) == n, zero, x)
            i = floor_div(i, n+1)
            x = torch.where(torch.remainder(i, 2) > 0, -x, x)
            return x
        elif self.type == 5:  # dst2
            i = torch.where(i < 0, n - 1 - i, i)
            x = torch.ones([1], dtype=torch.int8, device=i.device)
            i = floor_div(i, n)
            x = torch.where(torch.remainder(i, 2) > 0, -x, x)
            return x
        elif self.type == 0:  # zero
            one = torch.ones([1], dtype=torch.int8, device=i.device)
            zero = torch.zeros([1], dtype=torch.int8, device=i.device)
            outbounds = ((i < 0) | (i >= n))
            x = torch.where(outbounds, zero, one)
            return x
        else:
            return None