File size: 19,010 Bytes
fc498e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c59fc6b
 
0b430a0
c59fc6b
61e10b7
9f7ab84
d26dd8d
24fe4cb
61e10b7
f958c4b
c59fc6b
64de28a
61e10b7
fc498e0
 
61e10b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d72aea6
61e10b7
 
 
 
 
 
 
 
 
 
 
 
 
c59fc6b
fc498e0
 
 
 
c0e58be
e59b1eb
 
 
 
502ab4a
fc498e0
502ab4a
 
 
 
 
 
 
 
 
 
 
 
d72aea6
c59fc6b
61e10b7
a97003a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c59fc6b
a97003a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c59fc6b
a97003a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c59fc6b
4c70c9c
61e10b7
 
 
fc498e0
 
61e10b7
 
fc498e0
61e10b7
 
 
 
fc498e0
61347b1
33ddd8a
 
1cdf777
fc498e0
33ddd8a
 
1cdf777
33ddd8a
e21af99
1cdf777
e21af99
fc498e0
 
 
5b25ca3
 
 
 
 
fc498e0
61347b1
 
 
 
 
 
 
 
 
 
7ea3839
61347b1
 
e21af99
61347b1
 
 
fc498e0
61347b1
 
 
 
 
c59fc6b
fc498e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
#  Main script for KBVQA: Knowledge-Based Visual Question Answering Module

#  This module is the central component for implementing the designed model architecture for the Knowledge-Based Visual
#  Question Answering (KB-VQA) project. It integrates various sub-modules, including image captioning, object detection,
#  and a fine-tuned language model, to provide a comprehensive solution for answering questions based on visual input.

#  --- Description ---
#  **KBVQA class**:
#  The KBVQA class encapsulates the functionality needed to perform visual question answering using a combination of
#  multimodal models.
#  The class handles the following tasks:
#   - Loading and managing a fine-tuned language model (LLaMA-2) for question answering.
#   - Integrating an image captioning model to generate descriptive captions for input images.
#   - Utilizing an object detection model to identify and describe objects within the images.
#   - Formatting and generating prompts for the language model based on the image captions and detected objects.
#   - Providing methods to analyze images and generate answers to user-provided questions.

#  **prepare_kbvqa_model function**:
#   - The prepare_kbvqa_model function orchestrates the loading and initialization of the KBVQA class, ensuring it is
#     ready for inference.

#  ---Instructions---
#   **Model Preparation**:
#   Use the prepare_kbvqa_model function to prepare and initialize the KBVQA system, ensuring all required models are
#   loaded and ready for use.

#   **Image Processing and Question Answering**:
#    Use the get_caption method to generate captions for input images.
#    Use the detect_objects method to identify and describe objects in the images.
#    Use the generate_answer method to answer questions based on the image captions and detected objects.

#  This module forms the backbone of the KB-VQA project, integrating advanced models to provide an end-to-end solution
#  for visual question answering tasks.
#  Ensure all dependencies are installed and the required configuration file is in place before running this script.
#  The configurations for the KBVQA class are defined in the 'my_model/config/kbvqa_config.py' file.

#  ---------- Please run this module to utilize the full KB-VQA functionality ----------#
#  ---------- Please ensure this is run on a GPU ----------#


import streamlit as st
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from typing import Tuple, Optional
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.captioner.image_captioning import ImageCaptioningModel
from my_model.detector.object_detection import ObjectDetector
import my_model.config.kbvqa_config as config


