Sohaib36's picture
add: adding monoscene
6705a8b
raw
history blame
10.4 kB
import pytorch_lightning as pl
import torch
import torch.nn as nn
from monoscene.models.unet3d_nyu import UNet3D as UNet3DNYU
from monoscene.models.unet3d_kitti import UNet3D as UNet3DKitti
from monoscene.loss.sscMetrics import SSCMetrics
from monoscene.loss.ssc_loss import sem_scal_loss, CE_ssc_loss, KL_sep, geo_scal_loss
from monoscene.models.flosp import FLoSP
from monoscene.loss.CRP_loss import compute_super_CP_multilabel_loss
import numpy as np
import torch.nn.functional as F
from monoscene.models.unet2d import UNet2D
from torch.optim.lr_scheduler import MultiStepLR
class MonoScene(pl.LightningModule):
def __init__(
self,
n_classes,
class_names,
feature,
class_weights,
project_scale,
full_scene_size,
dataset,
n_relations=4,
context_prior=True,
fp_loss=True,
project_res=[],
frustum_size=4,
relation_loss=False,
CE_ssc_loss=True,
geo_scal_loss=True,
sem_scal_loss=True,
lr=1e-4,
weight_decay=1e-4,
):
super().__init__()
self.project_res = project_res
self.fp_loss = fp_loss
self.dataset = dataset
self.context_prior = context_prior
self.frustum_size = frustum_size
self.class_names = class_names
self.relation_loss = relation_loss
self.CE_ssc_loss = CE_ssc_loss
self.sem_scal_loss = sem_scal_loss
self.geo_scal_loss = geo_scal_loss
self.project_scale = project_scale
self.class_weights = class_weights
self.lr = lr
self.weight_decay = weight_decay
self.projects = {}
self.scale_2ds = [1, 2, 4, 8] # 2D scales
for scale_2d in self.scale_2ds:
self.projects[str(scale_2d)] = FLoSP(
full_scene_size, project_scale=self.project_scale, dataset=self.dataset
)
self.projects = nn.ModuleDict(self.projects)
self.n_classes = n_classes
if self.dataset == "NYU":
self.net_3d_decoder = UNet3DNYU(
self.n_classes,
nn.BatchNorm3d,
n_relations=n_relations,
feature=feature,
full_scene_size=full_scene_size,
context_prior=context_prior,
)
elif self.dataset == "kitti":
self.net_3d_decoder = UNet3DKitti(
self.n_classes,
nn.BatchNorm3d,
project_scale=project_scale,
feature=feature,
full_scene_size=full_scene_size,
context_prior=context_prior,
)
self.net_rgb = UNet2D.build(out_feature=feature, use_decoder=True)
# log hyperparameters
self.save_hyperparameters()
self.train_metrics = SSCMetrics(self.n_classes)
self.val_metrics = SSCMetrics(self.n_classes)
self.test_metrics = SSCMetrics(self.n_classes)
def forward(self, batch):
img = batch["img"]
bs = len(img)
out = {}
x_rgb = self.net_rgb(img)
x3ds = []
for i in range(bs):
x3d = None
for scale_2d in self.project_res:
# project features at each 2D scale to target 3D scale
scale_2d = int(scale_2d)
projected_pix = batch["projected_pix_{}".format(self.project_scale)][i].cuda()
fov_mask = batch["fov_mask_{}".format(self.project_scale)][i].cuda()
# Sum all the 3D features
if x3d is None:
x3d = self.projects[str(scale_2d)](
x_rgb["1_" + str(scale_2d)][i],
projected_pix // scale_2d,
fov_mask,
)
else:
x3d += self.projects[str(scale_2d)](
x_rgb["1_" + str(scale_2d)][i],
projected_pix // scale_2d,
fov_mask,
)
x3ds.append(x3d)
input_dict = {
"x3d": torch.stack(x3ds),
}
out = self.net_3d_decoder(input_dict)
return out
def step(self, batch, step_type, metric):
bs = len(batch["img"])
loss = 0
out_dict = self(batch)
ssc_pred = out_dict["ssc_logit"]
target = batch["target"]
if self.context_prior:
P_logits = out_dict["P_logits"]
CP_mega_matrices = batch["CP_mega_matrices"]
if self.relation_loss:
loss_rel_ce = compute_super_CP_multilabel_loss(
P_logits, CP_mega_matrices
)
loss += loss_rel_ce
self.log(
step_type + "/loss_relation_ce_super",
loss_rel_ce.detach(),
on_epoch=True,
sync_dist=True,
)
class_weight = self.class_weights.type_as(batch["img"])
if self.CE_ssc_loss:
loss_ssc = CE_ssc_loss(ssc_pred, target, class_weight)
loss += loss_ssc
self.log(
step_type + "/loss_ssc",
loss_ssc.detach(),
on_epoch=True,
sync_dist=True,
)
if self.sem_scal_loss:
loss_sem_scal = sem_scal_loss(ssc_pred, target)
loss += loss_sem_scal
self.log(
step_type + "/loss_sem_scal",
loss_sem_scal.detach(),
on_epoch=True,
sync_dist=True,
)
if self.geo_scal_loss:
loss_geo_scal = geo_scal_loss(ssc_pred, target)
loss += loss_geo_scal
self.log(
step_type + "/loss_geo_scal",
loss_geo_scal.detach(),
on_epoch=True,
sync_dist=True,
)
if self.fp_loss and step_type != "test":
frustums_masks = torch.stack(batch["frustums_masks"])
frustums_class_dists = torch.stack(
batch["frustums_class_dists"]
).float() # (bs, n_frustums, n_classes)
n_frustums = frustums_class_dists.shape[1]
pred_prob = F.softmax(ssc_pred, dim=1)
batch_cnt = frustums_class_dists.sum(0) # (n_frustums, n_classes)
frustum_loss = 0
frustum_nonempty = 0
for frus in range(n_frustums):
frustum_mask = frustums_masks[:, frus, :, :, :].unsqueeze(1).float()
prob = frustum_mask * pred_prob # bs, n_classes, H, W, D
prob = prob.reshape(bs, self.n_classes, -1).permute(1, 0, 2)
prob = prob.reshape(self.n_classes, -1)
cum_prob = prob.sum(dim=1) # n_classes
total_cnt = torch.sum(batch_cnt[frus])
total_prob = prob.sum()
if total_prob > 0 and total_cnt > 0:
frustum_target_proportion = batch_cnt[frus] / total_cnt
cum_prob = cum_prob / total_prob # n_classes
frustum_loss_i = KL_sep(cum_prob, frustum_target_proportion)
frustum_loss += frustum_loss_i
frustum_nonempty += 1
frustum_loss = frustum_loss / frustum_nonempty
loss += frustum_loss
self.log(
step_type + "/loss_frustums",
frustum_loss.detach(),
on_epoch=True,
sync_dist=True,
)
y_true = target.cpu().numpy()
y_pred = ssc_pred.detach().cpu().numpy()
y_pred = np.argmax(y_pred, axis=1)
metric.add_batch(y_pred, y_true)
self.log(step_type + "/loss", loss.detach(), on_epoch=True, sync_dist=True)
return loss
def training_step(self, batch, batch_idx):
return self.step(batch, "train", self.train_metrics)
def validation_step(self, batch, batch_idx):
self.step(batch, "val", self.val_metrics)
def validation_epoch_end(self, outputs):
metric_list = [("train", self.train_metrics), ("val", self.val_metrics)]
for prefix, metric in metric_list:
stats = metric.get_stats()
for i, class_name in enumerate(self.class_names):
self.log(
"{}_SemIoU/{}".format(prefix, class_name),
stats["iou_ssc"][i],
sync_dist=True,
)
self.log("{}/mIoU".format(prefix), stats["iou_ssc_mean"], sync_dist=True)
self.log("{}/IoU".format(prefix), stats["iou"], sync_dist=True)
self.log("{}/Precision".format(prefix), stats["precision"], sync_dist=True)
self.log("{}/Recall".format(prefix), stats["recall"], sync_dist=True)
metric.reset()
def test_step(self, batch, batch_idx):
self.step(batch, "test", self.test_metrics)
def test_epoch_end(self, outputs):
classes = self.class_names
metric_list = [("test", self.test_metrics)]
for prefix, metric in metric_list:
print("{}======".format(prefix))
stats = metric.get_stats()
print(
"Precision={:.4f}, Recall={:.4f}, IoU={:.4f}".format(
stats["precision"] * 100, stats["recall"] * 100, stats["iou"] * 100
)
)
print("class IoU: {}, ".format(classes))
print(
" ".join(["{:.4f}, "] * len(classes)).format(
*(stats["iou_ssc"] * 100).tolist()
)
)
print("mIoU={:.4f}".format(stats["iou_ssc_mean"] * 100))
metric.reset()
def configure_optimizers(self):
if self.dataset == "NYU":
optimizer = torch.optim.AdamW(
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
)
scheduler = MultiStepLR(optimizer, milestones=[20], gamma=0.1)
return [optimizer], [scheduler]
elif self.dataset == "kitti":
optimizer = torch.optim.AdamW(
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
)
scheduler = MultiStepLR(optimizer, milestones=[20], gamma=0.1)
return [optimizer], [scheduler]