File size: 13,980 Bytes
250d697 ff63123 250d697 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 |
'''
Author: Egrt
Date: 2022-03-19 10:25:50
LastEditors: Egrt
LastEditTime: 2022-03-20 14:58:13
FilePath: \Luuu\gis.py
'''
import os
import numpy as np
import skimage.io
import torch
from tqdm import tqdm
from frame_field_learning import data_transforms, save_utils
from frame_field_learning.model import FrameFieldModel
from frame_field_learning import inference
from frame_field_learning import local_utils
from backbone import get_backbone
from torch_lydorn import torchvision
import argparse
from lydorn_utils import print_utils
from lydorn_utils import run_utils
class GIS(object):
#-----------------------------------------#
# 注意修改model_path
#-----------------------------------------#
_defaults = {
}
#---------------------------------------------------#
# 初始化SRGAN
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
self.args = self.get_args()
self.config = self.launch_inference_from_filepath(self.args)
self.generate()
def get_args(self):
argparser = argparse.ArgumentParser(description=__doc__)
argparser.add_argument(
'--in_filepath',
type=str,
nargs='*',
default='images/ex1images',
help='For launching prediction on several images, use this argument to specify their paths.'
'If --out_dirpath is specified, prediction outputs will be saved there..'
'If --out_dirpath is not specified, predictions will be saved next to inputs.'
'Make sure to also specify the run_name of the model to use for prediction.')
argparser.add_argument(
'--out_dirpath',
type=str,
default='images',
help='Path to the output directory of prediction when using the --in_filepath option to launch prediction on several images.')
argparser.add_argument(
'-c', '--config',
type=str,
help='Name of the config file, excluding the .json file extension.')
argparser.add_argument(
'--dataset_params',
type=str,
help='Allows to overwrite the dataset_params in the config file. Accepts a path to a .json file.')
argparser.add_argument(
'-r', '--runs_dirpath',
default="runs",
type=str,
help='Directory where runs are recorded (model saves and logs).')
argparser.add_argument(
'--run_name',
type=str,
default='mapping_dataset.unet_resnet101_pretrained.train_val',
help='Name of the run to use.'
'That name does not include the timestamp of the folder name: <run_name> | <yyyy-mm-dd hh:mm:ss>.')
argparser.add_argument(
'--new_run',
action='store_true',
help="Train from scratch (when True) or train from the last checkpoint (when False)")
argparser.add_argument(
'--init_run_name',
type=str,
help="This is the run_name to initialize the weights from."
"If None, weights will be initialized randomly."
"This is a single word, without the timestamp.")
argparser.add_argument(
'--samples',
type=int,
help='Limits the number of samples to train (and validate and test) if set.')
argparser.add_argument(
'-b', '--batch_size',
type=int,
help='Batch size. Default value can be set in config file. Is doubled when no back propagation is done (while in eval mode). If a specific effective batch size is desired, set the eval_batch_size argument.')
argparser.add_argument(
'--eval_batch_size',
type=int,
help='Batch size for evaluation. Overrides the effective batch size when evaluating.')
argparser.add_argument(
'-m', '--mode',
default="train",
type=str,
choices=['train', 'eval', 'eval_coco'],
help='Mode to launch the script in. '
'Train: train model on speciffied folds. '
'Eval: eval model on specified fold. '
'Eval_coco: measures COCO metrics of specified fold')
argparser.add_argument(
'--fold',
nargs='*',
type=str,
choices=['train', 'val', 'test'],
help='If training (mode=train): all folds entered here will be used for optimizing the network.'
'If the train fold is selected and not the val fold, the val fold will be used during training to validate at each epoch.'
'The most common scenario is to optimize on train and validate on val: select only train.'
'When optimizing the network for the last time before test, we would like to optimize it on train + val: in that case select both train and val folds.'
'Then for evaluation (mode=eval), we might want to evaluate on the val folds for hyper-parameter selection.'
'And finally evaluate (mode=eval) on the test fold for the final predictions (and possibly metric) for the paper/competition')
argparser.add_argument(
'--max_epoch',
type=int,
help='Stop training when max_epoch is reached. If not set, value in config is used.')
argparser.add_argument(
'--eval_patch_size',
type=int,
help='When evaluating, patch size the tile split into.')
argparser.add_argument(
'--eval_patch_overlap',
type=int,
help='When evaluating, patch the tile with the specified overlap to reduce edge artifacts when reconstructing '
'the whole tile')
argparser.add_argument('--master_addr', default="localhost", type=str, help="Address of master node")
argparser.add_argument('--master_port', default="6666", type=str, help="Port on master node")
argparser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', help="Number of total nodes")
argparser.add_argument('-g', '--gpus', default=1, type=int, help='Number of gpus per node')
argparser.add_argument('-nr', '--nr', default=0, type=int, help='Ranking within the nodes')
args = argparser.parse_args()
return args
def launch_inference_from_filepath(self, args):
# --- First step: figure out what run (experiment) is to be evaluated
# Option 1: the run_name argument is given in which case that's our run
run_name = None
config = None
if args.run_name is not None:
run_name = args.run_name
# Else option 2: Check if a config has been given to look for the run_name
if args.config is not None:
config = run_utils.load_config(args.config)
if config is not None and "run_name" in config and run_name is None:
run_name = config["run_name"]
# Else abort...
if run_name is None:
print_utils.print_error("ERROR: the run to evaluate could no be identified with the given arguments. "
"Please specify either the --run_name argument or the --config argument "
"linking to a config file that has a 'run_name' field filled with the name of "
"the run name to evaluate.")
# --- Second step: get path to the run and if --config was not specified, load the config from the run's folder
run_dirpath = local_utils.get_run_dirpath(args.runs_dirpath, run_name)
if config is None:
config = run_utils.load_config(config_dirpath=run_dirpath)
if config is None:
print_utils.print_error(f"ERROR: the default run's config file at {run_dirpath} could not be loaded. "
f"Exiting now...")
# --- Add command-line arguments
if args.batch_size is not None:
config["optim_params"]["batch_size"] = args.batch_size
if args.eval_batch_size is not None:
config["optim_params"]["eval_batch_size"] = args.eval_batch_size
else:
config["optim_params"]["eval_batch_size"] = 2*config["optim_params"]["batch_size"]
# --- Load params in config set as relative path to another JSON file
config = run_utils.load_defaults_in_config(config, filepath_key="defaults_filepath")
config["eval_params"]["run_dirpath"] = run_dirpath
if args.eval_patch_size is not None:
config["eval_params"]["patch_size"] = args.eval_patch_size
if args.eval_patch_overlap is not None:
config["eval_params"]["patch_overlap"] = args.eval_patch_overlap
self.backbone = get_backbone(config["backbone_params"])
return config
# 加载模型
def generate(self):
# --- Online transform performed on the device (GPU):
eval_online_cuda_transform = data_transforms.get_eval_online_cuda_transform(self.config)
print("Loading model...")
self.model = FrameFieldModel(self.config, backbone=self.backbone, eval_transform=eval_online_cuda_transform)
self.model.to(self.config["device"])
checkpoints_dirpath = run_utils.setup_run_subdir(self.config["eval_params"]["run_dirpath"], self.config["optim_params"]["checkpoints_dirname"])
self.model = inference.load_checkpoint(self.model, checkpoints_dirpath, self.config["device"])
self.model.eval()
def get_save_filepath(self, base_filepath, name=None, ext=""):
if type(base_filepath) is tuple:
if name is not None:
save_filepath = os.path.join(base_filepath[0], name, base_filepath[1] + ext)
else:
save_filepath = os.path.join(base_filepath[0], base_filepath[1] + ext)
elif type(base_filepath) is str:
if name is not None:
save_filepath = base_filepath + "." + name + ext
else:
save_filepath = base_filepath + ext
return save_filepath
# 检测单张图片
def detect_image(self, in_filepath):
out_dirpath = self.args.out_dirpath
image = skimage.io.imread(in_filepath)
patch_size = self.config['eval_params']['patch_size']
# 如果超出切片预期的大小则关闭切片处理
if image.shape[0] < patch_size or image.shape[1] < patch_size:
self.config['eval_params']['patch_size'] = None
if 3 < image.shape[2]:
print_utils.print_info(f"Image {in_filepath} has more than 3 channels. Keeping the first 3 channels and discarding the rest...")
image = image[:, :, :3]
elif image.shape[2] < 3:
print_utils.print_error(f"Image {in_filepath} has only {image.shape[2]} channels but the network expects 3 channels.")
raise ValueError
image_float = image / 255
mean = np.mean(image_float.reshape(-1, image_float.shape[-1]), axis=0)
std = np.std(image_float.reshape(-1, image_float.shape[-1]), axis=0)
sample = {
"image": torchvision.transforms.functional.to_tensor(image)[None, ...],
"image_mean": torch.from_numpy(mean)[None, ...],
"image_std": torch.from_numpy(std)[None, ...],
"image_filepath": [in_filepath],
}
tile_data = inference.inference(self.config, self.model, sample, compute_polygonization=True)
tile_data = local_utils.batch_to_cpu(tile_data)
# Remove batch dim:
tile_data = local_utils.split_batch(tile_data)[0]
# Figuring out_base_filepath out:
if out_dirpath is None:
out_dirpath = os.path.dirname(in_filepath)
base_filename = os.path.splitext(os.path.basename(in_filepath))[0]
out_base_filepath = (out_dirpath, base_filename)
if self.config["compute_seg"]:
if self.config["eval_params"]["save_individual_outputs"]["seg_mask"]:
seg_mask = 0.5 < tile_data["seg"][0]
result_seg_mask_path = save_utils.save_seg_mask(seg_mask, out_base_filepath, "mask", tile_data["image_filepath"])
if self.config["eval_params"]["save_individual_outputs"]["seg"]:
result_seg_path = save_utils.save_seg(tile_data["seg"], out_base_filepath, "seg", tile_data["image_filepath"])
if "poly_viz" in self.config["eval_params"]["save_individual_outputs"] and \
self.config["eval_params"]["save_individual_outputs"]["poly_viz"]:
save_utils.save_poly_viz(tile_data["image"], tile_data["polygons"], tile_data["polygon_probs"], out_base_filepath, "poly_viz")
if self.config["eval_params"]["save_individual_outputs"]["poly_shapefile"]:
save_utils.save_shapefile(tile_data["polygons"], out_base_filepath, "poly_shapefile", tile_data["image_filepath"])
pdf_filepath = os.path.join(out_dirpath, 'poly_viz.acm.tol_0.125', base_filename + ".pdf")
cpg_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".cpg")
dbf_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".dbf")
shx_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".shx")
shp_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".shp")
prj_filepath = os.path.join(out_dirpath, 'poly_shapefile.acm.tol_0.125', base_filename + ".prj")
return base_filename, [result_seg_mask_path, result_seg_path, pdf_filepath, cpg_filepath, dbf_filepath, shx_filepath, shp_filepath, prj_filepath]
|