Spaces:
Running
Running
File size: 2,630 Bytes
d86aec5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from PIL import Image, ImageDraw
from torchvision.transforms import functional as F
from huggingface_hub import hf_hub_download
class FootDetection:
def __init__(self, device="cpu"):
self.device = torch.device(device)
self.checkpoint_dir = "checkpoints"
self.checkpoint_file = "fasterrcnn_foot.pth"
self.model = self._load_model()
self.last_detection = None
def _load_model(self):
local_path = os.path.join(self.checkpoint_dir, self.checkpoint_file)
# Download if not exists
if not os.path.exists(local_path):
os.makedirs(self.checkpoint_dir, exist_ok=True)
print("Downloading model from Hugging Face...")
local_path = hf_hub_download(
repo_id="tonyassi/foot-detection",
filename=self.checkpoint_file,
local_dir=self.checkpoint_dir
)
# Load model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
model.load_state_dict(torch.load(local_path, map_location=self.device))
model.to(self.device)
model.eval()
return model
def detect(self, image, threshold=0.1):
"""Run foot detection on a PIL image"""
image_tensor = F.to_tensor(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(image_tensor)[0]
boxes = []
scores = []
for box, score in zip(outputs["boxes"], outputs["scores"]):
if score >= threshold:
boxes.append(box.tolist())
scores.append(score.item())
self.last_detection = {
"boxes": boxes,
"scores": scores
}
return self.last_detection
def draw_boxes(self, image):
"""Draw the most recent detection boxes on a copy of the image"""
if self.last_detection is None:
raise ValueError("No detection results found. Run .detect(image) first.")
image_copy = image.copy()
draw = ImageDraw.Draw(image_copy)
for box, score in zip(self.last_detection["boxes"], self.last_detection["scores"]):
x0, y0, x1, y1 = box
draw.rectangle([x0, y0, x1, y1], outline="red", width=3)
draw.text((x0, y0), f"{score:.2f}", fill="red")
return image_copy
|