from utils import * from modules import * from data import * import torch.nn.functional as F import pytorch_lightning as pl import torch.multiprocessing import seaborn as sns import unet class LitUnsupervisedSegmenter(pl.LightningModule): def __init__(self, n_classes, cfg): super().__init__() self.name = "LitUnsupervisedSegmenter" self.cfg = cfg self.n_classes = n_classes if not cfg.continuous: dim = n_classes else: dim = cfg.dim data_dir = join(cfg.output_root, "data") if cfg.arch == "feature-pyramid": cut_model = load_model(cfg.model_type, data_dir).cuda() self.net = FeaturePyramidNet( cfg.granularity, cut_model, dim, cfg.continuous ) elif cfg.arch == "dino": self.net = DinoFeaturizer(dim, cfg) else: raise ValueError("Unknown arch {}".format(cfg.arch)) self.train_cluster_probe = ClusterLookup(dim, n_classes) self.cluster_probe = ClusterLookup(dim, n_classes + cfg.extra_clusters) # self.linear_probe = nn.Conv2d(dim, n_classes, (1, 1)) # self.linear_probe = nn.Sequential(OrderedDict([ # ('conv1', nn.Conv2d(dim, 2*n_classes, (7, 7), padding='same')), # ('relu1', nn.ReLU()), # ('conv2', nn.Conv2d(2*n_classes, n_classes, (3, 3), padding='same')) # ])) self.linear_probe = unet.AuxUNet( enc_chs=(3, 32, 64, 128, 256), dec_chs=(256, 128, 64, 32), aux_ch=70, num_class=n_classes, ) self.decoder = nn.Conv2d(dim, self.net.n_feats, (1, 1)) self.cluster_metrics = UnsupervisedMetrics( "test/cluster/", n_classes, cfg.extra_clusters, True ) self.linear_metrics = UnsupervisedMetrics("test/linear/", n_classes, 0, False) self.test_cluster_metrics = UnsupervisedMetrics( "final/cluster/", n_classes, cfg.extra_clusters, True ) self.test_linear_metrics = UnsupervisedMetrics( "final/linear/", n_classes, 0, False ) self.linear_probe_loss_fn = torch.nn.CrossEntropyLoss() self.crf_loss_fn = ContrastiveCRFLoss( cfg.crf_samples, cfg.alpha, cfg.beta, cfg.gamma, cfg.w1, cfg.w2, cfg.shift ) self.contrastive_corr_loss_fn = ContrastiveCorrelationLoss(cfg) for p in self.contrastive_corr_loss_fn.parameters(): p.requires_grad = False self.automatic_optimization = False if self.cfg.dataset_name.startswith("cityscapes"): self.label_cmap = create_cityscapes_colormap() else: self.label_cmap = create_pascal_label_colormap() self.val_steps = 0 self.save_hyperparameters() def forward(self, x): # in lightning, forward defines the prediction/inference actions return self.net(x)[1] def training_step(self, batch, batch_idx): # training_step defined the train loop. # It is independent of forward net_optim, linear_probe_optim, cluster_probe_optim = self.optimizers() net_optim.zero_grad() linear_probe_optim.zero_grad() cluster_probe_optim.zero_grad() with torch.no_grad(): ind = batch["ind"] img = batch["img"] img_aug = batch["img_aug"] coord_aug = batch["coord_aug"] img_pos = batch["img_pos"] label = batch["label"] label_pos = batch["label_pos"] feats, code = self.net(img) if self.cfg.correspondence_weight > 0: feats_pos, code_pos = self.net(img_pos) log_args = dict(sync_dist=False, rank_zero_only=True) if self.cfg.use_true_labels: signal = one_hot_feats(label + 1, self.n_classes + 1) signal_pos = one_hot_feats(label_pos + 1, self.n_classes + 1) else: signal = feats signal_pos = feats_pos loss = 0 should_log_hist = ( (self.cfg.hist_freq is not None) and (self.global_step % self.cfg.hist_freq == 0) and (self.global_step > 0) ) if self.cfg.use_salience: salience = batch["mask"].to(torch.float32).squeeze(1) salience_pos = batch["mask_pos"].to(torch.float32).squeeze(1) else: salience = None salience_pos = None if self.cfg.correspondence_weight > 0: ( pos_intra_loss, pos_intra_cd, pos_inter_loss, pos_inter_cd, neg_inter_loss, neg_inter_cd, ) = self.contrastive_corr_loss_fn( signal, signal_pos, salience, salience_pos, code, code_pos, ) if should_log_hist: self.logger.experiment.add_histogram( "intra_cd", pos_intra_cd, self.global_step ) self.logger.experiment.add_histogram( "inter_cd", pos_inter_cd, self.global_step ) self.logger.experiment.add_histogram( "neg_cd", neg_inter_cd, self.global_step ) neg_inter_loss = neg_inter_loss.mean() pos_intra_loss = pos_intra_loss.mean() pos_inter_loss = pos_inter_loss.mean() self.log("loss/pos_intra", pos_intra_loss, **log_args) self.log("loss/pos_inter", pos_inter_loss, **log_args) self.log("loss/neg_inter", neg_inter_loss, **log_args) self.log("cd/pos_intra", pos_intra_cd.mean(), **log_args) self.log("cd/pos_inter", pos_inter_cd.mean(), **log_args) self.log("cd/neg_inter", neg_inter_cd.mean(), **log_args) loss += ( self.cfg.pos_inter_weight * pos_inter_loss + self.cfg.pos_intra_weight * pos_intra_loss + self.cfg.neg_inter_weight * neg_inter_loss ) * self.cfg.correspondence_weight if self.cfg.rec_weight > 0: rec_feats = self.decoder(code) rec_loss = -(norm(rec_feats) * norm(feats)).sum(1).mean() self.log("loss/rec", rec_loss, **log_args) loss += self.cfg.rec_weight * rec_loss if self.cfg.aug_alignment_weight > 0: orig_feats_aug, orig_code_aug = self.net(img_aug) downsampled_coord_aug = resize( coord_aug.permute(0, 3, 1, 2), orig_code_aug.shape[2] ).permute(0, 2, 3, 1) aug_alignment = -torch.einsum( "bkhw,bkhw->bhw", norm(sample(code, downsampled_coord_aug)), norm(orig_code_aug), ).mean() self.log("loss/aug_alignment", aug_alignment, **log_args) loss += self.cfg.aug_alignment_weight * aug_alignment if self.cfg.crf_weight > 0: crf = self.crf_loss_fn(resize(img, 56), norm(resize(code, 56))).mean() self.log("loss/crf", crf, **log_args) loss += self.cfg.crf_weight * crf flat_label = label.reshape(-1) mask = (flat_label >= 0) & (flat_label < self.n_classes) detached_code = torch.clone(code.detach()) # pdb.set_trace() linear_logits = self.linear_probe(img, detached_code) linear_logits = F.interpolate( linear_logits, label.shape[-2:], mode="bilinear", align_corners=False ) linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, self.n_classes) linear_loss = self.linear_probe_loss_fn( linear_logits[mask], flat_label[mask] ).mean() loss += linear_loss self.log("loss/linear", linear_loss, **log_args) cluster_loss, cluster_probs = self.cluster_probe(detached_code, None) loss += cluster_loss self.log("loss/cluster", cluster_loss, **log_args) self.log("loss/total", loss, **log_args) self.manual_backward(loss) net_optim.step() cluster_probe_optim.step() linear_probe_optim.step() if ( self.cfg.reset_probe_steps is not None and self.global_step == self.cfg.reset_probe_steps ): print("RESETTING PROBES") self.linear_probe.reset_parameters() self.cluster_probe.reset_parameters() self.trainer.optimizers[1] = torch.optim.Adam( list(self.linear_probe.parameters()), lr=5e-3 ) self.trainer.optimizers[2] = torch.optim.Adam( list(self.cluster_probe.parameters()), lr=5e-3 ) if self.global_step % 2000 == 0 and self.global_step > 0: print("RESETTING TFEVENT FILE") # Make a new tfevent file self.logger.experiment.close() self.logger.experiment._get_file_writer() return loss def on_train_start(self): tb_metrics = {**self.linear_metrics.compute(), **self.cluster_metrics.compute()} self.logger.log_hyperparams(self.cfg, tb_metrics) def validation_step(self, batch, batch_idx): img = batch["img"] label = batch["label"] self.net.eval() with torch.no_grad(): feats, code = self.net(img) # code = F.interpolate(code, label.shape[-2:], mode='bilinear', align_corners=False) # linear_preds = self.linear_probe(code) linear_preds = self.linear_probe(img, code) linear_preds = linear_preds.argmax(1) self.linear_metrics.update(linear_preds, label) code = F.interpolate( code, label.shape[-2:], mode="bilinear", align_corners=False ) cluster_loss, cluster_preds = self.cluster_probe(code, None) cluster_preds = cluster_preds.argmax(1) self.cluster_metrics.update(cluster_preds, label) return { "img": img[: self.cfg.n_images].detach().cpu(), "linear_preds": linear_preds[: self.cfg.n_images].detach().cpu(), "cluster_preds": cluster_preds[: self.cfg.n_images].detach().cpu(), "label": label[: self.cfg.n_images].detach().cpu(), } def validation_epoch_end(self, outputs) -> None: super().validation_epoch_end(outputs) with torch.no_grad(): tb_metrics = { **self.linear_metrics.compute(), **self.cluster_metrics.compute(), } if self.trainer.is_global_zero and not self.cfg.submitting_to_aml: # output_num = 0 output_num = random.randint(0, len(outputs) - 1) output = {k: v.detach().cpu() for k, v in outputs[output_num].items()} # pdb.set_trace() alpha = 0.4 n_rows = 6 fig, ax = plt.subplots( n_rows, self.cfg.n_images, figsize=(self.cfg.n_images * 3, n_rows * 3), ) for i in range(self.cfg.n_images): try: rbg_img = prep_for_plot(output["img"][i]) true_label = output["label"].squeeze()[i] true_label[true_label == -1] = 7 except: continue # ax[0, i].imshow(prep_for_plot(output["img"][i])) # ax[1, i].imshow(self.label_cmap[output["label"].squeeze()[i]]) # ax[2, i].imshow(self.label_cmap[output["linear_preds"][i]]) # ax[3, i].imshow(self.label_cmap[self.cluster_metrics.map_clusters(output["cluster_preds"][i])]) ax[0, i].imshow(rbg_img) ax[1, i].imshow(rbg_img) ax[1, i].imshow(true_label, alpha=alpha, cmap=cmap, norm=norm) ax[2, i].imshow(rbg_img) pred_label = output["linear_preds"][i] ax[2, i].imshow(pred_label, alpha=alpha, cmap=cmap, norm=norm) ax[3, i].imshow(rbg_img) retouched_label = retouch_label(pred_label.numpy(), true_label) ax[3, i].imshow(retouched_label, alpha=alpha, cmap=cmap, norm=norm) ax[4, i].imshow(rbg_img) pred_label = self.cluster_metrics.map_clusters( output["cluster_preds"][i] ) ax[4, i].imshow(pred_label, alpha=alpha, cmap=cmap, norm=norm) # ax[3, i].imshow(map_clusters_with_label(true_label, pred_label), alpha=0.5, cmap=cmap, norm=norm) ax[5, i].imshow(rbg_img) retouched_label = retouch_label(pred_label.numpy(), true_label) ax[5, i].imshow(retouched_label, alpha=alpha, cmap=cmap, norm=norm) ax[0, 0].set_ylabel("Image", fontsize=16) ax[1, 0].set_ylabel("Label", fontsize=16) ax[2, 0].set_ylabel("UNet Probe", fontsize=16) ax[3, 0].set_ylabel("Retouched UNet Probe", fontsize=16) ax[4, 0].set_ylabel("Cluster Probe", fontsize=16) ax[5, 0].set_ylabel("Retouched cluster Probe", fontsize=16) remove_axes(ax) plt.tight_layout() add_plot(self.logger.experiment, "plot_labels", self.global_step) if self.cfg.has_labels: fig = plt.figure(figsize=(13, 10)) ax = fig.gca() hist = ( self.cluster_metrics.histogram.detach().cpu().to(torch.float32) ) hist /= torch.clamp_min(hist.sum(dim=0, keepdim=True), 1) sns.heatmap(hist.t(), annot=False, fmt="g", ax=ax, cmap="Blues") ax.set_xlabel("Predicted labels") ax.set_ylabel("True labels") names = get_class_labels(self.cfg.dataset_name) if self.cfg.extra_clusters: names = names + ["Extra"] ax.set_xticks(np.arange(0, len(names)) + 0.5) ax.set_yticks(np.arange(0, len(names)) + 0.5) ax.xaxis.tick_top() ax.xaxis.set_ticklabels(names, fontsize=14) ax.yaxis.set_ticklabels(names, fontsize=14) colors = [self.label_cmap[i] / 255.0 for i in range(len(names))] [ t.set_color(colors[i]) for i, t in enumerate(ax.xaxis.get_ticklabels()) ] [ t.set_color(colors[i]) for i, t in enumerate(ax.yaxis.get_ticklabels()) ] # ax.yaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0) # ax.xaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0) plt.xticks(rotation=90) plt.yticks(rotation=0) ax.vlines( np.arange(0, len(names) + 1), color=[0.5, 0.5, 0.5], *ax.get_xlim() ) ax.hlines( np.arange(0, len(names) + 1), color=[0.5, 0.5, 0.5], *ax.get_ylim() ) plt.tight_layout() add_plot(self.logger.experiment, "conf_matrix", self.global_step) all_bars = torch.cat( [ self.cluster_metrics.histogram.sum(0).cpu(), self.cluster_metrics.histogram.sum(1).cpu(), ], axis=0, ) ymin = max(all_bars.min() * 0.8, 1) ymax = all_bars.max() * 1.2 fig, ax = plt.subplots(1, 2, figsize=(2 * 5, 1 * 4)) ax[0].bar( range(self.n_classes + self.cfg.extra_clusters), self.cluster_metrics.histogram.sum(0).cpu(), tick_label=names, color=colors, ) ax[0].set_ylim(ymin, ymax) ax[0].set_title("Label Frequency") ax[0].set_yscale("log") ax[0].tick_params(axis="x", labelrotation=90) ax[1].bar( range(self.n_classes + self.cfg.extra_clusters), self.cluster_metrics.histogram.sum(1).cpu(), tick_label=names, color=colors, ) ax[1].set_ylim(ymin, ymax) ax[1].set_title("Cluster Frequency") ax[1].set_yscale("log") ax[1].tick_params(axis="x", labelrotation=90) plt.tight_layout() add_plot( self.logger.experiment, "label frequency", self.global_step ) if self.global_step > 2: self.log_dict(tb_metrics) if self.trainer.is_global_zero and self.cfg.azureml_logging: from azureml.core.run import Run run_logger = Run.get_context() for metric, value in tb_metrics.items(): run_logger.log(metric, value) self.linear_metrics.reset() self.cluster_metrics.reset() def configure_optimizers(self): main_params = list(self.net.parameters()) if self.cfg.rec_weight > 0: main_params.extend(self.decoder.parameters()) net_optim = torch.optim.Adam(main_params, lr=self.cfg.lr) linear_probe_optim = torch.optim.Adam( list(self.linear_probe.parameters()), lr=5e-3 ) cluster_probe_optim = torch.optim.Adam( list(self.cluster_probe.parameters()), lr=5e-3 ) return net_optim, linear_probe_optim, cluster_probe_optim