File size: 6,454 Bytes
24e35df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch

def merge_tensors_fft2(v0: torch.Tensor, v1: torch.Tensor, t: float) -> torch.Tensor:
    """
    Merges two tensors using 2D Fourier transform interpolation.
    
    Parameters:
    - v0 (torch.Tensor): The first input tensor.
    - v1 (torch.Tensor): The second input tensor.
    - t (float): Interpolation parameter (0 <= t <= 1).
    
    Returns:
    - torch.Tensor: The tensor resulting from the interpolated inverse FFT.
    """
    # Ensure the input tensors are on the same device and dtype
    v0 = v0.to("cuda:0")
    v1 = v1.to("cuda:0")

    if len(v0.shape) == 1:
        fft_v0 = torch.fft.fft(v0)
        fft_v1 = torch.fft.fft(v1)
        result_fft = torch.zeros_like(fft_v0)
        
        real_v0 = fft_v0.real
        real_v1 = fft_v1.real
        abs_real_v0 = real_v0.abs() 
        abs_real_v1 = real_v1.abs() 

        sign_mask = real_v0.sign() == real_v1.sign()
        larger_values_mask = abs_real_v0 > abs_real_v1

        result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask])
        result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])

        imag_v0 = fft_v0.imag
        imag_v1 = fft_v1.imag
        abs_imag_v0 = imag_v0.abs()
        abs_imag_v1 = imag_v1.abs()
        larger_values_mask_imag = abs_imag_v0 > abs_imag_v1

        result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
        result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])

        merged_tensor = torch.fft.ifft(result_fft).real  # Taking the real part
        del v0, v1, fft_v0, fft_v1, result_fft
        return merged_tensor

    # Perform the 2D FFT on both tensors
    fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
    fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
    
    # Initialize the result FFT tensor
    result_fft = torch.zeros_like(fft_v0)
    
    # Compare real parts of the coefficients
    real_v0 = fft_v0.real
    real_v1 = fft_v1.real
    abs_real_v0 = real_v0.abs()
    abs_real_v1 = real_v1.abs()
    
    # Create masks for where signs match and where they do not
    sign_mask = real_v0.sign() == real_v1.sign()
    larger_values_mask = abs_real_v0 > abs_real_v1
    
    # Where signs match, interpolate; where signs do not match, take the larger by magnitude
    result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask])
    result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
    
    # Assuming the imaginary part should be treated similarly, adjust this if not
    imag_v0 = fft_v0.imag
    imag_v1 = fft_v1.imag
    abs_imag_v0 = imag_v0.abs()
    abs_imag_v1 = imag_v1.abs()
    larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
    
    result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
    result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
    
    # Perform the inverse FFT to go back to the spatial domain
    merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real  # Taking the real part
    
    return merged_tensor

def merge_tensors_fft_shell(v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor:
    """
    Merges two tensors using 2D Fourier transform interpolation.
    
    Parameters:
    - v0 (torch.Tensor): The first input tensor.
    - v1 (torch.Tensor): The second input tensor.
    
    Returns:
    - torch.Tensor: The tensor resulting from the maximal interpolated inverse FFT.
    """
    # Ensure the input tensors are on the same device and dtype
    v0 = v0.to("cuda:0")
    v1 = v1.to("cuda:0")

    if len(v0.shape) == 1:
        fft_v0 = torch.fft.fft(v0)
        fft_v1 = torch.fft.fft(v1)
        result_fft = torch.zeros_like(fft_v0)
        
        real_v0 = fft_v0.real
        real_v1 = fft_v1.real
        abs_real_v0 = real_v0.abs() 
        abs_real_v1 = real_v1.abs() 

        sign_mask = real_v0.sign() == real_v1.sign()
        larger_values_mask = abs_real_v0 > abs_real_v1

        result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask])
        result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])

        imag_v0 = fft_v0.imag
        imag_v1 = fft_v1.imag
        abs_imag_v0 = imag_v0.abs()
        abs_imag_v1 = imag_v1.abs()
        larger_values_mask_imag = abs_imag_v0 > abs_imag_v1

        result_fft.imag[sign_mask] = torch.max(imag_v0[sign_mask], imag_v1[sign_mask])
        result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])

        merged_tensor = torch.fft.ifft(result_fft).real  # Taking the real part
        del v0, v1, fft_v0, fft_v1, result_fft
        return merged_tensor

    # Perform the 2D FFT on both tensors
    fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
    fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
    
    # Initialize the result FFT tensor
    result_fft = torch.zeros_like(fft_v0)
    
    # Compare real parts of the coefficients
    real_v0 = fft_v0.real
    real_v1 = fft_v1.real
    abs_real_v0 = real_v0.abs()
    abs_real_v1 = real_v1.abs()
    
    # Create masks for where signs match and where they do not
    sign_mask = real_v0.sign() == real_v1.sign()
    larger_values_mask = abs_real_v0 > abs_real_v1
    
    # Where signs match, interpolate; where signs do not match, take the larger by magnitude
    result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask])
    result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
    
    # Assuming the imaginary part should be treated similarly, adjust this if not
    imag_v0 = fft_v0.imag
    imag_v1 = fft_v1.imag
    abs_imag_v0 = imag_v0.abs()
    abs_imag_v1 = imag_v1.abs()
    larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
    
    result_fft.imag[sign_mask] = torch.max(imag_v0[sign_mask], imag_v1[sign_mask])
    result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
    
    # Perform the inverse FFT to go back to the spatial domain
    merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real  # Taking the real part
    
    return merged_tensor