|
from typing import List, Tuple, Optional, Dict |
|
import math |
|
import numpy as np |
|
|
|
|
|
def get_accumulate_timestamp_idxs( |
|
timestamps: List[float], |
|
start_time: float, |
|
dt: float, |
|
eps:float=1e-5, |
|
next_global_idx: Optional[int]=0, |
|
allow_negative=False |
|
) -> Tuple[List[int], List[int], int]: |
|
""" |
|
For each dt window, choose the first timestamp in the window. |
|
Assumes timestamps sorted. One timestamp might be chosen multiple times due to dropped frames. |
|
next_global_idx should start at 0 normally, and then use the returned next_global_idx. |
|
However, when overwiting previous values are desired, set last_global_idx to None. |
|
|
|
Returns: |
|
local_idxs: which index in the given timestamps array to chose from |
|
global_idxs: the global index of each chosen timestamp |
|
next_global_idx: used for next call. |
|
""" |
|
local_idxs = list() |
|
global_idxs = list() |
|
for local_idx, ts in enumerate(timestamps): |
|
|
|
|
|
global_idx = math.floor((ts - start_time) / dt + eps) |
|
if (not allow_negative) and (global_idx < 0): |
|
continue |
|
if next_global_idx is None: |
|
next_global_idx = global_idx |
|
|
|
n_repeats = max(0, global_idx - next_global_idx + 1) |
|
for i in range(n_repeats): |
|
local_idxs.append(local_idx) |
|
global_idxs.append(next_global_idx + i) |
|
next_global_idx += n_repeats |
|
return local_idxs, global_idxs, next_global_idx |
|
|
|
|
|
def align_timestamps( |
|
timestamps: List[float], |
|
target_global_idxs: List[int], |
|
start_time: float, |
|
dt: float, |
|
eps:float=1e-5): |
|
if isinstance(target_global_idxs, np.ndarray): |
|
target_global_idxs = target_global_idxs.tolist() |
|
assert len(target_global_idxs) > 0 |
|
|
|
local_idxs, global_idxs, _ = get_accumulate_timestamp_idxs( |
|
timestamps=timestamps, |
|
start_time=start_time, |
|
dt=dt, |
|
eps=eps, |
|
next_global_idx=target_global_idxs[0], |
|
allow_negative=True |
|
) |
|
if len(global_idxs) > len(target_global_idxs): |
|
|
|
global_idxs = global_idxs[:len(target_global_idxs)] |
|
local_idxs = local_idxs[:len(target_global_idxs)] |
|
|
|
if len(global_idxs) == 0: |
|
import pdb; pdb.set_trace() |
|
|
|
for i in range(len(target_global_idxs) - len(global_idxs)): |
|
|
|
local_idxs.append(len(timestamps)-1) |
|
global_idxs.append(global_idxs[-1] + 1) |
|
assert global_idxs == target_global_idxs |
|
assert len(local_idxs) == len(global_idxs) |
|
return local_idxs |
|
|
|
|
|
class TimestampObsAccumulator: |
|
def __init__(self, |
|
start_time: float, |
|
dt: float, |
|
eps: float=1e-5): |
|
self.start_time = start_time |
|
self.dt = dt |
|
self.eps = eps |
|
self.obs_buffer = dict() |
|
self.timestamp_buffer = None |
|
self.next_global_idx = 0 |
|
|
|
def __len__(self): |
|
return self.next_global_idx |
|
|
|
@property |
|
def data(self): |
|
if self.timestamp_buffer is None: |
|
return dict() |
|
result = dict() |
|
for key, value in self.obs_buffer.items(): |
|
result[key] = value[:len(self)] |
|
return result |
|
|
|
@property |
|
def actual_timestamps(self): |
|
if self.timestamp_buffer is None: |
|
return np.array([]) |
|
return self.timestamp_buffer[:len(self)] |
|
|
|
@property |
|
def timestamps(self): |
|
if self.timestamp_buffer is None: |
|
return np.array([]) |
|
return self.start_time + np.arange(len(self)) * self.dt |
|
|
|
def put(self, data: Dict[str, np.ndarray], timestamps: np.ndarray): |
|
""" |
|
data: |
|
key: T,* |
|
""" |
|
|
|
local_idxs, global_idxs, self.next_global_idx = get_accumulate_timestamp_idxs( |
|
timestamps=timestamps, |
|
start_time=self.start_time, |
|
dt=self.dt, |
|
eps=self.eps, |
|
next_global_idx=self.next_global_idx |
|
) |
|
|
|
if len(global_idxs) > 0: |
|
if self.timestamp_buffer is None: |
|
|
|
self.obs_buffer = dict() |
|
for key, value in data.items(): |
|
self.obs_buffer[key] = np.zeros_like(value) |
|
self.timestamp_buffer = np.zeros( |
|
(len(timestamps),), dtype=np.float64) |
|
|
|
this_max_size = global_idxs[-1] + 1 |
|
if this_max_size > len(self.timestamp_buffer): |
|
|
|
new_size = max(this_max_size, len(self.timestamp_buffer) * 2) |
|
for key in list(self.obs_buffer.keys()): |
|
new_shape = (new_size,) + self.obs_buffer[key].shape[1:] |
|
self.obs_buffer[key] = np.resize(self.obs_buffer[key], new_shape) |
|
self.timestamp_buffer = np.resize(self.timestamp_buffer, (new_size)) |
|
|
|
|
|
for key, value in self.obs_buffer.items(): |
|
value[global_idxs] = data[key][local_idxs] |
|
self.timestamp_buffer[global_idxs] = timestamps[local_idxs] |
|
|
|
|
|
class TimestampActionAccumulator: |
|
def __init__(self, |
|
start_time: float, |
|
dt: float, |
|
eps: float=1e-5): |
|
""" |
|
Different from Obs accumulator, the action accumulator |
|
allows overwriting previous values. |
|
""" |
|
self.start_time = start_time |
|
self.dt = dt |
|
self.eps = eps |
|
self.action_buffer = None |
|
self.timestamp_buffer = None |
|
self.size = 0 |
|
|
|
def __len__(self): |
|
return self.size |
|
|
|
@property |
|
def actions(self): |
|
if self.action_buffer is None: |
|
return np.array([]) |
|
return self.action_buffer[:len(self)] |
|
|
|
@property |
|
def actual_timestamps(self): |
|
if self.timestamp_buffer is None: |
|
return np.array([]) |
|
return self.timestamp_buffer[:len(self)] |
|
|
|
@property |
|
def timestamps(self): |
|
if self.timestamp_buffer is None: |
|
return np.array([]) |
|
return self.start_time + np.arange(len(self)) * self.dt |
|
|
|
def put(self, actions: np.ndarray, timestamps: np.ndarray): |
|
""" |
|
Note: timestamps is the time when the action will be issued, |
|
not when the action will be completed (target_timestamp) |
|
""" |
|
|
|
local_idxs, global_idxs, _ = get_accumulate_timestamp_idxs( |
|
timestamps=timestamps, |
|
start_time=self.start_time, |
|
dt=self.dt, |
|
eps=self.eps, |
|
|
|
next_global_idx=None |
|
) |
|
|
|
if len(global_idxs) > 0: |
|
if self.timestamp_buffer is None: |
|
|
|
self.action_buffer = np.zeros_like(actions) |
|
self.timestamp_buffer = np.zeros((len(actions),), dtype=np.float64) |
|
|
|
this_max_size = global_idxs[-1] + 1 |
|
if this_max_size > len(self.timestamp_buffer): |
|
|
|
new_size = max(this_max_size, len(self.timestamp_buffer) * 2) |
|
new_shape = (new_size,) + self.action_buffer.shape[1:] |
|
self.action_buffer = np.resize(self.action_buffer, new_shape) |
|
self.timestamp_buffer = np.resize(self.timestamp_buffer, (new_size,)) |
|
|
|
|
|
self.action_buffer[global_idxs] = actions[local_idxs] |
|
self.timestamp_buffer[global_idxs] = timestamps[local_idxs] |
|
self.size = max(self.size, this_max_size) |
|
|