File size: 842 Bytes
1c1f2e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from PIL import Image
import torch

def generate_caption(image_path, trigger_word):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load BLIP-2 (smaller model for HF Spaces)
    processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    model = Blip2ForConditionalGeneration.from_pretrained(
        "Salesforce/blip2-opt-2.7b", 
        torch_dtype=torch.float16
    ).to(device)
    
    # Generate caption
    image = Image.open(image_path)
    inputs = processor(image, return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(**inputs, max_new_tokens=50)
    caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
    
    return f"a photo of [{trigger_word}], {caption}"