Spaces:
Sleeping
Sleeping
from dataclasses import dataclass | |
from typing import Dict, Iterable, Optional | |
import torch | |
import numpy as np | |
def show_item(item: Dict): | |
for key in item.keys(): | |
value = item[key] | |
if torch.is_tensor(value) and value.numel() < 5: | |
value_str = value | |
elif torch.is_tensor(value): | |
value_str = value.shape | |
elif isinstance(value, str): | |
value_str = ('...' + value[-52:]) if len(value) > 50 else value | |
elif isinstance(value, dict): | |
value_str = str({k: type(v) for k, v in value.items()}) | |
else: | |
value_str = type(value) | |
print(f"{key:<30} {value_str}") | |
def normalize_to_zero_one(x: torch.Tensor): | |
return (x - x.min()) / (x.max() - x.min()) | |
def default(x, d): | |
return d if x is None else x | |
class DatasetMap: | |
train: Optional[Iterable] = None | |
val: Optional[Iterable] = None | |
test: Optional[Iterable] = None | |
def create_grid_points(bound=1.0, res=128): | |
x_ = np.linspace(-bound, bound, res) | |
y_ = np.linspace(-bound, bound, res) | |
z_ = np.linspace(-bound, bound, res) | |
x, y, z = np.meshgrid(x_, y_, z_) | |
# print(x.shape, y.shape) # (res, res, res) | |
pts = np.concatenate([y.reshape(-1, 1), x.reshape(-1, 1), z.reshape(-1, 1)], axis=-1) | |
return pts |