KB-VQA-E / my_model /captioner /image_captioning.py
m7mdal7aj's picture
adding captioning folder and files
75a53d9 verified
raw
history blame
3.08 kB
import os
import torch
import PIL
from PIL import Image
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import bitsandbytes
import accelerate
import captioning_config as config
class ImageCaptioningModel:
def __init__(self):
self.model_type = config.MODEL_TYPE
self.processor = None
self.model = None
self.prompt = config.PROMPT
self.max_image_size = config.MAX_IMAGE_SIZE
self.min_length = config.MIN_LENGTH
self.max_new_tokens = config.MAX_NEW_TOKENS
self.model_path = config.MODEL_PATH
self.device_map = config.DEVICE_MAP
self.torch_dtype = config.TORCH_DTYPE
self.load_in_8bit = config.LOAD_IN_8BIT
self.low_cpu_mem_usage = config.LOW_CPU_MEM_USAGE
self.skip_secial_tokens = config.SKIP_SPECIAL_TOKENS
def load_model(self):
if self.model_type == 'i_blip':
self.processor = InstructBlipProcessor.from_pretrained(self.model_path,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
device_map=self.device_map
)
self.model = InstructBlipForConditionalGeneration.from_pretrained(self.model_path,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=self.low_cpu_mem_usage,
device_map=self.device_map
)
def resize_image(self, image, max_image_size=None):
if max_image_size is None:
max_image_size = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
h, w = image.size
scale = max_image_size / max(h, w)
if scale < 1:
new_w = int(w * scale)
new_h = int(h * scale)
image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS)
return image
def generate_caption(self, image_path):
image = Image.open(image_path)
image = self.resize_image(image)
inputs = self.processor(image, self.prompt, return_tensors="pt").to("cuda", self.torch_dtype)
outputs = self.model.generate(**inputs, min_length=self.min_length, max_new_tokens=self.max_new_tokens)
caption = self.processor.decode(outputs[0], skip_special_tokens=self.skip_secial_tokens).strip()
return caption
def generate_captions_for_multiple_images(self, image_paths):
return [self.generate_caption(image_path) for image_path in image_paths]
if __name__ == "__main__":
pass