File size: 6,454 Bytes
24e35df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import torch
def merge_tensors_fft2(v0: torch.Tensor, v1: torch.Tensor, t: float) -> torch.Tensor:
"""
Merges two tensors using 2D Fourier transform interpolation.
Parameters:
- v0 (torch.Tensor): The first input tensor.
- v1 (torch.Tensor): The second input tensor.
- t (float): Interpolation parameter (0 <= t <= 1).
Returns:
- torch.Tensor: The tensor resulting from the interpolated inverse FFT.
"""
# Ensure the input tensors are on the same device and dtype
v0 = v0.to("cuda:0")
v1 = v1.to("cuda:0")
if len(v0.shape) == 1:
fft_v0 = torch.fft.fft(v0)
fft_v1 = torch.fft.fft(v1)
result_fft = torch.zeros_like(fft_v0)
real_v0 = fft_v0.real
real_v1 = fft_v1.real
abs_real_v0 = real_v0.abs()
abs_real_v1 = real_v1.abs()
sign_mask = real_v0.sign() == real_v1.sign()
larger_values_mask = abs_real_v0 > abs_real_v1
result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask])
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
imag_v0 = fft_v0.imag
imag_v1 = fft_v1.imag
abs_imag_v0 = imag_v0.abs()
abs_imag_v1 = imag_v1.abs()
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part
del v0, v1, fft_v0, fft_v1, result_fft
return merged_tensor
# Perform the 2D FFT on both tensors
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
# Initialize the result FFT tensor
result_fft = torch.zeros_like(fft_v0)
# Compare real parts of the coefficients
real_v0 = fft_v0.real
real_v1 = fft_v1.real
abs_real_v0 = real_v0.abs()
abs_real_v1 = real_v1.abs()
# Create masks for where signs match and where they do not
sign_mask = real_v0.sign() == real_v1.sign()
larger_values_mask = abs_real_v0 > abs_real_v1
# Where signs match, interpolate; where signs do not match, take the larger by magnitude
result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask])
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
# Assuming the imaginary part should be treated similarly, adjust this if not
imag_v0 = fft_v0.imag
imag_v1 = fft_v1.imag
abs_imag_v0 = imag_v0.abs()
abs_imag_v1 = imag_v1.abs()
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
# Perform the inverse FFT to go back to the spatial domain
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part
return merged_tensor
def merge_tensors_fft_shell(v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor:
"""
Merges two tensors using 2D Fourier transform interpolation.
Parameters:
- v0 (torch.Tensor): The first input tensor.
- v1 (torch.Tensor): The second input tensor.
Returns:
- torch.Tensor: The tensor resulting from the maximal interpolated inverse FFT.
"""
# Ensure the input tensors are on the same device and dtype
v0 = v0.to("cuda:0")
v1 = v1.to("cuda:0")
if len(v0.shape) == 1:
fft_v0 = torch.fft.fft(v0)
fft_v1 = torch.fft.fft(v1)
result_fft = torch.zeros_like(fft_v0)
real_v0 = fft_v0.real
real_v1 = fft_v1.real
abs_real_v0 = real_v0.abs()
abs_real_v1 = real_v1.abs()
sign_mask = real_v0.sign() == real_v1.sign()
larger_values_mask = abs_real_v0 > abs_real_v1
result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask])
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
imag_v0 = fft_v0.imag
imag_v1 = fft_v1.imag
abs_imag_v0 = imag_v0.abs()
abs_imag_v1 = imag_v1.abs()
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
result_fft.imag[sign_mask] = torch.max(imag_v0[sign_mask], imag_v1[sign_mask])
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part
del v0, v1, fft_v0, fft_v1, result_fft
return merged_tensor
# Perform the 2D FFT on both tensors
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
# Initialize the result FFT tensor
result_fft = torch.zeros_like(fft_v0)
# Compare real parts of the coefficients
real_v0 = fft_v0.real
real_v1 = fft_v1.real
abs_real_v0 = real_v0.abs()
abs_real_v1 = real_v1.abs()
# Create masks for where signs match and where they do not
sign_mask = real_v0.sign() == real_v1.sign()
larger_values_mask = abs_real_v0 > abs_real_v1
# Where signs match, interpolate; where signs do not match, take the larger by magnitude
result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask])
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
# Assuming the imaginary part should be treated similarly, adjust this if not
imag_v0 = fft_v0.imag
imag_v1 = fft_v1.imag
abs_imag_v0 = imag_v0.abs()
abs_imag_v1 = imag_v1.abs()
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
result_fft.imag[sign_mask] = torch.max(imag_v0[sign_mask], imag_v1[sign_mask])
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
# Perform the inverse FFT to go back to the spatial domain
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part
return merged_tensor
|