|
import random |
|
from statistics import mean |
|
from typing import List, Tuple |
|
|
|
import torch as th |
|
import pytorch_lightning as pl |
|
from jaxtyping import Float, Int |
|
import numpy as np |
|
from torch_geometric.nn.conv import GATv2Conv |
|
|
|
from models.SAP.dpsr import DPSR |
|
from models.SAP.model import PSR2Mesh |
|
|
|
|
|
|
|
th.manual_seed(0) |
|
np.random.seed(0) |
|
|
|
BATCH_SIZE = 1 |
|
|
|
IN_DIM = 1 |
|
OUT_DIM = 1 |
|
LATENT_DIM = 32 |
|
|
|
DROPOUT_PROB = 0.1 |
|
GRID_SIZE = 128 |
|
|
|
def generate_grid_edge_list(gs: int = 128): |
|
grid_edge_list = [] |
|
|
|
for k in range(gs): |
|
for j in range(gs): |
|
for i in range(gs): |
|
current_idx = i + gs*j + k*gs*gs |
|
if (i - 1) >= 0: |
|
grid_edge_list.append([current_idx, i-1 + gs*j + k*gs*gs]) |
|
if (i + 1) < gs: |
|
grid_edge_list.append([current_idx, i+1 + gs*j + k*gs*gs]) |
|
if (j - 1) >= 0: |
|
grid_edge_list.append([current_idx, i + gs*(j-1) + k*gs*gs]) |
|
if (j + 1) < gs: |
|
grid_edge_list.append([current_idx, i + gs*(j+1) + k*gs*gs]) |
|
if (k - 1) >= 0: |
|
grid_edge_list.append([current_idx, i + gs*j + (k-1)*gs*gs]) |
|
if (k + 1) < gs: |
|
grid_edge_list.append([current_idx, i + gs*j + (k+1)*gs*gs]) |
|
return grid_edge_list |
|
|
|
GRID_EDGE_LIST = generate_grid_edge_list(GRID_SIZE) |
|
GRID_EDGE_LIST = th.tensor(GRID_EDGE_LIST, dtype=th.int) |
|
GRID_EDGE_LIST = GRID_EDGE_LIST.T |
|
|
|
GRID_EDGE_LIST.requires_grad = False |
|
|
|
|
|
class FormOptimizer(th.nn.Module): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
layers = [] |
|
|
|
self.gconv1 = GATv2Conv(in_channels=IN_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB) |
|
self.gconv2 = GATv2Conv(in_channels=LATENT_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB) |
|
|
|
self.actv = th.nn.Sigmoid() |
|
self.head = th.nn.Linear(in_features=LATENT_DIM, out_features=OUT_DIM) |
|
|
|
def forward(self, |
|
field: Float[th.Tensor, "GS GS GS"]) -> Float[th.Tensor, "GS GS GS"]: |
|
""" |
|
Args: |
|
field (Tensor [GS, GS, GS]): vertices and normals tensor. |
|
""" |
|
vertex_features = field.clone() |
|
vertex_features = vertex_features.reshape(GRID_SIZE*GRID_SIZE*GRID_SIZE, IN_DIM) |
|
|
|
vertex_features = self.gconv1(x=vertex_features, edge_index=GRID_EDGE_LIST) |
|
vertex_features = self.gconv2(x=vertex_features, edge_index=GRID_EDGE_LIST) |
|
field_delta = self.head(self.actv(vertex_features)) |
|
|
|
field_delta = field_delta.reshape(BATCH_SIZE, GRID_SIZE, GRID_SIZE, GRID_SIZE) |
|
field_delta += field |
|
field_delta = th.clamp(field_delta, min=-0.5, max=0.5) |
|
|
|
return field_delta |
|
|
|
class Model(pl.LightningModule): |
|
def __init__(self): |
|
super().__init__() |
|
self.form_optimizer = FormOptimizer() |
|
|
|
self.dpsr = DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=0.0) |
|
self.field2mesh = PSR2Mesh().apply |
|
|
|
self.metric = th.nn.MSELoss() |
|
|
|
self.val_losses = [] |
|
self.train_losses = [] |
|
|
|
def log_h5(self, points, normals): |
|
dset = self.log_points_file.create_dataset( |
|
name=str(self.h5_frame), |
|
shape=points.shape, |
|
dtype=np.float16, |
|
compression="gzip") |
|
dset[:] = points |
|
dset = self.log_normals_file.create_dataset( |
|
name=str(self.h5_frame), |
|
shape=normals.shape, |
|
dtype=np.float16, |
|
compression="gzip") |
|
dset[:] = normals |
|
self.h5_frame += 1 |
|
|
|
def forward(self, |
|
v: Float[th.Tensor, "BS N 3"], |
|
n: Float[th.Tensor, "BS N 3"]) -> Tuple[Float[th.Tensor, "BS N 3"], |
|
Int[th.Tensor, "2 E"], |
|
Float[th.Tensor, "BS N 3"], |
|
Float[th.Tensor, "BS GR GR GR"]]: |
|
field = self.dpsr(v, n) |
|
field = self.form_optimizer(field) |
|
v, f, n = self.field2mesh(field) |
|
return v, f, n, field |
|
|
|
def training_step(self, batch, batch_idx) -> Float[th.Tensor, "1"]: |
|
vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch |
|
|
|
mask = th.rand((vertices.shape[1], ), device=th.device("cuda")) < (random.random() / 2.0 + 0.5) |
|
vertices = vertices[:, mask] |
|
vertices_normals = vertices_normals[:, mask] |
|
|
|
vr, fr, nr, field_r = model(vertices, vertices_normals) |
|
|
|
loss = self.metric(field_r, field_gt) |
|
train_per_step_loss = loss.item() |
|
self.train_losses.append(train_per_step_loss) |
|
|
|
return loss |
|
|
|
def on_train_epoch_end(self): |
|
mean_train_per_epoch_loss = mean(self.train_losses) |
|
self.log("mean_train_per_epoch_loss", mean_train_per_epoch_loss, on_step=False, on_epoch=True) |
|
self.train_losses = [] |
|
|
|
def validation_step(self, batch, batch_idx): |
|
vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch |
|
|
|
vr, fr, nr, field_r = model(vertices, vertices_normals) |
|
|
|
loss = self.metric(field_r, field_gt) |
|
val_per_step_loss = loss.item() |
|
self.val_losses.append(val_per_step_loss) |
|
return loss |
|
|
|
def on_validation_epoch_end(self): |
|
mean_val_per_epoch_loss = mean(self.val_losses) |
|
self.log("mean_val_per_epoch_loss", mean_val_per_epoch_loss, on_step=False, on_epoch=True) |
|
self.val_losses = [] |
|
|
|
def configure_optimizers(self): |
|
optimizer = th.optim.Adam(self.parameters(), lr=LR) |
|
scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5) |
|
|
|
return { |
|
"optimizer": optimizer, |
|
"lr_scheduler": { |
|
"scheduler": scheduler, |
|
"monitor": "mean_val_per_epoch_loss", |
|
"interval": "epoch", |
|
"frequency": 1, |
|
|
|
|
|
|
|
"strict": True, |
|
|
|
|
|
|
|
"name": None, |
|
} |
|
} |
|
|