| | from typing import Dict, Any |
| | from PIL import Image |
| | import torch |
| | import requests |
| | from io import BytesIO |
| | from transformers import BlipForConditionalGeneration, BlipProcessor |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | |
| | self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
| | self.model = BlipForConditionalGeneration.from_pretrained( |
| | "Salesforce/blip-image-captioning-base" |
| | ).to(device) |
| | self.model.eval() |
| |
|
| | def __call__(self, data: Any) -> Dict[str, Any]: |
| | """ |
| | Args: |
| | data (:obj:`dict`): |
| | Includes the input data and the parameters for the inference. |
| | Return: |
| | A :obj:`dict`. The object returned contains: |
| | - "caption": A string corresponding to the generated caption. |
| | """ |
| | |
| | image_url = data.get("image") |
| | parameters = data.get("parameters", {}) |
| |
|
| | if not image_url: |
| | return {"error": "Missing 'image' field in request body."} |
| |
|
| | try: |
| | |
| | response = requests.get(image_url) |
| | response.raise_for_status() |
| | raw_image = Image.open(BytesIO(response.content)).convert("RGB") |
| | except Exception as e: |
| | return {"error": f"Failed to fetch image from URL: {str(e)}"} |
| |
|
| | |
| | processed_image = self.processor(images=raw_image, return_tensors="pt") |
| | processed_image["pixel_values"] = processed_image["pixel_values"].to(device) |
| |
|
| | |
| | processed_image = {**processed_image, **parameters} |
| |
|
| | with torch.no_grad(): |
| | out = self.model.generate(**processed_image) |
| |
|
| | |
| | caption = self.processor.decode(out[0], skip_special_tokens=True) |
| | |
| | return {"caption": caption} |
| |
|