m7mdal7aj commited on
Commit
b9d4498
1 Parent(s): e72ec95

Update my_model/captioner/image_captioning.py

Browse files
my_model/captioner/image_captioning.py CHANGED
@@ -3,6 +3,7 @@ import io
3
  import torch
4
  import PIL
5
  from PIL import Image
 
6
  from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
7
  import bitsandbytes
8
  import accelerate
@@ -11,7 +12,31 @@ from my_model.utilities.gen_utilities import free_gpu_resources
11
 
12
 
13
  class ImageCaptioningModel:
14
- def __init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  self.model_type = config.MODEL_TYPE
16
  self.processor = None
17
  self.model = None
@@ -29,9 +54,12 @@ class ImageCaptioningModel:
29
 
30
 
31
 
32
- def load_model(self):
33
-
34
- if self.load_in_4bit and self.load_in_8bit: # check if in case both set to True by mistake.
 
 
 
35
  self.load_in_4bit = False
36
 
37
  if self.model_type == 'i_blip':
@@ -53,7 +81,18 @@ class ImageCaptioningModel:
53
  free_gpu_resources()
54
 
55
 
56
- def resize_image(self, image, max_image_size=None):
 
 
 
 
 
 
 
 
 
 
 
57
  if max_image_size is None:
58
  max_image_size = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
59
  h, w = image.size
@@ -67,7 +106,17 @@ class ImageCaptioningModel:
67
  return image
68
 
69
 
70
- def generate_caption(self, image_path):
 
 
 
 
 
 
 
 
 
 
71
  free_gpu_resources()
72
  free_gpu_resources()
73
  if isinstance(image_path, str) or isinstance(image_path, io.IOBase):
@@ -85,12 +134,30 @@ class ImageCaptioningModel:
85
  free_gpu_resources()
86
  return caption
87
 
88
- def generate_captions_for_multiple_images(self, image_paths):
 
 
 
 
 
 
 
 
 
89
 
90
  return [self.generate_caption(image_path) for image_path in image_paths]
91
 
92
 
93
- def get_caption(img):
 
 
 
 
 
 
 
 
 
94
  captioner = ImageCaptioningModel()
95
  free_gpu_resources()
96
  captioner.load_model()
 
3
  import torch
4
  import PIL
5
  from PIL import Image
6
+ from typing import Optional, Union, List
7
  from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
8
  import bitsandbytes
9
  import accelerate
 
12
 
13
 
14
  class ImageCaptioningModel:
15
+ """
16
+ A class to handle image captioning using InstructBlip model.
17
+
18
+ Attributes:
19
+ model_type (str): Type of the model to use.
20
+ processor (InstructBlipProcessor or None): The processor for handling image input.
21
+ model (InstructBlipForConditionalGeneration or None): The loaded model.
22
+ prompt (str): Prompt for the model.
23
+ max_image_size (int): Maximum size for the input image.
24
+ min_length (int): Minimum length of the generated caption.
25
+ max_new_tokens (int): Maximum number of new tokens to generate.
26
+ model_path (str): Path to the pre-trained model.
27
+ device_map (str): Device map for model loading.
28
+ torch_dtype (torch.dtype): Data type for torch tensors.
29
+ load_in_8bit (bool): Whether to load the model in 8-bit precision.
30
+ load_in_4bit (bool): Whether to load the model in 4-bit precision.
31
+ low_cpu_mem_usage (bool): Whether to optimize for low CPU memory usage.
32
+ skip_special_tokens (bool): Whether to skip special tokens in the generated captions.
33
+ """
34
+
35
+ def __init__(self) -> None:
36
+ """
37
+ Initializes the ImageCaptioningModel class with configuration settings.
38
+ """
39
+
40
  self.model_type = config.MODEL_TYPE
41
  self.processor = None
42
  self.model = None
 
54
 
55
 
56
 
57
+ def load_model(self) -> None:
58
+ """
59
+ Loads the InstructBlip model and processor based on the specified configuration.
60
+ """
61
+
62
+ if self.load_in_4bit and self.load_in_8bit: # Ensure only one of 4-bit or 8-bit precision is used.
63
  self.load_in_4bit = False
64
 
65
  if self.model_type == 'i_blip':
 
81
  free_gpu_resources()
82
 
83
 
84
+ def resize_image(self, image: Image.Image, max_image_size: Optional[int] = None) -> Image.Image:
85
+ """
86
+ Resizes the image to fit within the specified maximum size while maintaining aspect ratio.
87
+
88
+ Args:
89
+ image (Image.Image): The input image to resize.
90
+ max_image_size (Optional[int]): The maximum size for the resized image. Defaults to None.
91
+
92
+ Returns:
93
+ Image.Image: The resized image.
94
+ """
95
+
96
  if max_image_size is None:
97
  max_image_size = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
98
  h, w = image.size
 
106
  return image
107
 
108
 
109
+ def generate_caption(self, image_path: Union[str, io.IOBase, Image.Image]) -> str:
110
+ """
111
+ Generates a caption for the given image.
112
+
113
+ Args:
114
+ image_path (Union[str, io.IOBase, Image.Image]): The path to the image, file-like object, or PIL Image.
115
+
116
+ Returns:
117
+ str: The generated caption for the image.
118
+ """
119
+
120
  free_gpu_resources()
121
  free_gpu_resources()
122
  if isinstance(image_path, str) or isinstance(image_path, io.IOBase):
 
134
  free_gpu_resources()
135
  return caption
136
 
137
+ def generate_captions_for_multiple_images(self, image_paths: List[Union[str, io.IOBase, Image.Image]]) -> List[str]:
138
+ """
139
+ Generates captions for multiple images.
140
+
141
+ Args:
142
+ image_paths (List[Union[str, io.IOBase, Image.Image]]): A list of paths to images, file-like objects, or PIL Images.
143
+
144
+ Returns:
145
+ List[str]: A list of captions for the provided images.
146
+ """
147
 
148
  return [self.generate_caption(image_path) for image_path in image_paths]
149
 
150
 
151
+ def get_caption(img: Union[str, io.IOBase, Image.Image]) -> str:
152
+ """
153
+ Loads the captioning model and generates a caption for a single image.
154
+
155
+ Args:
156
+ img (Union[str, io.IOBase, Image.Image]): The path to the image, file-like object, or PIL Image.
157
+
158
+ Returns:
159
+ str: The generated caption for the image.
160
+ """
161
  captioner = ImageCaptioningModel()
162
  free_gpu_resources()
163
  captioner.load_model()