|
|
|
import numpy as np |
|
import cv2 |
|
|
|
from models_loader import load_model |
|
from utils import ( |
|
draw_detections, apply_nms, count_classes, compute_metrics, |
|
generate_bar_chart, generate_pie_chart, generate_html_table, generate_metrics_table |
|
) |
|
|
|
def full_inference(image, model_choice): |
|
model = load_model(model_choice) |
|
pred = model.predict(image, conf=0.25)[0] |
|
|
|
final_boxes = [] |
|
for b in pred.boxes: |
|
cid = int(b.cls[0].item()) |
|
cname = model.names[cid] |
|
xyxy = b.xyxy[0].cpu().numpy() |
|
conf = float(b.conf[0].item()) |
|
final_boxes.append([*xyxy, conf, cid, cname]) |
|
|
|
annotated = draw_detections(image, final_boxes) |
|
c_counts = count_classes(final_boxes) |
|
metrics = compute_metrics(final_boxes) |
|
return annotated, final_boxes, c_counts, metrics |
|
|
|
def gois_inference( |
|
image, model_choice, |
|
coarse_size, fine_size, |
|
coarse_overlap, fine_overlap, |
|
nms_thresh, |
|
skip_width_threshold=200, |
|
skip_height_threshold=200 |
|
): |
|
model = load_model(model_choice) |
|
H, W = image.shape[:2] |
|
|
|
|
|
coarse_step = max(1, int(coarse_size * (1 - coarse_overlap))) |
|
coarse_patches = [] |
|
patch_coords = [] |
|
|
|
for y in range(0, H, coarse_step): |
|
for x in range(0, W, coarse_step): |
|
patch = image[y:y+coarse_size, x:x+coarse_size] |
|
coarse_patches.append(patch) |
|
patch_coords.append((x, y)) |
|
|
|
coarse_results = model.predict(coarse_patches, conf=0.25, batch=16) |
|
coarse_boxes = [] |
|
for i, result in enumerate(coarse_results): |
|
px, py = patch_coords[i] |
|
for b in result.boxes: |
|
cid = int(b.cls[0].item()) |
|
cname = model.names[cid] |
|
conf = float(b.conf[0].item()) |
|
xyxy = b.xyxy[0].cpu().numpy() |
|
xyxy[0] += px |
|
xyxy[1] += py |
|
xyxy[2] += px |
|
xyxy[3] += py |
|
coarse_boxes.append([*xyxy, conf, cid, cname]) |
|
|
|
|
|
fine_patches = [] |
|
fine_positions = [] |
|
|
|
for (x1,y1,x2,y2,conf,cid,cname) in coarse_boxes: |
|
width = x2 - x1 |
|
height = y2 - y1 |
|
|
|
if width >= skip_width_threshold or height >= skip_height_threshold: |
|
continue |
|
|
|
x1c = max(0, int(x1)) |
|
y1c = max(0, int(y1)) |
|
x2c = min(W, int(x2)) |
|
y2c = min(H, int(y2)) |
|
if x2c <= x1c or y2c <= y1c: |
|
continue |
|
|
|
roi = image[y1c:y2c, x1c:x2c] |
|
roiH, roiW = roi.shape[:2] |
|
fstep = max(1, int(fine_size * (1 - fine_overlap))) |
|
for fy in range(0, roiH, fstep): |
|
for fx in range(0, roiW, fstep): |
|
sub_patch = roi[fy:fy+fine_size, fx:fx+fine_size] |
|
fine_patches.append(sub_patch) |
|
fine_positions.append((x1c+fx,y1c+fy)) |
|
|
|
fine_boxes = [] |
|
if fine_patches: |
|
fine_results = model.predict(fine_patches, conf=0.4, batch=16) |
|
for i, fres in enumerate(fine_results): |
|
fx, fy = fine_positions[i] |
|
for b in fres.boxes: |
|
cid = int(b.cls[0].item()) |
|
cname = model.names[cid] |
|
conf = float(b.conf[0].item()) |
|
xyxy = b.xyxy[0].cpu().numpy() |
|
xyxy[0]+=fx |
|
xyxy[1]+=fy |
|
xyxy[2]+=fx |
|
xyxy[3]+=fy |
|
fine_boxes.append([*xyxy, conf, cid, cname]) |
|
|
|
combined = coarse_boxes + fine_boxes |
|
final_boxes = apply_nms(combined, iou_thresh=nms_thresh) |
|
|
|
annotated = draw_detections(image, final_boxes) |
|
c_counts = count_classes(final_boxes) |
|
metrics = compute_metrics(final_boxes) |
|
return annotated, final_boxes, c_counts, metrics |
|
|
|
def run_inference(image, model_choice, c_size, f_size, c_overlap, f_overlap, nms): |
|
|
|
|
|
|
|
if not isinstance(image, np.ndarray): |
|
image = np.array(image, dtype=np.uint8) |
|
if image.shape[-1] == 4: |
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
fi_annot, fi_boxes, fi_ccount, fi_metrics = full_inference(image, model_choice) |
|
|
|
|
|
gois_annot, gois_boxes, gois_ccount, gois_metrics = gois_inference( |
|
image, model_choice, |
|
coarse_size=c_size, |
|
fine_size=f_size, |
|
coarse_overlap=c_overlap, |
|
fine_overlap=f_overlap, |
|
nms_thresh=nms |
|
) |
|
|
|
|
|
bar_chart = generate_bar_chart(fi_metrics, gois_metrics) |
|
fi_pie = generate_pie_chart(fi_ccount, "FI-Det Class Distribution") |
|
gois_pie = generate_pie_chart(gois_ccount, "GOIS-Det Class Distribution") |
|
|
|
fi_table = generate_html_table(fi_ccount, "FI-Det Class Distribution") |
|
gois_table = generate_html_table(gois_ccount, "GOIS-Det Class Distribution") |
|
metrics_tbl= generate_metrics_table(fi_metrics, gois_metrics, "Comparison Metrics") |
|
|
|
return ( |
|
fi_annot, |
|
gois_annot, |
|
fi_table, |
|
gois_table, |
|
bar_chart, |
|
fi_pie, |
|
gois_pie, |
|
metrics_tbl |
|
) |
|
|