class KBVQA:
    """
    The KBVQA class encapsulates the functionality for the Knowledge-Based Visual Question Answering (KBVQA) model.
    It integrates various components such as an image captioning model, object detection model, and a fine-tuned
    language model (LLAMA2) on OK-VQA dataset for generating answers to visual questions.

    Attributes:
        kbvqa_model_name (str): Name of the fine-tuned language model used for KBVQA.
        quantization (str): The quantization setting for the model (e.g., '4bit', '8bit').
        max_context_window (int): The maximum number of tokens allowed in the model's context window.
        add_eos_token (bool): Flag to indicate whether to add an end-of-sentence token to the tokenizer.
        trust_remote (bool): Flag to indicate whether to trust remote code when using the tokenizer.
        use_fast (bool): Flag to indicate whether to use the fast version of the tokenizer.
        low_cpu_mem_usage (bool): Flag to optimize model loading for low CPU memory usage.
        kbvqa_tokenizer (Optional[AutoTokenizer]): The tokenizer for the KBVQA model.
        captioner (Optional[ImageCaptioningModel]): The model used for generating image captions.
        detector (Optional[ObjectDetector]): The object detection model.
        detection_model (Optional[str]): The name of the object detection model.
        detection_confidence (Optional[float]): The confidence threshold for object detection.
        kbvqa_model (Optional[AutoModelForCausalLM]): The fine-tuned language model for KBVQA.
        bnb_config (BitsAndBytesConfig): Configuration for BitsAndBytes optimized model.
        access_token (str): Access token for Hugging Face API.
        current_prompt_length (int): Prompt length.

    Methods:
        create_bnb_config: Creates a BitsAndBytes configuration based on the quantization setting.
        load_caption_model: Loads the image captioning model.
        get_caption: Generates a caption for a given image.
        load_detector: Loads the object detection model.
        detect_objects: Detects objects in a given image.
        load_fine_tuned_model: Loads the fine-tuned KBVQA model along with its tokenizer.
        all_models_loaded: Checks if all the required models are loaded.
        force_reload_model: Forces a reload of all models, freeing up GPU resources.
        format_prompt: Formats the prompt for the KBVQA model.
        generate_answer: Generates an answer to a given question using the KBVQA model.
    """

    def __init__(self) -> None:
        """
        Initializes the KBVQA instance with configuration parameters.
        """

        if st.session_state["method"] == "7b-Fine-Tuned Model":
            self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME_7b
        elif st.session_state["method"] == "13b-Fine-Tuned Model":
            self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME_13b
        self.quantization: str = config.QUANTIZATION
        self.max_context_window: int = config.MAX_CONTEXT_WINDOW  # set to 4,000 tokens
        self.add_eos_token: bool = config.ADD_EOS_TOKEN
        self.trust_remote: bool = config.TRUST_REMOTE
        self.use_fast: bool = config.USE_FAST
        self.low_cpu_mem_usage: bool = config.LOW_CPU_MEM_USAGE
        self.kbvqa_tokenizer: Optional[AutoTokenizer] = None
        self.captioner: Optional[ImageCaptioningModel] = None
        self.detector: Optional[ObjectDetector] = None
        self.detection_model: Optional[str] = None
        self.detection_confidence: Optional[float] = None
        self.kbvqa_model: Optional[AutoModelForCausalLM] = None
        self.bnb_config: BitsAndBytesConfig = self.create_bnb_config()
        self.access_token: str = config.HUGGINGFACE_TOKEN
        self.current_prompt_length = None


    def create_bnb_config(self) -> BitsAndBytesConfig:
        """
        Creates a BitsAndBytes configuration based on the quantization setting.
        Returns:
            BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
        """
    
        if self.quantization == '4bit':
            return BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )
        elif self.quantization == '8bit':
            return BitsAndBytesConfig(
                load_in_8bit=True,
                bnb_8bit_use_double_quant=True,
                bnb_8bit_quant_type="nf4",
                bnb_8bit_compute_dtype=torch.bfloat16
            )
    
    
    def load_caption_model(self) -> None:
        """
        Loads the image captioning model into the KBVQA instance.
    
        Returns:
            None
        """
    
        self.captioner = ImageCaptioningModel()
        self.captioner.load_model()
        free_gpu_resources()
    
    
    def get_caption(self, img: Image.Image) -> str:
        """
        Generates a caption for a given image using the image captioning model.
    
        Args:
            img (PIL.Image.Image): The image for which to generate a caption.
    
        Returns:
            str: The generated caption for the image.
        """
        caption = self.captioner.generate_caption(img)
        free_gpu_resources()
        return caption
    
    
    def load_detector(self, model: str) -> None:
        """
        Loads the object detection model.
    
        Args:
            model (str): The name of the object detection model to load.
    
        Returns:
            None
        """
    
        self.detector = ObjectDetector()
        self.detector.load_model(model)
        free_gpu_resources()
    
    
    def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]:
        """
        Detects objects in a given image using the loaded object detection model.
    
        Args:
            img (PIL.Image.Image): The image in which to detect objects.
    
        Returns:
            tuple: A tuple containing the image with detected objects drawn and a string representation of detected objects.
        """
    
        image = self.detector.process_image(img)
        free_gpu_resources()
        detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=st.session_state[
            'confidence_level'])
        free_gpu_resources()
        image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
        free_gpu_resources()
        return image_with_boxes, detected_objects_string
    
    
    def load_fine_tuned_model(self) -> None:
        """
        Loads the fine-tuned KBVQA model along with its tokenizer.
    
        Returns:
            None
        """
    
        self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name,
                                                                device_map="auto",
                                                                low_cpu_mem_usage=True,
                                                                quantization_config=self.bnb_config,
                                                                token=self.access_token)
    
        free_gpu_resources()
    
        self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name,
                                                             use_fast=self.use_fast,
                                                             low_cpu_mem_usage=True,
                                                             trust_remote_code=self.trust_remote,
                                                             add_eos_token=self.add_eos_token,
                                                             token=self.access_token)
        free_gpu_resources()
    
    
    @property
    def all_models_loaded(self) -> bool:
        """
        Checks if all the required models (KBVQA, captioner, detector) are loaded.
    
        Returns:
            bool: True if all models are loaded, False otherwise.
        """
    
        return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
    
    
    def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None,
                      caption: str = None, objects: Optional[str] = None) -> str:
        """
        Formats the prompt for the KBVQA model based on the provided parameters.
    
        This implements the Prompt Engineering Module of the Overall KB-VQA Archetecture.
    
        Args:
            current_query (str): The current question to be answered.
            history (str, optional): The history of previous interactions.
            sys_prompt (str, optional): The system prompt or instructions for the model.
            caption (str, optional): The caption of the image.
            objects (str, optional): The detected objects in the image.
    
        Returns:
            str: The formatted prompt for the KBVQA model.
        """
    
        # These are the special tokens designed for the model to be fine-tuned on.
        B_CAP = '[CAP]'
        E_CAP = '[/CAP]'
        B_QES = '[QES]'
        E_QES = '[/QES]'
        B_OBJ = '[OBJ]'
        E_OBJ = '[/OBJ]'
    
        # These are the default special tokens of LLaMA-2 Chat Model.
        B_SENT = '<s>'
        E_SENT = '</s>'
        B_INST = '[INST]'
        E_INST = '[/INST]'
        B_SYS = '<<SYS>>\n'
        E_SYS = '\n<</SYS>>\n\n'
    
        current_query = current_query.strip()
        if sys_prompt is None:
            sys_prompt = config.SYSTEM_PROMPT.strip()
    
        # History can be used to facilitate multi turn chat, not used for the Run Inference tool within the demo app.
        if history is None:
            if objects is None:
                p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_QES}{current_query}{E_QES}{E_INST}"""
            else:
                p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_OBJ}{objects}{E_OBJ}{B_QES}taking into consideration the objects with high certainty, {current_query}{E_QES}{E_INST}"""
        else:
            p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
    
        return p
    
    
    @staticmethod
    def trim_objects(detected_objects_str: str) -> str:
        """
        Trim the last object from the detected objects string.
        This is implemented to ensure that the prompt length is within the context window, threshold set to 4,000 tokens.
    
        Args:
            detected_objects_str (str): String containing detected objects.
    
        Returns:
            str: The string with the last object removed.
        """
    
        objects = detected_objects_str.strip().split("\n")
        if len(objects) >= 1:
            return "\n".join(objects[:-1])
        return ""
    
    
    def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
        """
        Generates an answer to a given question using the KBVQA model.
    
        Args:
            question (str): The question to be answered.
            caption (str): The caption of the image related to the question.
            detected_objects_str (str): The string representation of detected objects in the image.
    
        Returns:
            str: The generated answer to the question.
        """
    
        free_gpu_resources()
        prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
        num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
        self.current_prompt_length = num_tokens
        trim = False  # flag used to check if prompt trim is required or no.
        # max_context_window is set to 4,000 tokens, refer to the config file.
        if self.current_prompt_length > self.max_context_window:
            trim = True
            st.warning(
                f"Prompt length is {self.current_prompt_length} which is larger than the maximum context window of LLaMA-2,"
                f" objects detected with low confidence will be removed one at a time until the prompt length is within the"
                f" maximum context window ...")
        # an object is trimmed from the bottom of the list until the overall prompt length is within the context window.
        while self.current_prompt_length > self.max_context_window:
            detected_objects_str = self.trim_objects(detected_objects_str)
            prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
            self.current_prompt_length = len(self.kbvqa_tokenizer.tokenize(prompt))
    
            if detected_objects_str == "":
                break  # Break if no objects are left
        if trim:
            st.warning(f"New prompt length is: {self.current_prompt_length}")
            trim = False
    
        model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
        free_gpu_resources()
        input_ids = model_inputs["input_ids"]
        output_ids = self.kbvqa_model.generate(input_ids)
        free_gpu_resources()
        index = input_ids.shape[1]  # needed to avoid printing the input prompt
        history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
        output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)
    
        return output_text.capitalize()
    

