File size: 1,313 Bytes
9d94b63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from dataclasses import dataclass
from typing import Dict, Iterable, Optional

import torch
import numpy as np

def show_item(item: Dict):
    for key in item.keys():
        value = item[key]
        if torch.is_tensor(value) and value.numel() < 5:
            value_str = value
        elif torch.is_tensor(value):
            value_str = value.shape
        elif isinstance(value, str):
            value_str = ('...' + value[-52:]) if len(value) > 50 else value
        elif isinstance(value, dict):
            value_str = str({k: type(v) for k, v in value.items()})
        else:
            value_str = type(value)
        print(f"{key:<30} {value_str}")


def normalize_to_zero_one(x: torch.Tensor):
    return (x - x.min()) / (x.max() - x.min())
    

def default(x, d):
    return d if x is None else x


@dataclass
class DatasetMap:
    train: Optional[Iterable] = None
    val: Optional[Iterable] = None
    test: Optional[Iterable] = None



def create_grid_points(bound=1.0, res=128):
    x_ = np.linspace(-bound, bound, res)
    y_ = np.linspace(-bound, bound, res)
    z_ = np.linspace(-bound, bound, res)

    x, y, z = np.meshgrid(x_, y_, z_)
    # print(x.shape, y.shape) # (res, res, res)
    pts = np.concatenate([y.reshape(-1, 1), x.reshape(-1, 1), z.reshape(-1, 1)], axis=-1)
    return pts