import random from statistics import mean from typing import List, Tuple import torch as th import pytorch_lightning as pl from jaxtyping import Float, Int import numpy as np from torch_geometric.nn.conv import GATv2Conv from models.SAP.dpsr import DPSR from models.SAP.model import PSR2Mesh # Constants th.manual_seed(0) np.random.seed(0) BATCH_SIZE = 1 # BS IN_DIM = 1 OUT_DIM = 1 LATENT_DIM = 32 DROPOUT_PROB = 0.1 GRID_SIZE = 128 def generate_grid_edge_list(gs: int = 128): grid_edge_list = [] for k in range(gs): for j in range(gs): for i in range(gs): current_idx = i + gs*j + k*gs*gs if (i - 1) >= 0: grid_edge_list.append([current_idx, i-1 + gs*j + k*gs*gs]) if (i + 1) < gs: grid_edge_list.append([current_idx, i+1 + gs*j + k*gs*gs]) if (j - 1) >= 0: grid_edge_list.append([current_idx, i + gs*(j-1) + k*gs*gs]) if (j + 1) < gs: grid_edge_list.append([current_idx, i + gs*(j+1) + k*gs*gs]) if (k - 1) >= 0: grid_edge_list.append([current_idx, i + gs*j + (k-1)*gs*gs]) if (k + 1) < gs: grid_edge_list.append([current_idx, i + gs*j + (k+1)*gs*gs]) return grid_edge_list GRID_EDGE_LIST = generate_grid_edge_list(GRID_SIZE) GRID_EDGE_LIST = th.tensor(GRID_EDGE_LIST, dtype=th.int) GRID_EDGE_LIST = GRID_EDGE_LIST.T # GRID_EDGE_LIST = GRID_EDGE_LIST.to(th.device("cuda")) GRID_EDGE_LIST.requires_grad = False # Do not forget to delete it if train class FormOptimizer(th.nn.Module): def __init__(self) -> None: super().__init__() layers = [] self.gconv1 = GATv2Conv(in_channels=IN_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB) self.gconv2 = GATv2Conv(in_channels=LATENT_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB) self.actv = th.nn.Sigmoid() self.head = th.nn.Linear(in_features=LATENT_DIM, out_features=OUT_DIM) def forward(self, field: Float[th.Tensor, "GS GS GS"]) -> Float[th.Tensor, "GS GS GS"]: """ Args: field (Tensor [GS, GS, GS]): vertices and normals tensor. """ vertex_features = field.clone() vertex_features = vertex_features.reshape(GRID_SIZE*GRID_SIZE*GRID_SIZE, IN_DIM) vertex_features = self.gconv1(x=vertex_features, edge_index=GRID_EDGE_LIST) vertex_features = self.gconv2(x=vertex_features, edge_index=GRID_EDGE_LIST) field_delta = self.head(self.actv(vertex_features)) field_delta = field_delta.reshape(BATCH_SIZE, GRID_SIZE, GRID_SIZE, GRID_SIZE) field_delta += field # field_delta carries the gradient field_delta = th.clamp(field_delta, min=-0.5, max=0.5) return field_delta class Model(pl.LightningModule): def __init__(self): super().__init__() self.form_optimizer = FormOptimizer() self.dpsr = DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=0.0) self.field2mesh = PSR2Mesh().apply self.metric = th.nn.MSELoss() self.val_losses = [] self.train_losses = [] def log_h5(self, points, normals): dset = self.log_points_file.create_dataset( name=str(self.h5_frame), shape=points.shape, dtype=np.float16, compression="gzip") dset[:] = points dset = self.log_normals_file.create_dataset( name=str(self.h5_frame), shape=normals.shape, dtype=np.float16, compression="gzip") dset[:] = normals self.h5_frame += 1 def forward(self, v: Float[th.Tensor, "BS N 3"], n: Float[th.Tensor, "BS N 3"]) -> Tuple[Float[th.Tensor, "BS N 3"], # v - vertices Int[th.Tensor, "2 E"], # f - faces Float[th.Tensor, "BS N 3"], # n - vertices normals Float[th.Tensor, "BS GR GR GR"]]: # field: field = self.dpsr(v, n) field = self.form_optimizer(field) v, f, n = self.field2mesh(field) return v, f, n, field def training_step(self, batch, batch_idx) -> Float[th.Tensor, "1"]: vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch mask = th.rand((vertices.shape[1], ), device=th.device("cuda")) < (random.random() / 2.0 + 0.5) vertices = vertices[:, mask] vertices_normals = vertices_normals[:, mask] vr, fr, nr, field_r = model(vertices, vertices_normals) loss = self.metric(field_r, field_gt) train_per_step_loss = loss.item() self.train_losses.append(train_per_step_loss) return loss def on_train_epoch_end(self): mean_train_per_epoch_loss = mean(self.train_losses) self.log("mean_train_per_epoch_loss", mean_train_per_epoch_loss, on_step=False, on_epoch=True) self.train_losses = [] def validation_step(self, batch, batch_idx): vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch vr, fr, nr, field_r = model(vertices, vertices_normals) loss = self.metric(field_r, field_gt) val_per_step_loss = loss.item() self.val_losses.append(val_per_step_loss) return loss def on_validation_epoch_end(self): mean_val_per_epoch_loss = mean(self.val_losses) self.log("mean_val_per_epoch_loss", mean_val_per_epoch_loss, on_step=False, on_epoch=True) self.val_losses = [] def configure_optimizers(self): optimizer = th.optim.Adam(self.parameters(), lr=LR) scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "monitor": "mean_val_per_epoch_loss", "interval": "epoch", "frequency": 1, # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, } }