baohuynhbk14 commited on
Commit
4d0481d
·
1 Parent(s): 46c8f02

Refactor message handling in conversation and prediction functions to improve clarity and functionality

Browse files
Files changed (3) hide show
  1. app.py +26 -14
  2. conversation.py +15 -6
  3. models.py +2 -9
app.py CHANGED
@@ -15,7 +15,7 @@ from filelock import FileLock
15
  from io import BytesIO
16
  from PIL import Image, ImageDraw, ImageFont
17
  from models import load_image
18
- from constants import LOGDIR
19
  from utils import (
20
  build_logger,
21
  server_error_msg,
@@ -164,6 +164,10 @@ def add_text(state, message, system_prompt, request: gr.Request):
164
 
165
  if len(images) > 0 and len(state.get_images(source=state.USER)) > 0:
166
  state = init_state(state)
 
 
 
 
167
  state.set_system_message(system_prompt)
168
  state.append_message(Conversation.USER, text, images)
169
  state.skip_next = False
@@ -183,19 +187,29 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, us
183
  @spaces.GPU
184
  def predict(message,
185
  image_path,
186
- history,
187
  max_input_tiles=6,
188
  temperature=1.0,
189
  max_output_tokens=700,
190
  top_p=0.7,
191
  repetition_penalty=2.5):
192
- pixel_values = load_image(image_path, max_num=max_input_tiles).to(torch.bfloat16).cuda()
 
 
 
193
  generation_config = dict(temperature=temperature, max_new_tokens= max_output_tokens, top_p=top_p, do_sample=False, num_beams = 3, repetition_penalty=repetition_penalty)
194
- if pixel_values is not None:
195
- question = '<image>\n'+message
196
- else:
197
- question = message
198
- print(f"FULL predict question: {question}")
 
 
 
 
 
 
 
199
  response, conv_history = model.chat(tokenizer, pixel_values, question, generation_config, history=history, return_history=True)
200
  return response, conv_history
201
 
@@ -246,21 +260,19 @@ def http_bot(
246
 
247
  try:
248
  # Stream output
249
- message = state.get_last_user_message(source=state.USER)
250
  logger.info(f"==== User message ====\n{message}")
251
  logger.info(f"==== Image paths ====\n{all_image_paths}")
252
 
253
- history = state.get_prompt()
254
- logger.info(f"==== History ====\n{history}")
255
- response, conv_history = predict(message,
256
  all_image_paths[0],
257
- history,
258
  max_input_tiles,
259
  temperature,
260
  max_new_tokens,
261
  top_p,
262
  repetition_penalty)
263
- logger.info(f"==== AI history ====\n{conv_history}")
264
 
265
 
266
  # response = "This is a test response"
 
15
  from io import BytesIO
16
  from PIL import Image, ImageDraw, ImageFont
17
  from models import load_image
18
+ from constants import LOGDIR, DEFAULT_IMAGE_TOKEN
19
  from utils import (
20
  build_logger,
21
  server_error_msg,
 
164
 
165
  if len(images) > 0 and len(state.get_images(source=state.USER)) > 0:
166
  state = init_state(state)
167
+
168
+ if len(images) > 0 and len(state.get_images(source=state.USER)) == 0:
169
+ text = DEFAULT_IMAGE_TOKEN + "\n" + text
170
+
171
  state.set_system_message(system_prompt)
172
  state.append_message(Conversation.USER, text, images)
173
  state.skip_next = False
 
187
  @spaces.GPU
188
  def predict(message,
189
  image_path,
190
+ state,
191
  max_input_tiles=6,
192
  temperature=1.0,
193
  max_output_tokens=700,
194
  top_p=0.7,
195
  repetition_penalty=2.5):
196
+
197
+ history = state.get_prompt()
198
+ logger.info(f"==== History ====\n{history}")
199
+
200
  generation_config = dict(temperature=temperature, max_new_tokens= max_output_tokens, top_p=top_p, do_sample=False, num_beams = 3, repetition_penalty=repetition_penalty)
201
+
202
+ question = message
203
+ pixel_values = None
204
+ if image_path is not None:
205
+ pixel_values = load_image(image_path, max_num=max_input_tiles).to(torch.bfloat16).cuda()
206
+ if pixel_values is not None:
207
+ # Check the first user message to see if it is an image
208
+ index, first_user_message = state.get_user_message(source=state.USER, position='first')
209
+ if first_user_message is not None and \
210
+ DEFAULT_IMAGE_TOKEN not in first_user_message:
211
+ state.messages[index]['content'] = DEFAULT_IMAGE_TOKEN + "\n" + first_user_message
212
+
213
  response, conv_history = model.chat(tokenizer, pixel_values, question, generation_config, history=history, return_history=True)
214
  return response, conv_history
215
 
 
260
 
261
  try:
262
  # Stream output
263
+ message = state.get_user_message(source=state.USER, position='last')
264
  logger.info(f"==== User message ====\n{message}")
265
  logger.info(f"==== Image paths ====\n{all_image_paths}")
266
 
267
+ response, _ = predict(message,
 
 
268
  all_image_paths[0],
269
+ state,
270
  max_input_tiles,
271
  temperature,
272
  max_new_tokens,
273
  top_p,
274
  repetition_penalty)
275
+ # logger.info(f"==== AI history ====\n{conv_history}")
276
 
277
 
278
  # response = "This is a test response"
conversation.py CHANGED
@@ -174,14 +174,23 @@ class Conversation:
174
 
175
  return images
176
 
177
- def get_last_user_message(self, source: Union[str, None] = None):
178
  assert len(self.messages) > 0, "No message in the conversation."
179
  assert source in [self.USER, self.ASSISTANT, None], f"Invalid source: {source}"
180
- for i in range(len(self.messages) - 1, -1, -1):
181
- if source and self.messages[i]["role"] != source:
182
- continue
183
- if self.messages[i]["role"] == self.USER:
184
- return self.messages[i]["content"]
 
 
 
 
 
 
 
 
 
185
 
186
  def to_gradio_chatbot(self):
187
  ret = []
 
174
 
175
  return images
176
 
177
+ def get_user_message(self, source: Union[str, None] = None, position="first"):
178
  assert len(self.messages) > 0, "No message in the conversation."
179
  assert source in [self.USER, self.ASSISTANT, None], f"Invalid source: {source}"
180
+
181
+ if position == "first":
182
+ for i, msg in enumerate(self.messages):
183
+ if source and msg["role"] != source:
184
+ continue
185
+ if msg["role"] == self.USER:
186
+ return i, msg["content"]
187
+
188
+ elif position == "last":
189
+ for i in range(len(self.messages) - 1, -1, -1):
190
+ if source and self.messages[i]["role"] != source:
191
+ continue
192
+ if self.messages[i]["role"] == self.USER:
193
+ return i, self.messages[i]["content"]
194
 
195
  def to_gradio_chatbot(self):
196
  ret = []
models.py CHANGED
@@ -74,16 +74,12 @@ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbna
74
  return processed_images
75
 
76
  def correct_image_orientation(image_path):
77
- # Mở ảnh
78
  image = Image.open(image_path)
79
-
80
- # Kiểm tra dữ liệu Exif (nếu có)
81
  try:
82
  exif = image._getexif()
83
  if exif is not None:
84
  for tag, value in exif.items():
85
  if ExifTags.TAGS.get(tag) == "Orientation":
86
- # Sửa hướng dựa trên Orientation
87
  if value == 3:
88
  image = image.rotate(180, expand=True)
89
  elif value == 6:
@@ -92,7 +88,8 @@ def correct_image_orientation(image_path):
92
  image = image.rotate(90, expand=True)
93
  break
94
  except Exception as e:
95
- print("Không thể xử lý Exif:", e)
 
96
 
97
  return image
98
 
@@ -100,13 +97,9 @@ def load_image(image_file, input_size=448, max_num=12):
100
  try:
101
  print("Loading image:", image_file)
102
  image = correct_image_orientation(image_file).convert('RGB')
103
- print("Image size:", image.size)
104
  transform = build_transform(input_size=input_size)
105
- print("Transform built.")
106
  images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
107
- print("Number of images:", len(images))
108
  pixel_values = [transform(image) for image in images]
109
- print("Images transformed.")
110
  pixel_values = torch.stack(pixel_values)
111
  print("Image loaded successfully.")
112
  except Exception as e:
 
74
  return processed_images
75
 
76
  def correct_image_orientation(image_path):
 
77
  image = Image.open(image_path)
 
 
78
  try:
79
  exif = image._getexif()
80
  if exif is not None:
81
  for tag, value in exif.items():
82
  if ExifTags.TAGS.get(tag) == "Orientation":
 
83
  if value == 3:
84
  image = image.rotate(180, expand=True)
85
  elif value == 6:
 
88
  image = image.rotate(90, expand=True)
89
  break
90
  except Exception as e:
91
+ print("Error reading exif:", e)
92
+ print(traceback.format_exc())
93
 
94
  return image
95
 
 
97
  try:
98
  print("Loading image:", image_file)
99
  image = correct_image_orientation(image_file).convert('RGB')
 
100
  transform = build_transform(input_size=input_size)
 
101
  images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
 
102
  pixel_values = [transform(image) for image in images]
 
103
  pixel_values = torch.stack(pixel_values)
104
  print("Image loaded successfully.")
105
  except Exception as e: