File size: 10,287 Bytes
b59223f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
# ztrain/tensors.py
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted

import torch
from typing import Generator, Tuple

def normalize_to(m1 : torch.Tensor, norm : torch.float32) -> tuple[torch.Tensor, torch.float32, torch.float32]:
    m1 = m1.to(torch.float32)
    m1_norm = torch.norm(m1)
    ratio = (norm / m1_norm).item()
    m1 = m1 * ratio
    return m1, norm.item(), ratio

def norm_ratio(m1 : torch.Tensor, m2 : torch.Tensor) -> float:
    m1_norm = torch.norm(m1)
    m2_norm = torch.norm(m2)
    ratio = (m1_norm / m2_norm).item()
    print(f"Norms {m1_norm} {m2_norm} {ratio}")
    return ratio

def merge_tensors_fft2(v0: torch.Tensor, v1: torch.Tensor, t: float) -> torch.Tensor:
    """
    Merges two tensors using 2D Fourier transform interpolation.
    
    Parameters:
    - v0 (torch.Tensor): The first input tensor.
    - v1 (torch.Tensor): The second input tensor.
    - t (float): Interpolation parameter (0 <= t <= 1).
    
    Returns:
    - torch.Tensor: The tensor resulting from the interpolated inverse FFT.
    """
    v0 = v0.to("cuda:0")
    v1 = v1.to("cuda:0")

    # Ensure the input tensors are on the same device and dtype
    if len(v0.shape) == 1:
        fft_v0 = torch.fft.fft(v0)
        fft_v1 = torch.fft.fft(v1)
        result_fft = torch.zeros_like(fft_v0)
        
        real_v0 = fft_v0.real
        real_v1 = fft_v1.real
        abs_real_v0 = real_v0.abs() 
        abs_real_v1 = real_v1.abs() 

        sign_mask = real_v0.sign() == real_v1.sign()
        larger_values_mask = abs_real_v0 > abs_real_v1

        result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
        result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])

        imag_v0 = fft_v0.imag
        imag_v1 = fft_v1.imag
        abs_imag_v0 = imag_v0.abs()
        abs_imag_v1 = imag_v1.abs()
        larger_values_mask_imag = abs_imag_v0 > abs_imag_v1

        result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
        result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])

        merged_tensor = torch.fft.ifft(result_fft).real  # Taking the real part
        del v0, v1, fft_v0, fft_v1, result_fft
        return merged_tensor

    # Perform the 2D FFT on both tensors
    fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
    fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
    
    # Initialize the result FFT tensor
    result_fft = torch.zeros_like(fft_v0)
    
    # Compare real parts of the coefficients
    real_v0 = fft_v0.real
    real_v1 = fft_v1.real
    abs_real_v0 = real_v0.abs()
    abs_real_v1 = real_v1.abs()
    
    # Create masks for where signs match and where they do not
    sign_mask = real_v0.sign() == real_v1.sign()
    larger_values_mask = abs_real_v0 > abs_real_v1
    
    # Where signs match, interpolate; where signs do not match, take the larger by magnitude
    result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
    result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])

    del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask
    
    # Assuming the imaginary part should be treated similarly, adjust this if not
    imag_v0 = fft_v0.imag
    imag_v1 = fft_v1.imag
    abs_imag_v0 = imag_v0.abs()
    abs_imag_v1 = imag_v1.abs()
    larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
    
    result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
    result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])

    del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask
    
    # Perform the inverse FFT to go back to the spatial domain
    merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real  # Taking the real part

    del fft_v0, fft_v1, result_fft

    return merged_tensor

def correlate_pairs(tensors : torch.Tensor, work_device : str = "cuda:0", store_device : str = "cpu") -> torch.Tensor:
    n = tensors.shape[0]
    matrix = torch.zeros(n, n).to(store_device)
    for i in range(n):
        a = tensors[i].to(work_device)
        for j in range(i + 1, n):
            b = tensors[j].to(work_device)
            matrix[i, j] = matrix[j, i] = torch.nn.functional.cosine_similarity(a, b, dim=0).nan_to_num(0).mean().item()
            b.to(store_device)
        a.to(store_device)
    return matrix

