make789 commited on
Commit
42834e0
·
verified ·
1 Parent(s): 1e4fbbe

Upload 2 files

Browse files
Files changed (2) hide show
  1. cancel_registry.py +79 -0
  2. ocr_service.py +97 -10
cancel_registry.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cancel Registry for OCR Jobs
3
+ Manages cancellation flags for cooperative job cancellation
4
+ """
5
+ import asyncio
6
+ from typing import Optional
7
+
8
+ # Global registry: job_id -> asyncio.Event
9
+ CANCEL_FLAGS: dict[str, asyncio.Event] = {}
10
+ _cancel_lock = asyncio.Lock()
11
+
12
+
13
+ def new_cancel_flag(job_id: str) -> asyncio.Event:
14
+ """Create a new cancellation flag for a job"""
15
+ async def _create():
16
+ async with _cancel_lock:
17
+ ev = asyncio.Event()
18
+ CANCEL_FLAGS[job_id] = ev
19
+ return ev
20
+ # Run in event loop if available, otherwise create synchronously
21
+ try:
22
+ loop = asyncio.get_event_loop()
23
+ if loop.is_running():
24
+ ev = asyncio.Event()
25
+ asyncio.create_task(_create_sync(job_id, ev))
26
+ return ev
27
+ else:
28
+ return loop.run_until_complete(_create())
29
+ except RuntimeError:
30
+ # No event loop, create directly
31
+ ev = asyncio.Event()
32
+ CANCEL_FLAGS[job_id] = ev
33
+ return ev
34
+
35
+
36
+ async def _create_sync(job_id: str, ev: asyncio.Event):
37
+ """Helper to register event"""
38
+ async with _cancel_lock:
39
+ CANCEL_FLAGS[job_id] = ev
40
+
41
+
42
+ def get_cancel_flag(job_id: str) -> Optional[asyncio.Event]:
43
+ """Get cancellation flag for a job"""
44
+ return CANCEL_FLAGS.get(job_id)
45
+
46
+
47
+ def cancel_job(job_id: str) -> bool:
48
+ """Cancel a job by setting its flag. Returns True if job exists."""
49
+ ev = CANCEL_FLAGS.get(job_id)
50
+ if ev:
51
+ ev.set()
52
+ return True
53
+ return False
54
+
55
+
56
+ async def is_cancelled(job_id: str) -> bool:
57
+ """Check if a job has been cancelled"""
58
+ ev = CANCEL_FLAGS.get(job_id)
59
+ if ev:
60
+ return ev.is_set()
61
+ return False
62
+
63
+
64
+ def remove_cancel_flag(job_id: str):
65
+ """Remove cancellation flag (cleanup after job completes)"""
66
+ async def _remove():
67
+ async with _cancel_lock:
68
+ CANCEL_FLAGS.pop(job_id, None)
69
+
70
+ try:
71
+ loop = asyncio.get_event_loop()
72
+ if loop.is_running():
73
+ asyncio.create_task(_remove())
74
+ else:
75
+ loop.run_until_complete(_remove())
76
+ except RuntimeError:
77
+ # No event loop, remove directly
78
+ CANCEL_FLAGS.pop(job_id, None)
79
+
ocr_service.py CHANGED
@@ -133,6 +133,36 @@ _jobs: dict[str, dict] = {} # job_id -> {status, progress, result, error, cance
133
  _jobs_lock = asyncio.Lock()
134
  _cancellation_tokens: dict[str, asyncio.Event] = {} # job_id -> cancellation event
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  def _download_and_patch_model_locally(model_id: str, revision: str) -> str:
138
  """
