xiexh20's picture
delele unnecessary dependency
9d94b63
raw
history blame
1.31 kB
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
import torch
import numpy as np
def show_item(item: Dict):
for key in item.keys():
value = item[key]
if torch.is_tensor(value) and value.numel() < 5:
value_str = value
elif torch.is_tensor(value):
value_str = value.shape
elif isinstance(value, str):
value_str = ('...' + value[-52:]) if len(value) > 50 else value
elif isinstance(value, dict):
value_str = str({k: type(v) for k, v in value.items()})
else:
value_str = type(value)
print(f"{key:<30} {value_str}")
def normalize_to_zero_one(x: torch.Tensor):
return (x - x.min()) / (x.max() - x.min())
def default(x, d):
return d if x is None else x
@dataclass
class DatasetMap:
train: Optional[Iterable] = None
val: Optional[Iterable] = None
test: Optional[Iterable] = None
def create_grid_points(bound=1.0, res=128):
x_ = np.linspace(-bound, bound, res)
y_ = np.linspace(-bound, bound, res)
z_ = np.linspace(-bound, bound, res)
x, y, z = np.meshgrid(x_, y_, z_)
# print(x.shape, y.shape) # (res, res, res)
pts = np.concatenate([y.reshape(-1, 1), x.reshape(-1, 1), z.reshape(-1, 1)], axis=-1)
return pts