Spaces:
Running
Running
Upload api.py with huggingface_hub
Browse files
api.py
CHANGED
|
@@ -15,6 +15,7 @@ import math
|
|
| 15 |
import contextlib
|
| 16 |
import requests
|
| 17 |
import unicodedata
|
|
|
|
| 18 |
from huggingface_hub import login
|
| 19 |
|
| 20 |
app = FastAPI()
|
|
@@ -207,7 +208,6 @@ async def load_resources():
|
|
| 207 |
print("🧠 ALLOCATING MEMORY FOR LOCAL ENGLISH MODEL")
|
| 208 |
print("=" * 60)
|
| 209 |
from transformers import VisionEncoderDecoderModel, TrOCRProcessor
|
| 210 |
-
import time
|
| 211 |
|
| 212 |
eng_model_path = "trocr-large-english"
|
| 213 |
if os.path.exists(eng_model_path):
|
|
@@ -371,6 +371,7 @@ def beam_search_decode(model, images, k=3, max_len=25):
|
|
| 371 |
beams = [(torch.full((1, 1), BOS_VAL, dtype=torch.long, device=device), 0.0, [])]
|
| 372 |
|
| 373 |
for step_idx in range(max_len):
|
|
|
|
| 374 |
candidates = []
|
| 375 |
for seq, score, history in beams:
|
| 376 |
# Skip beams that reached EOS
|
|
@@ -408,6 +409,17 @@ def beam_search_decode(model, images, k=3, max_len=25):
|
|
| 408 |
# Sort by cumulative score and prune to keep top K beams
|
| 409 |
beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:k]
|
| 410 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
# Stop if all surviving beams have reached EOS
|
| 412 |
if all(b[0][0, -1].item() == EOS_VAL for b in beams):
|
| 413 |
break
|
|
@@ -533,12 +545,15 @@ async def predict_ocr(file: UploadFile = File(...), lang: str = "hindi"):
|
|
| 533 |
pixel_values, debug_b64 = preprocess_english(image_bytes)
|
| 534 |
|
| 535 |
# Local Inference
|
|
|
|
| 536 |
with torch.no_grad():
|
| 537 |
generated_ids = model_eng.generate(pixel_values)
|
| 538 |
prediction = processor_eng.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
|
|
| 539 |
|
| 540 |
final_prediction = prediction
|
| 541 |
-
|
|
|
|
| 542 |
else:
|
| 543 |
if model is None: return {"error": "Hindi model not loaded"}
|
| 544 |
|
|
|
|
| 15 |
import contextlib
|
| 16 |
import requests
|
| 17 |
import unicodedata
|
| 18 |
+
import time
|
| 19 |
from huggingface_hub import login
|
| 20 |
|
| 21 |
app = FastAPI()
|
|
|
|
| 208 |
print("🧠 ALLOCATING MEMORY FOR LOCAL ENGLISH MODEL")
|
| 209 |
print("=" * 60)
|
| 210 |
from transformers import VisionEncoderDecoderModel, TrOCRProcessor
|
|
|
|
| 211 |
|
| 212 |
eng_model_path = "trocr-large-english"
|
| 213 |
if os.path.exists(eng_model_path):
|
|
|
|
| 371 |
beams = [(torch.full((1, 1), BOS_VAL, dtype=torch.long, device=device), 0.0, [])]
|
| 372 |
|
| 373 |
for step_idx in range(max_len):
|
| 374 |
+
step_start_time = time.time()
|
| 375 |
candidates = []
|
| 376 |
for seq, score, history in beams:
|
| 377 |
# Skip beams that reached EOS
|
|
|
|
| 409 |
# Sort by cumulative score and prune to keep top K beams
|
| 410 |
beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:k]
|
| 411 |
|
| 412 |
+
# Calculate step duration in seconds
|
| 413 |
+
step_duration_sec = time.time() - step_start_time
|
| 414 |
+
|
| 415 |
+
# Update history with duration for each beam
|
| 416 |
+
new_beams = []
|
| 417 |
+
for seq, score, history in beams:
|
| 418 |
+
if history:
|
| 419 |
+
history[-1]["duration_sec"] = round(step_duration_sec, 4)
|
| 420 |
+
new_beams.append((seq, score, history))
|
| 421 |
+
beams = new_beams
|
| 422 |
+
|
| 423 |
# Stop if all surviving beams have reached EOS
|
| 424 |
if all(b[0][0, -1].item() == EOS_VAL for b in beams):
|
| 425 |
break
|
|
|
|
| 545 |
pixel_values, debug_b64 = preprocess_english(image_bytes)
|
| 546 |
|
| 547 |
# Local Inference
|
| 548 |
+
start_eng = time.time()
|
| 549 |
with torch.no_grad():
|
| 550 |
generated_ids = model_eng.generate(pixel_values)
|
| 551 |
prediction = processor_eng.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 552 |
+
eng_duration_sec = time.time() - start_eng
|
| 553 |
|
| 554 |
final_prediction = prediction
|
| 555 |
+
inference_steps = [{"word": prediction, "steps": [{"step": "Total", "top_candidates": [{"char": "Full Sequence", "confidence": 1.0}], "duration_sec": round(eng_duration_sec, 3)}]}]
|
| 556 |
+
print(f"ROUTING TO '{final_lang}': Local Inference -> FINAL: '{final_prediction}' ({eng_duration_sec:.3f}s)")
|
| 557 |
else:
|
| 558 |
if model is None: return {"error": "Hindi model not loaded"}
|
| 559 |
|