m7mdal7aj commited on
Commit
c59fc6b
1 Parent(s): f7df8ad

added kbvqa draft

Browse files
Files changed (1) hide show
  1. my_model/KBVQA.py +146 -0
my_model/KBVQA.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
+ from typing import Optional
5
+ from my_model.image_captioning import ImageCaptioningModel
6
+ from my_model.object_detection import ObjectDetector
7
+
8
+
9
+ class KBVQA():
10
+
11
+ def __init__(self):
12
+ self.kbvqa_model_name = "m7mdal7aj/fine_tunned_llama_2_merged"
13
+ self.quantization='4bit'
14
+ self.bnb_config = self.create_bnb_config()
15
+ self.max_context_window = 4096
16
+ self.add_eos_token = False
17
+ self.trust_remote = False
18
+ self.use_fast = True
19
+ self.kbvqa_tokenizer = None
20
+ self.captioner = None
21
+ self.detector = None
22
+ self.kbvqa_model = None
23
+ # self.kbvqa_model_loaded = self.all_models_loaded()
24
+
25
+
26
+ def create_bnb_config(self) -> BitsAndBytesConfig:
27
+ """
28
+ Creates a BitsAndBytes configuration based on the quantization setting.
29
+ Returns:
30
+ BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
31
+ """
32
+ if self.quantization == '4bit':
33
+ return BitsAndBytesConfig(
34
+ load_in_4bit=True,
35
+ bnb_4bit_use_double_quant=True,
36
+ bnb_4bit_quant_type="nf4",
37
+ bnb_4bit_compute_dtype=torch.bfloat16
38
+ )
39
+ elif self.quantization == '8bit':
40
+ return BitsAndBytesConfig(
41
+ load_in_8bit=True,
42
+ bnb_8bit_use_double_quant=True,
43
+ bnb_8bit_quant_type="nf4",
44
+ bnb_8bit_compute_dtype=torch.bfloat16
45
+ )
46
+
47
+
48
+ def load_caption_model(self):
49
+ self.captioner = ImageCaptioningModel(model_type='i_blip')
50
+ self.captioner.load_model()
51
+
52
+ def get_caption(self, img):
53
+
54
+ return self.captioner.generate_caption(img)
55
+
56
+ def load_detector(self, model):
57
+
58
+ self.detector = ObjectDetector()
59
+ self.detector.load_model(model)
60
+
61
+ def detect_objects(self, img, threshold=0.2):
62
+ image = self.detector.process_image(img)
63
+ detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=threshold)
64
+ image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
65
+ return image_with_boxes, detected_objects_string
66
+
67
+ def load_fine_tuned_model(self):
68
+
69
+ self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name, device_map="auto", quantization_config=self.bnb_config)
70
+ self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name, use_fast=self.use_fast, trust_remote_code=self.trust_remote, add_eos_token=self.add_eos_token)
71
+
72
+
73
+ @property
74
+ def all_models_loaded(self):
75
+ return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
76
+
77
+
78
+
79
+ def format_prompt(self, current_query, history = None , sys_prompt=None, caption=None, objects=None):
80
+
81
+ if sys_prompt is None:
82
+ sys_prompt = "You are a helpful, respectful and honest assistant for visual question answering. you are provided with a caption of an image and a list of objects detected in the image along with their bounding boxes and level of certainty, you will output an answer to the given questions in no more than one sentence. Use logical reasoning to reach to the answer, but do not output your reasoning process unless asked for it. If provided, you will use the [CAP] and [/CAP] tags to indicate the begining and end of the caption respectively. If provided you will use the [OBJ] and [/OBJ] tags to indicate the begining and end of the list of detected objects in the image along with their bounding boxes respectively.if provided, you will use [QES] and [/QES] tags to indicate the begining and end of the question respectively."
83
+
84
+ B_SENT = '<s>'
85
+ E_SENT = '</s>'
86
+ B_INST = '[INST]'
87
+ E_INST = '[/INST]'
88
+ B_SYS = '<<SYS>>\n'
89
+ E_SYS = '\n<</SYS>>\n\n'
90
+ B_CAP = '[CAP]'
91
+ E_CAP = '[/CAP]'
92
+ B_QES = '[QES]'
93
+ E_QES = '[/QES]'
94
+ B_OBJ = '[OBJ]'
95
+ E_OBJ = '[/OBJ]'
96
+
97
+
98
+ current_query = current_query.strip()
99
+ sys_prompt = sys_prompt.strip()
100
+
101
+ if history is None:
102
+ if objects is None:
103
+ p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_QES}{current_query}{E_QES}{E_INST}"""
104
+ else:
105
+ p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_OBJ}{objects}{E_OBJ}{B_QES}taking into consideration the objects with high certainty, {current_query}{E_QES}{E_INST}"""
106
+ else:
107
+ p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
108
+
109
+
110
+ return p
111
+
112
+
113
+ def generate_answer(self, question, caption, detected_objects_str):
114
+ prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
115
+ num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
116
+ if num_tokens > self.max_context_window:
117
+ st.write(f"Prompt too long with {num_tokens} tokens, consider increasing the confidence threshold for the object detector")
118
+ return
119
+
120
+ model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
121
+ input_ids = model_inputs["input_ids"]
122
+ output_ids = self.kbvqa_model.generate(input_ids)
123
+ index = input_ids.shape[1] # needed to avoid printing the input prompt
124
+ history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
125
+ output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)
126
+
127
+ return output_text.capitalize()
128
+
129
+ def prepare_kbvqa_model(detection_model):
130
+ kbvqa = KBVQA()
131
+ # Progress bar for model loading
132
+ with st.spinner('Loading models...'):
133
+ progress_bar = st.progress(0)
134
+ kbvqa.load_fine_tuned_model()
135
+ progress_bar.progress(33)
136
+ kbvqa.load_caption_model()
137
+ progress_bar.progress(66)
138
+ kbvqa.load_detector(detection_model) # Replace with your model
139
+ progress_bar.progress(100)
140
+
141
+ if kbvqa.all_models_loaded:
142
+ st.success('Model loaded successfully!')
143
+ kbvqa.kbvqa_model.eval()
144
+ return kbvqa
145
+
146
+