Shoraky commited on
Commit
490eb58
·
verified ·
1 Parent(s): 6dc2def

Initial public API deployment

Browse files
Files changed (6) hide show
  1. .dockerignore +10 -0
  2. .gitignore +12 -0
  3. Dockerfile +29 -0
  4. README.md +16 -5
  5. api.py +962 -0
  6. requirements.txt +12 -0
.dockerignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ .git/
5
+ .sporalize_runtime/
6
+ Storage/
7
+ Weights/
8
+ ViTPose/
9
+ pipeline.py
10
+ DEPLOYMENT.md
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ .sporalize_runtime/
5
+ .pytest_cache/
6
+ .mypy_cache/
7
+ .ruff_cache/
8
+ .venv/
9
+ venv/
10
+ Storage/
11
+ Weights/
12
+ ViTPose/
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1 \
6
+ PORT=7860 \
7
+ HF_HOME=/home/appuser/.cache/huggingface
8
+
9
+ RUN apt-get update && apt-get install -y --no-install-recommends \
10
+ ffmpeg \
11
+ libglib2.0-0 \
12
+ libgl1 \
13
+ libsm6 \
14
+ libxext6 \
15
+ && rm -rf /var/lib/apt/lists/*
16
+
17
+ RUN useradd --create-home --uid 1000 appuser
18
+ WORKDIR /app
19
+
20
+ COPY --chown=appuser:appuser requirements.txt /app/requirements.txt
21
+ RUN pip install --upgrade pip && pip install -r /app/requirements.txt
22
+
23
+ COPY --chown=appuser:appuser api.py /app/api.py
24
+ COPY --chown=appuser:appuser README.md /app/README.md
25
+
26
+ USER appuser
27
+ EXPOSE 7860
28
+
29
+ CMD ["sh", "-c", "uvicorn api:app --host 0.0.0.0 --port ${PORT:-7860}"]
README.md CHANGED
@@ -1,10 +1,21 @@
1
  ---
2
- title: Sporalize Api
3
- emoji: 📉
4
- colorFrom: pink
5
- colorTo: yellow
6
  sdk: docker
 
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Sporalize API
3
+ emoji:
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  ---
10
 
11
+ # Sporalize API
12
+
13
+ Public Docker Space for the Sporalize backend API.
14
+
15
+ Runtime behavior:
16
+
17
+ - Loads `pipeline.py`, `ViTPose`, and optional seeded `Storage` from a private Hugging Face repo at startup.
18
+ - Downloads model weights at startup if they are not already cached.
19
+ - Serves the FastAPI API on port `7860`.
20
+
21
+ Set the runtime secrets and variables in the Space settings as documented in `DEPLOYMENT.md`.
api.py ADDED
@@ -0,0 +1,962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import uuid
5
+ import shutil
6
+ import traceback
7
+ import re
8
+ import sys
9
+ import importlib.util
10
+ import cv2
11
+ import numpy as np
12
+ from fastapi import FastAPI, File, UploadFile, Form, Request, HTTPException
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from fastapi.concurrency import run_in_threadpool
15
+ from fastapi.staticfiles import StaticFiles
16
+ from typing import List
17
+ from huggingface_hub import hf_hub_download, snapshot_download
18
+
19
+ app = FastAPI(title="Sporalize Labs 3D Analysis Engine")
20
+
21
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
22
+
23
+
24
+ def default_runtime_root():
25
+ if os.path.isdir("/data"):
26
+ return os.path.join("/data", "sporalize_runtime")
27
+ return os.path.join(CURRENT_DIR, ".sporalize_runtime")
28
+
29
+
30
+ RUNTIME_ROOT = os.environ.get("SPORALIZE_RUNTIME_DIR", default_runtime_root())
31
+ ASSETS_RUNTIME_ROOT = os.environ.get("SPORALIZE_ASSETS_DIR", os.path.join(RUNTIME_ROOT, "assets"))
32
+ WEIGHTS_RUNTIME_ROOT = os.environ.get("SPORALIZE_WEIGHTS_DIR", os.path.join(RUNTIME_ROOT, "weights"))
33
+ DEFAULT_LOCAL_STORAGE_ROOT = os.path.join(CURRENT_DIR, "Storage")
34
+ STORAGE_ROOT = os.environ.get(
35
+ "SPORALIZE_STORAGE_DIR",
36
+ os.path.join("/data", "sporalize_storage") if os.path.isdir("/data") else DEFAULT_LOCAL_STORAGE_ROOT,
37
+ )
38
+
39
+
40
+ DEFAULT_WEIGHT_SPECS = {
41
+ "POSE_PATH": {
42
+ "filename": "vitpose-s-coco_25.onnx",
43
+ "repo_id": os.environ.get("SPORALIZE_POSE_MODEL_REPO_ID", "JunkyByte/easy_ViTPose"),
44
+ "repo_type": os.environ.get("SPORALIZE_POSE_MODEL_REPO_TYPE", "model"),
45
+ "repo_file": os.environ.get("SPORALIZE_POSE_MODEL_FILE", "onnx/coco_25/vitpose-25-s.onnx"),
46
+ "override_env": "SPORALIZE_POSE_MODEL_PATH",
47
+ "local_fallback": os.path.join(CURRENT_DIR, "Weights", "vitpose-s-coco_25.onnx"),
48
+ },
49
+ "YOLO_PATH": {
50
+ "filename": "yolov8m.pt",
51
+ "repo_id": os.environ.get("SPORALIZE_YOLO_MODEL_REPO_ID", "Ultralytics/YOLOv8"),
52
+ "repo_type": os.environ.get("SPORALIZE_YOLO_MODEL_REPO_TYPE", "model"),
53
+ "repo_file": os.environ.get("SPORALIZE_YOLO_MODEL_FILE", "yolov8m.pt"),
54
+ "override_env": "SPORALIZE_YOLO_MODEL_PATH",
55
+ "local_fallback": os.path.join(CURRENT_DIR, "Weights", "yolov8m.pt"),
56
+ },
57
+ }
58
+
59
+ runtime_state = {
60
+ "ready": False,
61
+ "pipeline_root": None,
62
+ "run_pipeline": None,
63
+ "weights": {},
64
+ }
65
+
66
+
67
+ def get_hf_token():
68
+ return os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
69
+
70
+
71
+ def path_has_session_data(directory: str):
72
+ if not os.path.isdir(directory):
73
+ return False
74
+ for _root, _dirs, files in os.walk(directory):
75
+ if "session.json" in files:
76
+ return True
77
+ return False
78
+
79
+
80
+ def seed_storage_if_needed(seed_dir: str, target_dir: str):
81
+ if not os.path.isdir(seed_dir):
82
+ return
83
+ os.makedirs(target_dir, exist_ok=True)
84
+ if path_has_session_data(target_dir):
85
+ return
86
+ shutil.copytree(seed_dir, target_dir, dirs_exist_ok=True)
87
+
88
+
89
+ def resolve_pipeline_root():
90
+ local_pipeline = os.path.join(CURRENT_DIR, "pipeline.py")
91
+ local_vitpose = os.path.join(CURRENT_DIR, "ViTPose")
92
+ if os.path.isfile(local_pipeline) and os.path.isdir(local_vitpose):
93
+ return CURRENT_DIR
94
+
95
+ repo_id = os.environ.get("SPORALIZE_ASSETS_REPO_ID")
96
+ if not repo_id:
97
+ raise RuntimeError(
98
+ "SPORALIZE_ASSETS_REPO_ID is required when Backend/pipeline.py is not bundled locally."
99
+ )
100
+
101
+ assets_dir = os.path.join(ASSETS_RUNTIME_ROOT, safe_name(repo_id))
102
+ snapshot_download(
103
+ repo_id=repo_id,
104
+ repo_type=os.environ.get("SPORALIZE_ASSETS_REPO_TYPE", "dataset"),
105
+ revision=os.environ.get("SPORALIZE_ASSETS_REVISION"),
106
+ token=get_hf_token(),
107
+ local_dir=assets_dir,
108
+ allow_patterns=["pipeline.py", "ViTPose/**", "Storage/**"],
109
+ )
110
+ seed_storage_if_needed(os.path.join(assets_dir, "Storage"), STORAGE_ROOT)
111
+ return assets_dir
112
+
113
+
114
+ def load_pipeline_callable(pipeline_root: str):
115
+ pipeline_path = os.path.join(pipeline_root, "pipeline.py")
116
+ if not os.path.isfile(pipeline_path):
117
+ raise RuntimeError(f"pipeline.py was not found at {pipeline_path}")
118
+
119
+ if pipeline_root not in sys.path:
120
+ sys.path.insert(0, pipeline_root)
121
+
122
+ module_name = "sporalize_runtime_pipeline"
123
+ if module_name in sys.modules:
124
+ del sys.modules[module_name]
125
+
126
+ spec = importlib.util.spec_from_file_location(module_name, pipeline_path)
127
+ if spec is None or spec.loader is None:
128
+ raise RuntimeError(f"Unable to create import spec for {pipeline_path}")
129
+ module = importlib.util.module_from_spec(spec)
130
+ sys.modules[module_name] = module
131
+ spec.loader.exec_module(module)
132
+
133
+ run_pipeline = getattr(module, "run_pipeline", None)
134
+ if run_pipeline is None:
135
+ raise RuntimeError("run_pipeline was not found in the resolved pipeline module")
136
+ return run_pipeline
137
+
138
+
139
+ def ensure_weight_file(spec: dict):
140
+ override_path = os.environ.get(spec["override_env"])
141
+ if override_path and os.path.isfile(override_path):
142
+ return override_path
143
+
144
+ local_fallback = spec.get("local_fallback")
145
+ if local_fallback and os.path.isfile(local_fallback):
146
+ return local_fallback
147
+
148
+ os.makedirs(WEIGHTS_RUNTIME_ROOT, exist_ok=True)
149
+ cached_path = os.path.join(WEIGHTS_RUNTIME_ROOT, spec["filename"])
150
+ if os.path.isfile(cached_path):
151
+ return cached_path
152
+
153
+ return hf_hub_download(
154
+ repo_id=spec["repo_id"],
155
+ repo_type=spec.get("repo_type", "model"),
156
+ filename=spec["repo_file"],
157
+ token=get_hf_token(),
158
+ local_dir=WEIGHTS_RUNTIME_ROOT,
159
+ )
160
+
161
+
162
+ def ensure_runtime_ready(force: bool = False):
163
+ if runtime_state["ready"] and not force:
164
+ return runtime_state
165
+
166
+ os.makedirs(RUNTIME_ROOT, exist_ok=True)
167
+ os.makedirs(STORAGE_ROOT, exist_ok=True)
168
+
169
+ pipeline_root = resolve_pipeline_root()
170
+ run_pipeline = load_pipeline_callable(pipeline_root)
171
+ weight_paths = {name: ensure_weight_file(spec) for name, spec in DEFAULT_WEIGHT_SPECS.items()}
172
+
173
+ runtime_state.update({
174
+ "ready": True,
175
+ "pipeline_root": pipeline_root,
176
+ "run_pipeline": run_pipeline,
177
+ "weights": weight_paths,
178
+ })
179
+ return runtime_state
180
+
181
+
182
+ os.makedirs(STORAGE_ROOT, exist_ok=True)
183
+ app.mount("/storage", StaticFiles(directory=STORAGE_ROOT), name="storage")
184
+
185
+ progress_store = {}
186
+
187
+ cancel_store = {}
188
+
189
+
190
+ def safe_name(value: str) -> str:
191
+ allowed = []
192
+ for ch in str(value):
193
+ if ch.isalnum() or ch in ("-", "_", "."):
194
+ allowed.append(ch)
195
+ else:
196
+ allowed.append("_")
197
+ cleaned = "".join(allowed).strip("._")
198
+ return cleaned or "item"
199
+
200
+
201
+ def session_storage_paths(player_id: str, session_id: str):
202
+ player_dir = os.path.join(STORAGE_ROOT, safe_name(player_id))
203
+ session_dir = os.path.join(player_dir, safe_name(session_id))
204
+ videos_dir = os.path.join(session_dir, "videos")
205
+ return player_dir, session_dir, videos_dir
206
+
207
+
208
+ def list_session_files():
209
+ session_files = []
210
+ for root, _, files in os.walk(STORAGE_ROOT):
211
+ if "session.json" in files:
212
+ session_files.append(os.path.join(root, "session.json"))
213
+ return sorted(session_files, key=os.path.getmtime, reverse=True)
214
+
215
+
216
+ def build_storage_url(request: Request, *parts: str) -> str:
217
+ relative = "/".join(safe_name(part) if idx < len(parts) - 1 else part.replace("\\", "/") for idx, part in enumerate(parts))
218
+ return str(request.base_url).rstrip("/") + "/storage/" + relative
219
+
220
+
221
+ def parse_video_timecode(value, fps=30.0):
222
+ if value is None:
223
+ return 0.0
224
+ if isinstance(value, (int, float, np.integer, np.floating)):
225
+ return max(0.0, float(value))
226
+
227
+ parts = str(value).split(":")
228
+ if len(parts) == 4:
229
+ h, m, s, f = [int(float(part or 0)) for part in parts]
230
+ return max(0.0, (h * 3600) + (m * 60) + s + (f / max(1.0, float(fps))))
231
+
232
+ try:
233
+ return max(0.0, float(value))
234
+ except Exception:
235
+ return 0.0
236
+
237
+
238
+ def detect_camera_id(file_name: str):
239
+ match = re.search(r"_cam_(\d+)_", file_name)
240
+ if match:
241
+ return int(match.group(1))
242
+ return None
243
+
244
+
245
+ def build_camera_video_entries(request: Request, player_id: str, session_id: str, camera_map):
246
+ return [
247
+ {
248
+ "cameraId": int(camera_id),
249
+ "url": build_storage_url(
250
+ request,
251
+ safe_name(player_id),
252
+ safe_name(session_id),
253
+ "videos",
254
+ os.path.basename(video_path),
255
+ ),
256
+ }
257
+ for camera_id, video_path in sorted(camera_map.items())
258
+ if video_path and os.path.exists(video_path)
259
+ ]
260
+
261
+
262
+ def build_action_clip_entries(request: Request, player_id: str, session_id: str, clip_map):
263
+ return [
264
+ {
265
+ "cameraId": int(camera_id),
266
+ "url": build_storage_url(
267
+ request,
268
+ safe_name(player_id),
269
+ safe_name(session_id),
270
+ "clips",
271
+ os.path.basename(clip_path),
272
+ ),
273
+ }
274
+ for camera_id, clip_path in sorted(clip_map.items())
275
+ if clip_path and os.path.exists(clip_path)
276
+ ]
277
+
278
+
279
+ def normalize_session_payload(session: dict, request: Request):
280
+ session_id = session.get("id")
281
+ player_id = session.get("playerId")
282
+ if not session_id or not player_id:
283
+ return session
284
+
285
+ session_dir = find_session_path(session_id)
286
+ if not session_dir:
287
+ return session
288
+
289
+ videos_dir = os.path.join(session_dir, "videos")
290
+ if not os.path.isdir(videos_dir):
291
+ return session
292
+
293
+ camera_map = {}
294
+ for file_name in os.listdir(videos_dir):
295
+ camera_id = detect_camera_id(file_name)
296
+ if camera_id is None:
297
+ continue
298
+ camera_map[camera_id] = os.path.join(videos_dir, file_name)
299
+
300
+ if not camera_map:
301
+ return session
302
+
303
+ normalized_actions = []
304
+ for action in session.get("actions", []):
305
+ normalized_action = dict(action)
306
+ fps = float(normalized_action.get("fps") or 30.0)
307
+ fps = max(1.0, fps)
308
+
309
+ absolute_start_frame = normalized_action.get("sourceStartFrame")
310
+ absolute_end_frame = normalized_action.get("sourceEndFrame")
311
+ if absolute_start_frame is None or absolute_end_frame is None:
312
+ absolute_start_frame = normalized_action.get("startFrame")
313
+ absolute_end_frame = normalized_action.get("endFrame")
314
+
315
+ try:
316
+ absolute_start_frame = int(absolute_start_frame) if absolute_start_frame is not None else None
317
+ absolute_end_frame = int(absolute_end_frame) if absolute_end_frame is not None else None
318
+ except Exception:
319
+ absolute_start_frame = None
320
+ absolute_end_frame = None
321
+
322
+ if absolute_start_frame is not None and absolute_end_frame is not None and absolute_end_frame >= absolute_start_frame:
323
+ start_seconds = max(0.0, absolute_start_frame / fps)
324
+ end_seconds = max(start_seconds, (absolute_end_frame + 1) / fps)
325
+ normalized_action["startFrame"] = absolute_start_frame
326
+ normalized_action["endFrame"] = absolute_end_frame
327
+ else:
328
+ total_frames = int(normalized_action.get("totalFrames") or 0)
329
+ start_seconds = parse_video_timecode(normalized_action.get("start"), fps=fps)
330
+ if total_frames > 0:
331
+ end_seconds = start_seconds + (total_frames / fps)
332
+ else:
333
+ end_seconds = max(start_seconds, parse_video_timecode(normalized_action.get("end"), fps=fps))
334
+
335
+ normalized_action["cameraClips"] = normalized_action.get("sourceCameraClips") or build_camera_video_entries(
336
+ request, player_id, session_id, camera_map
337
+ )
338
+ normalized_action["startSeconds"] = round(start_seconds, 6)
339
+ normalized_action["endSeconds"] = round(end_seconds, 6)
340
+ normalized_actions.append(normalized_action)
341
+
342
+ normalized_session = dict(session)
343
+ normalized_session["actions"] = normalized_actions
344
+ return normalized_session
345
+
346
+
347
+ def json_default(value):
348
+ if isinstance(value, np.generic):
349
+ return value.item()
350
+ if isinstance(value, np.ndarray):
351
+ return value.tolist()
352
+ raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
353
+
354
+
355
+ def export_action_clips(camera_map, clips_dir, action_index, start_frame, end_frame, fps):
356
+ os.makedirs(clips_dir, exist_ok=True)
357
+ frame_count = max(0, end_frame - start_frame + 1)
358
+ clip_paths = {}
359
+
360
+ for camera_id, video_path in sorted(camera_map.items()):
361
+ cap = cv2.VideoCapture(video_path)
362
+ if not cap.isOpened():
363
+ continue
364
+
365
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
366
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
367
+ if width <= 0 or height <= 0:
368
+ cap.release()
369
+ continue
370
+
371
+ clip_name = f"action_{action_index:02d}_cam_{camera_id}.mp4"
372
+ clip_path = os.path.join(clips_dir, clip_name)
373
+ writer = cv2.VideoWriter(
374
+ clip_path,
375
+ cv2.VideoWriter_fourcc(*"mp4v"),
376
+ max(1.0, float(fps)),
377
+ (width, height),
378
+ )
379
+
380
+ cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
381
+ written = 0
382
+ while written < frame_count:
383
+ ok, frame = cap.read()
384
+ if not ok:
385
+ break
386
+ writer.write(frame)
387
+ written += 1
388
+
389
+ writer.release()
390
+ cap.release()
391
+
392
+ if written > 0 and os.path.exists(clip_path):
393
+ clip_paths[camera_id] = clip_path
394
+ elif os.path.exists(clip_path):
395
+ os.remove(clip_path)
396
+
397
+ return clip_paths
398
+
399
+
400
+ def load_session_by_id(session_id: str):
401
+ target_name = safe_name(session_id)
402
+ for session_file in list_session_files():
403
+ session_dir = os.path.basename(os.path.dirname(session_file))
404
+ if session_dir != target_name:
405
+ continue
406
+ with open(session_file, "r", encoding="utf-8") as f:
407
+ return json.load(f)
408
+ return None
409
+
410
+
411
+ def find_session_path(session_id: str):
412
+ target_name = safe_name(session_id)
413
+ for session_file in list_session_files():
414
+ session_dir = os.path.dirname(session_file)
415
+ if os.path.basename(session_dir) == target_name:
416
+ return session_dir
417
+ return None
418
+
419
+
420
+ def player_storage_path(player_id: str):
421
+ return os.path.join(STORAGE_ROOT, safe_name(player_id))
422
+
423
+
424
+ def get_cors_origins():
425
+ configured = os.environ.get("CORS_ALLOW_ORIGINS", "*").strip()
426
+ if not configured or configured == "*":
427
+ return ["*"]
428
+ return [origin.strip() for origin in configured.split(",") if origin.strip()]
429
+
430
+
431
+ @app.on_event("startup")
432
+ def startup_event():
433
+ ensure_runtime_ready()
434
+
435
+
436
+ @app.get("/healthz")
437
+ def healthz():
438
+ runtime = ensure_runtime_ready()
439
+ return {
440
+ "status": "ok",
441
+ "storageRoot": STORAGE_ROOT,
442
+ "pipelineRoot": runtime.get("pipeline_root"),
443
+ }
444
+
445
+ @app.post("/api/cancel/{client_id}")
446
+ def cancel_processing(client_id: str):
447
+ cancel_store[client_id] = True
448
+ return {"status": "cancelled"}
449
+
450
+ @app.get("/api/progress/{client_id}")
451
+ def get_progress(client_id: str):
452
+ return progress_store.get(client_id, {"progress": 0.0, "phase": "Initializing"})
453
+
454
+
455
+ @app.get("/api/sessions/{session_id}")
456
+ def get_session(session_id: str, request: Request):
457
+ session = load_session_by_id(session_id)
458
+ if session is None:
459
+ raise HTTPException(status_code=404, detail="Session not found")
460
+ return normalize_session_payload(session, request)
461
+
462
+
463
+ @app.delete("/api/sessions/{session_id}")
464
+ def delete_session(session_id: str):
465
+ session_dir = find_session_path(session_id)
466
+ if session_dir is None:
467
+ raise HTTPException(status_code=404, detail="Session not found")
468
+
469
+ player_dir = os.path.dirname(session_dir)
470
+ shutil.rmtree(session_dir, ignore_errors=True)
471
+
472
+ if os.path.isdir(player_dir) and not os.listdir(player_dir):
473
+ os.rmdir(player_dir)
474
+
475
+ return {"status": "deleted", "sessionId": session_id}
476
+
477
+
478
+ @app.delete("/api/players/{player_id}")
479
+ def delete_player(player_id: str):
480
+ player_dir = player_storage_path(player_id)
481
+ if not os.path.isdir(player_dir):
482
+ raise HTTPException(status_code=404, detail="Player storage not found")
483
+
484
+ shutil.rmtree(player_dir, ignore_errors=True)
485
+ return {"status": "deleted", "playerId": player_id}
486
+
487
+ app.add_middleware(
488
+ CORSMiddleware,
489
+ allow_origins=get_cors_origins(),
490
+ allow_credentials=True,
491
+ allow_methods=["*"],
492
+ allow_headers=["*"],
493
+ )
494
+
495
+
496
+ def format_metric_series(name, unit, values_list):
497
+ return {
498
+ "name": name,
499
+ "unit": unit,
500
+ "values": [
501
+ {"frame": i, "value": safe_float(v)}
502
+ for i, v in enumerate(values_list)
503
+ ]
504
+ }
505
+
506
+
507
+ def safe_float(value):
508
+ try:
509
+ number = float(value)
510
+ return None if np.isnan(number) else number
511
+ except Exception:
512
+ return None
513
+
514
+
515
+ def metric_name(key: str) -> str:
516
+ return key.replace("_", " ").title()
517
+
518
+
519
+ FULL_INTERVAL_KEYS = [
520
+ "left_knee_angles",
521
+ "right_knee_angles",
522
+ "torso_pitch_angles",
523
+ "head_angles",
524
+ "mid_foot_ball_distances",
525
+ "left_right_foot_distances",
526
+ ]
527
+
528
+ ACTION_METRIC_LAYOUTS = {
529
+ "Pass": {
530
+ "pre": ["body_to_ball_angle"],
531
+ "in": [
532
+ "body_to_ball_angle",
533
+ "l_r_foot_distance",
534
+ "trunc_pitch_angle",
535
+ "trunc_roll_angle",
536
+ "left_foot_orientation_angle",
537
+ "right_foot_orientation_angle",
538
+ "difference_in_angles",
539
+ "l_knee_angle",
540
+ "r_knee_angle",
541
+ "head_angle",
542
+ "head_pitch_angle",
543
+ "head_roll_angle",
544
+ "stand_foot_angle",
545
+ "active_foot_height_pct",
546
+ ],
547
+ "post": ["head_angle", "body_to_ball_angle"],
548
+ "top_level_scalars": ["backward_weighted_angle", "forward_weighted_angle"],
549
+ },
550
+ "Shot": {
551
+ "pre": ["body_to_ball_angle"],
552
+ "in": [
553
+ "body_to_ball_angle",
554
+ "l_r_foot_distance",
555
+ "trunc_pitch_angle",
556
+ "trunc_roll_angle",
557
+ "left_foot_orientation_angle",
558
+ "right_foot_orientation_angle",
559
+ "difference_in_angles",
560
+ "l_knee_angle",
561
+ "r_knee_angle",
562
+ "head_angle",
563
+ "head_pitch_angle",
564
+ "head_roll_angle",
565
+ "stand_foot_angle",
566
+ "l_elbow_shoulder_hip_angle",
567
+ "r_elbow_shoulder_hip_angle",
568
+ "active_ankle_angle",
569
+ ],
570
+ "post": ["head_angle", "body_to_ball_angle"],
571
+ "top_level_scalars": ["backward_weighted_angle", "forward_weighted_angle"],
572
+ },
573
+ "Receive": {
574
+ "pre": ["body_orientation_vs_ball", "head_angle"],
575
+ "in": [
576
+ "head_angle",
577
+ "l_knee_angle",
578
+ "r_knee_angle",
579
+ "trunc_pitch_angle",
580
+ "trunc_roll_angle",
581
+ "left_foot_orientation_angle",
582
+ "right_foot_orientation_angle",
583
+ "difference_in_angles",
584
+ "l_r_foot_distance",
585
+ "stand_foot_angle",
586
+ "body_orientation_vs_ball",
587
+ "active_foot_height_pct",
588
+ ],
589
+ "post": ["mid_feet_ball_dist", "ball_height_pct_body"],
590
+ "top_level_scalars": [],
591
+ },
592
+ "Dribble": {
593
+ "frames": [
594
+ "ball_feet_distance",
595
+ "trunk_pitch",
596
+ "trunk_roll",
597
+ "head_angle",
598
+ "ball_possession_score",
599
+ ],
600
+ "top_level_scalars": [],
601
+ },
602
+ }
603
+
604
+
605
+ def ordered_metric_keys(observed_keys, preferred_keys=None):
606
+ preferred = [key for key in (preferred_keys or []) if key in observed_keys]
607
+ extras = sorted(key for key in observed_keys if key not in preferred)
608
+ return preferred + extras
609
+
610
+
611
+ def build_series_from_entries(entries, unit_for, skip_keys=None, preferred_keys=None):
612
+ skip = {"frame"}
613
+ if skip_keys:
614
+ skip.update(skip_keys)
615
+
616
+ metric_keys = set(preferred_keys or [])
617
+ for entry in entries:
618
+ metric_keys.update(
619
+ key for key in entry.keys()
620
+ if key not in skip
621
+ )
622
+
623
+ series = [
624
+ format_metric_series(metric_name(key), unit_for(key), [entry.get(key) for entry in entries])
625
+ for key in ordered_metric_keys(metric_keys, preferred_keys)
626
+ ]
627
+ return series
628
+
629
+
630
+ def build_scalar_metrics(payload, unit_for, skip_keys=None, preferred_keys=None):
631
+ skip = set(skip_keys or [])
632
+ metrics = []
633
+ observed_keys = set(key for key in payload.keys() if key not in skip)
634
+ observed_keys.update(key for key in (preferred_keys or []) if key not in skip)
635
+ for key in ordered_metric_keys(observed_keys, preferred_keys):
636
+ if key in skip:
637
+ continue
638
+ value = safe_float(payload.get(key))
639
+ metrics.append({
640
+ "name": metric_name(key),
641
+ "value": round(value, 3) if value is not None else None,
642
+ "unit": unit_for(key),
643
+ })
644
+ return metrics
645
+
646
+
647
+ def build_top_level_interval_metrics(analytics, unit_for, skip_keys=None, preferred_keys=None):
648
+ skip = {
649
+ "action",
650
+ "active_foot",
651
+ "touch_frame",
652
+ "pre_action",
653
+ "action_frame",
654
+ "post_action",
655
+ "frames",
656
+ }
657
+ if skip_keys:
658
+ skip.update(skip_keys)
659
+
660
+ observed_keys = set(preferred_keys or [])
661
+ observed_keys.update(analytics.keys())
662
+
663
+ series = []
664
+ for key in ordered_metric_keys(observed_keys, preferred_keys):
665
+ if key in skip:
666
+ continue
667
+ values = analytics.get(key)
668
+ if values is None:
669
+ values = []
670
+ if not isinstance(values, list):
671
+ continue
672
+ series.append(format_metric_series(metric_name(key), unit_for(key), values))
673
+ order_index = {metric_name(key): idx for idx, key in enumerate(preferred_keys or [])}
674
+ return sorted(series, key=lambda item: order_index.get(item["name"], len(order_index)))
675
+
676
+
677
+ @app.post("/api/analyze")
678
+ async def analyze_endpoint(
679
+ request: Request,
680
+ playerId: str = Form(...),
681
+ targetW: float = Form(...),
682
+ targetH: float = Form(...),
683
+ clientId: str = Form(...),
684
+ videoOrders: List[int] = Form(...),
685
+ actionsJson: UploadFile = File(...),
686
+ calibration: UploadFile = File(...),
687
+ videos: List[UploadFile] = File(...)
688
+ ):
689
+ temp_dir = None
690
+ session_id = f"session-{int(time.time())}-{uuid.uuid4().hex[:6]}"
691
+ player_dir, session_dir, videos_dir = session_storage_paths(playerId, session_id)
692
+ try:
693
+ # Clear any previous cancellation flags
694
+ cancel_store.pop(clientId, None)
695
+ progress_store[clientId] = {"progress": 2.0, "step": 0, "total": 0, "phase": "Uploading & Validating Data"}
696
+
697
+ os.makedirs(videos_dir, exist_ok=True)
698
+ temp_dir = session_dir
699
+
700
+ # 1. Store incoming payloads
701
+ actions_path = os.path.join(session_dir, "actions.json")
702
+ with open(actions_path, "wb") as f:
703
+ f.write(await actionsJson.read())
704
+
705
+ calib_path = os.path.join(session_dir, "calibration.npz")
706
+ with open(calib_path, "wb") as f:
707
+ f.write(await calibration.read())
708
+
709
+ if len(videoOrders) != len(videos):
710
+ raise ValueError("Each uploaded video must include a matching camera order")
711
+ if len(set(videoOrders)) != len(videoOrders):
712
+ raise ValueError("Camera order values must be unique")
713
+
714
+ camera_map = {}
715
+ for idx, (camera_order, video) in enumerate(zip(videoOrders, videos)):
716
+ original_name = video.filename or f"camera_{camera_order}.mp4"
717
+ video_name = f"{idx:02d}_cam_{camera_order}_{safe_name(os.path.basename(original_name))}"
718
+ vid_path = os.path.join(videos_dir, video_name)
719
+ with open(vid_path, "wb") as f:
720
+ f.write(await video.read())
721
+ camera_map[int(camera_order)] = vid_path
722
+
723
+ progress_store[clientId] = {"progress": 10.0, "step": 0, "total": 0, "phase": "Preparing AI Models"}
724
+ runtime = ensure_runtime_ready()
725
+ utils_paths = {
726
+ "POSE_PATH": runtime["weights"]["POSE_PATH"],
727
+ "YOLO_PATH": runtime["weights"]["YOLO_PATH"],
728
+ "CALIBRATION_PATH": calib_path,
729
+ "ACTIONS_PATH": actions_path
730
+ }
731
+
732
+ sizes = {
733
+ "TARGET_SIZE": (int(targetW), int(targetH)),
734
+ "YOLO_IMGSZ": 960
735
+ }
736
+
737
+ print("Starting physical pipeline execution...")
738
+ progress_store[clientId] = {"progress": 20.0, "step": 0, "total": 0, "phase": "Extracting 3D Kinematics"}
739
+
740
+ def progress_tracker(current_act, total_act, step, total_frames):
741
+ if cancel_store.get(clientId):
742
+ return False # Signal pipeline to abort
743
+
744
+ base_p = current_act / max(1, total_act)
745
+ segment_p = (step / max(1, total_frames)) * (1.0 / max(1, total_act))
746
+ # Rescale 20% to 90% for processing
747
+ pct = 20.0 + round((base_p + segment_p) * 70.0, 1)
748
+ progress_store[clientId] = {
749
+ "progress": pct,
750
+ "step": step,
751
+ "total": total_frames,
752
+ "phase": f"Processing Action {current_act + 1}/{total_act}"
753
+ }
754
+ return True
755
+
756
+ # 2. Yield to worker thread to allow concurrent polling from front-end
757
+ def execute_pipeline():
758
+ return runtime["run_pipeline"](camera_map, utils_paths, sizes, progress_tracker)
759
+
760
+ reports = await run_in_threadpool(execute_pipeline)
761
+ progress_store[clientId] = {"progress": 100.0, "step": 0, "total": 0}
762
+
763
+ raw_reports_path = os.path.join(session_dir, "raw_reports.json")
764
+ with open(raw_reports_path, "w", encoding="utf-8") as f:
765
+ json.dump(reports, f, indent=2, default=json_default)
766
+
767
+ # 3. Format output dict perfectly mapping to the Frontend Types
768
+ with open(actions_path, "r") as f:
769
+ raw_actions = json.load(f).get("actions", [])
770
+
771
+ formatted_actions = []
772
+ failed_actions = []
773
+ camera_videos = build_camera_video_entries(request, playerId, session_id, camera_map)
774
+ for i, rep in enumerate(reports):
775
+ raw = raw_actions[i] if i < len(raw_actions) else {}
776
+
777
+ if "error" in rep:
778
+ failed_actions.append({
779
+ "id": f"err-{uuid.uuid4().hex[:6]}",
780
+ "label": rep.get("action", raw.get("label", "Unknown")),
781
+ "start": raw.get("start", "00:00:00:00"),
782
+ "end": raw.get("end", "00:00:00:00"),
783
+ "error": rep["error"]
784
+ })
785
+ continue
786
+
787
+ an = rep["analytics"]
788
+ sf = rep["start_frame"]
789
+ ef = rep["end_frame"]
790
+ fps = float(rep.get("fps", 30))
791
+ is_dribble = (an.get("action") == "Dribble")
792
+
793
+ if is_dribble:
794
+ tf = (sf + ef) // 2
795
+ else:
796
+ tf = an.get("touch_frame", (sf + ef) // 2)
797
+
798
+ # --- Skeleton: support full COCO-25/WB joint range (0–32) ---
799
+ skeleton_frames = []
800
+ raw_ball_history = rep.get("ball_history", {})
801
+ for f_idx in range(sf, ef + 1):
802
+ raw_skel = rep["skel_history"].get(f_idx, {})
803
+ raw_ball = raw_ball_history.get(f_idx)
804
+ # Find the max joint index present so we don't truncate
805
+ max_joint = max(raw_skel.keys()) if raw_skel else 32
806
+ n_joints = max(33, max_joint + 1)
807
+ joints = []
808
+ for j in range(n_joints):
809
+ pt = raw_skel.get(j)
810
+ if pt is not None:
811
+ joints.append([float(pt[0]), float(pt[1]), float(pt[2])])
812
+ else:
813
+ joints.append([0.0, 0.0, 0.0])
814
+ frame_payload = {"frame": f_idx - sf, "joints": joints}
815
+ if raw_ball is not None:
816
+ frame_payload["ball"] = [float(raw_ball[0]), float(raw_ball[1]), float(raw_ball[2])]
817
+ skeleton_frames.append(frame_payload)
818
+
819
+ # --- Unit dictionary for known metric names ---
820
+ UNITS = {
821
+ "head_angle": "°", "l_knee_angle": "°", "r_knee_angle": "°",
822
+ "trunc_pitch_angle": "°", "trunc_roll_angle": "°",
823
+ "trunk_pitch": "°", "trunk_roll": "°",
824
+ "head_pitch_angle": "°", "head_roll_angle": "°",
825
+ "left_foot_orientation_angle": "°", "right_foot_orientation_angle": "°",
826
+ "difference_in_angles": "°", "body_to_ball_angle": "°",
827
+ "body_orientation_vs_ball": "°", "stand_foot_angle": "°",
828
+ "active_ankle_angle": "°", "l_elbow_shoulder_hip_angle": "°",
829
+ "r_elbow_shoulder_hip_angle": "°", "backward_weighted_angle": "°",
830
+ "forward_weighted_angle": "°", "leg_separation_angle": "°",
831
+ "l_r_foot_distance": "cm", "l_foot_ball_distance": "cm",
832
+ "r_foot_ball_distance": "cm", "mid_feet_ball_dist": "cm",
833
+ "active_foot_height_pct": "%", "ball_height_pct_body": "%",
834
+ "ball_possession_score": "%", "ball_feet_distance": "cm",
835
+ }
836
+
837
+ def unit_for(key):
838
+ return UNITS.get(key, "")
839
+
840
+ action_layout = ACTION_METRIC_LAYOUTS.get(rep["action"], {})
841
+
842
+ if is_dribble:
843
+ dribble_frames = an.get("frames", [])
844
+ pre_metrics = build_series_from_entries(
845
+ dribble_frames,
846
+ unit_for,
847
+ preferred_keys=action_layout.get("frames"),
848
+ )
849
+ in_action_metrics = []
850
+ frame_metric_keys = ordered_metric_keys(
851
+ {k for frame in dribble_frames for k in frame.keys() if k != "frame"},
852
+ action_layout.get("frames"),
853
+ )
854
+ for key in frame_metric_keys:
855
+ numeric_values = [safe_float(frame.get(key)) for frame in dribble_frames]
856
+ numeric_values = [value for value in numeric_values if value is not None]
857
+ in_action_metrics.append({
858
+ "name": f"Avg {metric_name(key)}",
859
+ "value": round(float(np.mean(numeric_values)), 3) if numeric_values else None,
860
+ "unit": unit_for(key),
861
+ })
862
+ post_metrics = []
863
+ else:
864
+ pre_entries = an.get("pre_action", [])
865
+ post_entries = an.get("post_action", [])
866
+ action_frame_data = an.get("action_frame", {})
867
+ pre_metrics = build_series_from_entries(
868
+ pre_entries,
869
+ unit_for,
870
+ preferred_keys=action_layout.get("pre"),
871
+ )
872
+ in_action_metrics = build_scalar_metrics(
873
+ action_frame_data,
874
+ unit_for,
875
+ skip_keys={"active_foot"},
876
+ preferred_keys=action_layout.get("in"),
877
+ )
878
+ in_action_metrics.extend(
879
+ build_scalar_metrics(
880
+ an,
881
+ unit_for,
882
+ skip_keys={
883
+ "action",
884
+ "active_foot",
885
+ "touch_frame",
886
+ "pre_action",
887
+ "action_frame",
888
+ "post_action",
889
+ "frames",
890
+ "left_knee_angles",
891
+ "right_knee_angles",
892
+ "torso_pitch_angles",
893
+ "head_angles",
894
+ "mid_foot_ball_distances",
895
+ "left_right_foot_distances",
896
+ },
897
+ preferred_keys=action_layout.get("top_level_scalars"),
898
+ )
899
+ )
900
+ post_metrics = build_series_from_entries(
901
+ post_entries,
902
+ unit_for,
903
+ preferred_keys=action_layout.get("post"),
904
+ )
905
+
906
+ full_interval_metrics = build_top_level_interval_metrics(
907
+ an,
908
+ unit_for,
909
+ preferred_keys=FULL_INTERVAL_KEYS,
910
+ )
911
+
912
+ formatted_actions.append({
913
+ "id": f"{rep['action'].lower()}-{uuid.uuid4().hex[:6]}",
914
+ "label": rep["action"],
915
+ "start": raw.get("start", "00:00:00:00"),
916
+ "end": raw.get("end", "00:00:00:00"),
917
+ "fps": fps,
918
+ "startFrame": sf,
919
+ "endFrame": ef,
920
+ "startSeconds": max(0.0, sf / max(1.0, fps)),
921
+ "endSeconds": max(0.0, (ef + 1) / max(1.0, fps)),
922
+ "totalFrames": ef - sf + 1,
923
+ "preFrames": tf - sf,
924
+ "inFrame": tf - sf,
925
+ "postFrames": ef - tf,
926
+ "cameraClips": camera_videos,
927
+ "preMetrics": pre_metrics,
928
+ "inActionMetrics": in_action_metrics,
929
+ "postMetrics": post_metrics,
930
+ "fullIntervalMetrics": full_interval_metrics,
931
+ "skeleton": skeleton_frames,
932
+ "rawAnalytics": an,
933
+ })
934
+
935
+ print("Pipeline successful. Yielding payload payload.")
936
+ response_payload = {
937
+ "id": session_id,
938
+ "playerId": playerId,
939
+ "createdAt": int(time.time() * 1000),
940
+ "targetSize": [int(targetW), int(targetH)],
941
+ "cameraCount": len(camera_map),
942
+ "actions": formatted_actions,
943
+ "failedActions": failed_actions
944
+ }
945
+
946
+ session_json_path = os.path.join(session_dir, "session.json")
947
+ with open(session_json_path, "w", encoding="utf-8") as f:
948
+ json.dump(response_payload, f, indent=2, default=json_default)
949
+
950
+ return response_payload
951
+
952
+ except Exception as e:
953
+ print("--- PIPELINE ERROR ---")
954
+ traceback.print_exc()
955
+ if temp_dir and os.path.isdir(temp_dir):
956
+ shutil.rmtree(temp_dir, ignore_errors=True)
957
+ return {"error": str(e)}
958
+
959
+ if __name__ == "__main__":
960
+ import uvicorn
961
+ # Start ASGI interface natively mapping locally to the React vite environment
962
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8000")))
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.100.0
2
+ uvicorn>=0.23.0
3
+ python-multipart>=0.0.6
4
+ huggingface_hub>=0.34.0
5
+ numpy
6
+ scipy
7
+ opencv-python
8
+ onnxruntime
9
+ supervision
10
+ torch
11
+ ultralytics
12
+ plotly