File size: 19,010 Bytes
fc498e0 c59fc6b 0b430a0 c59fc6b 61e10b7 9f7ab84 d26dd8d 24fe4cb 61e10b7 f958c4b c59fc6b 64de28a 61e10b7 fc498e0 61e10b7 d72aea6 61e10b7 c59fc6b fc498e0 c0e58be e59b1eb 502ab4a fc498e0 502ab4a d72aea6 c59fc6b 61e10b7 a97003a c59fc6b a97003a c59fc6b a97003a c59fc6b 4c70c9c 61e10b7 fc498e0 61e10b7 fc498e0 61e10b7 fc498e0 61347b1 33ddd8a 1cdf777 fc498e0 33ddd8a 1cdf777 33ddd8a e21af99 1cdf777 e21af99 fc498e0 5b25ca3 fc498e0 61347b1 7ea3839 61347b1 e21af99 61347b1 fc498e0 61347b1 c59fc6b fc498e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 |
# Main script for KBVQA: Knowledge-Based Visual Question Answering Module
# This module is the central component for implementing the designed model architecture for the Knowledge-Based Visual
# Question Answering (KB-VQA) project. It integrates various sub-modules, including image captioning, object detection,
# and a fine-tuned language model, to provide a comprehensive solution for answering questions based on visual input.
# --- Description ---
# **KBVQA class**:
# The KBVQA class encapsulates the functionality needed to perform visual question answering using a combination of
# multimodal models.
# The class handles the following tasks:
# - Loading and managing a fine-tuned language model (LLaMA-2) for question answering.
# - Integrating an image captioning model to generate descriptive captions for input images.
# - Utilizing an object detection model to identify and describe objects within the images.
# - Formatting and generating prompts for the language model based on the image captions and detected objects.
# - Providing methods to analyze images and generate answers to user-provided questions.
# **prepare_kbvqa_model function**:
# - The prepare_kbvqa_model function orchestrates the loading and initialization of the KBVQA class, ensuring it is
# ready for inference.
# ---Instructions---
# **Model Preparation**:
# Use the prepare_kbvqa_model function to prepare and initialize the KBVQA system, ensuring all required models are
# loaded and ready for use.
# **Image Processing and Question Answering**:
# Use the get_caption method to generate captions for input images.
# Use the detect_objects method to identify and describe objects in the images.
# Use the generate_answer method to answer questions based on the image captions and detected objects.
# This module forms the backbone of the KB-VQA project, integrating advanced models to provide an end-to-end solution
# for visual question answering tasks.
# Ensure all dependencies are installed and the required configuration file is in place before running this script.
# The configurations for the KBVQA class are defined in the 'my_model/config/kbvqa_config.py' file.
# ---------- Please run this module to utilize the full KB-VQA functionality ----------#
# ---------- Please ensure this is run on a GPU ----------#
import streamlit as st
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from typing import Tuple, Optional
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.captioner.image_captioning import ImageCaptioningModel
from my_model.detector.object_detection import ObjectDetector
import my_model.config.kbvqa_config as config
class KBVQA:
"""
The KBVQA class encapsulates the functionality for the Knowledge-Based Visual Question Answering (KBVQA) model.
It integrates various components such as an image captioning model, object detection model, and a fine-tuned
language model (LLAMA2) on OK-VQA dataset for generating answers to visual questions.
Attributes:
kbvqa_model_name (str): Name of the fine-tuned language model used for KBVQA.
quantization (str): The quantization setting for the model (e.g., '4bit', '8bit').
max_context_window (int): The maximum number of tokens allowed in the model's context window.
add_eos_token (bool): Flag to indicate whether to add an end-of-sentence token to the tokenizer.
trust_remote (bool): Flag to indicate whether to trust remote code when using the tokenizer.
use_fast (bool): Flag to indicate whether to use the fast version of the tokenizer.
low_cpu_mem_usage (bool): Flag to optimize model loading for low CPU memory usage.
kbvqa_tokenizer (Optional[AutoTokenizer]): The tokenizer for the KBVQA model.
captioner (Optional[ImageCaptioningModel]): The model used for generating image captions.
detector (Optional[ObjectDetector]): The object detection model.
detection_model (Optional[str]): The name of the object detection model.
detection_confidence (Optional[float]): The confidence threshold for object detection.
kbvqa_model (Optional[AutoModelForCausalLM]): The fine-tuned language model for KBVQA.
bnb_config (BitsAndBytesConfig): Configuration for BitsAndBytes optimized model.
access_token (str): Access token for Hugging Face API.
current_prompt_length (int): Prompt length.
Methods:
create_bnb_config: Creates a BitsAndBytes configuration based on the quantization setting.
load_caption_model: Loads the image captioning model.
get_caption: Generates a caption for a given image.
load_detector: Loads the object detection model.
detect_objects: Detects objects in a given image.
load_fine_tuned_model: Loads the fine-tuned KBVQA model along with its tokenizer.
all_models_loaded: Checks if all the required models are loaded.
force_reload_model: Forces a reload of all models, freeing up GPU resources.
format_prompt: Formats the prompt for the KBVQA model.
generate_answer: Generates an answer to a given question using the KBVQA model.
"""
def __init__(self) -> None:
"""
Initializes the KBVQA instance with configuration parameters.
"""
if st.session_state["method"] == "7b-Fine-Tuned Model":
self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME_7b
elif st.session_state["method"] == "13b-Fine-Tuned Model":
self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME_13b
self.quantization: str = config.QUANTIZATION
self.max_context_window: int = config.MAX_CONTEXT_WINDOW # set to 4,000 tokens
self.add_eos_token: bool = config.ADD_EOS_TOKEN
self.trust_remote: bool = config.TRUST_REMOTE
self.use_fast: bool = config.USE_FAST
self.low_cpu_mem_usage: bool = config.LOW_CPU_MEM_USAGE
self.kbvqa_tokenizer: Optional[AutoTokenizer] = None
self.captioner: Optional[ImageCaptioningModel] = None
self.detector: Optional[ObjectDetector] = None
self.detection_model: Optional[str] = None
self.detection_confidence: Optional[float] = None
self.kbvqa_model: Optional[AutoModelForCausalLM] = None
self.bnb_config: BitsAndBytesConfig = self.create_bnb_config()
self.access_token: str = config.HUGGINGFACE_TOKEN
self.current_prompt_length = None
def create_bnb_config(self) -> BitsAndBytesConfig:
"""
Creates a BitsAndBytes configuration based on the quantization setting.
Returns:
BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
"""
if self.quantization == '4bit':
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
elif self.quantization == '8bit':
return BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_quant_type="nf4",
bnb_8bit_compute_dtype=torch.bfloat16
)
def load_caption_model(self) -> None:
"""
Loads the image captioning model into the KBVQA instance.
Returns:
None
"""
self.captioner = ImageCaptioningModel()
self.captioner.load_model()
free_gpu_resources()
def get_caption(self, img: Image.Image) -> str:
"""
Generates a caption for a given image using the image captioning model.
Args:
img (PIL.Image.Image): The image for which to generate a caption.
Returns:
str: The generated caption for the image.
"""
caption = self.captioner.generate_caption(img)
free_gpu_resources()
return caption
def load_detector(self, model: str) -> None:
"""
Loads the object detection model.
Args:
model (str): The name of the object detection model to load.
Returns:
None
"""
self.detector = ObjectDetector()
self.detector.load_model(model)
free_gpu_resources()
def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]:
"""
Detects objects in a given image using the loaded object detection model.
Args:
img (PIL.Image.Image): The image in which to detect objects.
Returns:
tuple: A tuple containing the image with detected objects drawn and a string representation of detected objects.
"""
image = self.detector.process_image(img)
free_gpu_resources()
detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=st.session_state[
'confidence_level'])
free_gpu_resources()
image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
free_gpu_resources()
return image_with_boxes, detected_objects_string
def load_fine_tuned_model(self) -> None:
"""
Loads the fine-tuned KBVQA model along with its tokenizer.
Returns:
None
"""
self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name,
device_map="auto",
low_cpu_mem_usage=True,
quantization_config=self.bnb_config,
token=self.access_token)
free_gpu_resources()
self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name,
use_fast=self.use_fast,
low_cpu_mem_usage=True,
trust_remote_code=self.trust_remote,
add_eos_token=self.add_eos_token,
token=self.access_token)
free_gpu_resources()
@property
def all_models_loaded(self) -> bool:
"""
Checks if all the required models (KBVQA, captioner, detector) are loaded.
Returns:
bool: True if all models are loaded, False otherwise.
"""
return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None,
caption: str = None, objects: Optional[str] = None) -> str:
"""
Formats the prompt for the KBVQA model based on the provided parameters.
This implements the Prompt Engineering Module of the Overall KB-VQA Archetecture.
Args:
current_query (str): The current question to be answered.
history (str, optional): The history of previous interactions.
sys_prompt (str, optional): The system prompt or instructions for the model.
caption (str, optional): The caption of the image.
objects (str, optional): The detected objects in the image.
Returns:
str: The formatted prompt for the KBVQA model.
"""
# These are the special tokens designed for the model to be fine-tuned on.
B_CAP = '[CAP]'
E_CAP = '[/CAP]'
B_QES = '[QES]'
E_QES = '[/QES]'
B_OBJ = '[OBJ]'
E_OBJ = '[/OBJ]'
# These are the default special tokens of LLaMA-2 Chat Model.
B_SENT = '<s>'
E_SENT = '</s>'
B_INST = '[INST]'
E_INST = '[/INST]'
B_SYS = '<<SYS>>\n'
E_SYS = '\n<</SYS>>\n\n'
current_query = current_query.strip()
if sys_prompt is None:
sys_prompt = config.SYSTEM_PROMPT.strip()
# History can be used to facilitate multi turn chat, not used for the Run Inference tool within the demo app.
if history is None:
if objects is None:
p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_QES}{current_query}{E_QES}{E_INST}"""
else:
p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_OBJ}{objects}{E_OBJ}{B_QES}taking into consideration the objects with high certainty, {current_query}{E_QES}{E_INST}"""
else:
p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
return p
@staticmethod
def trim_objects(detected_objects_str: str) -> str:
"""
Trim the last object from the detected objects string.
This is implemented to ensure that the prompt length is within the context window, threshold set to 4,000 tokens.
Args:
detected_objects_str (str): String containing detected objects.
Returns:
str: The string with the last object removed.
"""
objects = detected_objects_str.strip().split("\n")
if len(objects) >= 1:
return "\n".join(objects[:-1])
return ""
def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
"""
Generates an answer to a given question using the KBVQA model.
Args:
question (str): The question to be answered.
caption (str): The caption of the image related to the question.
detected_objects_str (str): The string representation of detected objects in the image.
Returns:
str: The generated answer to the question.
"""
free_gpu_resources()
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
self.current_prompt_length = num_tokens
trim = False # flag used to check if prompt trim is required or no.
# max_context_window is set to 4,000 tokens, refer to the config file.
if self.current_prompt_length > self.max_context_window:
trim = True
st.warning(
f"Prompt length is {self.current_prompt_length} which is larger than the maximum context window of LLaMA-2,"
f" objects detected with low confidence will be removed one at a time until the prompt length is within the"
f" maximum context window ...")
# an object is trimmed from the bottom of the list until the overall prompt length is within the context window.
while self.current_prompt_length > self.max_context_window:
detected_objects_str = self.trim_objects(detected_objects_str)
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
self.current_prompt_length = len(self.kbvqa_tokenizer.tokenize(prompt))
if detected_objects_str == "":
break # Break if no objects are left
if trim:
st.warning(f"New prompt length is: {self.current_prompt_length}")
trim = False
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
free_gpu_resources()
input_ids = model_inputs["input_ids"]
output_ids = self.kbvqa_model.generate(input_ids)
free_gpu_resources()
index = input_ids.shape[1] # needed to avoid printing the input prompt
history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)
return output_text.capitalize()
def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload: bool = False) -> KBVQA:
"""
Prepares the KBVQA model for use, including loading necessary sub-models.
This serves as the main function for loading and reloading the KB-VQA model.
Args:
only_reload_detection_model (bool): If True, only the object detection model is reloaded.
force_reload (bool): If True, forces the reload of all models.
Returns:
KBVQA: An instance of the KBVQA model ready for inference.
"""
if force_reload:
free_gpu_resources()
loading_message = 'Reloading model.. this should take no more than 2 or 3 minutes!'
try:
del st.session_state['kbvqa']
free_gpu_resources()
free_gpu_resources()
except:
free_gpu_resources()
free_gpu_resources()
pass
free_gpu_resources()
else:
loading_message = 'Looading model.. this should take no more than 2 or 3 minutes!'
free_gpu_resources()
kbvqa = KBVQA()
kbvqa.detection_model = st.session_state.detection_model
# Progress bar for model loading
with st.spinner(loading_message):
if not only_reload_detection_model:
progress_bar = st.progress(0)
kbvqa.load_detector(kbvqa.detection_model)
progress_bar.progress(33)
kbvqa.load_caption_model()
free_gpu_resources()
progress_bar.progress(75)
st.text('Almost there :)')
kbvqa.load_fine_tuned_model()
free_gpu_resources()
progress_bar.progress(100)
else:
free_gpu_resources()
progress_bar = st.progress(0)
kbvqa.load_detector(kbvqa.detection_model)
progress_bar.progress(100)
if kbvqa.all_models_loaded:
st.success('Model loaded successfully and ready for inferecne!')
kbvqa.kbvqa_model.eval()
free_gpu_resources()
return kbvqa
if __name__ == "__main__":
pass
#### Example on how to use the module ####
# Prepare the KBVQA model
# kbvqa = prepare_kbvqa_model()
# Load an image
# image = Image.open('path_to_image.jpg')
# Generate a caption for the image
# caption = kbvqa.get_caption(image)
# Detect objects in the image
# image_with_boxes, detected_objects_str = kbvqa.detect_objects(image)
# Generate an answer to a question about the image
# question = "What is the object in the image?"
# answer = kbvqa.generate_answer(question, caption, detected_objects_str)
# print(f"Answer: {answer}")
|