def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload: bool = False) -> KBVQA:
    """
    Prepares the KBVQA model for use, including loading necessary sub-models.

    This serves as the main function for loading and reloading the KB-VQA model.

    Args:
        only_reload_detection_model (bool): If True, only the object detection model is reloaded.
        force_reload (bool): If True, forces the reload of all models.

    Returns:
        KBVQA: An instance of the KBVQA model ready for inference.
    """

    if force_reload:
        free_gpu_resources()
        loading_message = 'Reloading model.. this should take no more than 2 or 3 minutes!'
        try:
            del st.session_state['kbvqa']
            free_gpu_resources()
            free_gpu_resources()
        except:
            free_gpu_resources()
            free_gpu_resources()
            pass
        free_gpu_resources()

    else:
        loading_message = 'Looading model.. this should take no more than 2 or 3 minutes!'

    free_gpu_resources()
    kbvqa = KBVQA()
    kbvqa.detection_model = st.session_state.detection_model
    # Progress bar for model loading

    with st.spinner(loading_message):
        if not only_reload_detection_model:
            progress_bar = st.progress(0)
            kbvqa.load_detector(kbvqa.detection_model)
            progress_bar.progress(33)
            kbvqa.load_caption_model()
            free_gpu_resources()
            progress_bar.progress(75)
            st.text('Almost there :)')
            kbvqa.load_fine_tuned_model()
            free_gpu_resources()
            progress_bar.progress(100)
        else:
            free_gpu_resources()
            progress_bar = st.progress(0)
            kbvqa.load_detector(kbvqa.detection_model)
            progress_bar.progress(100)

    if kbvqa.all_models_loaded:
        st.success('Model loaded successfully and ready for inferecne!')
        kbvqa.kbvqa_model.eval()
        free_gpu_resources()
        return kbvqa


if __name__ == "__main__":
    pass

    #### Example on how to use the module ####

    # Prepare the KBVQA model
    # kbvqa = prepare_kbvqa_model()

    # Load an image
    # image = Image.open('path_to_image.jpg')

    # Generate a caption for the image
    # caption = kbvqa.get_caption(image)

    # Detect objects in the image
    # image_with_boxes, detected_objects_str = kbvqa.detect_objects(image)

    # Generate an answer to a question about the image
    # question = "What is the object in the image?"
    # answer = kbvqa.generate_answer(question, caption, detected_objects_str)

    # print(f"Answer: {answer}")