Spaces:
Running
Running
Upload 2 files
Browse files- cancel_registry.py +79 -0
- ocr_service.py +97 -10
cancel_registry.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cancel Registry for OCR Jobs
|
| 3 |
+
Manages cancellation flags for cooperative job cancellation
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
# Global registry: job_id -> asyncio.Event
|
| 9 |
+
CANCEL_FLAGS: dict[str, asyncio.Event] = {}
|
| 10 |
+
_cancel_lock = asyncio.Lock()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def new_cancel_flag(job_id: str) -> asyncio.Event:
|
| 14 |
+
"""Create a new cancellation flag for a job"""
|
| 15 |
+
async def _create():
|
| 16 |
+
async with _cancel_lock:
|
| 17 |
+
ev = asyncio.Event()
|
| 18 |
+
CANCEL_FLAGS[job_id] = ev
|
| 19 |
+
return ev
|
| 20 |
+
# Run in event loop if available, otherwise create synchronously
|
| 21 |
+
try:
|
| 22 |
+
loop = asyncio.get_event_loop()
|
| 23 |
+
if loop.is_running():
|
| 24 |
+
ev = asyncio.Event()
|
| 25 |
+
asyncio.create_task(_create_sync(job_id, ev))
|
| 26 |
+
return ev
|
| 27 |
+
else:
|
| 28 |
+
return loop.run_until_complete(_create())
|
| 29 |
+
except RuntimeError:
|
| 30 |
+
# No event loop, create directly
|
| 31 |
+
ev = asyncio.Event()
|
| 32 |
+
CANCEL_FLAGS[job_id] = ev
|
| 33 |
+
return ev
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
async def _create_sync(job_id: str, ev: asyncio.Event):
|
| 37 |
+
"""Helper to register event"""
|
| 38 |
+
async with _cancel_lock:
|
| 39 |
+
CANCEL_FLAGS[job_id] = ev
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_cancel_flag(job_id: str) -> Optional[asyncio.Event]:
|
| 43 |
+
"""Get cancellation flag for a job"""
|
| 44 |
+
return CANCEL_FLAGS.get(job_id)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def cancel_job(job_id: str) -> bool:
|
| 48 |
+
"""Cancel a job by setting its flag. Returns True if job exists."""
|
| 49 |
+
ev = CANCEL_FLAGS.get(job_id)
|
| 50 |
+
if ev:
|
| 51 |
+
ev.set()
|
| 52 |
+
return True
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
async def is_cancelled(job_id: str) -> bool:
|
| 57 |
+
"""Check if a job has been cancelled"""
|
| 58 |
+
ev = CANCEL_FLAGS.get(job_id)
|
| 59 |
+
if ev:
|
| 60 |
+
return ev.is_set()
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def remove_cancel_flag(job_id: str):
|
| 65 |
+
"""Remove cancellation flag (cleanup after job completes)"""
|
| 66 |
+
async def _remove():
|
| 67 |
+
async with _cancel_lock:
|
| 68 |
+
CANCEL_FLAGS.pop(job_id, None)
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
loop = asyncio.get_event_loop()
|
| 72 |
+
if loop.is_running():
|
| 73 |
+
asyncio.create_task(_remove())
|
| 74 |
+
else:
|
| 75 |
+
loop.run_until_complete(_remove())
|
| 76 |
+
except RuntimeError:
|
| 77 |
+
# No event loop, remove directly
|
| 78 |
+
CANCEL_FLAGS.pop(job_id, None)
|
| 79 |
+
|
ocr_service.py
CHANGED
|
@@ -133,6 +133,36 @@ _jobs: dict[str, dict] = {} # job_id -> {status, progress, result, error, cance
|
|
| 133 |
_jobs_lock = asyncio.Lock()
|
| 134 |
_cancellation_tokens: dict[str, asyncio.Event] = {} # job_id -> cancellation event
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
def _download_and_patch_model_locally(model_id: str, revision: str) -> str:
|
| 138 |
"""
|
|
@@ -338,8 +368,17 @@ async def run_deepseek_ocr(
|
|
| 338 |
# Note: We can't interrupt inference mid-process, but we can check before/after
|
| 339 |
torch = _get_torch()
|
| 340 |
with torch.inference_mode():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
# Estimate inference takes ~80% of time (10-90%)
|
| 342 |
# We'll update progress during post-processing
|
|
|
|
|
|
|
| 343 |
result = model.infer(
|
| 344 |
tokenizer,
|
| 345 |
prompt=prompt,
|
|
@@ -351,6 +390,13 @@ async def run_deepseek_ocr(
|
|
| 351 |
save_results=False,
|
| 352 |
test_compress=False,
|
| 353 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
# Check for cancellation after inference
|
| 356 |
if job_id:
|
|
@@ -402,6 +448,13 @@ async def run_deepseek_ocr(
|
|
| 402 |
if cancel_event and cancel_event.is_set():
|
| 403 |
break
|
| 404 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
# Run locator query for this field
|
| 406 |
with torch.inference_mode():
|
| 407 |
locator_result = model.infer(
|
|
@@ -416,6 +469,13 @@ async def run_deepseek_ocr(
|
|
| 416 |
test_compress=False,
|
| 417 |
)
|
| 418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
# Parse locator boxes from result
|
| 420 |
locator_text = locator_result if isinstance(locator_result, str) else str(locator_result)
|
| 421 |
locator_boxes = _parse_locator_boxes(locator_text, field_name)
|
|
@@ -988,6 +1048,7 @@ async def ocr_page(
|
|
| 988 |
_jobs[job_id]["status"] = "cancelled"
|
| 989 |
_jobs[job_id]["message"] = "Job was cancelled"
|
| 990 |
_cancellation_tokens.pop(job_id, None)
|
|
|
|
| 991 |
raise HTTPException(status_code=499, detail="Job was cancelled")
|
| 992 |
except Exception as e:
|
| 993 |
# Log the error and update job status
|
|
@@ -1120,15 +1181,37 @@ async def run_ocr_job_async(job_id: str, file: UploadFile, bus):
|
|
| 1120 |
|
| 1121 |
|
| 1122 |
@app.get("/progress/{job_id}")
|
| 1123 |
-
async def get_progress_stream(job_id: str):
|
| 1124 |
-
"""SSE stream for real-time OCR progress updates"""
|
| 1125 |
try:
|
| 1126 |
from progress_bus import bus
|
| 1127 |
except ImportError:
|
| 1128 |
raise HTTPException(status_code=503, detail="SSE streaming not available")
|
| 1129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1130 |
return StreamingResponse(
|
| 1131 |
-
|
| 1132 |
media_type="text/event-stream",
|
| 1133 |
headers={
|
| 1134 |
"Cache-Control": "no-cache",
|
|
@@ -1156,31 +1239,35 @@ async def get_job_status(job_id: str):
|
|
| 1156 |
|
| 1157 |
|
| 1158 |
@app.post("/jobs/{job_id}/cancel")
|
| 1159 |
-
async def
|
| 1160 |
-
"""Cancel a running OCR job"""
|
| 1161 |
async with _jobs_lock:
|
| 1162 |
if job_id not in _jobs:
|
| 1163 |
raise HTTPException(status_code=404, detail="Job not found")
|
| 1164 |
|
| 1165 |
job = _jobs[job_id]
|
|
|
|
|
|
|
| 1166 |
if job["status"] in ("completed", "failed", "cancelled"):
|
| 1167 |
-
return {"message": f"Job already {job['status']}"}
|
| 1168 |
|
| 1169 |
-
# Set cancellation flag
|
|
|
|
| 1170 |
if job_id in _cancellation_tokens:
|
| 1171 |
_cancellation_tokens[job_id].set()
|
| 1172 |
|
| 1173 |
-
job["status"] = "
|
| 1174 |
job["message"] = "Cancellation requested..."
|
|
|
|
| 1175 |
|
| 1176 |
# Send cancellation to SSE stream
|
| 1177 |
try:
|
| 1178 |
from progress_bus import bus
|
| 1179 |
-
await bus.error(job_id, "
|
| 1180 |
except ImportError:
|
| 1181 |
pass
|
| 1182 |
|
| 1183 |
-
return {"message": "Cancellation requested", "job_id": job_id}
|
| 1184 |
|
| 1185 |
|
| 1186 |
@app.post("/split")
|
|
|
|
| 133 |
_jobs_lock = asyncio.Lock()
|
| 134 |
_cancellation_tokens: dict[str, asyncio.Event] = {} # job_id -> cancellation event
|
| 135 |
|
| 136 |
+
# Import cancel registry
|
| 137 |
+
try:
|
| 138 |
+
from cancel_registry import cancel_job, get_cancel_flag, new_cancel_flag, remove_cancel_flag, is_cancelled
|
| 139 |
+
except ImportError:
|
| 140 |
+
# Fallback if cancel_registry not available
|
| 141 |
+
def cancel_job(job_id: str): return False
|
| 142 |
+
def get_cancel_flag(job_id: str): return _cancellation_tokens.get(job_id)
|
| 143 |
+
def new_cancel_flag(job_id: str): return _cancellation_tokens.setdefault(job_id, asyncio.Event())
|
| 144 |
+
def remove_cancel_flag(job_id: str): pass
|
| 145 |
+
async def is_cancelled(job_id: str): return False
|
| 146 |
+
|
| 147 |
+
# StoppingCriteria for generation (if transformers supports it)
|
| 148 |
+
try:
|
| 149 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
| 150 |
+
_STOPPING_CRITERIA_AVAILABLE = True
|
| 151 |
+
except ImportError:
|
| 152 |
+
_STOPPING_CRITERIA_AVAILABLE = False
|
| 153 |
+
StoppingCriteria = None
|
| 154 |
+
StoppingCriteriaList = None
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class CancelCriterion(StoppingCriteria):
|
| 158 |
+
"""Stopping criteria that checks a cancellation flag"""
|
| 159 |
+
def __init__(self, cancel_flag: asyncio.Event):
|
| 160 |
+
self.cancel_flag = cancel_flag
|
| 161 |
+
|
| 162 |
+
def __call__(self, input_ids, scores, **kwargs):
|
| 163 |
+
"""Return True to stop generation immediately"""
|
| 164 |
+
return self.cancel_flag.is_set()
|
| 165 |
+
|
| 166 |
|
| 167 |
def _download_and_patch_model_locally(model_id: str, revision: str) -> str:
|
| 168 |
"""
|
|
|
|
| 368 |
# Note: We can't interrupt inference mid-process, but we can check before/after
|
| 369 |
torch = _get_torch()
|
| 370 |
with torch.inference_mode():
|
| 371 |
+
# Check cancellation one more time right before inference (critical point)
|
| 372 |
+
if job_id:
|
| 373 |
+
async with _jobs_lock:
|
| 374 |
+
cancel_event = _cancellation_tokens.get(job_id)
|
| 375 |
+
if cancel_event and cancel_event.is_set():
|
| 376 |
+
raise asyncio.CancelledError(f"Job {job_id} was cancelled")
|
| 377 |
+
|
| 378 |
# Estimate inference takes ~80% of time (10-90%)
|
| 379 |
# We'll update progress during post-processing
|
| 380 |
+
# Note: This is a blocking call - once it starts, it runs to completion
|
| 381 |
+
# The cancellation will be checked immediately after it returns
|
| 382 |
result = model.infer(
|
| 383 |
tokenizer,
|
| 384 |
prompt=prompt,
|
|
|
|
| 390 |
save_results=False,
|
| 391 |
test_compress=False,
|
| 392 |
)
|
| 393 |
+
|
| 394 |
+
# Check cancellation immediately after inference completes
|
| 395 |
+
if job_id:
|
| 396 |
+
async with _jobs_lock:
|
| 397 |
+
cancel_event = _cancellation_tokens.get(job_id)
|
| 398 |
+
if cancel_event and cancel_event.is_set():
|
| 399 |
+
raise asyncio.CancelledError(f"Job {job_id} was cancelled during inference")
|
| 400 |
|
| 401 |
# Check for cancellation after inference
|
| 402 |
if job_id:
|
|
|
|
| 448 |
if cancel_event and cancel_event.is_set():
|
| 449 |
break
|
| 450 |
|
| 451 |
+
# Check cancellation right before each field detection
|
| 452 |
+
if job_id:
|
| 453 |
+
async with _jobs_lock:
|
| 454 |
+
cancel_event = _cancellation_tokens.get(job_id)
|
| 455 |
+
if cancel_event and cancel_event.is_set():
|
| 456 |
+
raise asyncio.CancelledError(f"Job {job_id} was cancelled during field detection")
|
| 457 |
+
|
| 458 |
# Run locator query for this field
|
| 459 |
with torch.inference_mode():
|
| 460 |
locator_result = model.infer(
|
|
|
|
| 469 |
test_compress=False,
|
| 470 |
)
|
| 471 |
|
| 472 |
+
# Check cancellation immediately after locator inference
|
| 473 |
+
if job_id:
|
| 474 |
+
async with _jobs_lock:
|
| 475 |
+
cancel_event = _cancellation_tokens.get(job_id)
|
| 476 |
+
if cancel_event and cancel_event.is_set():
|
| 477 |
+
raise asyncio.CancelledError(f"Job {job_id} was cancelled after field detection")
|
| 478 |
+
|
| 479 |
# Parse locator boxes from result
|
| 480 |
locator_text = locator_result if isinstance(locator_result, str) else str(locator_result)
|
| 481 |
locator_boxes = _parse_locator_boxes(locator_text, field_name)
|
|
|
|
| 1048 |
_jobs[job_id]["status"] = "cancelled"
|
| 1049 |
_jobs[job_id]["message"] = "Job was cancelled"
|
| 1050 |
_cancellation_tokens.pop(job_id, None)
|
| 1051 |
+
remove_cancel_flag(job_id) # Cleanup cancel registry
|
| 1052 |
raise HTTPException(status_code=499, detail="Job was cancelled")
|
| 1053 |
except Exception as e:
|
| 1054 |
# Log the error and update job status
|
|
|
|
| 1181 |
|
| 1182 |
|
| 1183 |
@app.get("/progress/{job_id}")
|
| 1184 |
+
async def get_progress_stream(job_id: str, request: Request):
|
| 1185 |
+
"""SSE stream for real-time OCR progress updates with client disconnect detection"""
|
| 1186 |
try:
|
| 1187 |
from progress_bus import bus
|
| 1188 |
except ImportError:
|
| 1189 |
raise HTTPException(status_code=503, detail="SSE streaming not available")
|
| 1190 |
|
| 1191 |
+
async def gen_with_disconnect_check():
|
| 1192 |
+
"""Generator that checks for client disconnect and auto-cancels"""
|
| 1193 |
+
try:
|
| 1194 |
+
async for event in bus.stream(job_id):
|
| 1195 |
+
# Check if client disconnected
|
| 1196 |
+
if await request.is_disconnected():
|
| 1197 |
+
# Auto-cancel job on disconnect (optional but recommended)
|
| 1198 |
+
cancel_job(job_id)
|
| 1199 |
+
if job_id in _cancellation_tokens:
|
| 1200 |
+
_cancellation_tokens[job_id].set()
|
| 1201 |
+
async with _jobs_lock:
|
| 1202 |
+
if job_id in _jobs:
|
| 1203 |
+
_jobs[job_id]["status"] = "cancelled"
|
| 1204 |
+
_jobs[job_id]["message"] = "Client disconnected"
|
| 1205 |
+
break
|
| 1206 |
+
yield event
|
| 1207 |
+
except asyncio.CancelledError:
|
| 1208 |
+
# Stream was cancelled
|
| 1209 |
+
cancel_job(job_id)
|
| 1210 |
+
if job_id in _cancellation_tokens:
|
| 1211 |
+
_cancellation_tokens[job_id].set()
|
| 1212 |
+
|
| 1213 |
return StreamingResponse(
|
| 1214 |
+
gen_with_disconnect_check(),
|
| 1215 |
media_type="text/event-stream",
|
| 1216 |
headers={
|
| 1217 |
"Cache-Control": "no-cache",
|
|
|
|
| 1239 |
|
| 1240 |
|
| 1241 |
@app.post("/jobs/{job_id}/cancel")
|
| 1242 |
+
async def cancel_job_endpoint(job_id: str):
|
| 1243 |
+
"""Cancel a running OCR job (cooperative cancellation with StoppingCriteria)"""
|
| 1244 |
async with _jobs_lock:
|
| 1245 |
if job_id not in _jobs:
|
| 1246 |
raise HTTPException(status_code=404, detail="Job not found")
|
| 1247 |
|
| 1248 |
job = _jobs[job_id]
|
| 1249 |
+
|
| 1250 |
+
# Already finished?
|
| 1251 |
if job["status"] in ("completed", "failed", "cancelled"):
|
| 1252 |
+
return {"ok": True, "message": f"Job already {job['status']}", "job_id": job_id}
|
| 1253 |
|
| 1254 |
+
# Set cancellation flag (use cancel_registry for consistency)
|
| 1255 |
+
success = cancel_job(job_id)
|
| 1256 |
if job_id in _cancellation_tokens:
|
| 1257 |
_cancellation_tokens[job_id].set()
|
| 1258 |
|
| 1259 |
+
job["status"] = "cancelled"
|
| 1260 |
job["message"] = "Cancellation requested..."
|
| 1261 |
+
job["progress"] = job.get("progress", 0.0)
|
| 1262 |
|
| 1263 |
# Send cancellation to SSE stream
|
| 1264 |
try:
|
| 1265 |
from progress_bus import bus
|
| 1266 |
+
await bus.error(job_id, "Job cancelled by user")
|
| 1267 |
except ImportError:
|
| 1268 |
pass
|
| 1269 |
|
| 1270 |
+
return {"ok": True, "message": "Cancellation requested", "job_id": job_id}
|
| 1271 |
|
| 1272 |
|
| 1273 |
@app.post("/split")
|