import torch import torch.nn as nn import gradio as gr import numpy as np import os import random import pickle as pkl from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion from datasets.preprocessors import RGBDValPre from utils.constants import Constants as C class Arguments: def __init__(self, ratio): self.ratio = ratio self.masking_ratio = 1.0 colors = pkl.load(open('./colors.pkl', 'rb')) args = Arguments(ratio = 0.8) mtmodel = WeTrLinearFusion("mit_b2", args, num_classes=13, pretrained=False) mtmodelpath = './checkpoints/sid_1-500_mtteacher.pth' mtmodel.load_state_dict(torch.load(mtmodelpath, map_location=torch.device('cpu'))) mtmodel.eval() m3lmodel = LinearFusionMaskedConsistencyMixBatch("mit_b2", args, num_classes=13, pretrained=False) m3lmodelpath = './checkpoints/sid_1-500_m3lteacher.pth' m3lmodel.load_state_dict(torch.load(m3lmodelpath, map_location=torch.device('cpu'))) m3lmodel.eval() class MaskStudentTeacher(nn.Module): def __init__(self, student, teacher, ema_alpha, mode = 'train'): super(MaskStudentTeacher, self).__init__() self.student = student self.teacher = teacher self.teacher = self._detach_teacher(self.teacher) self.ema_alpha = ema_alpha self.mode = mode def forward(self, data, student = True, teacher = True, mask = False, range_batches_to_mask = None, **kwargs): ret = [] if student: if self.mode == 'train': ret.append(self.student(data, mask = mask, range_batches_to_mask = range_batches_to_mask, **kwargs)) elif self.mode == 'val': ret.append(self.student(data, mask = False, **kwargs)) else: raise Exception('Mode not supported') if teacher: ret.append(self.teacher(data, mask = False, **kwargs)) #Not computing loss for teacher ever but passing the results as if loss was also returned return ret def _detach_teacher(self, model): for param in model.parameters(): param.detach_() return model def update_teacher_models(self, global_step): alpha = min(1 - 1 / (global_step + 1), self.ema_alpha) for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()): ema_param.data.mul_(alpha).add_(1 - alpha, param.data) return def copy_student_to_teacher(self): for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()): ema_param.data.mul_(0).add_(param.data) return def get_params(self): student_params = self.student.get_params() teacher_params = self.teacher.get_params() return student_params def preprocess_data(rgb, depth, dataset_settings): #RGB: np.array, RGB #Depth: np.array, minmax normalized, *255 preprocess = RGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings) rgb, depth = preprocess(rgb, depth) if rgb is not None: rgb = torch.from_numpy(np.ascontiguousarray(rgb)).float() if depth is not None: depth = torch.from_numpy(np.ascontiguousarray(depth)).float() return rgb, depth def visualize(colors, pred, num_classes, dataset_settings): pred = pred.transpose(1, 2, 0) predvis = np.zeros((dataset_settings['orig_height'], dataset_settings['orig_width'], 3)) for i in range(num_classes): color = colors[i] predvis = np.where(pred == i, color, predvis) predvis /= 255.0 predvis = predvis[:,:,::-1] return predvis def predict(rgb, depth, check): dataset_settings = {} dataset_settings['image_height'], dataset_settings['image_width'] = 540, 540 dataset_settings['orig_height'], dataset_settings['orig_width'] = 540,540 rgb, depth = preprocess_data(rgb, depth, dataset_settings) if rgb is not None: rgb = rgb.unsqueeze(dim = 0) if depth is not None: depth = depth.unsqueeze(dim = 0) ret = [None, None, './classcolors.png'] if "Mean Teacher" in check: if rgb is None: rgb = torch.zeros_like(depth) if depth is None: depth = torch.zeros_like(rgb) scores = mtmodel([rgb, depth])[2] scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True) prob = scores.detach() _, pred = torch.max(prob, dim=1) pred = pred.numpy() predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings) ret[0] = predvis if "M3L" in check: mask = False masking_branch = None if rgb is None: mask = True masking_branch = 0 if depth is None: mask = True masking_branch = 1 scores = m3lmodel([rgb, depth], mask = mask, masking_branch = masking_branch)[2] scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True) prob = scores.detach() _, pred = torch.max(prob, dim=1) pred = pred.numpy() predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings) ret[1] = predvis return ret imgs = os.listdir('./examples/rgb') random.shuffle(imgs) examples = [] for img in imgs: examples.append([ './examples/rgb/'+img, './examples/depth/'+img, ["M3L", "Mean Teacher"] ]) with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Row(): gr.Markdown( """

M3L

Multi-modal teacher for Masked Modality Learning

Demo to visualize predictions from the Linear Fusion model trained with the vanilla Mean Teacher and the M3L framework when trained with 0.2% (98) labels.
""" ) with gr.Row(): rgbinput = gr.Image(label="RGB Input").style(height=256, width=256) depthinput = gr.Image(label="Depth Input").style(height=256, width=256) with gr.Row(): modelcheck = gr.CheckboxGroup(["Mean Teacher", "M3L"], label="Predictions from", info="Predict using model trained with:") with gr.Row(): submit_btn = gr.Button("Submit") with gr.Row(): mtoutput = gr.Image(label="Mean Teacher Output").style(height=384, width=384) m3loutput = gr.Image(label="M3L Output").style(height=384, width=384) classnameouptut = gr.Image(label="Classes").style(height=384, width=384) with gr.Row(): examplesRow = gr.Examples(examples=examples, examples_per_page=10, inputs=[rgbinput, depthinput, modelcheck]) with gr.Row(): gr.Markdown( """ Read more about [M3L](https://harshm121.github.io/projects/m3l.html)! """ ) submit_btn.click(fn = predict, inputs = [rgbinput, depthinput, modelcheck], outputs = [mtoutput, m3loutput, classnameouptut]) demo.queue(concurrency_count=3) demo.launch()