|
import re |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
from argparse import ArgumentParser |
|
import pytorch_lightning as pl |
|
from .lsegmentation_module import LSegmentationModule |
|
from .models.lseg_net import LSegNet |
|
from encoding.models.sseg.base import up_kwargs |
|
|
|
import os |
|
import clip |
|
import numpy as np |
|
|
|
from scipy import signal |
|
import glob |
|
|
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
|
|
|
|
class LSegModule(LSegmentationModule): |
|
def __init__(self, data_path, dataset, batch_size, base_lr, max_epochs, **kwargs): |
|
super(LSegModule, self).__init__( |
|
data_path, dataset, batch_size, base_lr, max_epochs, **kwargs |
|
) |
|
|
|
if dataset == "citys": |
|
self.base_size = 2048 |
|
self.crop_size = 768 |
|
else: |
|
self.base_size = 520 |
|
self.crop_size = 480 |
|
|
|
use_pretrained = True |
|
norm_mean= [0.5, 0.5, 0.5] |
|
norm_std = [0.5, 0.5, 0.5] |
|
|
|
print('** Use norm {}, {} as the mean and std **'.format(norm_mean, norm_std)) |
|
|
|
train_transform = [ |
|
transforms.ToTensor(), |
|
transforms.Normalize(norm_mean, norm_std), |
|
] |
|
|
|
val_transform = [ |
|
transforms.ToTensor(), |
|
transforms.Normalize(norm_mean, norm_std), |
|
] |
|
|
|
self.train_transform = transforms.Compose(train_transform) |
|
self.val_transform = transforms.Compose(val_transform) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.num_classes = 255 |
|
self.train_accuracy = pl.metrics.Accuracy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_batchnorm = ( |
|
(not kwargs["no_batchnorm"]) if "no_batchnorm" in kwargs else True |
|
) |
|
|
|
|
|
labels = self.get_labels('ade20k') |
|
|
|
self.net = LSegNet( |
|
labels=labels, |
|
backbone=kwargs["backbone"], |
|
features=kwargs["num_features"], |
|
crop_size=self.crop_size, |
|
arch_option=kwargs["arch_option"], |
|
block_depth=kwargs["block_depth"], |
|
activation=kwargs["activation"], |
|
) |
|
|
|
self.net.pretrained.model.patch_embed.img_size = ( |
|
self.crop_size, |
|
self.crop_size, |
|
) |
|
|
|
self._up_kwargs = up_kwargs |
|
self.mean = norm_mean |
|
self.std = norm_std |
|
|
|
self.criterion = self.get_criterion(**kwargs) |
|
|
|
def get_labels(self, dataset): |
|
labels = [] |
|
path = 'label_files/{}_objectInfo150.txt'.format(dataset) |
|
assert os.path.exists(path), '*** Error : {} not exist !!!'.format(path) |
|
f = open(path, 'r') |
|
lines = f.readlines() |
|
for line in lines: |
|
label = line.strip().split(',')[-1].split(';')[0] |
|
labels.append(label) |
|
f.close() |
|
if dataset in ['ade20k']: |
|
labels = labels[1:] |
|
return labels |
|
|
|
|
|
@staticmethod |
|
def add_model_specific_args(parent_parser): |
|
parser = LSegmentationModule.add_model_specific_args(parent_parser) |
|
parser = ArgumentParser(parents=[parser]) |
|
|
|
parser.add_argument( |
|
"--backbone", |
|
type=str, |
|
default="clip_vitl16_384", |
|
help="backbone network", |
|
) |
|
|
|
parser.add_argument( |
|
"--num_features", |
|
type=int, |
|
default=256, |
|
help="number of featurs that go from encoder to decoder", |
|
) |
|
|
|
parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate") |
|
|
|
parser.add_argument( |
|
"--finetune_weights", type=str, help="load weights to finetune from" |
|
) |
|
|
|
parser.add_argument( |
|
"--no-scaleinv", |
|
default=True, |
|
action="store_false", |
|
help="turn off scaleinv layers", |
|
) |
|
|
|
parser.add_argument( |
|
"--no-batchnorm", |
|
default=False, |
|
action="store_true", |
|
help="turn off batchnorm", |
|
) |
|
|
|
parser.add_argument( |
|
"--widehead", default=False, action="store_true", help="wider output head" |
|
) |
|
|
|
parser.add_argument( |
|
"--widehead_hr", |
|
default=False, |
|
action="store_true", |
|
help="wider output head", |
|
) |
|
|
|
parser.add_argument( |
|
"--arch_option", |
|
type=int, |
|
default=0, |
|
help="which kind of architecture to be used", |
|
) |
|
|
|
parser.add_argument( |
|
"--block_depth", |
|
type=int, |
|
default=0, |
|
help="how many blocks should be used", |
|
) |
|
|
|
parser.add_argument( |
|
"--activation", |
|
choices=['lrelu', 'tanh'], |
|
default="lrelu", |
|
help="use which activation to activate the block", |
|
) |
|
|
|
return parser |
|
|