@@ -338,8 +368,17 @@ async def run_deepseek_ocr(
338
  # Note: We can't interrupt inference mid-process, but we can check before/after
339
  torch = _get_torch()
340
  with torch.inference_mode():
 
 
 
 
 
 
 
341
  # Estimate inference takes ~80% of time (10-90%)
342
  # We'll update progress during post-processing
 
 
343
  result = model.infer(
344
  tokenizer,
345
  prompt=prompt,
@@ -351,6 +390,13 @@ async def run_deepseek_ocr(
351
  save_results=False,
352
  test_compress=False,
353
  )
 
 
 
 
 
 
 
354
 
355
  # Check for cancellation after inference
356
  if job_id:
@@ -402,6 +448,13 @@ async def run_deepseek_ocr(
402
  if cancel_event and cancel_event.is_set():
403
  break
404
 
 
 
 
 
 
 
 
405
  # Run locator query for this field
406
  with torch.inference_mode():
407
  locator_result = model.infer(
@@ -416,6 +469,13 @@ async def run_deepseek_ocr(
416
  test_compress=False,
417
  )
418
 
 
 
 
 
 
 
 
419
  # Parse locator boxes from result
420
  locator_text = locator_result if isinstance(locator_result, str) else str(locator_result)
421
  locator_boxes = _parse_locator_boxes(locator_text, field_name)
@@ -988,6 +1048,7 @@ async def ocr_page(
988
  _jobs[job_id]["status"] = "cancelled"
989
  _jobs[job_id]["message"] = "Job was cancelled"
990
  _cancellation_tokens.pop(job_id, None)
 
991
  raise HTTPException(status_code=499, detail="Job was cancelled")
992
  except Exception as e:
993
  # Log the error and update job status
@@ -1120,15 +1181,37 @@ async def run_ocr_job_async(job_id: str, file: UploadFile, bus):
1120
 
1121
 
1122
  @app.get("/progress/{job_id}")
1123
- async def get_progress_stream(job_id: str):
1124
- """SSE stream for real-time OCR progress updates"""
1125
  try:
1126
  from progress_bus import bus
1127
  except ImportError:
1128
  raise HTTPException(status_code=503, detail="SSE streaming not available")
1129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1130
  return StreamingResponse(
1131
- bus.stream(job_id),
1132
  media_type="text/event-stream",
1133
  headers={
1134
  "Cache-Control": "no-cache",
@@ -1156,31 +1239,35 @@ async def get_job_status(job_id: str):
1156
 
1157
 
1158
  @app.post("/jobs/{job_id}/cancel")
1159
- async def cancel_job(job_id: str):
1160
- """Cancel a running OCR job"""
1161
  async with _jobs_lock:
1162
  if job_id not in _jobs:
1163
  raise HTTPException(status_code=404, detail="Job not found")
1164
 
1165
  job = _jobs[job_id]
 
 
1166
  if job["status"] in ("completed", "failed", "cancelled"):
1167
- return {"message": f"Job already {job['status']}"}
1168
 
1169
- # Set cancellation flag
 
1170
  if job_id in _cancellation_tokens:
1171
  _cancellation_tokens[job_id].set()
1172
 
1173
- job["status"] = "cancelling"
1174
  job["message"] = "Cancellation requested..."
 
1175
 
1176
  # Send cancellation to SSE stream
1177
  try:
1178
  from progress_bus import bus
1179
- await bus.error(job_id, "Cancellation requested")
1180
  except ImportError:
1181
  pass
1182
 
1183
- return {"message": "Cancellation requested", "job_id": job_id}
1184
 
1185
 
1186
  @app.post("/split")
 
133
  _jobs_lock = asyncio.Lock()
134
  _cancellation_tokens: dict[str, asyncio.Event] = {} # job_id -> cancellation event
135
 
136
+ # Import cancel registry
137
+ try:
138
+ from cancel_registry import cancel_job, get_cancel_flag, new_cancel_flag, remove_cancel_flag, is_cancelled
139
+ except ImportError:
140
+ # Fallback if cancel_registry not available
141
+ def cancel_job(job_id: str): return False
142
+ def get_cancel_flag(job_id: str): return _cancellation_tokens.get(job_id)
143
+ def new_cancel_flag(job_id: str): return _cancellation_tokens.setdefault(job_id, asyncio.Event())
144
+ def remove_cancel_flag(job_id: str): pass
145
+ async def is_cancelled(job_id: str): return False
146
+
147
+ # StoppingCriteria for generation (if transformers supports it)
148
+ try:
149
+ from transformers import StoppingCriteria, StoppingCriteriaList
150
+ _STOPPING_CRITERIA_AVAILABLE = True
151
+ except ImportError:
152
+ _STOPPING_CRITERIA_AVAILABLE = False
153
+ StoppingCriteria = None
154
+ StoppingCriteriaList = None
155
+
156
+
157
+ class CancelCriterion(StoppingCriteria):
158
+ """Stopping criteria that checks a cancellation flag"""
159
+ def __init__(self, cancel_flag: asyncio.Event):
160
+ self.cancel_flag = cancel_flag
161
+
162
+ def __call__(self, input_ids, scores, **kwargs):
163
+ """Return True to stop generation immediately"""
164
+ return self.cancel_flag.is_set()
165
+
166
 
167
  def _download_and_patch_model_locally(model_id: str, revision: str) -> str:
168
  """
 
368
  # Note: We can't interrupt inference mid-process, but we can check before/after
369
  torch = _get_torch()
370
  with torch.inference_mode():
371
+ # Check cancellation one more time right before inference (critical point)
372
+ if job_id:
373
+ async with _jobs_lock:
374
+ cancel_event = _cancellation_tokens.get(job_id)
375
+ if cancel_event and cancel_event.is_set():
376
+ raise asyncio.CancelledError(f"Job {job_id} was cancelled")
377
+
378
  # Estimate inference takes ~80% of time (10-90%)
379
  # We'll update progress during post-processing
380
+ # Note: This is a blocking call - once it starts, it runs to completion
381
+ # The cancellation will be checked immediately after it returns
382
  result = model.infer(
383
  tokenizer,
384
  prompt=prompt,
 
390
  save_results=False,
391
  test_compress=False,
392
  )
393
+
394
+ # Check cancellation immediately after inference completes
395
+ if job_id:
396
+ async with _jobs_lock:
397
+ cancel_event = _cancellation_tokens.get(job_id)
398
+ if cancel_event and cancel_event.is_set():
399
+ raise asyncio.CancelledError(f"Job {job_id} was cancelled during inference")
400
 
401
  # Check for cancellation after inference
402
  if job_id:
 
448
  if cancel_event and cancel_event.is_set():
449
  break
450
 
451
+ # Check cancellation right before each field detection
452
+ if job_id:
453
+ async with _jobs_lock:
454
+ cancel_event = _cancellation_tokens.get(job_id)
455
+ if cancel_event and cancel_event.is_set():
456
+ raise asyncio.CancelledError(f"Job {job_id} was cancelled during field detection")
457
+
458
  # Run locator query for this field
459
  with torch.inference_mode():
460
  locator_result = model.infer(
 
469
  test_compress=False,
470
  )
471
 
472
+ # Check cancellation immediately after locator inference
473
+ if job_id:
474
+ async with _jobs_lock:
475
+ cancel_event = _cancellation_tokens.get(job_id)
476
+ if cancel_event and cancel_event.is_set():
477
+ raise asyncio.CancelledError(f"Job {job_id} was cancelled after field detection")
478
+
479
  # Parse locator boxes from result
480
  locator_text = locator_result if isinstance(locator_result, str) else str(locator_result)
481
  locator_boxes = _parse_locator_boxes(locator_text, field_name)
 
1048
  _jobs[job_id]["status"] = "cancelled"
1049
  _jobs[job_id]["message"] = "Job was cancelled"
1050
  _cancellation_tokens.pop(job_id, None)
1051
+ remove_cancel_flag(job_id) # Cleanup cancel registry
1052
  raise HTTPException(status_code=499, detail="Job was cancelled")
1053
  except Exception as e:
1054
  # Log the error and update job status
 
1181
 
1182
 
1183
  @app.get("/progress/{job_id}")
1184
+ async def get_progress_stream(job_id: str, request: Request):
1185
+ """SSE stream for real-time OCR progress updates with client disconnect detection"""
1186
  try:
1187
  from progress_bus import bus
1188
  except ImportError:
1189
  raise HTTPException(status_code=503, detail="SSE streaming not available")
1190
 
1191
+ async def gen_with_disconnect_check():
1192
+ """Generator that checks for client disconnect and auto-cancels"""
1193
+ try:
1194
+ async for event in bus.stream(job_id):
1195
+ # Check if client disconnected
1196
+ if await request.is_disconnected():
1197
+ # Auto-cancel job on disconnect (optional but recommended)
1198
+ cancel_job(job_id)
1199
+ if job_id in _cancellation_tokens:
1200
+ _cancellation_tokens[job_id].set()
1201
+ async with _jobs_lock:
1202
+ if job_id in _jobs:
1203
+ _jobs[job_id]["status"] = "cancelled"
1204
+ _jobs[job_id]["message"] = "Client disconnected"
1205
+ break
1206
+ yield event
1207
+ except asyncio.CancelledError:
1208
+ # Stream was cancelled
1209
+ cancel_job(job_id)
1210
+ if job_id in _cancellation_tokens:
1211
+ _cancellation_tokens[job_id].set()
1212
+
1213
  return StreamingResponse(
1214
+ gen_with_disconnect_check(),
1215
  media_type="text/event-stream",
1216
  headers={
1217
  "Cache-Control": "no-cache",
 
1239
 
1240
 
1241
  @app.post("/jobs/{job_id}/cancel")
1242
+ async def cancel_job_endpoint(job_id: str):
1243
+ """Cancel a running OCR job (cooperative cancellation with StoppingCriteria)"""
1244
  async with _jobs_lock:
1245
  if job_id not in _jobs:
1246
  raise HTTPException(status_code=404, detail="Job not found")
1247
 
1248
  job = _jobs[job_id]
1249
+
1250
+ # Already finished?
1251
  if job["status"] in ("completed", "failed", "cancelled"):
1252
+ return {"ok": True, "message": f"Job already {job['status']}", "job_id": job_id}
1253
 
1254
+ # Set cancellation flag (use cancel_registry for consistency)
1255
+ success = cancel_job(job_id)
1256
  if job_id in _cancellation_tokens:
1257
  _cancellation_tokens[job_id].set()
1258
 
1259
+ job["status"] = "cancelled"
1260
  job["message"] = "Cancellation requested..."
1261
+ job["progress"] = job.get("progress", 0.0)
1262
 
1263
  # Send cancellation to SSE stream
1264
  try:
1265
  from progress_bus import bus
1266
+ await bus.error(job_id, "Job cancelled by user")
1267
  except ImportError:
1268
  pass
1269
 
1270
+ return {"ok": True, "message": "Cancellation requested", "job_id": job_id}
1271
 
1272
 
1273
  @app.post("/split")