Spaces:
Running
Running
import random | |
from statistics import mean | |
from typing import List, Tuple | |
import torch as th | |
import pytorch_lightning as pl | |
from jaxtyping import Float, Int | |
import numpy as np | |
from torch_geometric.nn.conv import GATv2Conv | |
from models.SAP.dpsr import DPSR | |
from models.SAP.model import PSR2Mesh | |
# Constants | |
th.manual_seed(0) | |
np.random.seed(0) | |
BATCH_SIZE = 1 # BS | |
IN_DIM = 1 | |
OUT_DIM = 1 | |
LATENT_DIM = 32 | |
DROPOUT_PROB = 0.1 | |
GRID_SIZE = 128 | |
def generate_grid_edge_list(gs: int = 128): | |
grid_edge_list = [] | |
for k in range(gs): | |
for j in range(gs): | |
for i in range(gs): | |
current_idx = i + gs*j + k*gs*gs | |
if (i - 1) >= 0: | |
grid_edge_list.append([current_idx, i-1 + gs*j + k*gs*gs]) | |
if (i + 1) < gs: | |
grid_edge_list.append([current_idx, i+1 + gs*j + k*gs*gs]) | |
if (j - 1) >= 0: | |
grid_edge_list.append([current_idx, i + gs*(j-1) + k*gs*gs]) | |
if (j + 1) < gs: | |
grid_edge_list.append([current_idx, i + gs*(j+1) + k*gs*gs]) | |
if (k - 1) >= 0: | |
grid_edge_list.append([current_idx, i + gs*j + (k-1)*gs*gs]) | |
if (k + 1) < gs: | |
grid_edge_list.append([current_idx, i + gs*j + (k+1)*gs*gs]) | |
return grid_edge_list | |
GRID_EDGE_LIST = generate_grid_edge_list(GRID_SIZE) | |
GRID_EDGE_LIST = th.tensor(GRID_EDGE_LIST, dtype=th.int) | |
GRID_EDGE_LIST = GRID_EDGE_LIST.T | |
# GRID_EDGE_LIST = GRID_EDGE_LIST.to(th.device("cuda")) | |
GRID_EDGE_LIST.requires_grad = False # Do not forget to delete it if train | |
class FormOptimizer(th.nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
layers = [] | |
self.gconv1 = GATv2Conv(in_channels=IN_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB) | |
self.gconv2 = GATv2Conv(in_channels=LATENT_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB) | |
self.actv = th.nn.Sigmoid() | |
self.head = th.nn.Linear(in_features=LATENT_DIM, out_features=OUT_DIM) | |
def forward(self, | |
field: Float[th.Tensor, "GS GS GS"]) -> Float[th.Tensor, "GS GS GS"]: | |
""" | |
Args: | |
field (Tensor [GS, GS, GS]): vertices and normals tensor. | |
""" | |
vertex_features = field.clone() | |
vertex_features = vertex_features.reshape(GRID_SIZE*GRID_SIZE*GRID_SIZE, IN_DIM) | |
vertex_features = self.gconv1(x=vertex_features, edge_index=GRID_EDGE_LIST) | |
vertex_features = self.gconv2(x=vertex_features, edge_index=GRID_EDGE_LIST) | |
field_delta = self.head(self.actv(vertex_features)) | |
field_delta = field_delta.reshape(BATCH_SIZE, GRID_SIZE, GRID_SIZE, GRID_SIZE) | |
field_delta += field # field_delta carries the gradient | |
field_delta = th.clamp(field_delta, min=-0.5, max=0.5) | |
return field_delta | |
class Model(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.form_optimizer = FormOptimizer() | |
self.dpsr = DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=0.0) | |
self.field2mesh = PSR2Mesh().apply | |
self.metric = th.nn.MSELoss() | |
self.val_losses = [] | |
self.train_losses = [] | |
def log_h5(self, points, normals): | |
dset = self.log_points_file.create_dataset( | |
name=str(self.h5_frame), | |
shape=points.shape, | |
dtype=np.float16, | |
compression="gzip") | |
dset[:] = points | |
dset = self.log_normals_file.create_dataset( | |
name=str(self.h5_frame), | |
shape=normals.shape, | |
dtype=np.float16, | |
compression="gzip") | |
dset[:] = normals | |
self.h5_frame += 1 | |
def forward(self, | |
v: Float[th.Tensor, "BS N 3"], | |
n: Float[th.Tensor, "BS N 3"]) -> Tuple[Float[th.Tensor, "BS N 3"], # v - vertices | |
Int[th.Tensor, "2 E"], # f - faces | |
Float[th.Tensor, "BS N 3"], # n - vertices normals | |
Float[th.Tensor, "BS GR GR GR"]]: # field: | |
field = self.dpsr(v, n) | |
field = self.form_optimizer(field) | |
v, f, n = self.field2mesh(field) | |
return v, f, n, field | |
def training_step(self, batch, batch_idx) -> Float[th.Tensor, "1"]: | |
vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch | |
mask = th.rand((vertices.shape[1], ), device=th.device("cuda")) < (random.random() / 2.0 + 0.5) | |
vertices = vertices[:, mask] | |
vertices_normals = vertices_normals[:, mask] | |
vr, fr, nr, field_r = model(vertices, vertices_normals) | |
loss = self.metric(field_r, field_gt) | |
train_per_step_loss = loss.item() | |
self.train_losses.append(train_per_step_loss) | |
return loss | |
def on_train_epoch_end(self): | |
mean_train_per_epoch_loss = mean(self.train_losses) | |
self.log("mean_train_per_epoch_loss", mean_train_per_epoch_loss, on_step=False, on_epoch=True) | |
self.train_losses = [] | |
def validation_step(self, batch, batch_idx): | |
vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch | |
vr, fr, nr, field_r = model(vertices, vertices_normals) | |
loss = self.metric(field_r, field_gt) | |
val_per_step_loss = loss.item() | |
self.val_losses.append(val_per_step_loss) | |
return loss | |
def on_validation_epoch_end(self): | |
mean_val_per_epoch_loss = mean(self.val_losses) | |
self.log("mean_val_per_epoch_loss", mean_val_per_epoch_loss, on_step=False, on_epoch=True) | |
self.val_losses = [] | |
def configure_optimizers(self): | |
optimizer = th.optim.Adam(self.parameters(), lr=LR) | |
scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5) | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": { | |
"scheduler": scheduler, | |
"monitor": "mean_val_per_epoch_loss", | |
"interval": "epoch", | |
"frequency": 1, | |
# If set to `True`, will enforce that the value specified 'monitor' | |
# is available when the scheduler is updated, thus stopping | |
# training if not found. If set to `False`, it will only produce a warning | |
"strict": True, | |
# If using the `LearningRateMonitor` callback to monitor the | |
# learning rate progress, this keyword can be used to specify | |
# a custom logged name | |
"name": None, | |
} | |
} | |