MMUZAMMUL123's picture
Update gois_core.py
4e8f193 verified
raw
history blame contribute delete
5.25 kB
# gois_core.py
import numpy as np
import cv2
from models_loader import load_model
from utils import (
draw_detections, apply_nms, count_classes, compute_metrics,
generate_bar_chart, generate_pie_chart, generate_html_table, generate_metrics_table
)
def full_inference(image, model_choice):
model = load_model(model_choice)
pred = model.predict(image, conf=0.25)[0]
final_boxes = []
for b in pred.boxes:
cid = int(b.cls[0].item())
cname = model.names[cid]
xyxy = b.xyxy[0].cpu().numpy()
conf = float(b.conf[0].item())
final_boxes.append([*xyxy, conf, cid, cname])
annotated = draw_detections(image, final_boxes)
c_counts = count_classes(final_boxes)
metrics = compute_metrics(final_boxes)
return annotated, final_boxes, c_counts, metrics
def gois_inference(
image, model_choice,
coarse_size, fine_size,
coarse_overlap, fine_overlap,
nms_thresh,
skip_width_threshold=200,
skip_height_threshold=200
):
model = load_model(model_choice)
H, W = image.shape[:2]
# --- Coarse slicing
coarse_step = max(1, int(coarse_size * (1 - coarse_overlap)))
coarse_patches = []
patch_coords = []
for y in range(0, H, coarse_step):
for x in range(0, W, coarse_step):
patch = image[y:y+coarse_size, x:x+coarse_size]
coarse_patches.append(patch)
patch_coords.append((x, y))
coarse_results = model.predict(coarse_patches, conf=0.25, batch=16)
coarse_boxes = []
for i, result in enumerate(coarse_results):
px, py = patch_coords[i]
for b in result.boxes:
cid = int(b.cls[0].item())
cname = model.names[cid]
conf = float(b.conf[0].item())
xyxy = b.xyxy[0].cpu().numpy()
xyxy[0] += px
xyxy[1] += py
xyxy[2] += px
xyxy[3] += py
coarse_boxes.append([*xyxy, conf, cid, cname])
# --- Fine slicing (skip large boxes)
fine_patches = []
fine_positions = []
for (x1,y1,x2,y2,conf,cid,cname) in coarse_boxes:
width = x2 - x1
height = y2 - y1
# skip if large
if width >= skip_width_threshold or height >= skip_height_threshold:
continue
x1c = max(0, int(x1))
y1c = max(0, int(y1))
x2c = min(W, int(x2))
y2c = min(H, int(y2))
if x2c <= x1c or y2c <= y1c:
continue
roi = image[y1c:y2c, x1c:x2c]
roiH, roiW = roi.shape[:2]
fstep = max(1, int(fine_size * (1 - fine_overlap)))
for fy in range(0, roiH, fstep):
for fx in range(0, roiW, fstep):
sub_patch = roi[fy:fy+fine_size, fx:fx+fine_size]
fine_patches.append(sub_patch)
fine_positions.append((x1c+fx,y1c+fy))
fine_boxes = []
if fine_patches:
fine_results = model.predict(fine_patches, conf=0.4, batch=16)
for i, fres in enumerate(fine_results):
fx, fy = fine_positions[i]
for b in fres.boxes:
cid = int(b.cls[0].item())
cname = model.names[cid]
conf = float(b.conf[0].item())
xyxy = b.xyxy[0].cpu().numpy()
xyxy[0]+=fx
xyxy[1]+=fy
xyxy[2]+=fx
xyxy[3]+=fy
fine_boxes.append([*xyxy, conf, cid, cname])
combined = coarse_boxes + fine_boxes
final_boxes = apply_nms(combined, iou_thresh=nms_thresh)
annotated = draw_detections(image, final_boxes)
c_counts = count_classes(final_boxes)
metrics = compute_metrics(final_boxes)
return annotated, final_boxes, c_counts, metrics
def run_inference(image, model_choice, c_size, f_size, c_overlap, f_overlap, nms):
# NO local import numpy here (we already have it at the top)
# If needed, convert from PIL
if not isinstance(image, np.ndarray):
image = np.array(image, dtype=np.uint8)
if image.shape[-1] == 4: # RGBA -> RGB
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# Full image detection
fi_annot, fi_boxes, fi_ccount, fi_metrics = full_inference(image, model_choice)
# GOIS detection
gois_annot, gois_boxes, gois_ccount, gois_metrics = gois_inference(
image, model_choice,
coarse_size=c_size,
fine_size=f_size,
coarse_overlap=c_overlap,
fine_overlap=f_overlap,
nms_thresh=nms
)
# Charts and HTML tables
bar_chart = generate_bar_chart(fi_metrics, gois_metrics)
fi_pie = generate_pie_chart(fi_ccount, "FI-Det Class Distribution")
gois_pie = generate_pie_chart(gois_ccount, "GOIS-Det Class Distribution")
fi_table = generate_html_table(fi_ccount, "FI-Det Class Distribution")
gois_table = generate_html_table(gois_ccount, "GOIS-Det Class Distribution")
metrics_tbl= generate_metrics_table(fi_metrics, gois_metrics, "Comparison Metrics")
return (
fi_annot,
gois_annot,
fi_table,
gois_table,
bar_chart,
fi_pie,
gois_pie,
metrics_tbl
)