M3L / app.py
harshm121's picture
[Dev] theme=soft
d22da29
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
import os
import random
import pickle as pkl
from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch
from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion
from datasets.preprocessors import RGBDValPre
from utils.constants import Constants as C
class Arguments:
def __init__(self, ratio):
self.ratio = ratio
self.masking_ratio = 1.0
colors = pkl.load(open('./colors.pkl', 'rb'))
args = Arguments(ratio = 0.8)
mtmodel = WeTrLinearFusion("mit_b2", args, num_classes=13, pretrained=False)
mtmodelpath = './checkpoints/sid_1-500_mtteacher.pth'
mtmodel.load_state_dict(torch.load(mtmodelpath, map_location=torch.device('cpu')))
mtmodel.eval()
m3lmodel = LinearFusionMaskedConsistencyMixBatch("mit_b2", args, num_classes=13, pretrained=False)
m3lmodelpath = './checkpoints/sid_1-500_m3lteacher.pth'
m3lmodel.load_state_dict(torch.load(m3lmodelpath, map_location=torch.device('cpu')))
m3lmodel.eval()
class MaskStudentTeacher(nn.Module):
def __init__(self, student, teacher, ema_alpha, mode = 'train'):
super(MaskStudentTeacher, self).__init__()
self.student = student
self.teacher = teacher
self.teacher = self._detach_teacher(self.teacher)
self.ema_alpha = ema_alpha
self.mode = mode
def forward(self, data, student = True, teacher = True, mask = False, range_batches_to_mask = None, **kwargs):
ret = []
if student:
if self.mode == 'train':
ret.append(self.student(data, mask = mask, range_batches_to_mask = range_batches_to_mask, **kwargs))
elif self.mode == 'val':
ret.append(self.student(data, mask = False, **kwargs))
else:
raise Exception('Mode not supported')
if teacher:
ret.append(self.teacher(data, mask = False, **kwargs)) #Not computing loss for teacher ever but passing the results as if loss was also returned
return ret
def _detach_teacher(self, model):
for param in model.parameters():
param.detach_()
return model
def update_teacher_models(self, global_step):
alpha = min(1 - 1 / (global_step + 1), self.ema_alpha)
for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()):
ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
return
def copy_student_to_teacher(self):
for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()):
ema_param.data.mul_(0).add_(param.data)
return
def get_params(self):
student_params = self.student.get_params()
teacher_params = self.teacher.get_params()
return student_params
def preprocess_data(rgb, depth, dataset_settings):
#RGB: np.array, RGB
#Depth: np.array, minmax normalized, *255
preprocess = RGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
rgb, depth = preprocess(rgb, depth)
if rgb is not None:
rgb = torch.from_numpy(np.ascontiguousarray(rgb)).float()
if depth is not None:
depth = torch.from_numpy(np.ascontiguousarray(depth)).float()
return rgb, depth
def visualize(colors, pred, num_classes, dataset_settings):
pred = pred.transpose(1, 2, 0)
predvis = np.zeros((dataset_settings['orig_height'], dataset_settings['orig_width'], 3))
for i in range(num_classes):
color = colors[i]
predvis = np.where(pred == i, color, predvis)
predvis /= 255.0
predvis = predvis[:,:,::-1]
return predvis
def predict(rgb, depth, check):
dataset_settings = {}
dataset_settings['image_height'], dataset_settings['image_width'] = 540, 540
dataset_settings['orig_height'], dataset_settings['orig_width'] = 540,540
rgb, depth = preprocess_data(rgb, depth, dataset_settings)
if rgb is not None:
rgb = rgb.unsqueeze(dim = 0)
if depth is not None:
depth = depth.unsqueeze(dim = 0)
ret = [None, None, './classcolors.png']
if "Mean Teacher" in check:
if rgb is None:
rgb = torch.zeros_like(depth)
if depth is None:
depth = torch.zeros_like(rgb)
scores = mtmodel([rgb, depth])[2]
scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True)
prob = scores.detach()
_, pred = torch.max(prob, dim=1)
pred = pred.numpy()
predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings)
ret[0] = predvis
if "M3L" in check:
mask = False
masking_branch = None
if rgb is None:
mask = True
masking_branch = 0
if depth is None:
mask = True
masking_branch = 1
scores = m3lmodel([rgb, depth], mask = mask, masking_branch = masking_branch)[2]
scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True)
prob = scores.detach()
_, pred = torch.max(prob, dim=1)
pred = pred.numpy()
predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings)
ret[1] = predvis
return ret
imgs = os.listdir('./examples/rgb')
random.shuffle(imgs)
examples = []
for img in imgs:
examples.append([
'./examples/rgb/'+img, './examples/depth/'+img, ["M3L", "Mean Teacher"]
])
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Row():
gr.Markdown(
"""
<center><h2>M3L</h2></center>
<center>Multi-modal teacher for Masked Modality Learning</center>
<br>
<center>Demo to visualize predictions from the Linear Fusion model trained with the vanilla Mean Teacher and the <a href='https://harshm121.github.io/projects/m3l.html'>M3L</a> framework when trained with 0.2% (98) labels. </center>
"""
)
with gr.Row():
rgbinput = gr.Image(label="RGB Input").style(height=256, width=256)
depthinput = gr.Image(label="Depth Input").style(height=256, width=256)
with gr.Row():
modelcheck = gr.CheckboxGroup(["Mean Teacher", "M3L"], label="Predictions from", info="Predict using model trained with:")
with gr.Row():
submit_btn = gr.Button("Submit")
with gr.Row():
mtoutput = gr.Image(label="Mean Teacher Output").style(height=384, width=384)
m3loutput = gr.Image(label="M3L Output").style(height=384, width=384)
classnameouptut = gr.Image(label="Classes").style(height=384, width=384)
with gr.Row():
examplesRow = gr.Examples(examples=examples, examples_per_page=10, inputs=[rgbinput, depthinput, modelcheck])
with gr.Row():
gr.Markdown(
"""
Read more about [M3L](https://harshm121.github.io/projects/m3l.html)!
"""
)
submit_btn.click(fn = predict, inputs = [rgbinput, depthinput, modelcheck], outputs = [mtoutput, m3loutput, classnameouptut])
demo.queue(concurrency_count=3)
demo.launch()