TopicGen / utils /caption_utils.py
Testys's picture
Possible same
03c3cc5
raw
history blame
No virus
757 Bytes
from PIL import Image
import io
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
device = "cuda" if torch.cuda.is_available() else "cpu"
class ImageCaptioning:
def __init__(self):
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
def get_caption(self, image_bytes):
img = Image.open(io.BytesIO(image_bytes))
img_tensors = self.processor(img, return_tensors="pt").to(device)
output = self.model.generate(**img_tensors)
caption = self.processor.batch_decode(output, skip_special_tokens=True)[0]
return caption