Spaces:
Sleeping
Sleeping
courtline detector
Browse files
court_line_detector/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .court_line_detector import CourtLineDetector
|
court_line_detector/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (202 Bytes). View file
|
|
court_line_detector/__pycache__/court_line_detector.cpython-312.pyc
ADDED
Binary file (3.82 kB). View file
|
|
court_line_detector/court_line_detector.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
import cv2
|
4 |
+
from torchvision import models
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class CourtLineDetector:
|
8 |
+
def __init__(self, model_path):
|
9 |
+
self.model = models.resnet50(pretrained=True)
|
10 |
+
self.model.fc = torch.nn.Linear(self.model.fc.in_features, 14*2)
|
11 |
+
self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
12 |
+
self.transform = transforms.Compose([
|
13 |
+
transforms.ToPILImage(),
|
14 |
+
transforms.Resize((224, 224)),
|
15 |
+
transforms.ToTensor(),
|
16 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
17 |
+
])
|
18 |
+
|
19 |
+
def predict(self, image):
|
20 |
+
|
21 |
+
|
22 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
23 |
+
image_tensor = self.transform(image_rgb).unsqueeze(0)
|
24 |
+
with torch.no_grad():
|
25 |
+
outputs = self.model(image_tensor)
|
26 |
+
keypoints = outputs.squeeze().cpu().numpy()
|
27 |
+
original_h, original_w = image.shape[:2]
|
28 |
+
keypoints[::2] *= original_w / 224.0
|
29 |
+
keypoints[1::2] *= original_h / 224.0
|
30 |
+
|
31 |
+
return keypoints
|
32 |
+
|
33 |
+
def draw_keypoints(self, image, keypoints):
|
34 |
+
if isinstance(image, np.ndarray):
|
35 |
+
for i in range(0, len(keypoints), 2):
|
36 |
+
x, y = int(keypoints[i]), int(keypoints[i+1])
|
37 |
+
cv2.circle(image, (x, y), 5, (0, 255, 0), -1)
|
38 |
+
cv2.putText(image, str(i//2), (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
|
39 |
+
else:
|
40 |
+
print("image is not a numpy array")
|
41 |
+
return image
|
42 |
+
|
43 |
+
def draw_keypoints_on_video(self, video_frames, keypoints):
|
44 |
+
output_video_frames = []
|
45 |
+
for frame in video_frames:
|
46 |
+
frame = self.draw_keypoints(frame, keypoints)
|
47 |
+
output_video_frames.append(frame)
|
48 |
+
return output_video_frames
|
yolov8x.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3df4ada6b4dad6d657868f2fdf7faecfb34dcfccf3a25c4b82079064718524c8
|
3 |
+
size 136890692
|