|
''' |
|
Author: [egrt] |
|
Date: 2022-08-23 11:44:15 |
|
LastEditors: Egrt |
|
LastEditTime: 2022-11-23 15:25:35 |
|
Description: HEAT的模型加载与预测 |
|
''' |
|
from turtle import pos |
|
import torch |
|
import torch.nn as nn |
|
from models.resnet import ResNetBackbone |
|
from models.corner_models import HeatCorner |
|
from models.edge_models import HeatEdge |
|
from models.corner_to_edge import get_infer_edge_pairs |
|
from datasets.data_utils import get_pixel_features |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
from utils import image_utils |
|
from osgeo import gdal, ogr, osr |
|
from tqdm import tqdm |
|
import os |
|
import scipy |
|
import numpy as np |
|
import cv2 |
|
import skimage |
|
|
|
class HEAT(object): |
|
|
|
|
|
|
|
_defaults = { |
|
|
|
|
|
|
|
"model_data" : 'model_data/heat_checkpoints/checkpoints/ckpts_heat_outdoor_256/checkpoint.pth', |
|
|
|
|
|
|
|
"image_size" : [256, 256], |
|
|
|
|
|
|
|
"patch_size" : 512, |
|
|
|
|
|
|
|
"patch_overlap" : 0, |
|
|
|
|
|
|
|
"corner_thresh" : 0.01, |
|
|
|
|
|
|
|
"corner_to_edge_multiplier": 3, |
|
|
|
|
|
|
|
"infer_times" : 3, |
|
|
|
|
|
|
|
|
|
"cuda" : False, |
|
} |
|
|
|
|
|
|
|
|
|
def __init__(self, **kwargs): |
|
self.__dict__.update(self._defaults) |
|
for name, value in kwargs.items(): |
|
setattr(self, name, value) |
|
self.generate() |
|
|
|
def generate(self): |
|
|
|
filepath = hf_hub_download(repo_id="Egrt/HEAT", filename="checkpoint.pth") |
|
self.model = torch.load(filepath) |
|
|
|
self.backbone = ResNetBackbone() |
|
strides = self.backbone.strides |
|
num_channels = self.backbone.num_channels |
|
self.backbone = nn.DataParallel(self.backbone) |
|
self.backbone = self.backbone.cuda() |
|
self.backbone.eval() |
|
|
|
self.corner_model = HeatCorner(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides, |
|
backbone_num_channels=num_channels) |
|
self.corner_model = nn.DataParallel(self.corner_model) |
|
self.corner_model = self.corner_model.cuda() |
|
self.corner_model.eval() |
|
|
|
self.edge_model = HeatEdge(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides, |
|
backbone_num_channels=num_channels) |
|
self.edge_model = nn.DataParallel(self.edge_model) |
|
self.edge_model = self.edge_model.cuda() |
|
self.edge_model.eval() |
|
|
|
self.backbone.load_state_dict(self.model['backbone']) |
|
self.corner_model.load_state_dict(self.model['corner_model']) |
|
self.edge_model.load_state_dict(self.model['edge_model']) |
|
|
|
def detect_one_image(self, image): |
|
|
|
|
|
|
|
|
|
image = cvtColor(image) |
|
|
|
if image.size[0] < self.patch_size or image.size[1] < self.patch_size: |
|
is_slice = False |
|
else: |
|
is_slice = True |
|
if is_slice: |
|
|
|
image = np.array(image, dtype=np.uint8) |
|
|
|
viz_image = image.copy() |
|
height, width = image.shape[0], image.shape[1] |
|
|
|
scale = self.patch_size / self.image_size[0] |
|
|
|
pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np = [], [], [], [], [] |
|
|
|
stride = self.patch_size - self.patch_overlap |
|
patch_boundingboxes = image_utils.compute_patch_boundingboxes((height, width), |
|
stride=stride, |
|
patch_res=self.patch_size) |
|
edge_len = 0 |
|
|
|
for bbox in tqdm(patch_boundingboxes, desc="使用切分进行预测", leave=False): |
|
|
|
crop_image = image[bbox[1]:bbox[3], bbox[0]:bbox[2], :] |
|
|
|
crop_image = Image.fromarray(crop_image) |
|
try: |
|
pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np, _ = self.predict_no_patching(crop_image) |
|
except RuntimeError as e: |
|
print("ERROR: " + str(e)) |
|
print("INFO: 减小patch_size 直到适合内存") |
|
raise e |
|
|
|
pred_corners[:, 0] = pred_corners[:, 0] * scale + bbox[0] |
|
pred_corners[:, 1] = pred_corners[:, 1] * scale + bbox[1] |
|
pred_corners_viz = pred_corners |
|
viz_image = visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, edges=pos_edges, |
|
edge_confs=edge_confs, shpfile=False) |
|
|
|
hr_image = Image.fromarray(np.uint8(viz_image)) |
|
else: |
|
pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np, viz_image = self.predict_no_patching(image) |
|
|
|
|
|
|
|
|
|
pred_corners_viz = pred_corners |
|
image_result = visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, edges=pos_edges, |
|
edge_confs=edge_confs, shpfile=True) |
|
hr_image = Image.fromarray(np.uint8(image_result)) |
|
return hr_image |
|
|
|
|
|
|
|
|
|
|
|
def predict_no_patching(self, image): |
|
image = image.resize(tuple(self.image_size), Image.BICUBIC) |
|
|
|
image = np.array(image, dtype=np.uint8) |
|
|
|
viz_image = image.copy() |
|
|
|
image = process_image(image) |
|
|
|
pixels, pixel_features = get_pixel_features(image_size=self.image_size[0]) |
|
|
|
with torch.no_grad(): |
|
|
|
image_feats, feat_mask, all_image_feats = self.backbone(image) |
|
pixel_features = pixel_features.unsqueeze(0).repeat(image.shape[0], 1, 1, 1) |
|
preds_s1 = self.corner_model(image_feats, feat_mask, pixel_features, pixels, all_image_feats) |
|
|
|
c_outputs = preds_s1 |
|
|
|
c_outputs_np = c_outputs[0].detach().cpu().numpy() |
|
|
|
pos_indices = np.where(c_outputs_np >= self.corner_thresh) |
|
pred_corners = pixels[pos_indices] |
|
|
|
pred_confs = c_outputs_np[pos_indices] |
|
|
|
pred_corners, pred_confs = corner_nms(pred_corners, pred_confs, image_size=c_outputs.shape[1]) |
|
|
|
pred_corners, pred_confs, edge_coords, edge_mask, edge_ids = get_infer_edge_pairs(pred_corners, pred_confs) |
|
|
|
corner_nums = torch.tensor([len(pred_corners)]).to(image.device) |
|
max_candidates = torch.stack([corner_nums.max() * self.corner_to_edge_multiplier] * len(corner_nums), dim=0) |
|
|
|
all_pos_ids = set() |
|
|
|
all_edge_confs = dict() |
|
|
|
for tt in range(self.infer_times): |
|
if tt == 0: |
|
|
|
gt_values = torch.zeros_like(edge_mask).long() |
|
|
|
gt_values[:, :] = 2 |
|
|
|
|
|
s1_logits, s2_logits_hb, s2_logits_rel, selected_ids, s2_mask, s2_gt_values = self.edge_model(image_feats, |
|
feat_mask,pixel_features,edge_coords, edge_mask,gt_values, corner_nums,max_candidates,True) |
|
num_total = s1_logits.shape[2] |
|
num_selected = selected_ids.shape[1] |
|
num_filtered = num_total - num_selected |
|
|
|
s1_preds = s1_logits.squeeze().softmax(0) |
|
s2_preds_rel = s2_logits_rel.squeeze().softmax(0) |
|
s2_preds_hb = s2_logits_hb.squeeze().softmax(0) |
|
s1_preds_np = s1_preds[1, :].detach().cpu().numpy() |
|
s2_preds_rel_np = s2_preds_rel[1, :].detach().cpu().numpy() |
|
s2_preds_hb_np = s2_preds_hb[1, :].detach().cpu().numpy() |
|
|
|
selected_ids = selected_ids.squeeze().detach().cpu().numpy() |
|
|
|
if tt != self.infer_times - 1: |
|
s2_preds_np = s2_preds_hb_np |
|
|
|
pos_edge_ids = np.where(s2_preds_np >= 0.9) |
|
neg_edge_ids = np.where(s2_preds_np <= 0.01) |
|
for pos_id in pos_edge_ids[0]: |
|
actual_id = selected_ids[pos_id] |
|
if gt_values[0, actual_id] != 2: |
|
continue |
|
all_pos_ids.add(actual_id) |
|
all_edge_confs[actual_id] = s2_preds_np[pos_id] |
|
gt_values[0, actual_id] = 1 |
|
for neg_id in neg_edge_ids[0]: |
|
actual_id = selected_ids[neg_id] |
|
if gt_values[0, actual_id] != 2: |
|
continue |
|
gt_values[0, actual_id] = 0 |
|
num_to_pred = (gt_values == 2).sum() |
|
if num_to_pred <= num_filtered: |
|
break |
|
else: |
|
s2_preds_np = s2_preds_hb_np |
|
|
|
pos_edge_ids = np.where(s2_preds_np >= 0.5) |
|
for pos_id in pos_edge_ids[0]: |
|
actual_id = selected_ids[pos_id] |
|
if s2_mask[0][pos_id] is True or gt_values[0, actual_id] != 2: |
|
continue |
|
all_pos_ids.add(actual_id) |
|
all_edge_confs[actual_id] = s2_preds_np[pos_id] |
|
pos_edge_ids = list(all_pos_ids) |
|
edge_confs = [all_edge_confs[idx] for idx in pos_edge_ids] |
|
pos_edges = edge_ids[pos_edge_ids].cpu().numpy() |
|
edge_confs = np.array(edge_confs) |
|
|
|
if self.image_size[0] != 256: |
|
pred_corners = pred_corners / (self.image_size[0] / 256) |
|
|
|
return pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np, viz_image |
|
|
|
|
|
|
|
|
|
def cvtColor(image): |
|
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: |
|
return image |
|
else: |
|
image = image.convert('RGB') |
|
return image |
|
|
|
|
|
|
|
def corner_nms(preds, confs, image_size): |
|
data = np.zeros([image_size, image_size]) |
|
neighborhood_size = 5 |
|
threshold = 0 |
|
|
|
for i in range(len(preds)): |
|
data[preds[i, 1], preds[i, 0]] = confs[i] |
|
|
|
data_max = scipy.ndimage.filters.maximum_filter(data, neighborhood_size) |
|
maxima = (data == data_max) |
|
data_min = scipy.ndimage.filters.minimum_filter(data, neighborhood_size) |
|
diff = ((data_max - data_min) > threshold) |
|
maxima[diff == 0] = 0 |
|
|
|
results = np.where(maxima > 0) |
|
filtered_preds = np.stack([results[1], results[0]], axis=-1) |
|
|
|
new_confs = list() |
|
for i, pred in enumerate(filtered_preds): |
|
new_confs.append(data[pred[1], pred[0]]) |
|
new_confs = np.array(new_confs) |
|
|
|
return filtered_preds, new_confs |
|
|
|
def process_image(img): |
|
mean = [0.485, 0.456, 0.406] |
|
std = [0.229, 0.224, 0.225] |
|
img = skimage.img_as_float(img) |
|
img = img.transpose((2, 0, 1)) |
|
img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis] |
|
img = torch.Tensor(img).cuda() |
|
img = img.unsqueeze(0) |
|
return img |
|
|
|
def postprocess_preds(corners, confs, edges): |
|
corner_degrees = dict() |
|
for edge_i, edge_pair in enumerate(edges): |
|
corner_degrees[edge_pair[0]] = corner_degrees.setdefault(edge_pair[0], 0) + 1 |
|
corner_degrees[edge_pair[1]] = corner_degrees.setdefault(edge_pair[1], 0) + 1 |
|
good_ids = [i for i in range(len(corners)) if i in corner_degrees] |
|
if len(good_ids) == len(corners): |
|
return corners, confs, edges |
|
else: |
|
good_corners = corners[good_ids] |
|
good_confs = confs[good_ids] |
|
id_mapping = {value: idx for idx, value in enumerate(good_ids)} |
|
new_edges = list() |
|
for edge_pair in edges: |
|
new_pair = (id_mapping[edge_pair[0]], id_mapping[edge_pair[1]]) |
|
new_edges.append(new_pair) |
|
new_edges = np.array(new_edges) |
|
return good_corners, good_confs, new_edges |
|
|
|
|
|
|
|
|
|
|
|
def visualize_cond_generation(positive_pixels, confs, image, gt_corners=None, prec=None, recall=None, |
|
image_masks=None, edges=None, edge_confs=None, shpfile=False): |
|
|
|
image = image.copy() |
|
if confs is not None: |
|
viz_confs = confs |
|
|
|
if edges is not None: |
|
preds = positive_pixels.astype(int) |
|
c_degrees = dict() |
|
for edge_i, edge_pair in enumerate(edges): |
|
conf = (edge_confs[edge_i] * 2) - 1 |
|
cv2.line(image, tuple(preds[edge_pair[0]]), tuple(preds[edge_pair[1]]), (255 * conf, 255 * conf, 0), 2) |
|
c_degrees[edge_pair[0]] = c_degrees.setdefault(edge_pair[0], 0) + 1 |
|
c_degrees[edge_pair[1]] = c_degrees.setdefault(edge_pair[1], 0) + 1 |
|
|
|
for idx, c in enumerate(positive_pixels): |
|
if edges is not None and idx not in c_degrees: |
|
continue |
|
if confs is None: |
|
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1) |
|
else: |
|
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255 * viz_confs[idx]), -1) |
|
|
|
|
|
|
|
|
|
if gt_corners is not None: |
|
for c in gt_corners: |
|
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 255, 0), -1) |
|
|
|
if image_masks is not None: |
|
mask_ids = np.where(image_masks == 1)[0] |
|
for mask_id in mask_ids: |
|
y_idx = mask_id // 64 |
|
x_idx = (mask_id - y_idx * 64) |
|
x_coord = x_idx * 4 |
|
y_coord = y_idx * 4 |
|
cv2.rectangle(image, (x_coord, y_coord), (x_coord + 3, y_coord + 3), (127, 127, 0), thickness=-1) |
|
|
|
|
|
|
|
|
|
if prec is not None: |
|
if isinstance(prec, tuple): |
|
cv2.putText(image, 'edge p={:.2f}, edge r={:.2f}'.format(prec[0], recall[0]), (20, 20), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
0.5, (255, 255, 0), 1, cv2.LINE_AA) |
|
cv2.putText(image, 'region p={:.2f}, region r={:.2f}'.format(prec[1], recall[1]), (20, 40), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
0.5, (255, 255, 0), 1, cv2.LINE_AA) |
|
else: |
|
cv2.putText(image, 'prec={:.2f}, recall={:.2f}'.format(prec, recall), (20, 20), cv2.FONT_HERSHEY_SIMPLEX, |
|
0.5, (255, 255, 0), 1, cv2.LINE_AA) |
|
|
|
|
|
if shpfile: |
|
preds = positive_pixels.astype(int) |
|
|
|
Polyline = [] |
|
for edge_i, edge_pair in enumerate(edges): |
|
Polyline.append([preds[edge_pair[0]], preds[edge_pair[1]]]) |
|
Polyline = np.array(Polyline, dtype=np.int32) |
|
|
|
writeShp(save_file_dir="shpfile", Polyline=Polyline) |
|
|
|
|
|
return image |
|
|
|
def writeShp(save_file_dir="shpfile", Polyline=None): |
|
|
|
if os.path.exists(save_file_dir) is False: |
|
os.makedirs(save_file_dir) |
|
|
|
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES") |
|
|
|
gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8") |
|
|
|
ogr.RegisterAll() |
|
|
|
strDriverName = "ESRI Shapefile" |
|
oDriver = ogr.GetDriverByName(strDriverName) |
|
if oDriver == None: |
|
return "驱动不可用:"+strDriverName |
|
|
|
file_path = os.path.join(save_file_dir, "result.shp") |
|
oDS = oDriver.CreateDataSource(file_path) |
|
if oDS == None: |
|
return "创建文件失败:result.shp" |
|
if Polyline is not None: |
|
|
|
papszLCO = [] |
|
geosrs = osr.SpatialReference() |
|
geosrs.SetWellKnownGeogCS("WGS84") |
|
|
|
|
|
ogr_type = ogr.wkbMultiLineString |
|
|
|
oLayer = oDS.CreateLayer("Polyline", geosrs, ogr_type, papszLCO) |
|
if oLayer == None: |
|
return "图层创建失败!" |
|
|
|
|
|
oId = ogr.FieldDefn("id", ogr.OFTInteger) |
|
oLayer.CreateField(oId, 1) |
|
|
|
oName = ogr.FieldDefn("name", ogr.OFTString) |
|
oLayer.CreateField(oName, 1) |
|
oDefn = oLayer.GetLayerDefn() |
|
|
|
|
|
|
|
point_str_list = ['({} {},{} {})'.format(row[0, 0], row[0, 1], row[1, 0], row[1, 1]) for row in Polyline] |
|
Polyline_Wkt = ','.join(point_str_list) |
|
features = ['Polyline0;MULTILINESTRING({})'.format(Polyline_Wkt)] |
|
for index, f in enumerate(features): |
|
oFeaturePolygon = ogr.Feature(oDefn) |
|
oFeaturePolygon.SetField("id",index) |
|
oFeaturePolygon.SetField("name",f.split(";")[0]) |
|
geomPolygon = ogr.CreateGeometryFromWkt(f.split(";")[1]) |
|
oFeaturePolygon.SetGeometry(geomPolygon) |
|
oLayer.CreateFeature(oFeaturePolygon) |
|
|
|
oDS.Destroy() |
|
return "数据集创建完成!" |