|
import os |
|
|
|
|
|
import torch |
|
from STTNet import STTNet |
|
import torch.nn.functional as F |
|
from Utils.Datasets import get_data_loader |
|
from Utils.Utils import make_numpy_img, inv_normalize_img, encode_onehot_to_mask, get_metrics, Logger |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from collections import OrderedDict |
|
from torch.optim.lr_scheduler import MultiStepLR |
|
|
|
if __name__ == '__main__': |
|
model_infos = { |
|
|
|
'backbone': 'resnet50', |
|
'pretrained': True, |
|
'out_keys': ['block4'], |
|
'in_channel': 3, |
|
'n_classes': 2, |
|
'top_k_s': 64, |
|
'top_k_c': 16, |
|
'encoder_pos': True, |
|
'decoder_pos': True, |
|
'model_pattern': ['X', 'A', 'S', 'C'], |
|
|
|
'BATCH_SIZE': 8, |
|
'IS_SHUFFLE': True, |
|
'NUM_WORKERS': 0, |
|
'DATASET': 'Tools/generate_dep_info/train_data.csv', |
|
'model_path': 'Checkpoints', |
|
'log_path': 'Results', |
|
|
|
'IS_VAL': True, |
|
'VAL_BATCH_SIZE': 4, |
|
'VAL_DATASET': 'Tools/generate_dep_info/val_data.csv', |
|
|
|
'IS_TEST': True, |
|
'TEST_DATASET': 'Tools/generate_dep_info/test_data.csv', |
|
'IMG_SIZE': [512, 512], |
|
'PHASE': 'seg', |
|
|
|
|
|
'PRIOR_MEAN': [0.40672500537632994, 0.42829032416229895, 0.39331840468605667], |
|
'PRIOR_STD': [0.029498464618176873, 0.027740088491668233, 0.028246722411879095], |
|
|
|
|
|
|
|
|
|
|
|
'load_checkpoint_path': r'E:\BuildingExtractionDataset\INRIA_ckpt_latest.pt', |
|
|
|
'resume_checkpoint_path': '', |
|
|
|
} |
|
os.makedirs(model_infos['model_path'], exist_ok=True) |
|
if model_infos['IS_VAL']: |
|
os.makedirs(model_infos['log_path']+'/val', exist_ok=True) |
|
if model_infos['IS_TEST']: |
|
os.makedirs(model_infos['log_path']+'/test', exist_ok=True) |
|
logger = Logger(model_infos['log_path'] + '/log.log') |
|
|
|
data_loaders = get_data_loader(model_infos) |
|
loss_weight = 0.1 |
|
model = STTNet(**model_infos) |
|
|
|
epoch_start = 0 |
|
if model_infos['load_checkpoint_path'] is not None and os.path.exists(model_infos['load_checkpoint_path']): |
|
logger.write(f'load checkpoint from {model_infos["load_checkpoint_path"]}\n') |
|
state_dict = torch.load(model_infos['load_checkpoint_path'], map_location='cpu') |
|
model_dict = state_dict['model_state_dict'] |
|
try: |
|
model_dict = OrderedDict({k.replace('module.', ''): v for k, v in model_dict.items()}) |
|
model.load_state_dict(model_dict) |
|
except Exception as e: |
|
model.load_state_dict(model_dict) |
|
if model_infos['resume_checkpoint_path'] is not None and os.path.exists(model_infos['resume_checkpoint_path']): |
|
logger.write(f'resume checkpoint path from {model_infos["resume_checkpoint_path"]}\n') |
|
state_dict = torch.load(model_infos['resume_checkpoint_path'], map_location='cpu') |
|
epoch_start = state_dict['epoch_id'] |
|
model_dict = state_dict['model_state_dict'] |
|
logger.write(f'resume checkpoint from epoch {epoch_start}\n') |
|
try: |
|
model_dict = OrderedDict({k.replace('module.', ''): v for k, v in model_dict.items()}) |
|
model.load_state_dict(model_dict) |
|
except Exception as e: |
|
model.load_state_dict(model_dict) |
|
model = model.cuda() |
|
device_ids = range(torch.cuda.device_count()) |
|
if len(device_ids) > 1: |
|
model = torch.nn.DataParallel(model, device_ids=device_ids) |
|
logger.write(f'Use GPUs: {device_ids}\n') |
|
else: |
|
logger.write(f'Use GPUs: 1\n') |
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) |
|
max_epoch = 300 |
|
scheduler = MultiStepLR(optimizer, [int(max_epoch*2/3), int(max_epoch*5/6)], 0.5) |
|
|
|
for epoch_id in range(epoch_start, max_epoch): |
|
pattern = 'train' |
|
model.train() |
|
for batch_id, batch in enumerate(data_loaders[pattern]): |
|
|
|
img_batch = batch['img'].cuda() |
|
label_batch = batch['label'].cuda() |
|
|
|
|
|
optimizer.zero_grad() |
|
logits, att_branch_output = model(img_batch) |
|
|
|
|
|
label_downs = F.interpolate(label_batch, att_branch_output.size()[2:], mode='nearest') |
|
loss_branch = F.binary_cross_entropy_with_logits(att_branch_output, label_downs) |
|
loss_master = F.binary_cross_entropy_with_logits(logits, label_batch) |
|
loss = loss_master + loss_weight * loss_branch |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
if batch_id % 20 == 1: |
|
logger.write( |
|
f'{pattern}: {epoch_id}/{max_epoch} {batch_id}/{len(data_loaders[pattern])} loss: {loss.item():.4f}\n') |
|
|
|
scheduler.step() |
|
patterns = ['val', 'test'] |
|
for pattern_id, is_pattern in enumerate([model_infos['IS_VAL'], model_infos['IS_TEST']]): |
|
if is_pattern: |
|
|
|
|
|
|
|
collect_result = {'pred': [], 'target': []} |
|
pattern = patterns[pattern_id] |
|
model.eval() |
|
for batch_id, batch in enumerate(data_loaders[pattern]): |
|
|
|
img_batch = batch['img'].cuda() |
|
label_batch = batch['label'].cuda() |
|
img_names = batch['img_name'] |
|
collect_result['target'].append(label_batch.data.cpu()) |
|
|
|
|
|
with torch.no_grad(): |
|
logits, att_branch_output = model(img_batch) |
|
|
|
collect_result['pred'].append(logits.data.cpu()) |
|
|
|
pred_label = torch.argmax(logits, 1) |
|
pred_label *= 255 |
|
|
|
if pattern == 'test' or batch_id % 5 == 1: |
|
batch_size = pred_label.size(0) |
|
|
|
|
|
ids = range(batch_size) |
|
for img_id in ids: |
|
img = img_batch[img_id].detach().cpu() |
|
target = label_batch[img_id].detach().cpu() |
|
pred = pred_label[img_id].detach().cpu() |
|
img_name = img_names[img_id] |
|
|
|
img = make_numpy_img( |
|
inv_normalize_img(img, model_infos['PRIOR_MEAN'], model_infos['PRIOR_STD'])) |
|
target = make_numpy_img(encode_onehot_to_mask(target)) * 255 |
|
pred = make_numpy_img(pred) |
|
|
|
vis = np.concatenate([img / 255., target / 255., pred / 255.], axis=0) |
|
vis = np.clip(vis, a_min=0, a_max=1) |
|
file_name = os.path.join(model_infos['log_path'], pattern, f'Epoch_{epoch_id}_{img_name.split(".")[0]}.png') |
|
plt.imsave(file_name, vis) |
|
|
|
collect_result['pred'] = torch.cat(collect_result['pred'], dim=0) |
|
collect_result['target'] = torch.cat(collect_result['target'], dim=0) |
|
IoU, OA, F1_score = get_metrics('seg', **collect_result) |
|
logger.write(f'{pattern}: {epoch_id}/{max_epoch} Iou:{IoU[-1]:.4f} OA:{OA[-1]:.4f} F1:{F1_score[-1]:.4f}\n') |
|
if epoch_id % 20 == 1: |
|
torch.save({ |
|
'epoch_id': epoch_id, |
|
'model_state_dict': model.state_dict() |
|
}, os.path.join(model_infos['model_path'], f'ckpt_{epoch_id}.pt')) |
|
torch.save({ |
|
'epoch_id': epoch_id, |
|
'model_state_dict': model.state_dict() |
|
}, os.path.join(model_infos['model_path'], f'ckpt_latest.pt')) |
|
|
|
|