maldv's picture
Upload folder using huggingface_hub
b59223f verified
raw
history blame
No virus
10.3 kB
# ztrain/tensors.py
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
import torch
from typing import Generator, Tuple
def normalize_to(m1 : torch.Tensor, norm : torch.float32) -> tuple[torch.Tensor, torch.float32, torch.float32]:
m1 = m1.to(torch.float32)
m1_norm = torch.norm(m1)
ratio = (norm / m1_norm).item()
m1 = m1 * ratio
return m1, norm.item(), ratio
def norm_ratio(m1 : torch.Tensor, m2 : torch.Tensor) -> float:
m1_norm = torch.norm(m1)
m2_norm = torch.norm(m2)
ratio = (m1_norm / m2_norm).item()
print(f"Norms {m1_norm} {m2_norm} {ratio}")
return ratio
def merge_tensors_fft2(v0: torch.Tensor, v1: torch.Tensor, t: float) -> torch.Tensor:
"""
Merges two tensors using 2D Fourier transform interpolation.
Parameters:
- v0 (torch.Tensor): The first input tensor.
- v1 (torch.Tensor): The second input tensor.
- t (float): Interpolation parameter (0 <= t <= 1).
Returns:
- torch.Tensor: The tensor resulting from the interpolated inverse FFT.
"""
v0 = v0.to("cuda:0")
v1 = v1.to("cuda:0")
# Ensure the input tensors are on the same device and dtype
if len(v0.shape) == 1:
fft_v0 = torch.fft.fft(v0)
fft_v1 = torch.fft.fft(v1)
result_fft = torch.zeros_like(fft_v0)
real_v0 = fft_v0.real
real_v1 = fft_v1.real
abs_real_v0 = real_v0.abs()
abs_real_v1 = real_v1.abs()
sign_mask = real_v0.sign() == real_v1.sign()
larger_values_mask = abs_real_v0 > abs_real_v1
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
imag_v0 = fft_v0.imag
imag_v1 = fft_v1.imag
abs_imag_v0 = imag_v0.abs()
abs_imag_v1 = imag_v1.abs()
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part
del v0, v1, fft_v0, fft_v1, result_fft
return merged_tensor
# Perform the 2D FFT on both tensors
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
# Initialize the result FFT tensor
result_fft = torch.zeros_like(fft_v0)
# Compare real parts of the coefficients
real_v0 = fft_v0.real
real_v1 = fft_v1.real
abs_real_v0 = real_v0.abs()
abs_real_v1 = real_v1.abs()
# Create masks for where signs match and where they do not
sign_mask = real_v0.sign() == real_v1.sign()
larger_values_mask = abs_real_v0 > abs_real_v1
# Where signs match, interpolate; where signs do not match, take the larger by magnitude
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask
# Assuming the imaginary part should be treated similarly, adjust this if not
imag_v0 = fft_v0.imag
imag_v1 = fft_v1.imag
abs_imag_v0 = imag_v0.abs()
abs_imag_v1 = imag_v1.abs()
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask
# Perform the inverse FFT to go back to the spatial domain
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part
del fft_v0, fft_v1, result_fft
return merged_tensor
def correlate_pairs(tensors : torch.Tensor, work_device : str = "cuda:0", store_device : str = "cpu") -> torch.Tensor:
n = tensors.shape[0]
matrix = torch.zeros(n, n).to(store_device)
for i in range(n):
a = tensors[i].to(work_device)
for j in range(i + 1, n):
b = tensors[j].to(work_device)
matrix[i, j] = matrix[j, i] = torch.nn.functional.cosine_similarity(a, b, dim=0).nan_to_num(0).mean().item()
b.to(store_device)
a.to(store_device)
return matrix
def least_correlated_pairs(correlation_tensor: torch.Tensor) -> Generator[Tuple[int, int, float], None, None]:
"""
Generates tuples of indices and their corresponding least correlation coefficient
from a given correlation matrix, ensuring that once an index is used, it is no longer
considered in future tuples.
Args:
correlation_tensor (torch.Tensor): A 2D square tensor representing the correlation matrix.
Yields:
Tuple[int, int, float]: A tuple containing the x-index, y-index, and the correlation coefficient
of the least correlated pairs in the matrix.
"""
n = correlation_tensor.size(0)
# Create a mask to exclude diagonal and already processed elements
mask = torch.triu(torch.ones(n, n, dtype=torch.bool), diagonal=1)
while torch.any(mask):
# Apply mask to get relevant correlations
valid_correlation = torch.where(mask, correlation_tensor, torch.tensor(float('inf')))
# Find the minimum non-zero absolute correlation
min_val = torch.min(torch.abs(valid_correlation[valid_correlation != float('inf')]))
# Locate the indices with the minimum correlation
min_indices = torch.nonzero(torch.abs(valid_correlation) == min_val, as_tuple=True)
if len(min_indices[0]) == 0:
break
# Yield the first index pair (greedy approach) along with the correlation coefficient
x, y = min_indices[0][0].item(), min_indices[1][0].item()
coefficient = correlation_tensor[x, y].item() # Extract the actual correlation value
yield (x, y, coefficient)
# Mask out the entire row and column for both indices
mask[x, :] = False
mask[:, x] = False
mask[y, :] = False
mask[:, y] = False
def merge_tensors_fft2_autoscale(v0: torch.Tensor, v1: torch.Tensor, t: float) -> tuple[torch.Tensor, float, float]:
"""
Merges two tensors using 2D Fourier transform interpolation.
Parameters:
- v0 (torch.Tensor): The first input tensor.
- v1 (torch.Tensor): The second input tensor.
- t (float): Interpolation parameter (0 <= t <= 1).
Returns:
- torch.Tensor: The tensor resulting from the interpolated inverse FFT.
"""
v0 = v0.to("cuda:0")
v1 = v1.to("cuda:0")
# Calculate norms of each tensor
norm_v0_t = v0.norm()
norm_v1_t = v1.norm()
# Scale tensors by their norms
v0 = v0 / norm_v0_t if norm_v0_t != 0 else v0
v1 = v1 / norm_v1_t if norm_v1_t != 0 else v1
norm_v0 = norm_v0_t.item()
norm_v1 = norm_v1_t.item()
del norm_v0_t, norm_v1_t
# Ensure the input tensors are on the same device and dtype
if len(v0.shape) == 1:
fft_v0 = torch.fft.fft(v0)
fft_v1 = torch.fft.fft(v1)
result_fft = torch.zeros_like(fft_v0)
real_v0 = fft_v0.real
real_v1 = fft_v1.real
abs_real_v0 = real_v0.abs()
abs_real_v1 = real_v1.abs()
sign_mask = real_v0.sign() == real_v1.sign()
larger_values_mask = abs_real_v0 > abs_real_v1
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
imag_v0 = fft_v0.imag
imag_v1 = fft_v1.imag
abs_imag_v0 = imag_v0.abs()
abs_imag_v1 = imag_v1.abs()
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part
del v0, v1, fft_v0, fft_v1, result_fft
return merged_tensor, norm_v0, norm_v1
# Perform the 2D FFT on both tensors
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
# Initialize the result FFT tensor
result_fft = torch.zeros_like(fft_v0)
# Compare real parts of the coefficients
real_v0 = fft_v0.real
real_v1 = fft_v1.real
abs_real_v0 = real_v0.abs()
abs_real_v1 = real_v1.abs()
# Create masks for where signs match and where they do not
sign_mask = real_v0.sign() == real_v1.sign()
larger_values_mask = abs_real_v0 > abs_real_v1
# Where signs match, interpolate; where signs do not match, take the larger by magnitude
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask
# Assuming the imaginary part should be treated similarly, adjust this if not
imag_v0 = fft_v0.imag
imag_v1 = fft_v1.imag
abs_imag_v0 = imag_v0.abs()
abs_imag_v1 = imag_v1.abs()
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask
# Perform the inverse FFT to go back to the spatial domain
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part
del fft_v0, fft_v1, result_fft
return merged_tensor, norm_v0, norm_v1