|
|
|
|
|
|
|
import torch |
|
from typing import Generator, Tuple |
|
|
|
def normalize_to(m1 : torch.Tensor, norm : torch.float32) -> tuple[torch.Tensor, torch.float32, torch.float32]: |
|
m1 = m1.to(torch.float32) |
|
m1_norm = torch.norm(m1) |
|
ratio = (norm / m1_norm).item() |
|
m1 = m1 * ratio |
|
return m1, norm.item(), ratio |
|
|
|
def norm_ratio(m1 : torch.Tensor, m2 : torch.Tensor) -> float: |
|
m1_norm = torch.norm(m1) |
|
m2_norm = torch.norm(m2) |
|
ratio = (m1_norm / m2_norm).item() |
|
print(f"Norms {m1_norm} {m2_norm} {ratio}") |
|
return ratio |
|
|
|
def merge_tensors_fft2(v0: torch.Tensor, v1: torch.Tensor, t: float) -> torch.Tensor: |
|
""" |
|
Merges two tensors using 2D Fourier transform interpolation. |
|
|
|
Parameters: |
|
- v0 (torch.Tensor): The first input tensor. |
|
- v1 (torch.Tensor): The second input tensor. |
|
- t (float): Interpolation parameter (0 <= t <= 1). |
|
|
|
Returns: |
|
- torch.Tensor: The tensor resulting from the interpolated inverse FFT. |
|
""" |
|
v0 = v0.to("cuda:0") |
|
v1 = v1.to("cuda:0") |
|
|
|
|
|
if len(v0.shape) == 1: |
|
fft_v0 = torch.fft.fft(v0) |
|
fft_v1 = torch.fft.fft(v1) |
|
result_fft = torch.zeros_like(fft_v0) |
|
|
|
real_v0 = fft_v0.real |
|
real_v1 = fft_v1.real |
|
abs_real_v0 = real_v0.abs() |
|
abs_real_v1 = real_v1.abs() |
|
|
|
sign_mask = real_v0.sign() == real_v1.sign() |
|
larger_values_mask = abs_real_v0 > abs_real_v1 |
|
|
|
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask] |
|
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) |
|
|
|
imag_v0 = fft_v0.imag |
|
imag_v1 = fft_v1.imag |
|
abs_imag_v0 = imag_v0.abs() |
|
abs_imag_v1 = imag_v1.abs() |
|
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 |
|
|
|
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask] |
|
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) |
|
|
|
merged_tensor = torch.fft.ifft(result_fft).real |
|
del v0, v1, fft_v0, fft_v1, result_fft |
|
return merged_tensor |
|
|
|
|
|
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1)) |
|
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1)) |
|
|
|
|
|
result_fft = torch.zeros_like(fft_v0) |
|
|
|
|
|
real_v0 = fft_v0.real |
|
real_v1 = fft_v1.real |
|
abs_real_v0 = real_v0.abs() |
|
abs_real_v1 = real_v1.abs() |
|
|
|
|
|
sign_mask = real_v0.sign() == real_v1.sign() |
|
larger_values_mask = abs_real_v0 > abs_real_v1 |
|
|
|
|
|
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask] |
|
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) |
|
|
|
del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask |
|
|
|
|
|
imag_v0 = fft_v0.imag |
|
imag_v1 = fft_v1.imag |
|
abs_imag_v0 = imag_v0.abs() |
|
abs_imag_v1 = imag_v1.abs() |
|
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 |
|
|
|
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask] |
|
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) |
|
|
|
del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask |
|
|
|
|
|
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real |
|
|
|
del fft_v0, fft_v1, result_fft |
|
|
|
return merged_tensor |
|
|
|
def correlate_pairs(tensors : torch.Tensor, work_device : str = "cuda:0", store_device : str = "cpu") -> torch.Tensor: |
|
n = tensors.shape[0] |
|
matrix = torch.zeros(n, n).to(store_device) |
|
for i in range(n): |
|
a = tensors[i].to(work_device) |
|
for j in range(i + 1, n): |
|
b = tensors[j].to(work_device) |
|
matrix[i, j] = matrix[j, i] = torch.nn.functional.cosine_similarity(a, b, dim=0).nan_to_num(0).mean().item() |
|
b.to(store_device) |
|
a.to(store_device) |
|
return matrix |
|
|
|
def least_correlated_pairs(correlation_tensor: torch.Tensor) -> Generator[Tuple[int, int, float], None, None]: |
|
""" |
|
Generates tuples of indices and their corresponding least correlation coefficient |
|
from a given correlation matrix, ensuring that once an index is used, it is no longer |
|
considered in future tuples. |
|
|
|
Args: |
|
correlation_tensor (torch.Tensor): A 2D square tensor representing the correlation matrix. |
|
|
|
Yields: |
|
Tuple[int, int, float]: A tuple containing the x-index, y-index, and the correlation coefficient |
|
of the least correlated pairs in the matrix. |
|
""" |
|
n = correlation_tensor.size(0) |
|
|
|
mask = torch.triu(torch.ones(n, n, dtype=torch.bool), diagonal=1) |
|
|
|
while torch.any(mask): |
|
|
|
valid_correlation = torch.where(mask, correlation_tensor, torch.tensor(float('inf'))) |
|
|
|
|
|
min_val = torch.min(torch.abs(valid_correlation[valid_correlation != float('inf')])) |
|
|
|
|
|
min_indices = torch.nonzero(torch.abs(valid_correlation) == min_val, as_tuple=True) |
|
if len(min_indices[0]) == 0: |
|
break |
|
|
|
|
|
x, y = min_indices[0][0].item(), min_indices[1][0].item() |
|
coefficient = correlation_tensor[x, y].item() |
|
yield (x, y, coefficient) |
|
|
|
|
|
mask[x, :] = False |
|
mask[:, x] = False |
|
mask[y, :] = False |
|
mask[:, y] = False |
|
|
|
|
|
def merge_tensors_fft2_autoscale(v0: torch.Tensor, v1: torch.Tensor, t: float) -> tuple[torch.Tensor, float, float]: |
|
""" |
|
Merges two tensors using 2D Fourier transform interpolation. |
|
|
|
Parameters: |
|
- v0 (torch.Tensor): The first input tensor. |
|
- v1 (torch.Tensor): The second input tensor. |
|
- t (float): Interpolation parameter (0 <= t <= 1). |
|
|
|
Returns: |
|
- torch.Tensor: The tensor resulting from the interpolated inverse FFT. |
|
""" |
|
v0 = v0.to("cuda:0") |
|
v1 = v1.to("cuda:0") |
|
|
|
|
|
norm_v0_t = v0.norm() |
|
norm_v1_t = v1.norm() |
|
|
|
|
|
v0 = v0 / norm_v0_t if norm_v0_t != 0 else v0 |
|
v1 = v1 / norm_v1_t if norm_v1_t != 0 else v1 |
|
|
|
norm_v0 = norm_v0_t.item() |
|
norm_v1 = norm_v1_t.item() |
|
del norm_v0_t, norm_v1_t |
|
|
|
|
|
if len(v0.shape) == 1: |
|
fft_v0 = torch.fft.fft(v0) |
|
fft_v1 = torch.fft.fft(v1) |
|
result_fft = torch.zeros_like(fft_v0) |
|
|
|
real_v0 = fft_v0.real |
|
real_v1 = fft_v1.real |
|
abs_real_v0 = real_v0.abs() |
|
abs_real_v1 = real_v1.abs() |
|
|
|
sign_mask = real_v0.sign() == real_v1.sign() |
|
larger_values_mask = abs_real_v0 > abs_real_v1 |
|
|
|
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask] |
|
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) |
|
|
|
imag_v0 = fft_v0.imag |
|
imag_v1 = fft_v1.imag |
|
abs_imag_v0 = imag_v0.abs() |
|
abs_imag_v1 = imag_v1.abs() |
|
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 |
|
|
|
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask] |
|
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) |
|
|
|
merged_tensor = torch.fft.ifft(result_fft).real |
|
del v0, v1, fft_v0, fft_v1, result_fft |
|
return merged_tensor, norm_v0, norm_v1 |
|
|
|
|
|
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1)) |
|
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1)) |
|
|
|
|
|
result_fft = torch.zeros_like(fft_v0) |
|
|
|
|
|
real_v0 = fft_v0.real |
|
real_v1 = fft_v1.real |
|
abs_real_v0 = real_v0.abs() |
|
abs_real_v1 = real_v1.abs() |
|
|
|
|
|
sign_mask = real_v0.sign() == real_v1.sign() |
|
larger_values_mask = abs_real_v0 > abs_real_v1 |
|
|
|
|
|
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask] |
|
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) |
|
|
|
del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask |
|
|
|
|
|
imag_v0 = fft_v0.imag |
|
imag_v1 = fft_v1.imag |
|
abs_imag_v0 = imag_v0.abs() |
|
abs_imag_v1 = imag_v1.abs() |
|
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 |
|
|
|
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask] |
|
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) |
|
|
|
del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask |
|
|
|
|
|
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real |
|
|
|
del fft_v0, fft_v1, result_fft |
|
|
|
return merged_tensor, norm_v0, norm_v1 |
|
|