aheman20 commited on
Commit
fcb2da8
·
verified ·
1 Parent(s): fef1de3

courtline detector

Browse files
court_line_detector/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .court_line_detector import CourtLineDetector
court_line_detector/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (202 Bytes). View file
 
court_line_detector/__pycache__/court_line_detector.cpython-312.pyc ADDED
Binary file (3.82 kB). View file
 
court_line_detector/court_line_detector.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ import cv2
4
+ from torchvision import models
5
+ import numpy as np
6
+
7
+ class CourtLineDetector:
8
+ def __init__(self, model_path):
9
+ self.model = models.resnet50(pretrained=True)
10
+ self.model.fc = torch.nn.Linear(self.model.fc.in_features, 14*2)
11
+ self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
12
+ self.transform = transforms.Compose([
13
+ transforms.ToPILImage(),
14
+ transforms.Resize((224, 224)),
15
+ transforms.ToTensor(),
16
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
17
+ ])
18
+
19
+ def predict(self, image):
20
+
21
+
22
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
23
+ image_tensor = self.transform(image_rgb).unsqueeze(0)
24
+ with torch.no_grad():
25
+ outputs = self.model(image_tensor)
26
+ keypoints = outputs.squeeze().cpu().numpy()
27
+ original_h, original_w = image.shape[:2]
28
+ keypoints[::2] *= original_w / 224.0
29
+ keypoints[1::2] *= original_h / 224.0
30
+
31
+ return keypoints
32
+
33
+ def draw_keypoints(self, image, keypoints):
34
+ if isinstance(image, np.ndarray):
35
+ for i in range(0, len(keypoints), 2):
36
+ x, y = int(keypoints[i]), int(keypoints[i+1])
37
+ cv2.circle(image, (x, y), 5, (0, 255, 0), -1)
38
+ cv2.putText(image, str(i//2), (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
39
+ else:
40
+ print("image is not a numpy array")
41
+ return image
42
+
43
+ def draw_keypoints_on_video(self, video_frames, keypoints):
44
+ output_video_frames = []
45
+ for frame in video_frames:
46
+ frame = self.draw_keypoints(frame, keypoints)
47
+ output_video_frames.append(frame)
48
+ return output_video_frames
yolov8x.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3df4ada6b4dad6d657868f2fdf7faecfb34dcfccf3a25c4b82079064718524c8
3
+ size 136890692