Spaces:
Paused
Paused
| import os | |
| import torch | |
| import torchvision | |
| from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
| from PIL import Image, ImageDraw | |
| from torchvision.transforms import functional as F | |
| from huggingface_hub import hf_hub_download | |
| class FootDetection: | |
| def __init__(self, device="cpu"): | |
| self.device = torch.device(device) | |
| self.checkpoint_dir = "checkpoints" | |
| self.checkpoint_file = "fasterrcnn_foot.pth" | |
| self.model = self._load_model() | |
| self.last_detection = None | |
| def _load_model(self): | |
| local_path = os.path.join(self.checkpoint_dir, self.checkpoint_file) | |
| # Download if not exists | |
| if not os.path.exists(local_path): | |
| os.makedirs(self.checkpoint_dir, exist_ok=True) | |
| print("Downloading model from Hugging Face...") | |
| local_path = hf_hub_download( | |
| repo_id="tonyassi/foot-detection", | |
| filename=self.checkpoint_file, | |
| local_dir=self.checkpoint_dir | |
| ) | |
| # Load model | |
| model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT") | |
| in_features = model.roi_heads.box_predictor.cls_score.in_features | |
| model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2) | |
| model.load_state_dict(torch.load(local_path, map_location=self.device)) | |
| model.to(self.device) | |
| model.eval() | |
| return model | |
| def detect(self, image, threshold=0.1): | |
| """Run foot detection on a PIL image""" | |
| image_tensor = F.to_tensor(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(image_tensor)[0] | |
| boxes = [] | |
| scores = [] | |
| for box, score in zip(outputs["boxes"], outputs["scores"]): | |
| if score >= threshold: | |
| boxes.append(box.tolist()) | |
| scores.append(score.item()) | |
| self.last_detection = { | |
| "boxes": boxes, | |
| "scores": scores | |
| } | |
| return self.last_detection | |
| def draw_boxes(self, image): | |
| """Draw the most recent detection boxes on a copy of the image""" | |
| if self.last_detection is None: | |
| raise ValueError("No detection results found. Run .detect(image) first.") | |
| image_copy = image.copy() | |
| draw = ImageDraw.Draw(image_copy) | |
| for box, score in zip(self.last_detection["boxes"], self.last_detection["scores"]): | |
| x0, y0, x1, y1 = box | |
| draw.rectangle([x0, y0, x1, y1], outline="red", width=3) | |
| draw.text((x0, y0), f"{score:.2f}", fill="red") | |
| return image_copy | |