SurgiTrackDemo / tracker.py
akhellad's picture
Initial commit
26a3529
"""
SurgiTrack - Tracker Module (Simplified for HF Space)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import numpy as np
from scipy.optimize import linear_sum_assignment
from dataclasses import dataclass, field
from typing import List, Dict, Optional
import cv2
CLASS_NAMES = ['grasper', 'bipolar', 'hook', 'scissors', 'clipper', 'irrigator', 'specimenbag']
OPERATORS = ['MSLH', 'MSRH', 'ASRH', 'NULL']
class CoordinateAttention(nn.Module):
def __init__(self, in_channels, reduction=32):
super().__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
mid_channels = max(8, in_channels // reduction)
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
self.bn1 = nn.BatchNorm2d(mid_channels)
self.act = nn.ReLU(inplace=True)
self.conv_h = nn.Conv2d(mid_channels, in_channels, kernel_size=1)
self.conv_w = nn.Conv2d(mid_channels, in_channels, kernel_size=1)
def forward(self, x):
B, C, H, W = x.shape
x_h = self.pool_h(x)
x_w = self.pool_w(x).permute(0, 1, 3, 2)
y = torch.cat([x_h, x_w], dim=2)
y = self.act(self.bn1(self.conv1(y)))
x_h, x_w = torch.split(y, [H, W], dim=2)
x_w = x_w.permute(0, 1, 3, 2)
a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()
return x * a_h * a_w
class DirectionEstimator(nn.Module):
def __init__(self, num_classes=4, embedding_dim=128, pretrained=True):
super().__init__()
self.backbone = models.efficientnet_b0(
weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
)
backbone_out = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Identity()
self.coord_attention = CoordinateAttention(backbone_out)
self.embedding_head = nn.Sequential(
nn.Linear(backbone_out, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(512, embedding_dim)
)
self.direction_head = nn.Sequential(
nn.Linear(embedding_dim, 64),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(64, num_classes)
)
self.embedding_dim = embedding_dim
def forward(self, x, return_embedding=False):
features = self.backbone.features(x)
features = self.coord_attention(features)
features = self.backbone.avgpool(features)
features = features.flatten(1)
embedding = self.embedding_head(features)
embedding = F.normalize(embedding, p=2, dim=1)
direction = self.direction_head(embedding)
if return_embedding:
return direction, embedding
return direction
@dataclass
class Detection:
bbox: np.ndarray
class_id: int
class_name: str
confidence: float
frame_id: int
@dataclass
class OperatorSlot:
operator_id: int
operator_name: str
track_id: int
active: bool = False
class_id: int = -1
class_name: str = ""
bbox: np.ndarray = None
confidence: float = 0.0
embedding: np.ndarray = None
last_seen_frame: int = -1
total_detections: int = 0
bbox_history: List[np.ndarray] = field(default_factory=list)
class_history: List[int] = field(default_factory=list)
def update(self, detection: Detection, embedding: np.ndarray, frame_id: int):
self.active = True
self.bbox = detection.bbox
self.class_id = detection.class_id
self.class_name = detection.class_name
self.confidence = detection.confidence
self.embedding = embedding
self.last_seen_frame = frame_id
self.total_detections += 1
self.bbox_history.append(detection.bbox.copy())
self.class_history.append(detection.class_id)
if len(self.bbox_history) > 100:
self.bbox_history.pop(0)
self.class_history.pop(0)
def mark_inactive(self):
self.active = False
def frames_since_seen(self, current_frame: int) -> int:
if self.last_seen_frame < 0:
return float('inf')
return current_frame - self.last_seen_frame
class OperatorBasedTracker:
MAX_GRASPERS = 3
GRASPER_CLASS_ID = 0
SINGLE_INSTANCE_CLASSES = {1, 2, 3, 4, 5, 6}
def __init__(
self,
direction_model: DirectionEstimator = None,
max_inactive_frames: int = 300,
iou_threshold: float = 0.3,
direction_confidence_threshold: float = 0.5,
device: str = "cuda"
):
self.direction_model = direction_model
self.max_inactive_frames = max_inactive_frames
self.iou_threshold = iou_threshold
self.direction_confidence_threshold = direction_confidence_threshold
self.device = device
self.grasper_slots: List[OperatorSlot] = []
self.class_slots: Dict[int, OperatorSlot] = {}
self.next_track_id = 1
self.frame_count = 0
self._initialize_slots()
if self.direction_model is not None:
self.direction_model.to(device)
self.direction_model.eval()
def _initialize_slots(self):
for i in range(self.MAX_GRASPERS):
slot = OperatorSlot(
operator_id=-1,
operator_name=f"grasper_{i+1}",
track_id=self.next_track_id
)
slot.class_id = self.GRASPER_CLASS_ID
slot.class_name = 'grasper'
self.next_track_id += 1
self.grasper_slots.append(slot)
for class_id in self.SINGLE_INSTANCE_CLASSES:
slot = OperatorSlot(
operator_id=3,
operator_name=f"CLASS_{CLASS_NAMES[class_id]}",
track_id=self.next_track_id
)
slot.class_id = class_id
slot.class_name = CLASS_NAMES[class_id]
self.next_track_id += 1
self.class_slots[class_id] = slot
def _get_direction_prediction(self, frame: np.ndarray, bbox: np.ndarray):
if self.direction_model is None:
return 3, np.array([0.25, 0.25, 0.25, 0.25])
x1, y1, x2, y2 = bbox.astype(int)
h, w = frame.shape[:2]
pad_x = int((x2 - x1) * 0.3)
pad_y = int((y2 - y1) * 0.5)
x1 = max(0, x1 - pad_x)
y1 = max(0, y1 - pad_y)
x2 = min(w, x2 + pad_x)
y2 = min(h, y2 + pad_y)
crop = frame[y1:y2, x1:x2]
if crop.size == 0:
return 3, np.array([0.25, 0.25, 0.25, 0.25])
crop = cv2.resize(crop, (224, 224))
crop = crop.astype(np.float32) / 255.0
crop = (crop - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
crop = torch.from_numpy(crop).permute(2, 0, 1).unsqueeze(0).float().to(self.device)
with torch.no_grad():
logits, embedding = self.direction_model(crop, return_embedding=True)
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
return np.argmax(probs), probs
def _compute_iou(self, bbox1: np.ndarray, bbox2: np.ndarray) -> float:
if bbox1 is None or bbox2 is None:
return 0.0
x1 = max(bbox1[0], bbox2[0])
y1 = max(bbox1[1], bbox2[1])
x2 = min(bbox1[2], bbox2[2])
y2 = min(bbox1[3], bbox2[3])
inter = max(0, x2 - x1) * max(0, y2 - y1)
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
union = area1 + area2 - inter
return inter / (union + 1e-6)
def _find_best_slot(self, detection: Detection, predicted_op: int, direction_probs: np.ndarray) -> Optional[OperatorSlot]:
class_id = detection.class_id
if class_id in self.SINGLE_INSTANCE_CLASSES:
slot = self.class_slots.get(class_id)
if slot:
recency = slot.frames_since_seen(self.frame_count)
if not slot.active and recency >= 75:
slot.track_id = self.next_track_id
self.next_track_id += 1
return slot
if class_id == self.GRASPER_CLASS_ID:
direction_confident = predicted_op < 3 and direction_probs[predicted_op] > self.direction_confidence_threshold
best_slot = None
best_score = -1
for slot in self.grasper_slots:
if slot.bbox is None:
continue
recency = slot.frames_since_seen(self.frame_count)
if recency >= 75:
continue
iou = self._compute_iou(detection.bbox, slot.bbox)
det_center = (detection.bbox[:2] + detection.bbox[2:]) / 2
slot_center = (slot.bbox[:2] + slot.bbox[2:]) / 2
dist = np.linalg.norm(det_center - slot_center)
if iou > self.iou_threshold:
score = iou + (0.2 if slot.operator_id == predicted_op else 0)
elif dist < 150 and recency < 30:
score = 0.1 + (0.2 if slot.operator_id == predicted_op else 0)
else:
continue
if score > best_score:
best_score = score
best_slot = slot
if best_slot:
return best_slot
if direction_confident:
for slot in self.grasper_slots:
if slot.active or slot.bbox is None:
continue
if slot.operator_id == predicted_op and slot.frames_since_seen(self.frame_count) < 75:
return slot
if not direction_confident:
for slot in self.grasper_slots:
if slot.active or slot.bbox is None:
continue
if slot.frames_since_seen(self.frame_count) < 30:
det_center = (detection.bbox[:2] + detection.bbox[2:]) / 2
slot_center = (slot.bbox[:2] + slot.bbox[2:]) / 2
dist = np.linalg.norm(det_center - slot_center)
if dist < 100:
return slot
for slot in self.grasper_slots:
if not slot.active:
slot.track_id = self.next_track_id
self.next_track_id += 1
return slot
worst_slot = None
worst_iou = 1.0
for slot in self.grasper_slots:
iou = self._compute_iou(detection.bbox, slot.bbox)
if iou < worst_iou:
worst_iou = iou
worst_slot = slot
if worst_slot:
worst_slot.track_id = self.next_track_id
self.next_track_id += 1
return worst_slot
return None
def update(self, frame: np.ndarray, detections: List[Detection]) -> List[OperatorSlot]:
self.frame_count += 1
all_slots = self.grasper_slots + list(self.class_slots.values())
for slot in all_slots:
if slot.active and slot.frames_since_seen(self.frame_count) > 150:
slot.mark_inactive()
if len(detections) == 0:
return self._get_active_slots()
detection_info = []
for det in detections:
pred_op, probs = self._get_direction_prediction(frame, det.bbox)
detection_info.append((det, pred_op, probs))
detection_info.sort(key=lambda x: -x[0].confidence)
assigned_slots = set()
for det, pred_op, probs in detection_info:
slot = self._find_best_slot(det, pred_op, probs)
if slot and slot.track_id not in assigned_slots:
slot.update(det, probs, self.frame_count)
if det.class_id == self.GRASPER_CLASS_ID:
slot.operator_id = pred_op
assigned_slots.add(slot.track_id)
return self._get_active_slots()
def _get_active_slots(self) -> List[OperatorSlot]:
active = []
for slot in self.grasper_slots:
if slot.active and slot.last_seen_frame == self.frame_count:
active.append(slot)
for slot in self.class_slots.values():
if slot.active and slot.last_seen_frame == self.frame_count:
active.append(slot)
return active
def reset(self):
self.grasper_slots = []
self.class_slots = {}
self.next_track_id = 1
self.frame_count = 0
self._initialize_slots()