FreeTraj / utils /utils_freetraj.py
Anonymous
init
2a50f45
import torch
import torch.fft as fft
import math
def get_longpath(BOX_SIZE_H=0.3, BOX_SIZE_W=0.3, input_mode=4):
if input_mode == 1:
# mode 1
inputs = [[0, 0, 0 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W],
[7, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 15 * 7, (1-BOX_SIZE_W) / 15 * 7 + BOX_SIZE_W],
[8, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 15 * 8, (1-BOX_SIZE_W) / 15 * 8 + BOX_SIZE_W],
[15, 0, 0 + BOX_SIZE_H, 1-BOX_SIZE_W, 1],
[16, 0.1, 0.1 + BOX_SIZE_H, 0.9-BOX_SIZE_W, 0.9],
[25, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W],
[31, 0.9-BOX_SIZE_H, 0.9, 0.1, 0.1 + BOX_SIZE_W],
[32, 1-BOX_SIZE_H, 1, 0, 0 + BOX_SIZE_W],
[39, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 15 * 7, (1-BOX_SIZE_W) / 15 * 7 + BOX_SIZE_W],
[40, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 15 * 8, (1-BOX_SIZE_W) / 15 * 8 + BOX_SIZE_W],
[47, 1-BOX_SIZE_H, 1, 1-BOX_SIZE_W, 1],
[48, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9],
[57, 0.9-BOX_SIZE_H, 0.9, 0.1, 0.1 + BOX_SIZE_W],
[63, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W]]
elif input_mode == 2:
# mode 2
inputs = [[0, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W],
[6, 0.9-BOX_SIZE_H, 0.9, 0.1, 0.1 + BOX_SIZE_W],
[15, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9],
[16, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9],
[22, 0.1, 0.1 + BOX_SIZE_H, 0.9-BOX_SIZE_W, 0.9],
[31, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W],
[32, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W],
[41, 0.1, 0.1 + BOX_SIZE_H, 0.9-BOX_SIZE_W, 0.9],
[47, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9],
[48, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9],
[57, 0.9-BOX_SIZE_H, 0.9, 0.1, 0.1 + BOX_SIZE_W],
[63, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W]]
elif input_mode == 3:
# mode 3 ||||
inputs = [[0, 0, 0 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W],
[9, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 7 * 1, (1-BOX_SIZE_W) / 7 * 1 + BOX_SIZE_W],
[18, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 7 * 2, (1-BOX_SIZE_W) / 7 * 2 + BOX_SIZE_W],
[27, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 7 * 3, (1-BOX_SIZE_W) / 7 * 3 + BOX_SIZE_W],
[36, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 7 * 4, (1-BOX_SIZE_W) / 7 * 4 + BOX_SIZE_W],
[45, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 7 * 5, (1-BOX_SIZE_W) / 7 * 5 + BOX_SIZE_W],
[54, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 7 * 6, (1-BOX_SIZE_W) / 7 * 6 + BOX_SIZE_W],
[63, 1-BOX_SIZE_H, 1, 1-BOX_SIZE_W, 1]]
elif input_mode == 4:
# mode 4 ----
inputs = [[0, 0, 0 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W],
[9, (1-BOX_SIZE_H) / 7 * 1, (1-BOX_SIZE_H) / 7 * 1 + BOX_SIZE_H, 1-BOX_SIZE_W, 1],
[18, (1-BOX_SIZE_H) / 7 * 2, (1-BOX_SIZE_H) / 7 * 2 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W],
[27, (1-BOX_SIZE_H) / 7 * 3, (1-BOX_SIZE_H) / 7 * 3 + BOX_SIZE_H, 1-BOX_SIZE_W, 1],
[36, (1-BOX_SIZE_H) / 7 * 4, (1-BOX_SIZE_H) / 7 * 4 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W],
[45, (1-BOX_SIZE_H) / 7 * 5, (1-BOX_SIZE_H) / 7 * 5 + BOX_SIZE_H, 1-BOX_SIZE_W, 1],
[54, (1-BOX_SIZE_H) / 7 * 6, (1-BOX_SIZE_H) / 7 * 6 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W],
[63, 1-BOX_SIZE_H, 1, 1-BOX_SIZE_W, 1]]
else:
print('error')
exit()
outputs = plan_path(inputs)
# print(outputs)
return outputs
def get_path(BOX_SIZE_H=0.3, BOX_SIZE_W=0.3, input_mode=0):
if input_mode == 0:
# \ d
inputs = [[0, 0, 0 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W], [15, 1-BOX_SIZE_H, 1, 1-BOX_SIZE_W, 1]]
elif input_mode == 1:
# / re d
inputs = [[0, 0, 0 + BOX_SIZE_H, 1-BOX_SIZE_W, 1], [15, 1-BOX_SIZE_H, 1, 0, 0 + BOX_SIZE_W]]
elif input_mode == 2:
# L
inputs = [[0, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W], [6, 0.9-BOX_SIZE_H, 0.9, 0.1, 0.1 + BOX_SIZE_W], [15, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9]]
elif input_mode == 3:
# re L
inputs = [[0, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9], [6, 0.1, 0.1 + BOX_SIZE_H, 0.9-BOX_SIZE_W, 0.9], [15, 0.1, 0.1 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W]]
elif input_mode == 4:
# V
inputs = [[0, 0, 0 + BOX_SIZE_H, 0, 0 + BOX_SIZE_W], [7, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 15 * 7, (1-BOX_SIZE_W) / 15 * 7 + BOX_SIZE_W], [8, 1-BOX_SIZE_H, 1, (1-BOX_SIZE_W) / 15 * 8, (1-BOX_SIZE_W) / 15 * 8 + BOX_SIZE_W], [15, 0, 0 + BOX_SIZE_H, 1-BOX_SIZE_W, 1]]
elif input_mode == 5:
# re V
inputs = [[0, 1-BOX_SIZE_H, 1, 1-BOX_SIZE_W, 1], [7, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 15 * 8, (1-BOX_SIZE_W) / 15 * 8 + BOX_SIZE_W], [8, 0, 0 + BOX_SIZE_H, (1-BOX_SIZE_W) / 15 * 7, (1-BOX_SIZE_W) / 15 * 7 + BOX_SIZE_W], [15, 1-BOX_SIZE_H, 1, 0, 0 + BOX_SIZE_W]]
elif input_mode == 6:
# -- goback
inputs = [[0, 0.35, 0.35 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W], [7, 0.35, 0.35 + BOX_SIZE_H, 0.9-BOX_SIZE_W, 0.9], [8, 0.35, 0.35 + BOX_SIZE_H, 0.9-BOX_SIZE_W, 0.9], [15, 0.35, 0.35 + BOX_SIZE_H, 0.1, 0.1 + BOX_SIZE_W]]
elif input_mode == 7:
# tri
inputs = [[0, 0.1, 0.1 + BOX_SIZE_H, 0.35, 0.35 + BOX_SIZE_W], [5, 0.9-BOX_SIZE_H, 0.9, 0.9-BOX_SIZE_W, 0.9], [10, 0.9-BOX_SIZE_H, 0.9, 0.1, 0.1 + BOX_SIZE_W], [15, 0.1, 0.1 + BOX_SIZE_H, 0.35, 0.35 + BOX_SIZE_W]]
outputs = plan_path(inputs)
return outputs
# input: List([frame, h_start, h_end, w_start, w_end], ...)
# return: List([h_start, h_end, w_start, w_end], ...)
def plan_path(input, video_length = 16):
len_input = len(input)
path = [input[0][1:]]
for i in range(1, len_input):
start = input[i-1]
end = input[i]
start_frame = start[0]
end_frame = end[0]
h_start_change = (end[1] - start[1]) / (end_frame - start_frame)
h_end_change = (end[2] - start[2]) / (end_frame - start_frame)
w_start_change = (end[3] - start[3]) / (end_frame - start_frame)
w_end_change = (end[4] - start[4]) / (end_frame - start_frame)
for j in range(start_frame+1, end_frame + 1):
increase_frame = j - start_frame
path += [[increase_frame * h_start_change + start[1], increase_frame * h_end_change + start[2], increase_frame * w_start_change + start[3], increase_frame * w_end_change + start[4]]]
if input[0][0] > 0:
h_change = path[1][0] - path[0][0]
w_change = path[1][2] - path[0][2]
for i in range(input[0][0]):
path = [path[0][0] - h_change, path[0][1] - h_change, path[0][2] - w_change, path[0][3] - w_change] + path
if input[-1][0] < video_length - 1:
h_change = path[-1][0] - path[-2][0]
w_change = path[-1][2] - path[-2][2]
for i in range(video_length - 1 - input[-1][0]):
path = path + [path[-1][0] + h_change, path[-1][1] + h_change, path[-1][2] + w_change, path[-1][3] + w_change]
return path
def gaussian_2d(x=0, y=0, mx=0, my=0, sx=1, sy=1):
""" 2d Gaussian weight function
"""
gaussian_map = (
1
/ (2 * math.pi * sx * sy)
* torch.exp(-((x - mx) ** 2 / (2 * sx**2) + (y - my) ** 2 / (2 * sy**2)))
)
gaussian_map.div_(gaussian_map.max())
return gaussian_map
def gaussian_weight(height=32, width=32, KERNEL_DIVISION=3.0):
x = torch.linspace(0, height, height)
y = torch.linspace(0, width, width)
x, y = torch.meshgrid(x, y, indexing="ij")
noise_patch = (
gaussian_2d(
x,
y,
mx=int(height / 2),
my=int(width / 2),
sx=float(height / KERNEL_DIVISION),
sy=float(width / KERNEL_DIVISION),
)
).half()
return noise_patch
def freq_mix_3d(x, noise, LPF):
"""
Noise reinitialization.
Args:
x: diffused latent
noise: randomly sampled noise
LPF: low pass filter
"""
# FFT
x_freq = fft.fftn(x, dim=(-3, -2, -1))
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
# frequency mix
HPF = 1 - LPF
x_freq_low = x_freq * LPF
noise_freq_high = noise_freq * HPF
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
# IFFT
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
return x_mixed
def get_freq_filter(shape, device, filter_type, n, d_s, d_t):
"""
Form the frequency filter for noise reinitialization.
Args:
shape: shape of latent (B, C, T, H, W)
filter_type: type of the freq filter
n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
"""
if filter_type == "gaussian":
return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
elif filter_type == "ideal":
return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
elif filter_type == "box":
return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
elif filter_type == "butterworth":
return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device)
else:
raise NotImplementedError
def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25):
"""
Compute the gaussian low pass filter mask.
Args:
shape: shape of the filter (volume)
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
"""
T, H, W = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)
if d_s==0 or d_t==0:
return mask
for t in range(T):
for h in range(H):
for w in range(W):
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square)
return mask
def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25):
"""
Compute the butterworth low pass filter mask.
Args:
shape: shape of the filter (volume)
n: order of the filter, larger n ~ ideal, smaller n ~ gaussian
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
"""
T, H, W = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)
if d_s==0 or d_t==0:
return mask
for t in range(T):
for h in range(H):
for w in range(W):
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n)
return mask
def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25):
"""
Compute the ideal low pass filter mask.
Args:
shape: shape of the filter (volume)
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
"""
T, H, W = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)
if d_s==0 or d_t==0:
return mask
for t in range(T):
for h in range(H):
for w in range(W):
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0
return mask
def box_low_pass_filter(shape, d_s=0.25, d_t=0.25):
"""
Compute the ideal low pass filter mask (approximated version).
Args:
shape: shape of the filter (volume)
d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
d_t: normalized stop frequency for temporal dimension (0.0-1.0)
"""
T, H, W = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)
if d_s==0 or d_t==0:
return mask
threshold_s = round(int(H // 2) * d_s)
threshold_t = round(T // 2 * d_t)
cframe, crow, ccol = T // 2, H // 2, W //2
mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0
return mask