HEAT / HEAT.py
Egrt's picture
init
424188c
'''
Author: [egrt]
Date: 2022-08-23 11:44:15
LastEditors: Egrt
LastEditTime: 2022-11-23 15:25:35
Description: HEAT的模型加载与预测
'''
from turtle import pos
import torch
import torch.nn as nn
from models.resnet import ResNetBackbone
from models.corner_models import HeatCorner
from models.edge_models import HeatEdge
from models.corner_to_edge import get_infer_edge_pairs
from datasets.data_utils import get_pixel_features
from huggingface_hub import hf_hub_download
from PIL import Image
from utils import image_utils
from osgeo import gdal, ogr, osr
from tqdm import tqdm
import os
import scipy
import numpy as np
import cv2
import skimage
class HEAT(object):
#-----------------------------------------#
# 注意修改model_path
#-----------------------------------------#
_defaults = {
#-----------------------------------------------#
# model_data指向整体网络的地址
#-----------------------------------------------#
"model_data" : 'model_data/heat_checkpoints/checkpoints/ckpts_heat_outdoor_256/checkpoint.pth',
#-----------------------------------------------#
# image_size模型预测图像的像素大小
#-----------------------------------------------#
"image_size" : [256, 256],
#-----------------------------------------------#
# patch_size为模型切片的大小
#-----------------------------------------------#
"patch_size" : 512,
#-----------------------------------------------#
# patch_overlap为切片重叠像素
#-----------------------------------------------#
"patch_overlap" : 0,
#-----------------------------------------------#
# corner_thresh为预测角点的阈值大小
#-----------------------------------------------#
"corner_thresh" : 0.01,
#-----------------------------------------------#
# 基于角点候选数的最大边数(不能大于6)
#-----------------------------------------------#
"corner_to_edge_multiplier": 3,
#-----------------------------------------------#
# 边缘推理筛选的迭代次数
#-----------------------------------------------#
"infer_times" : 3,
#-------------------------------#
# 是否使用Cuda
# 没有GPU可以设置成False
#-------------------------------#
"cuda" : False,
}
#---------------------------------------------------#
# 初始化MASKGAN
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
self.generate()
def generate(self):
# 从Huggingface加载整体网络模型
filepath = hf_hub_download(repo_id="Egrt/HEAT", filename="checkpoint.pth")
self.model = torch.load(filepath)
# 加载Backbone
self.backbone = ResNetBackbone()
strides = self.backbone.strides
num_channels = self.backbone.num_channels
self.backbone = nn.DataParallel(self.backbone)
self.backbone = self.backbone.cuda()
self.backbone.eval()
# 加载角点检测模型
self.corner_model = HeatCorner(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
backbone_num_channels=num_channels)
self.corner_model = nn.DataParallel(self.corner_model)
self.corner_model = self.corner_model.cuda()
self.corner_model.eval()
# 加载边缘检测模型
self.edge_model = HeatEdge(input_dim=128, hidden_dim=256, num_feature_levels=4, backbone_strides=strides,
backbone_num_channels=num_channels)
self.edge_model = nn.DataParallel(self.edge_model)
self.edge_model = self.edge_model.cuda()
self.edge_model.eval()
# 分别加载模型的地址
self.backbone.load_state_dict(self.model['backbone'])
self.corner_model.load_state_dict(self.model['corner_model'])
self.edge_model.load_state_dict(self.model['edge_model'])
def detect_one_image(self, image):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
# 这里判断图片是否需要分成多个patch
if image.size[0] < self.patch_size or image.size[1] < self.patch_size:
is_slice = False
else:
is_slice = True
if is_slice:
# 复制原图
image = np.array(image, dtype=np.uint8)
# 复制输入的原图
viz_image = image.copy()
height, width = image.shape[0], image.shape[1]
# 获取缩放比例
scale = self.patch_size / self.image_size[0]
# 初始化角点、边缘列表
pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np = [], [], [], [], []
# 开始切分
stride = self.patch_size - self.patch_overlap
patch_boundingboxes = image_utils.compute_patch_boundingboxes((height, width),
stride=stride,
patch_res=self.patch_size)
edge_len = 0
# 获取切分后的图片
for bbox in tqdm(patch_boundingboxes, desc="使用切分进行预测", leave=False):
# 切分图像
crop_image = image[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
# np转Image类
crop_image = Image.fromarray(crop_image)
try:
pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np, _ = self.predict_no_patching(crop_image)
except RuntimeError as e:
print("ERROR: " + str(e))
print("INFO: 减小patch_size 直到适合内存")
raise e
# 拼接角点数组
pred_corners[:, 0] = pred_corners[:, 0] * scale + bbox[0]
pred_corners[:, 1] = pred_corners[:, 1] * scale + bbox[1]
pred_corners_viz = pred_corners
viz_image = visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, edges=pos_edges,
edge_confs=edge_confs, shpfile=False)
hr_image = Image.fromarray(np.uint8(viz_image))
else:
pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np, viz_image = self.predict_no_patching(image)
#---------------------------------------------------------#
# 此处推理结束
# 开始在原图上根据角点坐标绘制角点与边缘
#---------------------------------------------------------#
pred_corners_viz = pred_corners
image_result = visualize_cond_generation(pred_corners_viz, pred_confs, viz_image, edges=pos_edges,
edge_confs=edge_confs, shpfile=True)
hr_image = Image.fromarray(np.uint8(image_result))
return hr_image
#---------------------------------------------------------#
# 不使用切片预测图像
# 返回预测后的角点坐标、边缘
#---------------------------------------------------------#
def predict_no_patching(self, image):
image = image.resize(tuple(self.image_size), Image.BICUBIC)
# 将Image类转换为numpy
image = np.array(image, dtype=np.uint8)
# 复制输入的原图
viz_image = image.copy()
# preprocess image numpy->tensor
image = process_image(image)
# 获取所有像素的位置编码, 默认的图像尺度为256
pixels, pixel_features = get_pixel_features(image_size=self.image_size[0])
# 开始模型的预测
with torch.no_grad():
image_feats, feat_mask, all_image_feats = self.backbone(image)
pixel_features = pixel_features.unsqueeze(0).repeat(image.shape[0], 1, 1, 1)
preds_s1 = self.corner_model(image_feats, feat_mask, pixel_features, pixels, all_image_feats)
c_outputs = preds_s1
# 获取预测出的角点
c_outputs_np = c_outputs[0].detach().cpu().numpy()
# 筛选出大于阈值的角点的坐标
pos_indices = np.where(c_outputs_np >= self.corner_thresh)
pred_corners = pixels[pos_indices]
# 获取对应预测角点的置信度
pred_confs = c_outputs_np[pos_indices]
# 根据预测角点的置信度进行非极大抑制
pred_corners, pred_confs = corner_nms(pred_corners, pred_confs, image_size=c_outputs.shape[1])
# 对角点两两排列组合,获取所有的角点对
pred_corners, pred_confs, edge_coords, edge_mask, edge_ids = get_infer_edge_pairs(pred_corners, pred_confs)
# 获取角点数量
corner_nums = torch.tensor([len(pred_corners)]).to(image.device)
max_candidates = torch.stack([corner_nums.max() * self.corner_to_edge_multiplier] * len(corner_nums), dim=0)
# 无序不重复集合
all_pos_ids = set()
# 边缘置信度字典
all_edge_confs = dict()
# 推理的迭代次数为3次
for tt in range(self.infer_times):
if tt == 0:
# gt_values和边缘掩膜大小一样且初始值为0
gt_values = torch.zeros_like(edge_mask).long()
# 第一二维度的数值设置为2
gt_values[:, :] = 2
# 开始预测边缘
s1_logits, s2_logits_hb, s2_logits_rel, selected_ids, s2_mask, s2_gt_values = self.edge_model(image_feats,
feat_mask,pixel_features,edge_coords, edge_mask,gt_values, corner_nums,max_candidates,True)
num_total = s1_logits.shape[2]
num_selected = selected_ids.shape[1]
num_filtered = num_total - num_selected
# 将输出值固定为(0,1)之间的概率分布
s1_preds = s1_logits.squeeze().softmax(0)
s2_preds_rel = s2_logits_rel.squeeze().softmax(0)
s2_preds_hb = s2_logits_hb.squeeze().softmax(0)
s1_preds_np = s1_preds[1, :].detach().cpu().numpy()
s2_preds_rel_np = s2_preds_rel[1, :].detach().cpu().numpy()
s2_preds_hb_np = s2_preds_hb[1, :].detach().cpu().numpy()
selected_ids = selected_ids.squeeze().detach().cpu().numpy()
# 进行筛选,将(0.9, 1)之间的设置为T,将(0.01,0.9)之间的设置为U,(0,0.01)之间的设置为F
if tt != self.infer_times - 1:
s2_preds_np = s2_preds_hb_np
pos_edge_ids = np.where(s2_preds_np >= 0.9)
neg_edge_ids = np.where(s2_preds_np <= 0.01)
for pos_id in pos_edge_ids[0]:
actual_id = selected_ids[pos_id]
if gt_values[0, actual_id] != 2:
continue
all_pos_ids.add(actual_id)
all_edge_confs[actual_id] = s2_preds_np[pos_id]
gt_values[0, actual_id] = 1
for neg_id in neg_edge_ids[0]:
actual_id = selected_ids[neg_id]
if gt_values[0, actual_id] != 2:
continue
gt_values[0, actual_id] = 0
num_to_pred = (gt_values == 2).sum()
if num_to_pred <= num_filtered:
break
else:
s2_preds_np = s2_preds_hb_np
pos_edge_ids = np.where(s2_preds_np >= 0.5)
for pos_id in pos_edge_ids[0]:
actual_id = selected_ids[pos_id]
if s2_mask[0][pos_id] is True or gt_values[0, actual_id] != 2:
continue
all_pos_ids.add(actual_id)
all_edge_confs[actual_id] = s2_preds_np[pos_id]
pos_edge_ids = list(all_pos_ids)
edge_confs = [all_edge_confs[idx] for idx in pos_edge_ids]
pos_edges = edge_ids[pos_edge_ids].cpu().numpy()
edge_confs = np.array(edge_confs)
if self.image_size[0] != 256:
pred_corners = pred_corners / (self.image_size[0] / 256)
return pred_corners, pred_confs, pos_edges, edge_confs, c_outputs_np, viz_image
#---------------------------------------------------------#
# 将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
#---------------------------------------------------------#
# 根据角点的置信度排序,并筛选出大于置信度的角点坐标
#---------------------------------------------------------#
def corner_nms(preds, confs, image_size):
data = np.zeros([image_size, image_size])
neighborhood_size = 5
threshold = 0
for i in range(len(preds)):
data[preds[i, 1], preds[i, 0]] = confs[i]
data_max = scipy.ndimage.filters.maximum_filter(data, neighborhood_size)
maxima = (data == data_max)
data_min = scipy.ndimage.filters.minimum_filter(data, neighborhood_size)
diff = ((data_max - data_min) > threshold)
maxima[diff == 0] = 0
results = np.where(maxima > 0)
filtered_preds = np.stack([results[1], results[0]], axis=-1)
new_confs = list()
for i, pred in enumerate(filtered_preds):
new_confs.append(data[pred[1], pred[0]])
new_confs = np.array(new_confs)
return filtered_preds, new_confs
def process_image(img):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img = skimage.img_as_float(img)
img = img.transpose((2, 0, 1))
img = (img - np.array(mean)[:, np.newaxis, np.newaxis]) / np.array(std)[:, np.newaxis, np.newaxis]
img = torch.Tensor(img).cuda()
img = img.unsqueeze(0)
return img
def postprocess_preds(corners, confs, edges):
corner_degrees = dict()
for edge_i, edge_pair in enumerate(edges):
corner_degrees[edge_pair[0]] = corner_degrees.setdefault(edge_pair[0], 0) + 1
corner_degrees[edge_pair[1]] = corner_degrees.setdefault(edge_pair[1], 0) + 1
good_ids = [i for i in range(len(corners)) if i in corner_degrees]
if len(good_ids) == len(corners):
return corners, confs, edges
else:
good_corners = corners[good_ids]
good_confs = confs[good_ids]
id_mapping = {value: idx for idx, value in enumerate(good_ids)}
new_edges = list()
for edge_pair in edges:
new_pair = (id_mapping[edge_pair[0]], id_mapping[edge_pair[1]])
new_edges.append(new_pair)
new_edges = np.array(new_edges)
return good_corners, good_confs, new_edges
#---------------------------------------------------------#
# 将输入图像根据角点坐标进行可视化处理
# 不同于源代码,我们需要直接返回图像对象而不是保存到指定地址
#---------------------------------------------------------#
def visualize_cond_generation(positive_pixels, confs, image, gt_corners=None, prec=None, recall=None,
image_masks=None, edges=None, edge_confs=None, shpfile=False):
# 复制原图
image = image.copy()
if confs is not None:
viz_confs = confs
if edges is not None:
preds = positive_pixels.astype(int)
c_degrees = dict()
for edge_i, edge_pair in enumerate(edges):
conf = (edge_confs[edge_i] * 2) - 1
cv2.line(image, tuple(preds[edge_pair[0]]), tuple(preds[edge_pair[1]]), (255 * conf, 255 * conf, 0), 2)
c_degrees[edge_pair[0]] = c_degrees.setdefault(edge_pair[0], 0) + 1
c_degrees[edge_pair[1]] = c_degrees.setdefault(edge_pair[1], 0) + 1
for idx, c in enumerate(positive_pixels):
if edges is not None and idx not in c_degrees:
continue
if confs is None:
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255), -1)
else:
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 0, 255 * viz_confs[idx]), -1)
# if edges is not None:
# cv2.putText(image, '{}'.format(c_degrees[idx]), (int(c[0]), int(c[1] - 5)), cv2.FONT_HERSHEY_SIMPLEX,
# 0.5, (255, 0, 0), 1, cv2.LINE_AA)
if gt_corners is not None:
for c in gt_corners:
cv2.circle(image, (int(c[0]), int(c[1])), 3, (0, 255, 0), -1)
if image_masks is not None:
mask_ids = np.where(image_masks == 1)[0]
for mask_id in mask_ids:
y_idx = mask_id // 64
x_idx = (mask_id - y_idx * 64)
x_coord = x_idx * 4
y_coord = y_idx * 4
cv2.rectangle(image, (x_coord, y_coord), (x_coord + 3, y_coord + 3), (127, 127, 0), thickness=-1)
# if confs is not None:
# cv2.putText(image, 'max conf: {:.3f}'.format(confs.max()), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
# 0.5, (255, 255, 0), 1, cv2.LINE_AA)
if prec is not None:
if isinstance(prec, tuple):
cv2.putText(image, 'edge p={:.2f}, edge r={:.2f}'.format(prec[0], recall[0]), (20, 20),
cv2.FONT_HERSHEY_SIMPLEX,
0.5, (255, 255, 0), 1, cv2.LINE_AA)
cv2.putText(image, 'region p={:.2f}, region r={:.2f}'.format(prec[1], recall[1]), (20, 40),
cv2.FONT_HERSHEY_SIMPLEX,
0.5, (255, 255, 0), 1, cv2.LINE_AA)
else:
cv2.putText(image, 'prec={:.2f}, recall={:.2f}'.format(prec, recall), (20, 20), cv2.FONT_HERSHEY_SIMPLEX,
0.5, (255, 255, 0), 1, cv2.LINE_AA)
# 是否生成shp文件
if shpfile:
preds = positive_pixels.astype(int)
# 获取点列表
Polyline = []
for edge_i, edge_pair in enumerate(edges):
Polyline.append([preds[edge_pair[0]], preds[edge_pair[1]]])
Polyline = np.array(Polyline, dtype=np.int32)
# 写入shp文件
writeShp(save_file_dir="shpfile", Polyline=Polyline)
return image
def writeShp(save_file_dir="shpfile", Polyline=None):
# 创建文件夹
if os.path.exists(save_file_dir) is False:
os.makedirs(save_file_dir)
# 支持中文路径
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
# 属性表字段支持中文
gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
# 注册驱动
ogr.RegisterAll()
# 创建shp数据
strDriverName = "ESRI Shapefile"
oDriver = ogr.GetDriverByName(strDriverName)
if oDriver == None:
return "驱动不可用:"+strDriverName
# 创建数据源
file_path = os.path.join(save_file_dir, "result.shp")
oDS = oDriver.CreateDataSource(file_path)
if oDS == None:
return "创建文件失败:result.shp"
if Polyline is not None:
# 创建一个多边形图层,指定坐标系为WGS84
papszLCO = []
geosrs = osr.SpatialReference()
geosrs.SetWellKnownGeogCS("WGS84")
# 线:ogr_type = ogr.wkbLineString
# 点:ogr_type = ogr.wkbPoint
ogr_type = ogr.wkbMultiLineString
# 面的类型为Polygon,线的类型为Polyline,点的类型为Point
oLayer = oDS.CreateLayer("Polyline", geosrs, ogr_type, papszLCO)
if oLayer == None:
return "图层创建失败!"
# 创建属性表
# 创建id字段
oId = ogr.FieldDefn("id", ogr.OFTInteger)
oLayer.CreateField(oId, 1)
# 创建name字段
oName = ogr.FieldDefn("name", ogr.OFTString)
oLayer.CreateField(oName, 1)
oDefn = oLayer.GetLayerDefn()
# 创建要素
# 数据集
# wkt_geom id name
point_str_list = ['({} {},{} {})'.format(row[0, 0], row[0, 1], row[1, 0], row[1, 1]) for row in Polyline]
Polyline_Wkt = ','.join(point_str_list)
features = ['Polyline0;MULTILINESTRING({})'.format(Polyline_Wkt)]
for index, f in enumerate(features):
oFeaturePolygon = ogr.Feature(oDefn)
oFeaturePolygon.SetField("id",index)
oFeaturePolygon.SetField("name",f.split(";")[0])
geomPolygon = ogr.CreateGeometryFromWkt(f.split(";")[1])
oFeaturePolygon.SetGeometry(geomPolygon)
oLayer.CreateFeature(oFeaturePolygon)
# 创建完成后,关闭进程
oDS.Destroy()
return "数据集创建完成!"