xiexh20's picture
delele unnecessary dependency
9d94b63
raw
history blame contribute delete
No virus
1.31 kB
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
import torch
import numpy as np
def show_item(item: Dict):
for key in item.keys():
value = item[key]
if torch.is_tensor(value) and value.numel() < 5:
value_str = value
elif torch.is_tensor(value):
value_str = value.shape
elif isinstance(value, str):
value_str = ('...' + value[-52:]) if len(value) > 50 else value
elif isinstance(value, dict):
value_str = str({k: type(v) for k, v in value.items()})
else:
value_str = type(value)
print(f"{key:<30} {value_str}")
def normalize_to_zero_one(x: torch.Tensor):
return (x - x.min()) / (x.max() - x.min())
def default(x, d):
return d if x is None else x
@dataclass
class DatasetMap:
train: Optional[Iterable] = None
val: Optional[Iterable] = None
test: Optional[Iterable] = None
def create_grid_points(bound=1.0, res=128):
x_ = np.linspace(-bound, bound, res)
y_ = np.linspace(-bound, bound, res)
z_ = np.linspace(-bound, bound, res)
x, y, z = np.meshgrid(x_, y_, z_)
# print(x.shape, y.shape) # (res, res, res)
pts = np.concatenate([y.reshape(-1, 1), x.reshape(-1, 1), z.reshape(-1, 1)], axis=-1)
return pts