Spaces:
Running
Running
import os | |
import torch | |
import torchvision | |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
from PIL import Image, ImageDraw | |
from torchvision.transforms import functional as F | |
from huggingface_hub import hf_hub_download | |
class FootDetection: | |
def __init__(self, device="cpu"): | |
self.device = torch.device(device) | |
self.checkpoint_dir = "checkpoints" | |
self.checkpoint_file = "fasterrcnn_foot.pth" | |
self.model = self._load_model() | |
self.last_detection = None | |
def _load_model(self): | |
local_path = os.path.join(self.checkpoint_dir, self.checkpoint_file) | |
# Download if not exists | |
if not os.path.exists(local_path): | |
os.makedirs(self.checkpoint_dir, exist_ok=True) | |
print("Downloading model from Hugging Face...") | |
local_path = hf_hub_download( | |
repo_id="tonyassi/foot-detection", | |
filename=self.checkpoint_file, | |
local_dir=self.checkpoint_dir | |
) | |
# Load model | |
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT") | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2) | |
model.load_state_dict(torch.load(local_path, map_location=self.device)) | |
model.to(self.device) | |
model.eval() | |
return model | |
def detect(self, image, threshold=0.1): | |
"""Run foot detection on a PIL image""" | |
image_tensor = F.to_tensor(image).unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(image_tensor)[0] | |
boxes = [] | |
scores = [] | |
for box, score in zip(outputs["boxes"], outputs["scores"]): | |
if score >= threshold: | |
boxes.append(box.tolist()) | |
scores.append(score.item()) | |
self.last_detection = { | |
"boxes": boxes, | |
"scores": scores | |
} | |
return self.last_detection | |
def draw_boxes(self, image): | |
"""Draw the most recent detection boxes on a copy of the image""" | |
if self.last_detection is None: | |
raise ValueError("No detection results found. Run .detect(image) first.") | |
image_copy = image.copy() | |
draw = ImageDraw.Draw(image_copy) | |
for box, score in zip(self.last_detection["boxes"], self.last_detection["scores"]): | |
x0, y0, x1, y1 = box | |
draw.rectangle([x0, y0, x1, y1], outline="red", width=3) | |
draw.text((x0, y0), f"{score:.2f}", fill="red") | |
return image_copy | |