Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -108,7 +108,7 @@ else:
|
|
| 108 |
|
| 109 |
|
| 110 |
@app.get("/")
|
| 111 |
-
def root():
|
| 112 |
return {
|
| 113 |
"success": True,
|
| 114 |
"message": "Use POST /generate-caption with form-data key 'file' or 'files' (up to 5 images).",
|
|
@@ -117,7 +117,7 @@ def root():
|
|
| 117 |
|
| 118 |
|
| 119 |
@app.get("/health")
|
| 120 |
-
def health():
|
| 121 |
if db_init_error:
|
| 122 |
return {
|
| 123 |
"success": False,
|
|
@@ -304,7 +304,10 @@ def generate_caption_text(image: Image.Image, prompt: str = CAPTION_PROMPT) -> s
|
|
| 304 |
except Exception as exc:
|
| 305 |
if "Mismatch in `image` token count" not in str(exc):
|
| 306 |
raise AppError("Failed to preprocess image for captioning.", 422) from exc
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
inputs = {k: v.to(model_device) for k, v in inputs.items()}
|
| 310 |
|
|
@@ -313,13 +316,15 @@ def generate_caption_text(image: Image.Image, prompt: str = CAPTION_PROMPT) -> s
|
|
| 313 |
outputs = runtime_model.generate(
|
| 314 |
**inputs,
|
| 315 |
max_new_tokens=MAX_NEW_TOKENS,
|
| 316 |
-
do_sample=
|
| 317 |
-
|
|
|
|
|
|
|
| 318 |
)
|
|
|
|
| 319 |
except Exception as exc:
|
| 320 |
raise AppError("Caption generation failed.", 500) from exc
|
| 321 |
|
| 322 |
-
decoded = runtime_processor.decode(outputs[0], skip_special_tokens=True).strip()
|
| 323 |
caption = decoded.split("assistant")[-1].lstrip(":\n ").strip()
|
| 324 |
return _finalize_caption(caption)
|
| 325 |
|
|
@@ -383,7 +388,11 @@ async def _parse_images(request: Request) -> list[tuple[str, Image.Image]]:
|
|
| 383 |
if upload.content_type and not upload.content_type.startswith("image/"):
|
| 384 |
raise AppError("All uploaded files must be images.", 400)
|
| 385 |
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
if not file_bytes:
|
| 388 |
raise AppError("One of the uploaded images is empty.", 400)
|
| 389 |
|
|
@@ -400,22 +409,35 @@ async def _parse_images(request: Request) -> list[tuple[str, Image.Image]]:
|
|
| 400 |
return parsed_images
|
| 401 |
|
| 402 |
|
| 403 |
-
@app.post("/generate-caption
|
| 404 |
-
async def
|
| 405 |
_ensure_db_ready()
|
| 406 |
images = await _parse_images(request)
|
| 407 |
|
| 408 |
image_captions = []
|
| 409 |
for filename, image in images:
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
caption_texts = [x["caption"] for x in image_captions]
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
mongo_payload = {
|
| 421 |
"caption": caption,
|
|
@@ -426,523 +448,16 @@ async def generate_caption_summary(request: Request):
|
|
| 426 |
"created_at": datetime.now(timezone.utc),
|
| 427 |
}
|
| 428 |
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
response_data = {**mongo_payload, "audio_file_id": audio_file_id}
|
| 432 |
response_data.pop("_id", None) # Remove ObjectId as it is not JSON serializable
|
| 433 |
response_data["created_at"] = response_data["created_at"].isoformat()
|
| 434 |
|
| 435 |
-
return ok("Caption generated successfully.", response_data)
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
@app.post("/generate-caption-collage")
|
| 439 |
-
async def generate_caption_collage(request: Request):
|
| 440 |
-
_ensure_db_ready()
|
| 441 |
-
images = await _parse_images(request)
|
| 442 |
-
|
| 443 |
-
# Create collage (horizontal strip, resized to height 512 for consistency)
|
| 444 |
-
resized_images = []
|
| 445 |
-
target_height = 512
|
| 446 |
-
for _, img in images:
|
| 447 |
-
aspect_ratio = img.width / img.height
|
| 448 |
-
new_width = int(target_height * aspect_ratio)
|
| 449 |
-
resized_images.append(img.resize((new_width, target_height), Image.Resampling.LANCZOS))
|
| 450 |
-
|
| 451 |
-
total_width = sum(img.width for img in resized_images)
|
| 452 |
-
collage = Image.new("RGB", (total_width, target_height))
|
| 453 |
-
x_offset = 0
|
| 454 |
-
for img in resized_images:
|
| 455 |
-
collage.paste(img, (x_offset, 0))
|
| 456 |
-
x_offset += img.width
|
| 457 |
-
|
| 458 |
-
caption = generate_caption_text_safe(collage)
|
| 459 |
-
if not caption:
|
| 460 |
-
raise AppError("Collage caption generation produced empty text.", 500)
|
| 461 |
-
|
| 462 |
-
# For database storage, we list source filenames but the 'image_captions'
|
| 463 |
-
# will just contain the single collage caption to avoid confusion.
|
| 464 |
-
source_filenames = [fname for fname, _ in images]
|
| 465 |
-
|
| 466 |
-
mongo_payload = {
|
| 467 |
-
"caption": caption,
|
| 468 |
-
"source_filenames": source_filenames,
|
| 469 |
-
"image_captions": [{"filename": "collage", "caption": caption}],
|
| 470 |
-
"images_count": len(images),
|
| 471 |
-
"is_summarized": False, # It's a direct caption of a collage
|
| 472 |
-
"created_at": datetime.now(timezone.utc),
|
| 473 |
-
}
|
| 474 |
-
|
| 475 |
-
audio_file_id = insert_record(caption_collection, mongo_payload)
|
| 476 |
-
|
| 477 |
-
response_data = {**mongo_payload, "audio_file_id": audio_file_id}
|
| 478 |
-
response_data.pop("_id", None)
|
| 479 |
-
response_data["created_at"] = response_data["created_at"].isoformat()
|
| 480 |
-
|
| 481 |
-
return ok("Collage caption generated successfully.", response_data)
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
@app.post("/generate-caption-context")
|
| 485 |
-
async def generate_caption_context(request: Request):
|
| 486 |
-
_ensure_db_ready()
|
| 487 |
-
images = await _parse_images(request)
|
| 488 |
-
|
| 489 |
-
image_captions = []
|
| 490 |
-
previous_context = ""
|
| 491 |
-
|
| 492 |
-
for i, (filename, image) in enumerate(images):
|
| 493 |
-
prompt = CAPTION_PROMPT
|
| 494 |
-
if i > 0 and previous_context:
|
| 495 |
-
prompt = f"Context from previous image: {previous_context}. {CAPTION_PROMPT}"
|
| 496 |
-
|
| 497 |
-
caption = generate_caption_text_safe(image, prompt=prompt)
|
| 498 |
-
if not caption:
|
| 499 |
-
caption = "No caption generated."
|
| 500 |
-
|
| 501 |
-
image_captions.append({"filename": filename, "caption": caption})
|
| 502 |
-
previous_context = caption
|
| 503 |
-
|
| 504 |
-
# Combine captions for the main 'caption' field
|
| 505 |
-
full_text = " ".join([ic["caption"] for ic in image_captions])
|
| 506 |
-
|
| 507 |
-
mongo_payload = {
|
| 508 |
-
"caption": full_text,
|
| 509 |
-
"source_filenames": [fname for fname, _ in images],
|
| 510 |
-
"image_captions": image_captions,
|
| 511 |
-
"images_count": len(images),
|
| 512 |
-
"is_summarized": False,
|
| 513 |
-
"created_at": datetime.now(timezone.utc),
|
| 514 |
-
}
|
| 515 |
-
|
| 516 |
-
audio_file_id = insert_record(caption_collection, mongo_payload)
|
| 517 |
-
|
| 518 |
-
response_data = {**mongo_payload, "audio_file_id": audio_file_id}
|
| 519 |
-
response_data.pop("_id", None)
|
| 520 |
-
response_data["created_at"] = response_data["created_at"].isoformat()
|
| 521 |
-
|
| 522 |
-
return ok("Contextual captions generated successfully.", response_data)
|
| 523 |
-
# import io
|
| 524 |
-
# import logging
|
| 525 |
-
# import os
|
| 526 |
-
# import re
|
| 527 |
-
# import threading
|
| 528 |
-
# from datetime import datetime, timezone
|
| 529 |
-
|
| 530 |
-
# # Avoid invalid OMP setting from runtime environment (e.g. empty/non-numeric).
|
| 531 |
-
# _omp_threads = os.getenv("OMP_NUM_THREADS", "").strip()
|
| 532 |
-
# if not _omp_threads.isdigit() or int(_omp_threads) < 1:
|
| 533 |
-
# os.environ["OMP_NUM_THREADS"] = "8"
|
| 534 |
-
|
| 535 |
-
# import torch
|
| 536 |
-
# from dotenv import load_dotenv
|
| 537 |
-
# from fastapi import FastAPI, Request, UploadFile
|
| 538 |
-
# from fastapi.exceptions import RequestValidationError
|
| 539 |
-
# from fastapi.responses import JSONResponse
|
| 540 |
-
# from PIL import Image, UnidentifiedImageError
|
| 541 |
-
# from pymongo import MongoClient
|
| 542 |
-
# from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
|
| 543 |
-
# from starlette.datastructures import UploadFile as StarletteUploadFile
|
| 544 |
-
# from transformers import (
|
| 545 |
-
# AutoModelForImageTextToText,
|
| 546 |
-
# AutoModelForSeq2SeqLM,
|
| 547 |
-
# AutoProcessor,
|
| 548 |
-
# AutoTokenizer,
|
| 549 |
-
# )
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
# load_dotenv()
|
| 553 |
-
|
| 554 |
-
# CAPTION_MODEL_ID = os.getenv("CAPTION_MODEL_ID", "vidhi0405/Qwen_I2T")
|
| 555 |
-
# SUMMARIZER_MODEL_ID = os.getenv("SUMMARIZER_MODEL_ID", "facebook/bart-large-cnn")
|
| 556 |
-
# DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
| 557 |
-
# DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
|
| 558 |
-
# MAX_NEW_TOKENS = 120
|
| 559 |
-
# MAX_IMAGES = 5
|
| 560 |
-
# MONGO_URI = (os.getenv("MONGO_URI") or os.getenv("MONGODB_URI") or "").strip().strip('"').strip("'")
|
| 561 |
-
# MONGO_DB_NAME = os.getenv("MONGO_DB_NAME", "image_to_speech")
|
| 562 |
-
|
| 563 |
-
# CAPTION_PROMPT = (
|
| 564 |
-
# "Act as a professional news reporter delivering a live on-scene report in real time. "
|
| 565 |
-
# "Speak naturally, as if you are addressing viewers who are watching this unfold right now. "
|
| 566 |
-
# "Describe the scene in 3 to 4 complete, vivid sentences. "
|
| 567 |
-
# "Mention what is happening, the surrounding environment, and the overall mood, "
|
| 568 |
-
# "and convey the urgency or emotion of the moment when appropriate."
|
| 569 |
-
# )
|
| 570 |
-
# CAPTION_RETRY_PROMPT = (
|
| 571 |
-
# "Describe this image in 2 to 3 complete sentences. "
|
| 572 |
-
# "Mention the main subject, action, environment, and mood."
|
| 573 |
-
# )
|
| 574 |
-
# CAPTION_MIN_SENTENCES = 3
|
| 575 |
-
# CAPTION_MAX_SENTENCES = 4
|
| 576 |
-
# PROCESSOR_MAX_LENGTH = 8192
|
| 577 |
-
|
| 578 |
-
# logger = logging.getLogger(__name__)
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
# def ok(message: str, data):
|
| 582 |
-
# return JSONResponse(
|
| 583 |
-
# status_code=200,
|
| 584 |
-
# content={"success": True, "message": message, "data": data},
|
| 585 |
-
# )
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
# def fail(message: str, status_code: int = 400):
|
| 589 |
-
# return JSONResponse(
|
| 590 |
-
# status_code=status_code,
|
| 591 |
-
# content={"success": False, "message": message, "data": None},
|
| 592 |
-
# )
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
# class AppError(Exception):
|
| 596 |
-
# def __init__(self, message: str, status_code: int = 400):
|
| 597 |
-
# super().__init__(message)
|
| 598 |
-
# self.message = message
|
| 599 |
-
# self.status_code = status_code
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
# torch.set_num_threads(8)
|
| 603 |
-
# _caption_model = None
|
| 604 |
-
# _caption_processor = None
|
| 605 |
-
# _caption_lock = threading.Lock()
|
| 606 |
-
# _caption_force_cpu = False
|
| 607 |
-
# _summarizer_model = None
|
| 608 |
-
# _summarizer_tokenizer = None
|
| 609 |
-
# _summarizer_lock = threading.Lock()
|
| 610 |
-
|
| 611 |
-
# app = FastAPI(title="Image to Text API")
|
| 612 |
-
|
| 613 |
-
# mongo_client = None
|
| 614 |
-
# mongo_db = None
|
| 615 |
-
# caption_collection = None
|
| 616 |
-
# db_init_error = None
|
| 617 |
-
|
| 618 |
-
# if not MONGO_URI:
|
| 619 |
-
# db_init_error = "MONGO_URI (or MONGODB_URI) is not set."
|
| 620 |
-
# else:
|
| 621 |
-
# try:
|
| 622 |
-
# mongo_client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000)
|
| 623 |
-
# mongo_client.admin.command("ping")
|
| 624 |
-
# mongo_db = mongo_client[MONGO_DB_NAME]
|
| 625 |
-
# caption_collection = mongo_db["captions"]
|
| 626 |
-
# except ServerSelectionTimeoutError:
|
| 627 |
-
# db_init_error = "Unable to connect to MongoDB (timeout)."
|
| 628 |
-
# except PyMongoError as exc:
|
| 629 |
-
# db_init_error = "Unable to initialize MongoDB: {}".format(exc)
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
# @app.get("/")
|
| 633 |
-
# def root():
|
| 634 |
-
# return {
|
| 635 |
-
# "success": True,
|
| 636 |
-
# "message": "Use POST /generate-caption with form-data key 'file' or 'files' (up to 5 images).",
|
| 637 |
-
# "data": None,
|
| 638 |
-
# }
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
# @app.get("/health")
|
| 642 |
-
# def health():
|
| 643 |
-
# if db_init_error:
|
| 644 |
-
# return {
|
| 645 |
-
# "success": False,
|
| 646 |
-
# "message": db_init_error,
|
| 647 |
-
# "data": {
|
| 648 |
-
# "caption_model_id": CAPTION_MODEL_ID,
|
| 649 |
-
# "summarizer_model_id": SUMMARIZER_MODEL_ID,
|
| 650 |
-
# },
|
| 651 |
-
# }
|
| 652 |
-
# return {
|
| 653 |
-
# "success": True,
|
| 654 |
-
# "message": "ok",
|
| 655 |
-
# "data": {
|
| 656 |
-
# "caption_model_id": CAPTION_MODEL_ID,
|
| 657 |
-
# "summarizer_model_id": SUMMARIZER_MODEL_ID,
|
| 658 |
-
# },
|
| 659 |
-
# }
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
# @app.on_event("startup")
|
| 663 |
-
# async def preload_runtime_models():
|
| 664 |
-
# if os.getenv("PRELOAD_MODELS", "1").strip().lower() in {"0", "false", "no"}:
|
| 665 |
-
# logger.info("Model preloading disabled via PRELOAD_MODELS.")
|
| 666 |
-
# return
|
| 667 |
-
# try:
|
| 668 |
-
# _get_caption_runtime()
|
| 669 |
-
# logger.info("Caption model preloaded successfully.")
|
| 670 |
-
# except Exception as exc:
|
| 671 |
-
# logger.warning("Caption model preload failed: %s", exc)
|
| 672 |
-
# try:
|
| 673 |
-
# _get_summarizer_runtime()
|
| 674 |
-
# logger.info("Summarizer model preloaded successfully.")
|
| 675 |
-
# except Exception as exc:
|
| 676 |
-
# logger.warning("Summarizer model preload failed: %s", exc)
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
# @app.exception_handler(AppError)
|
| 680 |
-
# async def app_error_handler(_, exc: AppError):
|
| 681 |
-
# return fail(exc.message, exc.status_code)
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
# @app.exception_handler(RequestValidationError)
|
| 685 |
-
# async def validation_error_handler(_, exc: RequestValidationError):
|
| 686 |
-
# return fail("Invalid request payload.", 422)
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
# @app.exception_handler(Exception)
|
| 690 |
-
# async def unhandled_error_handler(_, exc: Exception):
|
| 691 |
-
# logger.exception("Unhandled server error: %s", exc)
|
| 692 |
-
# return fail("Internal server error.", 500)
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
# def _ensure_db_ready():
|
| 696 |
-
# if db_init_error:
|
| 697 |
-
# raise AppError(db_init_error, 503)
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
# def _finalize_caption(raw_text: str, max_sentences: int = CAPTION_MAX_SENTENCES) -> str:
|
| 701 |
-
# text = " ".join(raw_text.split()).strip()
|
| 702 |
-
# if not text:
|
| 703 |
-
# return ""
|
| 704 |
-
|
| 705 |
-
# sentences = re.findall(r"[^.!?]+[.!?]", text)
|
| 706 |
-
# sentences = [s.strip() for s in sentences if s.strip()]
|
| 707 |
-
|
| 708 |
-
# if len(sentences) >= CAPTION_MIN_SENTENCES:
|
| 709 |
-
# return " ".join(sentences[:max_sentences]).strip()
|
| 710 |
-
|
| 711 |
-
# if text and text[-1] not in ".!?":
|
| 712 |
-
# text = re.sub(r"[,:;\-]\s*[^,:;\-]*$", "", text).strip()
|
| 713 |
-
# return text
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
# def _get_caption_runtime():
|
| 717 |
-
# global _caption_model, _caption_processor, _caption_force_cpu
|
| 718 |
-
# if _caption_model is not None and _caption_processor is not None:
|
| 719 |
-
# return _caption_model, _caption_processor
|
| 720 |
-
|
| 721 |
-
# with _caption_lock:
|
| 722 |
-
# if _caption_model is None or _caption_processor is None:
|
| 723 |
-
# device = "cpu" if _caption_force_cpu else DEVICE
|
| 724 |
-
# dtype = torch.float32 if device == "cpu" else DTYPE
|
| 725 |
-
# try:
|
| 726 |
-
# loaded_model = AutoModelForImageTextToText.from_pretrained(
|
| 727 |
-
# CAPTION_MODEL_ID,
|
| 728 |
-
# trust_remote_code=True,
|
| 729 |
-
# torch_dtype=dtype,
|
| 730 |
-
# low_cpu_mem_usage=True,
|
| 731 |
-
# ).to(device)
|
| 732 |
-
# loaded_processor = AutoProcessor.from_pretrained(
|
| 733 |
-
# CAPTION_MODEL_ID,
|
| 734 |
-
# trust_remote_code=True,
|
| 735 |
-
# )
|
| 736 |
-
# except Exception as exc:
|
| 737 |
-
# raise AppError("Failed to load caption model.", 503) from exc
|
| 738 |
-
# loaded_model.eval()
|
| 739 |
-
# _caption_model = loaded_model
|
| 740 |
-
# _caption_processor = loaded_processor
|
| 741 |
-
|
| 742 |
-
# return _caption_model, _caption_processor
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
# def _get_summarizer_runtime():
|
| 746 |
-
# global _summarizer_model, _summarizer_tokenizer
|
| 747 |
-
# if _summarizer_model is not None and _summarizer_tokenizer is not None:
|
| 748 |
-
# return _summarizer_model, _summarizer_tokenizer
|
| 749 |
-
|
| 750 |
-
# with _summarizer_lock:
|
| 751 |
-
# if _summarizer_model is None or _summarizer_tokenizer is None:
|
| 752 |
-
# try:
|
| 753 |
-
# tokenizer = AutoTokenizer.from_pretrained(SUMMARIZER_MODEL_ID)
|
| 754 |
-
# model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZER_MODEL_ID, torch_dtype=DTYPE).to(DEVICE)
|
| 755 |
-
# except Exception as exc:
|
| 756 |
-
# raise AppError("Failed to load summarization model.", 503) from exc
|
| 757 |
-
# model.eval()
|
| 758 |
-
# _summarizer_tokenizer = tokenizer
|
| 759 |
-
# _summarizer_model = model
|
| 760 |
-
|
| 761 |
-
# return _summarizer_model, _summarizer_tokenizer
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
# def summarize_captions(captions: list[str]) -> str:
|
| 765 |
-
# if not captions:
|
| 766 |
-
# return ""
|
| 767 |
-
# if len(captions) == 1:
|
| 768 |
-
# return captions[0]
|
| 769 |
-
|
| 770 |
-
# model, tokenizer = _get_summarizer_runtime()
|
| 771 |
-
# combined = " ".join(c.strip() for c in captions if c and c.strip())
|
| 772 |
-
# if not combined:
|
| 773 |
-
# return ""
|
| 774 |
-
|
| 775 |
-
# try:
|
| 776 |
-
# inputs = tokenizer(
|
| 777 |
-
# combined,
|
| 778 |
-
# max_length=1024,
|
| 779 |
-
# truncation=True,
|
| 780 |
-
# return_tensors="pt",
|
| 781 |
-
# )
|
| 782 |
-
# inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 783 |
-
# with torch.no_grad():
|
| 784 |
-
# output_ids = model.generate(
|
| 785 |
-
# **inputs,
|
| 786 |
-
# max_length=300,
|
| 787 |
-
# min_length=50,
|
| 788 |
-
# length_penalty=2.0,
|
| 789 |
-
# num_beams=4,
|
| 790 |
-
# early_stopping=True,
|
| 791 |
-
# )
|
| 792 |
-
# summary = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
| 793 |
-
# except Exception as exc:
|
| 794 |
-
# raise AppError("Failed to summarize captions.", 500) from exc
|
| 795 |
-
|
| 796 |
-
# return _finalize_caption(summary, max_sentences=10)
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
# def generate_caption_text(image: Image.Image) -> str:
|
| 800 |
-
# runtime_model, runtime_processor = _get_caption_runtime()
|
| 801 |
-
# model_device = str(next(runtime_model.parameters()).device)
|
| 802 |
-
|
| 803 |
-
# def _build_inputs(prompt: str):
|
| 804 |
-
# messages = [
|
| 805 |
-
# {
|
| 806 |
-
# "role": "user",
|
| 807 |
-
# "content": [
|
| 808 |
-
# {"type": "image"},
|
| 809 |
-
# {"type": "text", "text": prompt},
|
| 810 |
-
# ],
|
| 811 |
-
# }
|
| 812 |
-
# ]
|
| 813 |
-
# text = runtime_processor.apply_chat_template(
|
| 814 |
-
# messages, tokenize=False, add_generation_prompt=True
|
| 815 |
-
# )
|
| 816 |
-
# return runtime_processor(
|
| 817 |
-
# text=text,
|
| 818 |
-
# images=image,
|
| 819 |
-
# return_tensors="pt",
|
| 820 |
-
# truncation=False,
|
| 821 |
-
# max_length=PROCESSOR_MAX_LENGTH,
|
| 822 |
-
# )
|
| 823 |
-
|
| 824 |
-
# try:
|
| 825 |
-
# inputs = _build_inputs(CAPTION_PROMPT)
|
| 826 |
-
# except Exception as exc:
|
| 827 |
-
# if "Mismatch in `image` token count" not in str(exc):
|
| 828 |
-
# raise AppError("Failed to preprocess image for captioning.", 422) from exc
|
| 829 |
-
# inputs = _build_inputs(CAPTION_RETRY_PROMPT)
|
| 830 |
-
|
| 831 |
-
# inputs = {k: v.to(model_device) for k, v in inputs.items()}
|
| 832 |
-
|
| 833 |
-
# try:
|
| 834 |
-
# with torch.no_grad():
|
| 835 |
-
# outputs = runtime_model.generate(
|
| 836 |
-
# **inputs,
|
| 837 |
-
# max_new_tokens=MAX_NEW_TOKENS,
|
| 838 |
-
# do_sample=False,
|
| 839 |
-
# num_beams=1,
|
| 840 |
-
# )
|
| 841 |
-
# except Exception as exc:
|
| 842 |
-
# raise AppError("Caption generation failed.", 500) from exc
|
| 843 |
-
|
| 844 |
-
# decoded = runtime_processor.decode(outputs[0], skip_special_tokens=True).strip()
|
| 845 |
-
# caption = decoded.split("assistant")[-1].lstrip(":\n ").strip()
|
| 846 |
-
# return _finalize_caption(caption)
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
# def generate_caption_text_safe(image: Image.Image) -> str:
|
| 850 |
-
# global _caption_model, _caption_processor, _caption_force_cpu
|
| 851 |
-
# try:
|
| 852 |
-
# return generate_caption_text(image)
|
| 853 |
-
# except Exception as exc:
|
| 854 |
-
# msg = str(exc)
|
| 855 |
-
# if "CUDA error" not in msg and "device-side assert" not in msg:
|
| 856 |
-
# raise
|
| 857 |
-
|
| 858 |
-
# with _caption_lock:
|
| 859 |
-
# _caption_force_cpu = True
|
| 860 |
-
# _caption_model = None
|
| 861 |
-
# _caption_processor = None
|
| 862 |
-
|
| 863 |
-
# if torch.cuda.is_available():
|
| 864 |
-
# try:
|
| 865 |
-
# torch.cuda.empty_cache()
|
| 866 |
-
# except Exception:
|
| 867 |
-
# pass
|
| 868 |
-
|
| 869 |
-
# return generate_caption_text(image)
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
# def insert_record(collection, payload: dict) -> str:
|
| 873 |
-
# try:
|
| 874 |
-
# result = collection.insert_one(payload)
|
| 875 |
-
# return str(result.inserted_id)
|
| 876 |
-
# except PyMongoError as exc:
|
| 877 |
-
# raise AppError("MongoDB insert failed.", 503) from exc
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
# @app.post("/generate-caption")
|
| 881 |
-
# async def generate_caption(request: Request):
|
| 882 |
-
# _ensure_db_ready()
|
| 883 |
-
|
| 884 |
-
# try:
|
| 885 |
-
# form = await request.form()
|
| 886 |
-
# except Exception as exc:
|
| 887 |
-
# raise AppError("Invalid request payload.", 422) from exc
|
| 888 |
-
|
| 889 |
-
# uploads: list[UploadFile | StarletteUploadFile] = []
|
| 890 |
-
# for key in ("files", "files[]", "file"):
|
| 891 |
-
# for value in form.getlist(key):
|
| 892 |
-
# if isinstance(value, (UploadFile, StarletteUploadFile)):
|
| 893 |
-
# uploads.append(value)
|
| 894 |
-
|
| 895 |
-
# # Fallback for clients that send non-standard multipart keys.
|
| 896 |
-
# if not uploads:
|
| 897 |
-
# for _, value in form.multi_items():
|
| 898 |
-
# if isinstance(value, (UploadFile, StarletteUploadFile)):
|
| 899 |
-
# uploads.append(value)
|
| 900 |
-
|
| 901 |
-
# if not uploads:
|
| 902 |
-
# raise AppError("At least one image is required.", 400)
|
| 903 |
-
# if len(uploads) > MAX_IMAGES:
|
| 904 |
-
# raise AppError("You can upload a maximum of 5 images.", 400)
|
| 905 |
-
|
| 906 |
-
# image_captions = []
|
| 907 |
-
# for upload in uploads:
|
| 908 |
-
# if upload.content_type and not upload.content_type.startswith("image/"):
|
| 909 |
-
# raise AppError("All uploaded files must be images.", 400)
|
| 910 |
-
|
| 911 |
-
# file_bytes = await upload.read()
|
| 912 |
-
# if not file_bytes:
|
| 913 |
-
# raise AppError("One of the uploaded images is empty.", 400)
|
| 914 |
-
|
| 915 |
-
# try:
|
| 916 |
-
# image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
|
| 917 |
-
# except UnidentifiedImageError as exc:
|
| 918 |
-
# raise AppError("One of the uploaded files is not a valid image.", 400) from exc
|
| 919 |
-
# except OSError as exc:
|
| 920 |
-
# raise AppError("Unable to read one of the uploaded images.", 400) from exc
|
| 921 |
-
|
| 922 |
-
# caption = generate_caption_text_safe(image)
|
| 923 |
-
# if not caption:
|
| 924 |
-
# raise AppError("Caption generation produced empty text.", 500)
|
| 925 |
-
|
| 926 |
-
# image_captions.append({"filename": upload.filename, "caption": caption})
|
| 927 |
-
|
| 928 |
-
# caption_texts = [x["caption"] for x in image_captions]
|
| 929 |
-
# caption = summarize_captions(caption_texts)
|
| 930 |
-
# if not caption:
|
| 931 |
-
# raise AppError("Caption summarization produced empty text.", 500)
|
| 932 |
-
|
| 933 |
-
# mongo_payload = {
|
| 934 |
-
# "caption": caption,
|
| 935 |
-
# "source_filenames": [item["filename"] for item in image_captions],
|
| 936 |
-
# "image_captions": image_captions,
|
| 937 |
-
# "images_count": len(image_captions),
|
| 938 |
-
# "is_summarized": len(image_captions) > 1,
|
| 939 |
-
# "created_at": datetime.now(timezone.utc),
|
| 940 |
-
# }
|
| 941 |
-
|
| 942 |
-
# audio_file_id = insert_record(caption_collection, mongo_payload)
|
| 943 |
-
|
| 944 |
-
# response_data = {**mongo_payload, "audio_file_id": audio_file_id}
|
| 945 |
-
# response_data.pop("_id", None) # Remove ObjectId as it is not JSON serializable
|
| 946 |
-
# response_data["created_at"] = response_data["created_at"].isoformat()
|
| 947 |
-
|
| 948 |
-
# return ok("Caption generated successfully.", response_data)
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
@app.get("/")
|
| 111 |
+
async def root():
|
| 112 |
return {
|
| 113 |
"success": True,
|
| 114 |
"message": "Use POST /generate-caption with form-data key 'file' or 'files' (up to 5 images).",
|
|
|
|
| 117 |
|
| 118 |
|
| 119 |
@app.get("/health")
|
| 120 |
+
async def health():
|
| 121 |
if db_init_error:
|
| 122 |
return {
|
| 123 |
"success": False,
|
|
|
|
| 304 |
except Exception as exc:
|
| 305 |
if "Mismatch in `image` token count" not in str(exc):
|
| 306 |
raise AppError("Failed to preprocess image for captioning.", 422) from exc
|
| 307 |
+
try:
|
| 308 |
+
inputs = _build_inputs(CAPTION_RETRY_PROMPT)
|
| 309 |
+
except Exception as retry_exc:
|
| 310 |
+
raise AppError("Failed to preprocess image during retry.", 422) from retry_exc
|
| 311 |
|
| 312 |
inputs = {k: v.to(model_device) for k, v in inputs.items()}
|
| 313 |
|
|
|
|
| 316 |
outputs = runtime_model.generate(
|
| 317 |
**inputs,
|
| 318 |
max_new_tokens=MAX_NEW_TOKENS,
|
| 319 |
+
do_sample=True,
|
| 320 |
+
top_p=0.9,
|
| 321 |
+
temperature=0.7,
|
| 322 |
+
repetition_penalty=1.2,
|
| 323 |
)
|
| 324 |
+
decoded = runtime_processor.decode(outputs[0], skip_special_tokens=True).strip()
|
| 325 |
except Exception as exc:
|
| 326 |
raise AppError("Caption generation failed.", 500) from exc
|
| 327 |
|
|
|
|
| 328 |
caption = decoded.split("assistant")[-1].lstrip(":\n ").strip()
|
| 329 |
return _finalize_caption(caption)
|
| 330 |
|
|
|
|
| 388 |
if upload.content_type and not upload.content_type.startswith("image/"):
|
| 389 |
raise AppError("All uploaded files must be images.", 400)
|
| 390 |
|
| 391 |
+
try:
|
| 392 |
+
file_bytes = await upload.read()
|
| 393 |
+
except Exception as exc:
|
| 394 |
+
raise AppError("Failed to read uploaded file content.", 400) from exc
|
| 395 |
+
|
| 396 |
if not file_bytes:
|
| 397 |
raise AppError("One of the uploaded images is empty.", 400)
|
| 398 |
|
|
|
|
| 409 |
return parsed_images
|
| 410 |
|
| 411 |
|
| 412 |
+
@app.post("/generate-caption")
|
| 413 |
+
async def generate_caption(request: Request):
|
| 414 |
_ensure_db_ready()
|
| 415 |
images = await _parse_images(request)
|
| 416 |
|
| 417 |
image_captions = []
|
| 418 |
for filename, image in images:
|
| 419 |
+
try:
|
| 420 |
+
caption = generate_caption_text_safe(image)
|
| 421 |
+
if not caption:
|
| 422 |
+
raise AppError(f"Caption generation produced empty text for {filename}.", 500)
|
| 423 |
+
image_captions.append({"filename": filename, "caption": caption})
|
| 424 |
+
except AppError:
|
| 425 |
+
raise
|
| 426 |
+
except Exception as exc:
|
| 427 |
+
logger.error(f"Error generating caption for {filename}: {exc}")
|
| 428 |
+
raise AppError(f"Failed to generate caption for {filename}.", 500) from exc
|
| 429 |
|
| 430 |
caption_texts = [x["caption"] for x in image_captions]
|
| 431 |
+
|
| 432 |
+
try:
|
| 433 |
+
caption = summarize_captions(caption_texts)
|
| 434 |
+
if not caption:
|
| 435 |
+
raise AppError("Caption summarization produced empty text.", 500)
|
| 436 |
+
except AppError:
|
| 437 |
+
raise
|
| 438 |
+
except Exception as exc:
|
| 439 |
+
logger.error(f"Summarization error: {exc}")
|
| 440 |
+
raise AppError("Failed to summarize captions.", 500) from exc
|
| 441 |
|
| 442 |
mongo_payload = {
|
| 443 |
"caption": caption,
|
|
|
|
| 448 |
"created_at": datetime.now(timezone.utc),
|
| 449 |
}
|
| 450 |
|
| 451 |
+
try:
|
| 452 |
+
audio_file_id = insert_record(caption_collection, mongo_payload)
|
| 453 |
+
except AppError:
|
| 454 |
+
raise
|
| 455 |
+
except Exception as exc:
|
| 456 |
+
logger.error(f"Database insert error: {exc}")
|
| 457 |
+
raise AppError("Failed to save record to database.", 503) from exc
|
| 458 |
|
| 459 |
response_data = {**mongo_payload, "audio_file_id": audio_file_id}
|
| 460 |
response_data.pop("_id", None) # Remove ObjectId as it is not JSON serializable
|
| 461 |
response_data["created_at"] = response_data["created_at"].isoformat()
|
| 462 |
|
| 463 |
+
return ok("Caption generated successfully.", response_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|