LogicGoInfotechSpaces commited on
Commit
26a4e1b
·
verified ·
1 Parent(s): 4864dad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +537 -15
app.py CHANGED
@@ -261,8 +261,8 @@ def summarize_captions(captions: list[str]) -> str:
261
  with torch.no_grad():
262
  output_ids = model.generate(
263
  **inputs,
264
- max_length=300,
265
- min_length=50,
266
  length_penalty=2.0,
267
  num_beams=4,
268
  early_stopping=True,
@@ -274,7 +274,7 @@ def summarize_captions(captions: list[str]) -> str:
274
  return _finalize_caption(summary, max_sentences=10)
275
 
276
 
277
- def generate_caption_text(image: Image.Image) -> str:
278
  runtime_model, runtime_processor = _get_caption_runtime()
279
  model_device = str(next(runtime_model.parameters()).device)
280
 
@@ -300,7 +300,7 @@ def generate_caption_text(image: Image.Image) -> str:
300
  )
301
 
302
  try:
303
- inputs = _build_inputs(CAPTION_PROMPT)
304
  except Exception as exc:
305
  if "Mismatch in `image` token count" not in str(exc):
306
  raise AppError("Failed to preprocess image for captioning.", 422) from exc
@@ -324,10 +324,10 @@ def generate_caption_text(image: Image.Image) -> str:
324
  return _finalize_caption(caption)
325
 
326
 
327
- def generate_caption_text_safe(image: Image.Image) -> str:
328
  global _caption_model, _caption_processor, _caption_force_cpu
329
  try:
330
- return generate_caption_text(image)
331
  except Exception as exc:
332
  msg = str(exc)
333
  if "CUDA error" not in msg and "device-side assert" not in msg:
@@ -344,7 +344,7 @@ def generate_caption_text_safe(image: Image.Image) -> str:
344
  except Exception:
345
  pass
346
 
347
- return generate_caption_text(image)
348
 
349
 
350
  def insert_record(collection, payload: dict) -> str:
@@ -355,10 +355,7 @@ def insert_record(collection, payload: dict) -> str:
355
  raise AppError("MongoDB insert failed.", 503) from exc
356
 
357
 
358
- @app.post("/generate-caption")
359
- async def generate_caption(request: Request):
360
- _ensure_db_ready()
361
-
362
  try:
363
  form = await request.form()
364
  except Exception as exc:
@@ -381,8 +378,8 @@ async def generate_caption(request: Request):
381
  if len(uploads) > MAX_IMAGES:
382
  raise AppError("You can upload a maximum of 5 images.", 400)
383
 
384
- image_captions = []
385
- for upload in uploads:
386
  if upload.content_type and not upload.content_type.startswith("image/"):
387
  raise AppError("All uploaded files must be images.", 400)
388
 
@@ -397,11 +394,23 @@ async def generate_caption(request: Request):
397
  except OSError as exc:
398
  raise AppError("Unable to read one of the uploaded images.", 400) from exc
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  caption = generate_caption_text_safe(image)
401
  if not caption:
402
  raise AppError("Caption generation produced empty text.", 500)
403
-
404
- image_captions.append({"filename": upload.filename, "caption": caption})
405
 
406
  caption_texts = [x["caption"] for x in image_captions]
407
  caption = summarize_captions(caption_texts)
@@ -424,3 +433,516 @@ async def generate_caption(request: Request):
424
  response_data["created_at"] = response_data["created_at"].isoformat()
425
 