def least_correlated_pairs(correlation_tensor: torch.Tensor) -> Generator[Tuple[int, int, float], None, None]:
    """
    Generates tuples of indices and their corresponding least correlation coefficient
    from a given correlation matrix, ensuring that once an index is used, it is no longer
    considered in future tuples.

    Args:
        correlation_tensor (torch.Tensor): A 2D square tensor representing the correlation matrix.

    Yields:
        Tuple[int, int, float]: A tuple containing the x-index, y-index, and the correlation coefficient
                                of the least correlated pairs in the matrix.
    """
    n = correlation_tensor.size(0)
    # Create a mask to exclude diagonal and already processed elements
    mask = torch.triu(torch.ones(n, n, dtype=torch.bool), diagonal=1)
    
    while torch.any(mask):
        # Apply mask to get relevant correlations
        valid_correlation = torch.where(mask, correlation_tensor, torch.tensor(float('inf')))
        
        # Find the minimum non-zero absolute correlation
        min_val = torch.min(torch.abs(valid_correlation[valid_correlation != float('inf')]))
        
        # Locate the indices with the minimum correlation
        min_indices = torch.nonzero(torch.abs(valid_correlation) == min_val, as_tuple=True)
        if len(min_indices[0]) == 0:
            break
        
        # Yield the first index pair (greedy approach) along with the correlation coefficient
        x, y = min_indices[0][0].item(), min_indices[1][0].item()
        coefficient = correlation_tensor[x, y].item()  # Extract the actual correlation value
        yield (x, y, coefficient)
        
        # Mask out the entire row and column for both indices
        mask[x, :] = False
        mask[:, x] = False
        mask[y, :] = False
        mask[:, y] = False


def merge_tensors_fft2_autoscale(v0: torch.Tensor, v1: torch.Tensor, t: float) -> tuple[torch.Tensor, float, float]:
    """
    Merges two tensors using 2D Fourier transform interpolation.
    
    Parameters:
    - v0 (torch.Tensor): The first input tensor.
    - v1 (torch.Tensor): The second input tensor.
    - t (float): Interpolation parameter (0 <= t <= 1).
    
    Returns:
    - torch.Tensor: The tensor resulting from the interpolated inverse FFT.
    """
    v0 = v0.to("cuda:0")
    v1 = v1.to("cuda:0")

    # Calculate norms of each tensor
    norm_v0_t = v0.norm()
    norm_v1_t = v1.norm()

    # Scale tensors by their norms
    v0 = v0 / norm_v0_t if norm_v0_t != 0 else v0
    v1 = v1 / norm_v1_t if norm_v1_t != 0 else v1
    
    norm_v0 = norm_v0_t.item()
    norm_v1 = norm_v1_t.item()
    del norm_v0_t, norm_v1_t
    
    # Ensure the input tensors are on the same device and dtype
    if len(v0.shape) == 1:
        fft_v0 = torch.fft.fft(v0)
        fft_v1 = torch.fft.fft(v1)
        result_fft = torch.zeros_like(fft_v0)
        
        real_v0 = fft_v0.real
        real_v1 = fft_v1.real
        abs_real_v0 = real_v0.abs() 
        abs_real_v1 = real_v1.abs() 

        sign_mask = real_v0.sign() == real_v1.sign()
        larger_values_mask = abs_real_v0 > abs_real_v1

        result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
        result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])

        imag_v0 = fft_v0.imag
        imag_v1 = fft_v1.imag
        abs_imag_v0 = imag_v0.abs()
        abs_imag_v1 = imag_v1.abs()
        larger_values_mask_imag = abs_imag_v0 > abs_imag_v1

        result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
        result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])

        merged_tensor = torch.fft.ifft(result_fft).real  # Taking the real part
        del v0, v1, fft_v0, fft_v1, result_fft
        return merged_tensor, norm_v0, norm_v1

    # Perform the 2D FFT on both tensors
    fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
    fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
    
    # Initialize the result FFT tensor
    result_fft = torch.zeros_like(fft_v0)
    
    # Compare real parts of the coefficients
    real_v0 = fft_v0.real
    real_v1 = fft_v1.real
    abs_real_v0 = real_v0.abs()
    abs_real_v1 = real_v1.abs()
    
    # Create masks for where signs match and where they do not
    sign_mask = real_v0.sign() == real_v1.sign()
    larger_values_mask = abs_real_v0 > abs_real_v1
    
    # Where signs match, interpolate; where signs do not match, take the larger by magnitude
    result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
    result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])

    del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask
    
    # Assuming the imaginary part should be treated similarly, adjust this if not
    imag_v0 = fft_v0.imag
    imag_v1 = fft_v1.imag
    abs_imag_v0 = imag_v0.abs()
    abs_imag_v1 = imag_v1.abs()
    larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
    
    result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
    result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])

    del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask
    
    # Perform the inverse FFT to go back to the spatial domain
    merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real  # Taking the real part

    del fft_v0, fft_v1, result_fft

    return merged_tensor, norm_v0, norm_v1