MVV's picture
Upload 19 files
f74bb58
raw
history blame contribute delete
No virus
7.31 kB
import random
from statistics import mean
from typing import List, Tuple
import torch as th
import pytorch_lightning as pl
from jaxtyping import Float, Int
import numpy as np
from torch_geometric.nn.conv import GATv2Conv
from models.SAP.dpsr import DPSR
from models.SAP.model import PSR2Mesh
# Constants
th.manual_seed(0)
np.random.seed(0)
BATCH_SIZE = 1 # BS
IN_DIM = 1
OUT_DIM = 1
LATENT_DIM = 32
DROPOUT_PROB = 0.1
GRID_SIZE = 128
def generate_grid_edge_list(gs: int = 128):
grid_edge_list = []
for k in range(gs):
for j in range(gs):
for i in range(gs):
current_idx = i + gs*j + k*gs*gs
if (i - 1) >= 0:
grid_edge_list.append([current_idx, i-1 + gs*j + k*gs*gs])
if (i + 1) < gs:
grid_edge_list.append([current_idx, i+1 + gs*j + k*gs*gs])
if (j - 1) >= 0:
grid_edge_list.append([current_idx, i + gs*(j-1) + k*gs*gs])
if (j + 1) < gs:
grid_edge_list.append([current_idx, i + gs*(j+1) + k*gs*gs])
if (k - 1) >= 0:
grid_edge_list.append([current_idx, i + gs*j + (k-1)*gs*gs])
if (k + 1) < gs:
grid_edge_list.append([current_idx, i + gs*j + (k+1)*gs*gs])
return grid_edge_list
GRID_EDGE_LIST = generate_grid_edge_list(GRID_SIZE)
GRID_EDGE_LIST = th.tensor(GRID_EDGE_LIST, dtype=th.int)
GRID_EDGE_LIST = GRID_EDGE_LIST.T
# GRID_EDGE_LIST = GRID_EDGE_LIST.to(th.device("cuda"))
GRID_EDGE_LIST.requires_grad = False # Do not forget to delete it if train
class FormOptimizer(th.nn.Module):
def __init__(self) -> None:
super().__init__()
layers = []
self.gconv1 = GATv2Conv(in_channels=IN_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
self.gconv2 = GATv2Conv(in_channels=LATENT_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
self.actv = th.nn.Sigmoid()
self.head = th.nn.Linear(in_features=LATENT_DIM, out_features=OUT_DIM)
def forward(self,
field: Float[th.Tensor, "GS GS GS"]) -> Float[th.Tensor, "GS GS GS"]:
"""
Args:
field (Tensor [GS, GS, GS]): vertices and normals tensor.
"""
vertex_features = field.clone()
vertex_features = vertex_features.reshape(GRID_SIZE*GRID_SIZE*GRID_SIZE, IN_DIM)
vertex_features = self.gconv1(x=vertex_features, edge_index=GRID_EDGE_LIST)
vertex_features = self.gconv2(x=vertex_features, edge_index=GRID_EDGE_LIST)
field_delta = self.head(self.actv(vertex_features))
field_delta = field_delta.reshape(BATCH_SIZE, GRID_SIZE, GRID_SIZE, GRID_SIZE)
field_delta += field # field_delta carries the gradient
field_delta = th.clamp(field_delta, min=-0.5, max=0.5)
return field_delta
class Model(pl.LightningModule):
def __init__(self):
super().__init__()
self.form_optimizer = FormOptimizer()
self.dpsr = DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=0.0)
self.field2mesh = PSR2Mesh().apply
self.metric = th.nn.MSELoss()
self.val_losses = []
self.train_losses = []
def log_h5(self, points, normals):
dset = self.log_points_file.create_dataset(
name=str(self.h5_frame),
shape=points.shape,
dtype=np.float16,
compression="gzip")
dset[:] = points
dset = self.log_normals_file.create_dataset(
name=str(self.h5_frame),
shape=normals.shape,
dtype=np.float16,
compression="gzip")
dset[:] = normals
self.h5_frame += 1
def forward(self,
v: Float[th.Tensor, "BS N 3"],
n: Float[th.Tensor, "BS N 3"]) -> Tuple[Float[th.Tensor, "BS N 3"], # v - vertices
Int[th.Tensor, "2 E"], # f - faces
Float[th.Tensor, "BS N 3"], # n - vertices normals
Float[th.Tensor, "BS GR GR GR"]]: # field:
field = self.dpsr(v, n)
field = self.form_optimizer(field)
v, f, n = self.field2mesh(field)
return v, f, n, field
def training_step(self, batch, batch_idx) -> Float[th.Tensor, "1"]:
vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
mask = th.rand((vertices.shape[1], ), device=th.device("cuda")) < (random.random() / 2.0 + 0.5)
vertices = vertices[:, mask]
vertices_normals = vertices_normals[:, mask]
vr, fr, nr, field_r = model(vertices, vertices_normals)
loss = self.metric(field_r, field_gt)
train_per_step_loss = loss.item()
self.train_losses.append(train_per_step_loss)
return loss
def on_train_epoch_end(self):
mean_train_per_epoch_loss = mean(self.train_losses)
self.log("mean_train_per_epoch_loss", mean_train_per_epoch_loss, on_step=False, on_epoch=True)
self.train_losses = []
def validation_step(self, batch, batch_idx):
vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
vr, fr, nr, field_r = model(vertices, vertices_normals)
loss = self.metric(field_r, field_gt)
val_per_step_loss = loss.item()
self.val_losses.append(val_per_step_loss)
return loss
def on_validation_epoch_end(self):
mean_val_per_epoch_loss = mean(self.val_losses)
self.log("mean_val_per_epoch_loss", mean_val_per_epoch_loss, on_step=False, on_epoch=True)
self.val_losses = []
def configure_optimizers(self):
optimizer = th.optim.Adam(self.parameters(), lr=LR)
scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "mean_val_per_epoch_loss",
"interval": "epoch",
"frequency": 1,
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
}