426
  return ok("Caption generated successfully.", response_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  with torch.no_grad():
262
  output_ids = model.generate(
263
  **inputs,
264
+ max_length=512,
265
+ min_length=100,
266
  length_penalty=2.0,
267
  num_beams=4,
268
  early_stopping=True,
 
274
  return _finalize_caption(summary, max_sentences=10)
275
 
276
 
277
+ def generate_caption_text(image: Image.Image, prompt: str = CAPTION_PROMPT) -> str:
278
  runtime_model, runtime_processor = _get_caption_runtime()
279
  model_device = str(next(runtime_model.parameters()).device)
280
 
 
300
  )
301
 
302
  try:
303
+ inputs = _build_inputs(prompt)
304
  except Exception as exc:
305
  if "Mismatch in `image` token count" not in str(exc):
306
  raise AppError("Failed to preprocess image for captioning.", 422) from exc
 
324
  return _finalize_caption(caption)
325
 
326
 
327
+ def generate_caption_text_safe(image: Image.Image, prompt: str = CAPTION_PROMPT) -> str:
328
  global _caption_model, _caption_processor, _caption_force_cpu
329
  try:
330
+ return generate_caption_text(image, prompt)
331
  except Exception as exc:
332
  msg = str(exc)
333
  if "CUDA error" not in msg and "device-side assert" not in msg:
 
344
  except Exception:
345
  pass
346
 
347
+ return generate_caption_text(image, prompt)
348
 
349
 
350
  def insert_record(collection, payload: dict) -> str:
 
355
  raise AppError("MongoDB insert failed.", 503) from exc
356
 
357
 
358
+ async def _parse_images(request: Request) -> list[tuple[str, Image.Image]]:
 
 
 
359
  try:
360
  form = await request.form()
361
  except Exception as exc:
 
378
  if len(uploads) > MAX_IMAGES:
379
  raise AppError("You can upload a maximum of 5 images.", 400)
380
 
381
+ parsed_images = []
382
+ for i, upload in enumerate(uploads):
383
  if upload.content_type and not upload.content_type.startswith("image/"):
384
  raise AppError("All uploaded files must be images.", 400)
385
 
 
394
  except OSError as exc:
395
  raise AppError("Unable to read one of the uploaded images.", 400) from exc
396
 
397
+ filename = upload.filename or f"image_{i+1}"
398
+ parsed_images.append((filename, image))
399
+
400
+ return parsed_images
401
+
402
+
403
+ @app.post("/generate-caption-summary")
404
+ async def generate_caption_summary(request: Request):
405
+ _ensure_db_ready()
406
+ images = await _parse_images(request)
407
+
408
+ image_captions = []
409
+ for filename, image in images:
410
  caption = generate_caption_text_safe(image)
411
  if not caption:
412
  raise AppError("Caption generation produced empty text.", 500)
413
+ image_captions.append({"filename": filename, "caption": caption})
 
414
 
415
  caption_texts = [x["caption"] for x in image_captions]
416
  caption = summarize_captions(caption_texts)
 
433
  response_data["created_at"] = response_data["created_at"].isoformat()
434
 
435
  return ok("Caption generated successfully.", response_data)
436
+
437
+
438
+ @app.post("/generate-caption-collage")
439
+ async def generate_caption_collage(request: Request):
440
+ _ensure_db_ready()
441
+ images = await _parse_images(request)
442
+
443
+ # Create collage (horizontal strip, resized to height 512 for consistency)
444
+ resized_images = []
445
+ target_height = 512
446
+ for _, img in images:
447
+ aspect_ratio = img.width / img.height
448
+ new_width = int(target_height * aspect_ratio)
449
+ resized_images.append(img.resize((new_width, target_height), Image.Resampling.LANCZOS))
450
+
451
+ total_width = sum(img.width for img in resized_images)
452
+ collage = Image.new("RGB", (total_width, target_height))
453
+ x_offset = 0
454
+ for img in resized_images:
455
+ collage.paste(img, (x_offset, 0))
456
+ x_offset += img.width
457
+
458
+ caption = generate_caption_text_safe(collage)
459
+ if not caption:
460
+ raise AppError("Collage caption generation produced empty text.", 500)
461
+
462
+ # For database storage, we list source filenames but the 'image_captions'
463
+ # will just contain the single collage caption to avoid confusion.
464
+ source_filenames = [fname for fname, _ in images]
465
+
466
+ mongo_payload = {
467
+ "caption": caption,
468
+ "source_filenames": source_filenames,
469
+ "image_captions": [{"filename": "collage", "caption": caption}],
470
+ "images_count": len(images),
471
+ "is_summarized": False, # It's a direct caption of a collage
472
+ "created_at": datetime.now(timezone.utc),
473
+ }
474
+
475
+ audio_file_id = insert_record(caption_collection, mongo_payload)
476
+
477
+ response_data = {**mongo_payload, "audio_file_id": audio_file_id}
478
+ response_data.pop("_id", None)
479
+ response_data["created_at"] = response_data["created_at"].isoformat()
480
+
481
+ return ok("Collage caption generated successfully.", response_data)
482
+
483
+
484
+ @app.post("/generate-caption-context")
485
+ async def generate_caption_context(request: Request):
486
+ _ensure_db_ready()
487
+ images = await _parse_images(request)
488
+
489
+ image_captions = []
490
+ previous_context = ""
491
+
492
+ for i, (filename, image) in enumerate(images):
493
+ prompt = CAPTION_PROMPT
494
+ if i > 0 and previous_context:
495
+ prompt = f"Context from previous image: {previous_context}. {CAPTION_PROMPT}"
496
+
497
+ caption = generate_caption_text_safe(image, prompt=prompt)
498
+ if not caption:
499
+ caption = "No caption generated."
500
+
501
+ image_captions.append({"filename": filename, "caption": caption})
502
+ previous_context = caption
503
+
504
+ # Combine captions for the main 'caption' field
505
+ full_text = " ".join([ic["caption"] for ic in image_captions])
506
+
507
+ mongo_payload = {
508
+ "caption": full_text,
509
+ "source_filenames": [fname for fname, _ in images],
510
+ "image_captions": image_captions,
511
+ "images_count": len(images),
512
+ "is_summarized": False,
513
+ "created_at": datetime.now(timezone.utc),
514
+ }
515
+
516
+ audio_file_id = insert_record(caption_collection, mongo_payload)
517
+
518
+ response_data = {**mongo_payload, "audio_file_id": audio_file_id}
519
+ response_data.pop("_id", None)
520
+ response_data["created_at"] = response_data["created_at"].isoformat()
521
+
522
+ return ok("Contextual captions generated successfully.", response_data)
523
+ # import io
524
+ # import logging
525
+ # import os
526
+ # import re
527
+ # import threading
528
+ # from datetime import datetime, timezone
529
+
530
+ # # Avoid invalid OMP setting from runtime environment (e.g. empty/non-numeric).
531
+ # _omp_threads = os.getenv("OMP_NUM_THREADS", "").strip()
532
+ # if not _omp_threads.isdigit() or int(_omp_threads) < 1:
533
+ # os.environ["OMP_NUM_THREADS"] = "8"
534
+
535
+ # import torch
536
+ # from dotenv import load_dotenv
537
+ # from fastapi import FastAPI, Request, UploadFile
538
+ # from fastapi.exceptions import RequestValidationError
539
+ # from fastapi.responses import JSONResponse
540
+ # from PIL import Image, UnidentifiedImageError
541
+ # from pymongo import MongoClient
542
+ # from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
543
+ # from starlette.datastructures import UploadFile as StarletteUploadFile
544
+ # from transformers import (
545
+ # AutoModelForImageTextToText,
546
+ # AutoModelForSeq2SeqLM,
547
+ # AutoProcessor,
548
+ # AutoTokenizer,
549
+ # )
550
+
551
+
552
+ # load_dotenv()
553
+
554
+ # CAPTION_MODEL_ID = os.getenv("CAPTION_MODEL_ID", "vidhi0405/Qwen_I2T")
555
+ # SUMMARIZER_MODEL_ID = os.getenv("SUMMARIZER_MODEL_ID", "facebook/bart-large-cnn")
556
+ # DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
557
+ # DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
558
+ # MAX_NEW_TOKENS = 120
559
+ # MAX_IMAGES = 5
560
+ # MONGO_URI = (os.getenv("MONGO_URI") or os.getenv("MONGODB_URI") or "").strip().strip('"').strip("'")
561
+ # MONGO_DB_NAME = os.getenv("MONGO_DB_NAME", "image_to_speech")
562
+
563
+ # CAPTION_PROMPT = (
564
+ # "Act as a professional news reporter delivering a live on-scene report in real time. "
565
+ # "Speak naturally, as if you are addressing viewers who are watching this unfold right now. "
566
+ # "Describe the scene in 3 to 4 complete, vivid sentences. "
567
+ # "Mention what is happening, the surrounding environment, and the overall mood, "
568
+ # "and convey the urgency or emotion of the moment when appropriate."
569
+ # )
570
+ # CAPTION_RETRY_PROMPT = (
571
+ # "Describe this image in 2 to 3 complete sentences. "
572
+ # "Mention the main subject, action, environment, and mood."
573
+ # )
574
+ # CAPTION_MIN_SENTENCES = 3
575
+ # CAPTION_MAX_SENTENCES = 4
576
+ # PROCESSOR_MAX_LENGTH = 8192
577
+
578
+ # logger = logging.getLogger(__name__)
579
+
580
+
581
+ # def ok(message: str, data):
582
+ # return JSONResponse(
583
+ # status_code=200,
584
+ # content={"success": True, "message": message, "data": data},
585
+ # )
586
+
587
+
588
+ # def fail(message: str, status_code: int = 400):
589
+ # return JSONResponse(
590
+ # status_code=status_code,
591
+ # content={"success": False, "message": message, "data": None},
592
+ # )
593
+
594
+
595
+ # class AppError(Exception):
596
+ # def __init__(self, message: str, status_code: int = 400):
597
+ # super().__init__(message)
598
+ # self.message = message
599
+ # self.status_code = status_code
600
+
601
+
602
+ # torch.set_num_threads(8)
603
+ # _caption_model = None
604
+ # _caption_processor = None
605
+ # _caption_lock = threading.Lock()
606
+ # _caption_force_cpu = False
607
+ # _summarizer_model = None
608
+ # _summarizer_tokenizer = None
609
+ # _summarizer_lock = threading.Lock()
610
+
611
+ # app = FastAPI(title="Image to Text API")
612
+
613
+ # mongo_client = None
614
+ # mongo_db = None
615
+ # caption_collection = None
616
+ # db_init_error = None
617
+
618
+ # if not MONGO_URI:
619
+ # db_init_error = "MONGO_URI (or MONGODB_URI) is not set."
620
+ # else:
621
+ # try:
622
+ # mongo_client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000)
623
+ # mongo_client.admin.command("ping")
624
+ # mongo_db = mongo_client[MONGO_DB_NAME]
625
+ # caption_collection = mongo_db["captions"]
626
+ # except ServerSelectionTimeoutError:
627
+ # db_init_error = "Unable to connect to MongoDB (timeout)."
628
+ # except PyMongoError as exc:
629
+ # db_init_error = "Unable to initialize MongoDB: {}".format(exc)
630
+
631
+
632
+ # @app.get("/")
633
+ # def root():
634
+ # return {
635
+ # "success": True,
636
+ # "message": "Use POST /generate-caption with form-data key 'file' or 'files' (up to 5 images).",
637
+ # "data": None,
638
+ # }
639
+
640
+
641
+ # @app.get("/health")
642
+ # def health():
643
+ # if db_init_error:
644
+ # return {
645
+ # "success": False,
646
+ # "message": db_init_error,
647
+ # "data": {
648
+ # "caption_model_id": CAPTION_MODEL_ID,
649
+ # "summarizer_model_id": SUMMARIZER_MODEL_ID,
650
+ # },
651
+ # }
652
+ # return {
653
+ # "success": True,
654
+ # "message": "ok",
655
+ # "data": {
656
+ # "caption_model_id": CAPTION_MODEL_ID,
657
+ # "summarizer_model_id": SUMMARIZER_MODEL_ID,
658
+ # },
659
+ # }
660
+
661
+
662
+ # @app.on_event("startup")
663
+ # async def preload_runtime_models():
664
+ # if os.getenv("PRELOAD_MODELS", "1").strip().lower() in {"0", "false", "no"}:
665
+ # logger.info("Model preloading disabled via PRELOAD_MODELS.")
666
+ # return
667
+ # try:
668
+ # _get_caption_runtime()
669
+ # logger.info("Caption model preloaded successfully.")
670
+ # except Exception as exc:
671
+ # logger.warning("Caption model preload failed: %s", exc)
672
+ # try:
673
+ # _get_summarizer_runtime()
674
+ # logger.info("Summarizer model preloaded successfully.")
675
+ # except Exception as exc:
676
+ # logger.warning("Summarizer model preload failed: %s", exc)
677
+
678
+
679
+ # @app.exception_handler(AppError)
680
+ # async def app_error_handler(_, exc: AppError):
681
+ # return fail(exc.message, exc.status_code)
682
+
683
+
684
+ # @app.exception_handler(RequestValidationError)
685
+ # async def validation_error_handler(_, exc: RequestValidationError):
686
+ # return fail("Invalid request payload.", 422)
687
+
688
+
689
+ # @app.exception_handler(Exception)
690
+ # async def unhandled_error_handler(_, exc: Exception):
691
+ # logger.exception("Unhandled server error: %s", exc)
692
+ # return fail("Internal server error.", 500)
693
+
694
+
695
+ # def _ensure_db_ready():
696
+ # if db_init_error:
697
+ # raise AppError(db_init_error, 503)
698
+
699
+
700
+ # def _finalize_caption(raw_text: str, max_sentences: int = CAPTION_MAX_SENTENCES) -> str:
701
+ # text = " ".join(raw_text.split()).strip()
702
+ # if not text:
703
+ # return ""
704
+
705
+ # sentences = re.findall(r"[^.!?]+[.!?]", text)
706
+ # sentences = [s.strip() for s in sentences if s.strip()]
707
+
708
+ # if len(sentences) >= CAPTION_MIN_SENTENCES:
709
+ # return " ".join(sentences[:max_sentences]).strip()
710
+
711
+ # if text and text[-1] not in ".!?":
712
+ # text = re.sub(r"[,:;\-]\s*[^,:;\-]*$", "", text).strip()
713
+ # return text
714
+
715
+
716
+ # def _get_caption_runtime():
717
+ # global _caption_model, _caption_processor, _caption_force_cpu
718
+ # if _caption_model is not None and _caption_processor is not None:
719
+ # return _caption_model, _caption_processor
720
+
721
+ # with _caption_lock:
722
+ # if _caption_model is None or _caption_processor is None:
723
+ # device = "cpu" if _caption_force_cpu else DEVICE
724
+ # dtype = torch.float32 if device == "cpu" else DTYPE
725
+ # try:
726
+ # loaded_model = AutoModelForImageTextToText.from_pretrained(
727
+ # CAPTION_MODEL_ID,
728
+ # trust_remote_code=True,
729
+ # torch_dtype=dtype,
730
+ # low_cpu_mem_usage=True,
731
+ # ).to(device)
732
+ # loaded_processor = AutoProcessor.from_pretrained(
733
+ # CAPTION_MODEL_ID,
734
+ # trust_remote_code=True,
735
+ # )
736
+ # except Exception as exc:
737
+ # raise AppError("Failed to load caption model.", 503) from exc
738
+ # loaded_model.eval()
739
+ # _caption_model = loaded_model
740
+ # _caption_processor = loaded_processor
741
+
742
+ # return _caption_model, _caption_processor
743
+
744
+
745
+ # def _get_summarizer_runtime():
746
+ # global _summarizer_model, _summarizer_tokenizer
747
+ # if _summarizer_model is not None and _summarizer_tokenizer is not None:
748
+ # return _summarizer_model, _summarizer_tokenizer
749
+
750
+ # with _summarizer_lock:
751
+ # if _summarizer_model is None or _summarizer_tokenizer is None:
752
+ # try:
753
+ # tokenizer = AutoTokenizer.from_pretrained(SUMMARIZER_MODEL_ID)
754
+ # model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZER_MODEL_ID, torch_dtype=DTYPE).to(DEVICE)
755
+ # except Exception as exc:
756
+ # raise AppError("Failed to load summarization model.", 503) from exc
757
+ # model.eval()
758
+ # _summarizer_tokenizer = tokenizer
759
+ # _summarizer_model = model
760
+
761
+ # return _summarizer_model, _summarizer_tokenizer
762
+
763
+
764
+ # def summarize_captions(captions: list[str]) -> str:
765
+ # if not captions:
766
+ # return ""
767
+ # if len(captions) == 1:
768
+ # return captions[0]
769
+
770
+ # model, tokenizer = _get_summarizer_runtime()
771
+ # combined = " ".join(c.strip() for c in captions if c and c.strip())
772
+ # if not combined:
773
+ # return ""
774
+
775
+ # try:
776
+ # inputs = tokenizer(
777
+ # combined,
778
+ # max_length=1024,
779
+ # truncation=True,
780
+ # return_tensors="pt",
781
+ # )
782
+ # inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
783
+ # with torch.no_grad():
784
+ # output_ids = model.generate(
785
+ # **inputs,
786
+ # max_length=300,
787
+ # min_length=50,
788
+ # length_penalty=2.0,
789
+ # num_beams=4,
790
+ # early_stopping=True,
791
+ # )
792
+ # summary = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
793
+ # except Exception as exc:
794
+ # raise AppError("Failed to summarize captions.", 500) from exc
795
+
796
+ # return _finalize_caption(summary, max_sentences=10)
797
+
798
+
799
+ # def generate_caption_text(image: Image.Image) -> str:
800
+ # runtime_model, runtime_processor = _get_caption_runtime()
801
+ # model_device = str(next(runtime_model.parameters()).device)
802
+
803
+ # def _build_inputs(prompt: str):
804
+ # messages = [
805
+ # {
806
+ # "role": "user",
807
+ # "content": [
808
+ # {"type": "image"},
809
+ # {"type": "text", "text": prompt},
810
+ # ],
811
+ # }
812
+ # ]
813
+ # text = runtime_processor.apply_chat_template(
814
+ # messages, tokenize=False, add_generation_prompt=True
815
+ # )
816
+ # return runtime_processor(
817
+ # text=text,
818
+ # images=image,
819
+ # return_tensors="pt",
820
+ # truncation=False,
821
+ # max_length=PROCESSOR_MAX_LENGTH,
822
+ # )
823
+
824
+ # try:
825
+ # inputs = _build_inputs(CAPTION_PROMPT)
826
+ # except Exception as exc:
827
+ # if "Mismatch in `image` token count" not in str(exc):
828
+ # raise AppError("Failed to preprocess image for captioning.", 422) from exc
829
+ # inputs = _build_inputs(CAPTION_RETRY_PROMPT)
830
+
831
+ # inputs = {k: v.to(model_device) for k, v in inputs.items()}
832
+
833
+ # try:
834
+ # with torch.no_grad():
835
+ # outputs = runtime_model.generate(
836
+ # **inputs,
837
+ # max_new_tokens=MAX_NEW_TOKENS,
838
+ # do_sample=False,
839
+ # num_beams=1,
840
+ # )
841
+ # except Exception as exc:
842
+ # raise AppError("Caption generation failed.", 500) from exc
843
+
844
+ # decoded = runtime_processor.decode(outputs[0], skip_special_tokens=True).strip()
845
+ # caption = decoded.split("assistant")[-1].lstrip(":\n ").strip()
846
+ # return _finalize_caption(caption)
847
+
848
+
849
+ # def generate_caption_text_safe(image: Image.Image) -> str:
850
+ # global _caption_model, _caption_processor, _caption_force_cpu
851
+ # try:
852
+ # return generate_caption_text(image)
853
+ # except Exception as exc:
854
+ # msg = str(exc)
855
+ # if "CUDA error" not in msg and "device-side assert" not in msg:
856
+ # raise
857
+
858
+ # with _caption_lock:
859
+ # _caption_force_cpu = True
860
+ # _caption_model = None
861
+ # _caption_processor = None
862
+
863
+ # if torch.cuda.is_available():
864
+ # try:
865
+ # torch.cuda.empty_cache()
866
+ # except Exception:
867
+ # pass
868
+
869
+ # return generate_caption_text(image)
870
+
871
+
872
+ # def insert_record(collection, payload: dict) -> str:
873
+ # try:
874
+ # result = collection.insert_one(payload)
875
+ # return str(result.inserted_id)
876
+ # except PyMongoError as exc:
877
+ # raise AppError("MongoDB insert failed.", 503) from exc
878
+
879
+
880
+ # @app.post("/generate-caption")
881
+ # async def generate_caption(request: Request):
882
+ # _ensure_db_ready()
883
+
884
+ # try:
885
+ # form = await request.form()
886
+ # except Exception as exc:
887
+ # raise AppError("Invalid request payload.", 422) from exc
888
+
889
+ # uploads: list[UploadFile | StarletteUploadFile] = []
890
+ # for key in ("files", "files[]", "file"):
891
+ # for value in form.getlist(key):
892
+ # if isinstance(value, (UploadFile, StarletteUploadFile)):
893
+ # uploads.append(value)
894
+
895
+ # # Fallback for clients that send non-standard multipart keys.
896
+ # if not uploads:
897
+ # for _, value in form.multi_items():
898
+ # if isinstance(value, (UploadFile, StarletteUploadFile)):
899
+ # uploads.append(value)
900
+
901
+ # if not uploads:
902
+ # raise AppError("At least one image is required.", 400)
903
+ # if len(uploads) > MAX_IMAGES:
904
+ # raise AppError("You can upload a maximum of 5 images.", 400)
905
+
906
+ # image_captions = []
907
+ # for upload in uploads:
908
+ # if upload.content_type and not upload.content_type.startswith("image/"):
909
+ # raise AppError("All uploaded files must be images.", 400)
910
+
911
+ # file_bytes = await upload.read()
912
+ # if not file_bytes:
913
+ # raise AppError("One of the uploaded images is empty.", 400)
914
+
915
+ # try:
916
+ # image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
917
+ # except UnidentifiedImageError as exc:
918
+ # raise AppError("One of the uploaded files is not a valid image.", 400) from exc
919
+ # except OSError as exc:
920
+ # raise AppError("Unable to read one of the uploaded images.", 400) from exc
921
+
922
+ # caption = generate_caption_text_safe(image)
923
+ # if not caption:
924
+ # raise AppError("Caption generation produced empty text.", 500)
925
+
926
+ # image_captions.append({"filename": upload.filename, "caption": caption})
927
+
928
+ # caption_texts = [x["caption"] for x in image_captions]
929
+ # caption = summarize_captions(caption_texts)
930
+ # if not caption:
931
+ # raise AppError("Caption summarization produced empty text.", 500)
932
+
933
+ # mongo_payload = {
934
+ # "caption": caption,
935
+ # "source_filenames": [item["filename"] for item in image_captions],
936
+ # "image_captions": image_captions,
937
+ # "images_count": len(image_captions),
938
+ # "is_summarized": len(image_captions) > 1,
939
+ # "created_at": datetime.now(timezone.utc),
940
+ # }
941
+
942
+ # audio_file_id = insert_record(caption_collection, mongo_payload)
943
+
944
+ # response_data = {**mongo_payload, "audio_file_id": audio_file_id}
945
+ # response_data.pop("_id", None) # Remove ObjectId as it is not JSON serializable
946
+ # response_data["created_at"] = response_data["created_at"].isoformat()
947
+
948
+ # return ok("Caption generated successfully.", response_data)