Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- app.py +170 -0
- inference.py +198 -0
app.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 3 |
+
import torch
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from gradio_image_prompter import ImagePrompter
|
| 6 |
+
from torch.nn import DataParallel
|
| 7 |
+
from models.counter_infer import build_model
|
| 8 |
+
from utils.arg_parser import get_argparser
|
| 9 |
+
from utils.data import resize_and_pad
|
| 10 |
+
import torchvision.ops as ops
|
| 11 |
+
from torchvision import transforms as T
|
| 12 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
# Load model (once, to avoid reloading)
|
| 16 |
+
def load_model():
|
| 17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
args = get_argparser().parse_args()
|
| 19 |
+
args.zero_shot = True
|
| 20 |
+
model = DataParallel(build_model(args).to(device))
|
| 21 |
+
model.load_state_dict(torch.load('CNTQG_multitrain_ca44.pth', weights_only=True)['model'], strict=False)
|
| 22 |
+
model.eval()
|
| 23 |
+
return model, device
|
| 24 |
+
|
| 25 |
+
model, device = load_model()
|
| 26 |
+
|
| 27 |
+
# **Function to Process Image Once**
|
| 28 |
+
def process_image_once(inputs, enable_mask):
|
| 29 |
+
model.module.return_masks = enable_mask
|
| 30 |
+
|
| 31 |
+
image = inputs['image']
|
| 32 |
+
drawn_boxes = inputs['points']
|
| 33 |
+
image_tensor = torch.tensor(image).to(device)
|
| 34 |
+
image_tensor = image_tensor.permute(2, 0, 1).float() / 255.0
|
| 35 |
+
image_tensor = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image_tensor)
|
| 36 |
+
|
| 37 |
+
bboxes_tensor = torch.tensor([[box[0], box[1], box[3], box[4]] for box in drawn_boxes], dtype=torch.float32).to(device)
|
| 38 |
+
|
| 39 |
+
img, bboxes, scale = resize_and_pad(image_tensor, bboxes_tensor, size=1024.0)
|
| 40 |
+
img = img.unsqueeze(0).to(device)
|
| 41 |
+
bboxes = bboxes.unsqueeze(0).to(device)
|
| 42 |
+
|
| 43 |
+
with torch.no_grad():
|
| 44 |
+
outputs, _, _, _, masks = model(img, bboxes)
|
| 45 |
+
|
| 46 |
+
return image, outputs, masks, img, scale, drawn_boxes
|
| 47 |
+
|
| 48 |
+
# **Post-process and Update Output**
|
| 49 |
+
def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold):
|
| 50 |
+
idx = 0
|
| 51 |
+
threshold = 1/threshold
|
| 52 |
+
keep = ops.nms(outputs[idx]['pred_boxes'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / threshold],
|
| 53 |
+
outputs[idx]['box_v'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / threshold], 0.5)
|
| 54 |
+
|
| 55 |
+
pred_boxes = outputs[idx]['pred_boxes'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / threshold][keep]
|
| 56 |
+
pred_boxes = torch.clamp(pred_boxes, 0, 1)
|
| 57 |
+
|
| 58 |
+
pred_boxes = (pred_boxes.cpu() / scale * img.shape[-1]).tolist()
|
| 59 |
+
|
| 60 |
+
image = Image.fromarray((image).astype(np.uint8))
|
| 61 |
+
|
| 62 |
+
if enable_mask:
|
| 63 |
+
from matplotlib import pyplot as plt
|
| 64 |
+
masks_ = masks[idx][(outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / threshold)[0]]
|
| 65 |
+
N_masks = masks_.shape[0]
|
| 66 |
+
indices = torch.randint(1, N_masks + 1, (1, N_masks), device=masks_.device).view(-1, 1, 1)
|
| 67 |
+
masks = (masks_ * indices).sum(dim=0)
|
| 68 |
+
mask_display = (
|
| 69 |
+
T.Resize((int(img.shape[2] / scale), int(img.shape[3] / scale)), interpolation=T.InterpolationMode.NEAREST)(
|
| 70 |
+
masks.cpu().unsqueeze(0))[0])[:image.size[1], :image.size[0]]
|
| 71 |
+
cmap = plt.cm.tab20
|
| 72 |
+
norm = plt.Normalize(vmin=0, vmax=N_masks)
|
| 73 |
+
del masks
|
| 74 |
+
del masks_
|
| 75 |
+
del outputs
|
| 76 |
+
rgba_image = cmap(norm(mask_display))
|
| 77 |
+
rgba_image[mask_display == 0, -1] = 0
|
| 78 |
+
rgba_image[mask_display != 0, -1] = 0.5
|
| 79 |
+
|
| 80 |
+
overlay = Image.fromarray((rgba_image * 255).astype(np.uint8), mode="RGBA")
|
| 81 |
+
image = image.convert("RGBA")
|
| 82 |
+
image = Image.alpha_composite(image, overlay)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
draw = ImageDraw.Draw(image)
|
| 86 |
+
for box in pred_boxes:
|
| 87 |
+
draw.rectangle([box[0], box[1], box[2], box[3]], outline="orange", width=5)
|
| 88 |
+
# for box in drawn_boxes:
|
| 89 |
+
# draw.rectangle([box[0], box[1], box[3], box[4]], outline="red", width=3)
|
| 90 |
+
|
| 91 |
+
width, height = image.size
|
| 92 |
+
square_size = int(0.05 * width)
|
| 93 |
+
x1, y1 = 10, height - square_size - 10
|
| 94 |
+
x2, y2 = x1 + square_size, y1 + square_size
|
| 95 |
+
|
| 96 |
+
# draw.rectangle([x1, y1, x2, y2], outline="black", fill="black", width=1)
|
| 97 |
+
# font = ImageFont.load_default()
|
| 98 |
+
# txt = str(len(pred_boxes))
|
| 99 |
+
# w = draw.textlength(txt, font=font)
|
| 100 |
+
# text_x = x1 + (square_size - w) / 2
|
| 101 |
+
# text_y = y1 + (square_size - 10) / 2
|
| 102 |
+
# draw.text((text_x, text_y), txt, fill="white", font=font)
|
| 103 |
+
|
| 104 |
+
return image, len(pred_boxes)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
iface = gr.Blocks()
|
| 108 |
+
|
| 109 |
+
with iface:
|
| 110 |
+
# Store intermediate states
|
| 111 |
+
image_input = gr.State()
|
| 112 |
+
outputs_state = gr.State()
|
| 113 |
+
masks_state = gr.State()
|
| 114 |
+
img_state = gr.State()
|
| 115 |
+
scale_state = gr.State()
|
| 116 |
+
drawn_boxes_state = gr.State()
|
| 117 |
+
|
| 118 |
+
# UI Layout: Input Section
|
| 119 |
+
with gr.Row():
|
| 120 |
+
image_prompter = ImagePrompter()
|
| 121 |
+
image_output = gr.Image(type="pil")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# UI Layout: Output Section
|
| 125 |
+
with gr.Row():
|
| 126 |
+
count_output = gr.Number(label="Total Count")
|
| 127 |
+
enable_mask = gr.Checkbox(label="Predict masks", value=True) # Mask enabled by default
|
| 128 |
+
threshold = gr.Slider(0.05, 0.95, value=0.33, step=0.01, label="Threshold") # Updated range and default
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# Create the 'Count' button
|
| 132 |
+
count_button = gr.Button("Count")
|
| 133 |
+
|
| 134 |
+
# Process image once when "Count" button is pressed
|
| 135 |
+
def initial_process(inputs, enable_mask, threshold):
|
| 136 |
+
# Perform inference once
|
| 137 |
+
image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask)
|
| 138 |
+
|
| 139 |
+
# Save intermediate states
|
| 140 |
+
return (
|
| 141 |
+
*post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold), # Processed outputs
|
| 142 |
+
image, outputs, masks, img, scale, drawn_boxes # Store in states for later use
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Update image and count when the threshold slider changes (post-process only)
|
| 146 |
+
def update_threshold(threshold, image, outputs, masks, img, scale, drawn_boxes, enable_mask):
|
| 147 |
+
return post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold)
|
| 148 |
+
|
| 149 |
+
# Run initial inference and post-process when "Count" button is clicked
|
| 150 |
+
count_button.click(
|
| 151 |
+
initial_process,
|
| 152 |
+
[image_prompter, enable_mask, threshold], # Inputs
|
| 153 |
+
[image_output, count_output, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state] # Outputs + States
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Adjust the output dynamically based on the threshold slider (no re-inference)
|
| 157 |
+
threshold.change(
|
| 158 |
+
update_threshold,
|
| 159 |
+
[threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
|
| 160 |
+
[image_output, count_output]
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
enable_mask.change(
|
| 164 |
+
update_threshold,
|
| 165 |
+
[threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
|
| 166 |
+
[image_output, count_output]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
iface.launch(share=True)
|
| 170 |
+
|
inference.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import skimage
|
| 8 |
+
import torch
|
| 9 |
+
from torch.nn import DataParallel
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from torchvision import ops
|
| 12 |
+
from torchvision.transforms import Resize
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from models.counter_infer import build_model
|
| 15 |
+
from models.matcher import build_matcher
|
| 16 |
+
from utils.arg_parser import get_argparser
|
| 17 |
+
from utils.box_ops import BoxList
|
| 18 |
+
from utils.data import FSC147DATASET, pad_collate_test
|
| 19 |
+
from utils.losses import SetCriterion
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@torch.no_grad()
|
| 23 |
+
def evaluate(args):
|
| 24 |
+
gpu=0
|
| 25 |
+
torch.cuda.set_device(gpu)
|
| 26 |
+
device = torch.device(gpu)
|
| 27 |
+
|
| 28 |
+
model = DataParallel(
|
| 29 |
+
build_model(args).to(device),
|
| 30 |
+
device_ids=[gpu],
|
| 31 |
+
output_device=gpu
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
state_dict = torch.load(os.path.join(args.model_path, f'{args.model_name}.pth'))['model']
|
| 35 |
+
state_dict = {k if 'module.' in k else 'module.' + k: v for k, v in state_dict.items()}
|
| 36 |
+
model.load_state_dict(state_dict, strict=False)
|
| 37 |
+
|
| 38 |
+
for split in ['val', 'test']:
|
| 39 |
+
test = FSC147DATASET(
|
| 40 |
+
args.data_path,
|
| 41 |
+
args.image_size,
|
| 42 |
+
split=split,
|
| 43 |
+
num_objects=args.num_objects,
|
| 44 |
+
tiling_p=args.tiling_p,
|
| 45 |
+
return_ids=True,
|
| 46 |
+
training=False
|
| 47 |
+
)
|
| 48 |
+
test_loader = DataLoader(
|
| 49 |
+
test,
|
| 50 |
+
batch_size=args.batch_size,
|
| 51 |
+
drop_last=False,
|
| 52 |
+
num_workers=args.num_workers,
|
| 53 |
+
collate_fn=pad_collate_test,
|
| 54 |
+
)
|
| 55 |
+
ae = torch.tensor(0.0).to(device)
|
| 56 |
+
se = torch.tensor(0.0).to(device)
|
| 57 |
+
model.eval()
|
| 58 |
+
matcher = build_matcher(args)
|
| 59 |
+
criterion = SetCriterion(0, matcher, {"loss_giou":args.giou_loss_coef}, ["bboxes", "ce"], focal_alpha=args.focal_alpha)
|
| 60 |
+
criterion.to(device)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
predictions = dict()
|
| 64 |
+
predictions["categories"] = [{"name": "fg", "id": 1}]
|
| 65 |
+
predictions["images"] = list()
|
| 66 |
+
predictions["annotations"] = list()
|
| 67 |
+
anno_id = 1
|
| 68 |
+
|
| 69 |
+
for img, bboxes, density_map, ids, gt_bboxes, scaling_factor, padwh in test_loader:
|
| 70 |
+
img = img.to(device)
|
| 71 |
+
bboxes = bboxes.to(device)
|
| 72 |
+
gt_bboxes = gt_bboxes.to(device)
|
| 73 |
+
|
| 74 |
+
outputs, ref_points, _, _, masks = model(img, bboxes)
|
| 75 |
+
|
| 76 |
+
w, h = img.shape[-1], img.shape[-2]
|
| 77 |
+
losses = []
|
| 78 |
+
num_objects_gt = []
|
| 79 |
+
num_objects_pred = []
|
| 80 |
+
nms_bboxes = []
|
| 81 |
+
nms_scores = []
|
| 82 |
+
nms_masks = []
|
| 83 |
+
for idx in range(img.shape[0]):
|
| 84 |
+
|
| 85 |
+
thr = 1/0.11
|
| 86 |
+
|
| 87 |
+
if len(outputs[idx]['pred_boxes'][-1]) == 0:
|
| 88 |
+
nms_bboxes.append(torch.zeros((0, 4)))
|
| 89 |
+
nms_scores.append(torch.zeros((0)))
|
| 90 |
+
num_objects_pred.append(0)
|
| 91 |
+
|
| 92 |
+
else:
|
| 93 |
+
|
| 94 |
+
# threshold and NMS
|
| 95 |
+
v = outputs[idx]["box_v"]
|
| 96 |
+
v_thr = v.max() / thr
|
| 97 |
+
mask = v > v_thr
|
| 98 |
+
keep = ops.nms(
|
| 99 |
+
outputs[idx]["pred_boxes"][mask],
|
| 100 |
+
v[mask],
|
| 101 |
+
0.5,
|
| 102 |
+
)
|
| 103 |
+
boxes = outputs[idx]["pred_boxes"][mask][keep]
|
| 104 |
+
boxes = torch.clamp(boxes, 0, 1)
|
| 105 |
+
scores = outputs[idx]["scores"][mask][keep]
|
| 106 |
+
|
| 107 |
+
# remove bboxes in padded area
|
| 108 |
+
maxw = (img.shape[-1] - padwh[idx][0]).to(device)
|
| 109 |
+
maxh = (img.shape[-2] - padwh[idx][1]).to(device)
|
| 110 |
+
center = (boxes[:, :2] + boxes[:, 2:]) / 2
|
| 111 |
+
valid = (center[:, 0] * h < maxw) & (center[:, 1] * w < maxh)
|
| 112 |
+
scores = scores[valid]
|
| 113 |
+
boxes = boxes[valid]
|
| 114 |
+
|
| 115 |
+
nms_bboxes.append(boxes)
|
| 116 |
+
nms_scores.append(scores)
|
| 117 |
+
num_objects_pred.append(len(boxes))
|
| 118 |
+
|
| 119 |
+
if False:
|
| 120 |
+
from matplotlib import pyplot as plt
|
| 121 |
+
fig1 = plt.figure(figsize=(8, 8))
|
| 122 |
+
((ax1_11, ax1_12), (ax1_21, ax1_22)) = fig1.subplots(2, 2)
|
| 123 |
+
fig1.tight_layout(pad=2.5)
|
| 124 |
+
img_ = np.array((img).cpu()[idx].permute(1, 2, 0))
|
| 125 |
+
img_ = img_ - np.min(img_)
|
| 126 |
+
img_ = img_ / np.max(img_)
|
| 127 |
+
ax1_11.imshow(img_)
|
| 128 |
+
ax1_11.set_title("Input", fontsize=8)
|
| 129 |
+
bboxes_ = np.array(bboxes.cpu())[idx]
|
| 130 |
+
for i in range(3):
|
| 131 |
+
ax1_11.plot([bboxes_[i][0], bboxes_[i][0], bboxes_[i][2], bboxes_[i][2], bboxes_[i][0]],
|
| 132 |
+
[bboxes_[i][1], bboxes_[i][3], bboxes_[i][3], bboxes_[i][1], bboxes_[i][1]], c='r')
|
| 133 |
+
ax1_12.imshow(img_)
|
| 134 |
+
ax1_12.set_title("gt bboxes", fontsize=8)
|
| 135 |
+
target_bboxes = gt_bboxes[idx][torch.logical_not((gt_bboxes[idx] == 0).all(dim=1))]
|
| 136 |
+
bboxes_ = ((target_bboxes)).detach().cpu()
|
| 137 |
+
for i in range(len(bboxes_)):
|
| 138 |
+
ax1_12.plot([bboxes_[i][0], bboxes_[i][0], bboxes_[i][2], bboxes_[i][2], bboxes_[i][0]],
|
| 139 |
+
[bboxes_[i][1], bboxes_[i][3], bboxes_[i][3], bboxes_[i][1], bboxes_[i][1]], c='g')
|
| 140 |
+
ax1_21.imshow(img_)
|
| 141 |
+
|
| 142 |
+
bboxes_pred = nms_bboxes[idx]
|
| 143 |
+
bboxes_ = ((bboxes_pred * img_.shape[0])).detach().cpu()
|
| 144 |
+
for i in range(len(bboxes_)):
|
| 145 |
+
ax1_21.plot([bboxes_[i][0], bboxes_[i][0], bboxes_[i][2], bboxes_[i][2], bboxes_[i][0]],
|
| 146 |
+
[bboxes_[i][1], bboxes_[i][3], bboxes_[i][3], bboxes_[i][1], bboxes_[i][1]],
|
| 147 |
+
c='orange', linewidth=0.5)
|
| 148 |
+
ax1_21.set_title("#GT-#PRED=" + str(len(target_bboxes) - len(bboxes_pred)))
|
| 149 |
+
from torchvision import transforms as T
|
| 150 |
+
res = T.Resize((1024, 1024))
|
| 151 |
+
ax1_21.imshow(res(centerness).detach().cpu()[idx][0], alpha=0.6)
|
| 152 |
+
plt.savefig(test.image_names[ids[idx].item()], dpi=200)
|
| 153 |
+
plt.close()
|
| 154 |
+
|
| 155 |
+
for idx in range(img.shape[0]):
|
| 156 |
+
img_info = {
|
| 157 |
+
"id": test.map_img_name_to_ori_id()[test.image_names[ids[idx].item()]],
|
| 158 |
+
"file_name": "None",
|
| 159 |
+
}
|
| 160 |
+
bboxes = ops.box_convert(nms_bboxes[idx], 'xyxy', 'xywh')
|
| 161 |
+
bboxes = bboxes * img.shape[-1] / scaling_factor[idx]
|
| 162 |
+
for idxi in range(len(nms_bboxes[idx])):
|
| 163 |
+
box = bboxes[idxi].detach().cpu()
|
| 164 |
+
anno = {
|
| 165 |
+
"id": anno_id,
|
| 166 |
+
"image_id": test.map_img_name_to_ori_id()[test.image_names[ids[idx].item()]],
|
| 167 |
+
"area": int((box[2] * box[3]).item()),
|
| 168 |
+
"bbox": [int(box[0].item()), int(box[1].item()), int(box[2].item()), int(box[3].item())],
|
| 169 |
+
"category_id": 1,
|
| 170 |
+
"score": float(nms_scores[idx][idxi].item()),
|
| 171 |
+
}
|
| 172 |
+
anno_id += 1
|
| 173 |
+
predictions["annotations"].append(anno)
|
| 174 |
+
predictions["images"].append(img_info)
|
| 175 |
+
num_objects_gt = density_map.flatten(1).sum(dim=1)
|
| 176 |
+
num_objects_pred = torch.tensor(num_objects_pred)
|
| 177 |
+
ae += torch.abs(
|
| 178 |
+
num_objects_gt - num_objects_pred
|
| 179 |
+
).sum()
|
| 180 |
+
se += torch.pow(
|
| 181 |
+
num_objects_gt - num_objects_pred, 2
|
| 182 |
+
).sum()
|
| 183 |
+
print(
|
| 184 |
+
f"{split.capitalize()} set",
|
| 185 |
+
f"MAE: {ae.item() / len(test):.2f}",
|
| 186 |
+
f"RMSE: {torch.sqrt(se / len(test)).item():.2f}",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
with open("geco2_" + split + ".json", "w") as handle:
|
| 190 |
+
json.dump(predictions, handle)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == '__main__':
|
| 194 |
+
parser = argparse.ArgumentParser('GECO2', parents=[get_argparser()])
|
| 195 |
+
args = parser.parse_args()
|
| 196 |
+
print(args)
|
| 197 |
+
print("model_name: ", args.model_name)
|
| 198 |
+
evaluate(args)
|