Biomap / biomap /train.py
jeremyLE-Ekimetrics's picture
first commit
5c718d1
raw
history blame contribute delete
No virus
7.42 kB
from utils import *
from modules import *
from data import *
from torch.utils.data import DataLoader
import torch.nn.functional as F
from datetime import datetime
import hydra
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything
import torch.multiprocessing
import seaborn as sns
from pytorch_lightning.callbacks import ModelCheckpoint
import sys
import pdb
import matplotlib as mpl
from skimage import measure
from scipy.stats import mode as statsmode
from collections import OrderedDict
import unet
import pdb
torch.multiprocessing.set_sharing_strategy("file_system")
colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey")
class_names = (
"Buildings",
"Cultivation",
"Natural green",
"Wetland",
"Water",
"Infrastructure",
"Background",
)
bounds = list(np.arange(len(class_names) + 1) + 1)
cmap = mpl.colors.ListedColormap(colors)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
def retouch_label(pred_label, true_label):
retouched_label = pred_label + 0
blobs = measure.label(retouched_label)
for idx in np.unique(blobs):
# most frequent label class in this blob
retouched_label[blobs == idx] = statsmode(true_label[blobs == idx])[0][0]
return retouched_label
def get_class_labels(dataset_name):
if dataset_name.startswith("cityscapes"):
return [
"road",
"sidewalk",
"parking",
"rail track",
"building",
"wall",
"fence",
"guard rail",
"bridge",
"tunnel",
"pole",
"polegroup",
"traffic light",
"traffic sign",
"vegetation",
"terrain",
"sky",
"person",
"rider",
"car",
"truck",
"bus",
"caravan",
"trailer",
"train",
"motorcycle",
"bicycle",
]
elif dataset_name == "cocostuff27":
return [
"electronic",
"appliance",
"food",
"furniture",
"indoor",
"kitchen",
"accessory",
"animal",
"outdoor",
"person",
"sports",
"vehicle",
"ceiling",
"floor",
"food",
"furniture",
"rawmaterial",
"textile",
"wall",
"window",
"building",
"ground",
"plant",
"sky",
"solid",
"structural",
"water",
]
elif dataset_name == "voc":
return [
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
elif dataset_name == "potsdam":
return ["roads and cars", "buildings and clutter", "trees and vegetation"]
else:
raise ValueError("Unknown Dataset {}".format(dataset_name))
@hydra.main(config_path="configs", config_name="train_config.yml")
def my_app(cfg: DictConfig) -> None:
OmegaConf.set_struct(cfg, False)
print(OmegaConf.to_yaml(cfg))
pytorch_data_dir = cfg.pytorch_data_dir
data_dir = join(cfg.output_root, "data")
log_dir = join(cfg.output_root, "logs")
checkpoint_dir = join(cfg.output_root, "checkpoints")
prefix = "{}/{}_{}".format(cfg.log_dir, cfg.dataset_name, cfg.experiment_name)
name = "{}_date_{}".format(prefix, datetime.now().strftime("%b%d_%H-%M-%S"))
cfg.full_name = prefix
os.makedirs(data_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
seed_everything(seed=0)
print(data_dir)
print(cfg.output_root)
geometric_transforms = T.Compose(
[T.RandomHorizontalFlip(), T.RandomResizedCrop(size=cfg.res, scale=(0.8, 1.0))]
)
photometric_transforms = T.Compose(
[
T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
T.RandomGrayscale(0.2),
T.RandomApply([T.GaussianBlur((5, 5))]),
]
)
sys.stdout.flush()
train_dataset = ContrastiveSegDataset(
pytorch_data_dir=pytorch_data_dir,
dataset_name=cfg.dataset_name,
crop_type=cfg.crop_type,
image_set="train",
transform=get_transform(cfg.res, False, cfg.loader_crop_type),
target_transform=get_transform(cfg.res, True, cfg.loader_crop_type),
cfg=cfg,
aug_geometric_transform=geometric_transforms,
aug_photometric_transform=photometric_transforms,
num_neighbors=cfg.num_neighbors,
mask=True,
pos_images=True,
pos_labels=True,
)
if cfg.dataset_name == "voc":
val_loader_crop = None
else:
val_loader_crop = "center"
val_dataset = ContrastiveSegDataset(
pytorch_data_dir=pytorch_data_dir,
dataset_name=cfg.dataset_name,
crop_type=None,
image_set="val",
transform=get_transform(320, False, val_loader_crop),
target_transform=get_transform(320, True, val_loader_crop),
mask=True,
cfg=cfg,
)
# val_dataset = MaterializedDataset(val_dataset)
train_loader = DataLoader(
train_dataset,
cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
pin_memory=True,
)
if cfg.submitting_to_aml:
val_batch_size = 16
else:
val_batch_size = cfg.batch_size
val_loader = DataLoader(
val_dataset,
val_batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=True,
)
model = LitUnsupervisedSegmenter(train_dataset.n_classes, cfg)
tb_logger = TensorBoardLogger(join(log_dir, name), default_hp_metric=False)
if cfg.submitting_to_aml:
gpu_args = dict(gpus=1, val_check_interval=250)
if gpu_args["val_check_interval"] > len(train_loader):
gpu_args.pop("val_check_interval")
else:
gpu_args = dict(gpus=-1, accelerator="ddp", val_check_interval=cfg.val_freq)
# gpu_args = dict(gpus=1, accelerator='ddp', val_check_interval=cfg.val_freq)
if gpu_args["val_check_interval"] > len(train_loader) // 4:
gpu_args.pop("val_check_interval")
trainer = Trainer(
log_every_n_steps=cfg.scalar_log_freq,
logger=tb_logger,
max_steps=cfg.max_steps,
callbacks=[
ModelCheckpoint(
dirpath=join(checkpoint_dir, name),
every_n_train_steps=400,
save_top_k=2,
monitor="test/cluster/mIoU",
mode="max",
)
],
**gpu_args
)
trainer.fit(model, train_loader, val_loader)
if __name__ == "__main__":
prep_args()
my_app()