Spaces:
No application file
No application file
# Copyright 2022 Lunar Ring. All rights reserved. | |
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
torch.backends.cudnn.benchmark = False | |
import numpy as np | |
import warnings | |
warnings.filterwarnings('ignore') | |
import time | |
import warnings | |
import datetime | |
from typing import List, Union | |
torch.set_grad_enabled(False) | |
import yaml | |
def interpolate_spherical(p0, p1, fract_mixing: float): | |
r""" | |
Helper function to correctly mix two random variables using spherical interpolation. | |
See https://en.wikipedia.org/wiki/Slerp | |
The function will always cast up to float64 for sake of extra 4. | |
Args: | |
p0: | |
First tensor for interpolation | |
p1: | |
Second tensor for interpolation | |
fract_mixing: float | |
Mixing coefficient of interval [0, 1]. | |
0 will return in p0 | |
1 will return in p1 | |
0.x will return a mix between both preserving angular velocity. | |
""" | |
if p0.dtype == torch.float16: | |
recast_to = 'fp16' | |
else: | |
recast_to = 'fp32' | |
p0 = p0.double() | |
p1 = p1.double() | |
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1) | |
epsilon = 1e-7 | |
dot = torch.sum(p0 * p1) / norm | |
dot = dot.clamp(-1 + epsilon, 1 - epsilon) | |
theta_0 = torch.arccos(dot) | |
sin_theta_0 = torch.sin(theta_0) | |
theta_t = theta_0 * fract_mixing | |
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0 | |
s1 = torch.sin(theta_t) / sin_theta_0 | |
interp = p0 * s0 + p1 * s1 | |
if recast_to == 'fp16': | |
interp = interp.half() | |
elif recast_to == 'fp32': | |
interp = interp.float() | |
return interp | |
def interpolate_linear(p0, p1, fract_mixing): | |
r""" | |
Helper function to mix two variables using standard linear interpolation. | |
Args: | |
p0: | |
First tensor / np.ndarray for interpolation | |
p1: | |
Second tensor / np.ndarray for interpolation | |
fract_mixing: float | |
Mixing coefficient of interval [0, 1]. | |
0 will return in p0 | |
1 will return in p1 | |
0.x will return a linear mix between both. | |
""" | |
reconvert_uint8 = False | |
if type(p0) is np.ndarray and p0.dtype == 'uint8': | |
reconvert_uint8 = True | |
p0 = p0.astype(np.float64) | |
if type(p1) is np.ndarray and p1.dtype == 'uint8': | |
reconvert_uint8 = True | |
p1 = p1.astype(np.float64) | |
interp = (1 - fract_mixing) * p0 + fract_mixing * p1 | |
if reconvert_uint8: | |
interp = np.clip(interp, 0, 255).astype(np.uint8) | |
return interp | |
def add_frames_linear_interp( | |
list_imgs: List[np.ndarray], | |
fps_target: Union[float, int] = None, | |
duration_target: Union[float, int] = None, | |
nmb_frames_target: int = None): | |
r""" | |
Helper function to cheaply increase the number of frames given a list of images, | |
by virtue of standard linear interpolation. | |
The number of inserted frames will be automatically adjusted so that the total of number | |
of frames can be fixed precisely, using a random shuffling technique. | |
The function allows 1:1 comparisons between transitions as videos. | |
Args: | |
list_imgs: List[np.ndarray) | |
List of images, between each image new frames will be inserted via linear interpolation. | |
fps_target: | |
OptionA: specify here the desired frames per second. | |
duration_target: | |
OptionA: specify here the desired duration of the transition in seconds. | |
nmb_frames_target: | |
OptionB: directly fix the total number of frames of the output. | |
""" | |
# Sanity | |
if nmb_frames_target is not None and fps_target is not None: | |
raise ValueError("You cannot specify both fps_target and nmb_frames_target") | |
if fps_target is None: | |
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target" | |
if nmb_frames_target is None: | |
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target" | |
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target" | |
nmb_frames_target = fps_target * duration_target | |
# Get number of frames that are missing | |
nmb_frames_diff = len(list_imgs) - 1 | |
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1 | |
if nmb_frames_missing < 1: | |
return list_imgs | |
list_imgs_float = [img.astype(np.float32) for img in list_imgs] | |
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame | |
mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff | |
constfact = np.floor(mean_nmb_frames_insert) | |
remainder_x = 1 - (mean_nmb_frames_insert - constfact) | |
nmb_iter = 0 | |
while True: | |
nmb_frames_to_insert = np.random.rand(nmb_frames_diff) | |
nmb_frames_to_insert[nmb_frames_to_insert <= remainder_x] = 0 | |
nmb_frames_to_insert[nmb_frames_to_insert > remainder_x] = 1 | |
nmb_frames_to_insert += constfact | |
if np.sum(nmb_frames_to_insert) == nmb_frames_missing: | |
break | |
nmb_iter += 1 | |
if nmb_iter > 100000: | |
print("add_frames_linear_interp: issue with inserting the right number of frames") | |
break | |
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32) | |
list_imgs_interp = [] | |
for i in range(len(list_imgs_float) - 1): | |
img0 = list_imgs_float[i] | |
img1 = list_imgs_float[i + 1] | |
list_imgs_interp.append(img0.astype(np.uint8)) | |
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i] + 2)[1:-1] | |
for fract_linblend in list_fracts_linblend: | |
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8) | |
list_imgs_interp.append(img_blend.astype(np.uint8)) | |
if i == len(list_imgs_float) - 2: | |
list_imgs_interp.append(img1.astype(np.uint8)) | |
return list_imgs_interp | |
def get_spacing(nmb_points: int, scaling: float): | |
""" | |
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5 | |
Args: | |
nmb_points: int | |
Number of points between [0, 1] | |
scaling: float | |
Higher values will return higher sampling density around 0.5 | |
""" | |
if scaling < 1.7: | |
return np.linspace(0, 1, nmb_points) | |
nmb_points_per_side = nmb_points // 2 + 1 | |
if np.mod(nmb_points, 2) != 0: # Uneven case | |
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5) | |
right_side = 1 - left_side[::-1][1:] | |
else: | |
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1] | |
right_side = 1 - left_side[::-1] | |
all_fracts = np.hstack([left_side, right_side]) | |
return all_fracts | |
def get_time(resolution=None): | |
""" | |
Helper function returning an nicely formatted time string, e.g. 221117_1620 | |
""" | |
if resolution is None: | |
resolution = "second" | |
if resolution == "day": | |
t = time.strftime('%y%m%d', time.localtime()) | |
elif resolution == "minute": | |
t = time.strftime('%y%m%d_%H%M', time.localtime()) | |
elif resolution == "second": | |
t = time.strftime('%y%m%d_%H%M%S', time.localtime()) | |
elif resolution == "millisecond": | |
t = time.strftime('%y%m%d_%H%M%S', time.localtime()) | |
t += "_" | |
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f')) / 1000))) | |
else: | |
raise ValueError("bad resolution provided: %s" % resolution) | |
return t | |
def compare_dicts(a, b): | |
""" | |
Compares two dictionaries a and b and returns a dictionary c, with all | |
keys,values that have shared keys in a and b but same values in a and b. | |
The values of a and b are stacked together in the output. | |
Example: | |
a = {}; a['bobo'] = 4 | |
b = {}; b['bobo'] = 5 | |
c = dict_compare(a,b) | |
c = {"bobo",[4,5]} | |
""" | |
c = {} | |
for key in a.keys(): | |
if key in b.keys(): | |
val_a = a[key] | |
val_b = b[key] | |
if val_a != val_b: | |
c[key] = [val_a, val_b] | |
return c | |
def yml_load(fp_yml, print_fields=False): | |
""" | |
Helper function for loading yaml files | |
""" | |
with open(fp_yml) as f: | |
data = yaml.load(f, Loader=yaml.loader.SafeLoader) | |
dict_data = dict(data) | |
print("load: loaded {}".format(fp_yml)) | |
return dict_data | |
def yml_save(fp_yml, dict_stuff): | |
""" | |
Helper function for saving yaml files | |
""" | |
with open(fp_yml, 'w') as f: | |
yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False) | |
print("yml_save: saved {}".format(fp_yml)) | |