File size: 7,026 Bytes
75a53d9
7553f0c
75a53d9
 
 
b9d4498
75a53d9
 
 
b434799
ab38e0e
5f4a46b
75a53d9
 
b9d4498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75a53d9
 
 
 
 
 
 
 
 
 
 
178416a
75a53d9
 
 
 
 
b9d4498
 
 
 
 
 
398c0e8
 
75a53d9
 
 
1089b06
75a53d9
 
 
609d6f1
75a53d9
 
f711846
75a53d9
 
 
 
 
609d6f1
75a53d9
609d6f1
b9d4498
 
 
 
 
 
 
 
 
 
 
 
75a53d9
 
 
 
 
 
 
 
 
 
 
 
 
b9d4498
 
 
 
 
 
 
 
 
 
 
609d6f1
 
aefece3
15d3f2d
 
 
54d3921
 
 
75a53d9
 
 
 
609d6f1
 
75a53d9
 
b9d4498
 
 
 
 
 
 
 
 
 
75a53d9
 
8f97cdd
 
b9d4498
 
 
 
 
 
 
 
 
 
8f97cdd
609d6f1
8f97cdd
609d6f1
8f97cdd
609d6f1
8f97cdd
5f4a46b
609d6f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import io
import torch
import PIL
from PIL import Image
from typing import Optional, Union, List
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import bitsandbytes
import accelerate
from my_model.config import captioning_config as config
from my_model.utilities.gen_utilities import free_gpu_resources
    

class ImageCaptioningModel:
    """
    A class to handle image captioning using InstructBlip model.

    Attributes:
        model_type (str): Type of the model to use.
        processor (InstructBlipProcessor or None): The processor for handling image input.
        model (InstructBlipForConditionalGeneration or None): The loaded model.
        prompt (str): Prompt for the model.
        max_image_size (int): Maximum size for the input image.
        min_length (int): Minimum length of the generated caption.
        max_new_tokens (int): Maximum number of new tokens to generate.
        model_path (str): Path to the pre-trained model.
        device_map (str): Device map for model loading.
        torch_dtype (torch.dtype): Data type for torch tensors.
        load_in_8bit (bool): Whether to load the model in 8-bit precision.
        load_in_4bit (bool): Whether to load the model in 4-bit precision.
        low_cpu_mem_usage (bool): Whether to optimize for low CPU memory usage.
        skip_special_tokens (bool): Whether to skip special tokens in the generated captions.
    """
    
    def __init__(self) -> None:
        """
        Initializes the ImageCaptioningModel class with configuration settings.
        """
        
        self.model_type = config.MODEL_TYPE
        self.processor = None
        self.model = None
        self.prompt = config.PROMPT
        self.max_image_size = config.MAX_IMAGE_SIZE
        self.min_length = config.MIN_LENGTH
        self.max_new_tokens = config.MAX_NEW_TOKENS
        self.model_path = config.MODEL_PATH
        self.device_map = config.DEVICE_MAP
        self.torch_dtype = config.TORCH_DTYPE
        self.load_in_8bit = config.LOAD_IN_8BIT
        self.load_in_4bit = config.LOAD_IN_4BIT
        self.low_cpu_mem_usage = config.LOW_CPU_MEM_USAGE
        self.skip_secial_tokens = config.SKIP_SPECIAL_TOKENS



    def load_model(self) -> None:
        """
        Loads the InstructBlip model and processor based on the specified configuration.
        """
        
        if self.load_in_4bit and self.load_in_8bit:  # Ensure only one of 4-bit or 8-bit precision is used.
            self.load_in_4bit = False
            
        if self.model_type == 'i_blip':
            self.processor = InstructBlipProcessor.from_pretrained(self.model_path,
                                                                   load_in_8bit=self.load_in_8bit,
                                                                   load_in_4bit=self.load_in_4bit,
                                                                   torch_dtype=self.torch_dtype,
                                                                   device_map=self.device_map
                                                                   )
            free_gpu_resources()
            self.model = InstructBlipForConditionalGeneration.from_pretrained(self.model_path,
                                                                              load_in_8bit=self.load_in_8bit,
                                                                              load_in_4bit=self.load_in_4bit,
                                                                              torch_dtype=self.torch_dtype,
                                                                              low_cpu_mem_usage=self.low_cpu_mem_usage,
                                                                              device_map=self.device_map
                                                                              )

            free_gpu_resources()

            
    def resize_image(self, image: Image.Image, max_image_size: Optional[int] = None) -> Image.Image:
        """
        Resizes the image to fit within the specified maximum size while maintaining aspect ratio.

        Args:
            image (Image.Image): The input image to resize.
            max_image_size (Optional[int]): The maximum size for the resized image. Defaults to None.

        Returns:
            Image.Image: The resized image.
        """
        
        if max_image_size is None:
            max_image_size = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
        h, w = image.size
        scale = max_image_size / max(h, w)

        if scale < 1:
            new_w = int(w * scale)
            new_h = int(h * scale)
            image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS)

        return image


    def generate_caption(self, image_path: Union[str, io.IOBase, Image.Image]) -> str:
        """
        Generates a caption for the given image.

        Args:
            image_path (Union[str, io.IOBase, Image.Image]): The path to the image, file-like object, or PIL Image.

        Returns:
            str: The generated caption for the image.
        """
        
        free_gpu_resources()
        free_gpu_resources()
        if isinstance(image_path, str) or isinstance(image_path, io.IOBase):
        # If it's a file path or file-like object, open it as a PIL Image
            image = Image.open(image_path)
            
        elif isinstance(image_path, Image.Image):
            image = image_path
            
        image = self.resize_image(image)
        inputs = self.processor(image, self.prompt, return_tensors="pt").to("cuda", self.torch_dtype)
        outputs = self.model.generate(**inputs, min_length=self.min_length, max_new_tokens=self.max_new_tokens)
        caption = self.processor.decode(outputs[0], skip_special_tokens=self.skip_secial_tokens).strip()
        free_gpu_resources()
        free_gpu_resources()
        return caption

    def generate_captions_for_multiple_images(self, image_paths: List[Union[str, io.IOBase, Image.Image]]) -> List[str]:
        """
        Generates captions for multiple images.

        Args:
            image_paths (List[Union[str, io.IOBase, Image.Image]]): A list of paths to images, file-like objects, or PIL Images.

        Returns:
            List[str]: A list of captions for the provided images.
        """

        return [self.generate_caption(image_path) for image_path in image_paths]
        

def get_caption(img: Union[str, io.IOBase, Image.Image]) -> str:
    """
    Loads the captioning model and generates a caption for a single image.

    Args:
        img (Union[str, io.IOBase, Image.Image]): The path to the image, file-like object, or PIL Image.

    Returns:
        str: The generated caption for the image.
    """
    captioner = ImageCaptioningModel()
    free_gpu_resources()
    captioner.load_model()
    free_gpu_resources()
    caption = captioner.generate_caption(img)
    free_gpu_resources()


    return caption