Vladyslav Humennyy Claude commited on
Commit
14729c6
·
1 Parent(s): a113c8a

Fix image handling for Gradio compatibility

Browse files

- Store images as file paths for Gradio display (type: "image")
- Keep base64 in _base64 metadata for model processing
- Clean metadata before displaying to avoid validation errors
- Combine text and image in single message structure
- Properly convert base64 to PIL images for processor

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +56 -16
app.py CHANGED
@@ -61,7 +61,7 @@ model, tokenizer, processor, device = load_model()
61
 
62
 
63
  def user(user_message, image_data, history: list):
64
- """Format user message with optional image (like app_chat_vllm.py)."""
65
  import base64
66
  import io
67
  from PIL import Image
@@ -72,22 +72,30 @@ def user(user_message, image_data, history: list):
72
 
73
  stripped_message = user_message.strip()
74
 
75
- # Format message with image in base64 format (matching app_chat_vllm.py)
76
  if image_data is not None:
77
- # Convert PIL image to base64
 
 
 
 
 
78
  buffered = io.BytesIO()
79
  image_data.save(buffered, format="JPEG")
80
  img_base64 = base64.b64encode(buffered.getvalue()).decode()
81
 
82
  text_content = stripped_message if stripped_message else "Describe this image"
83
 
 
84
  updated_history.append({
85
  "role": "user",
86
  "content": [
87
  {"type": "text", "text": text_content},
88
  {
89
- "type": "image_url",
90
- "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
 
 
91
  },
92
  ],
93
  })
@@ -126,6 +134,33 @@ def _extract_text_from_content(content: Any) -> str:
126
  return str(content)
127
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @spaces.GPU
130
  def bot(
131
  history: list[dict[str, Any]]
@@ -147,7 +182,7 @@ def bot(
147
  # Check if any message contains images
148
  has_images = any(
149
  isinstance(msg.get("content"), list) and
150
- any(item.get("type") == "image_url" for item in msg.get("content") if isinstance(item, dict))
151
  for msg in history
152
  )
153
 
@@ -156,8 +191,7 @@ def bot(
156
  # Use processor if images are present
157
  if processor is not None and has_images:
158
  try:
159
- # Processor expects messages with PIL images, not base64
160
- # We need to convert base64 back to PIL for the processor
161
  from PIL import Image
162
  import base64
163
  import io
@@ -175,13 +209,19 @@ def bot(
175
  if isinstance(item, dict):
176
  if item.get("type") == "text":
177
  formatted_content.append({"type": "text", "text": item.get("text", "")})
178
- elif item.get("type") == "image_url":
179
- # Extract base64 and convert to PIL
180
- img_url = item.get("image_url", {}).get("url", "")
181
- if img_url.startswith("data:image"):
182
- base64_data = img_url.split(",")[1]
183
- img_data = base64.b64decode(base64_data)
184
- pil_image = Image.open(io.BytesIO(img_data))
 
 
 
 
 
 
185
  formatted_content.append({"type": "image", "image": pil_image})
186
  if formatted_content:
187
  processor_history.append({"role": role, "content": formatted_content})
@@ -241,7 +281,7 @@ def bot(
241
  # Yield tokens as they come in
242
  for new_text in streamer:
243
  history[-1]["content"] += new_text
244
- yield history
245
 
246
  assistant_message = history[-1]["content"]
247
  logger.log_interaction(user=user_message_text, answer=assistant_message)
 
61
 
62
 
63
  def user(user_message, image_data, history: list):
64
+ """Format user message with optional image."""
65
  import base64
66
  import io
67
  from PIL import Image
 
72
 
73
  stripped_message = user_message.strip()
74
 
75
+ # If we have an image, save it to temp file for Gradio display and also encode as base64 for model
76
  if image_data is not None:
77
+ # Save to temp file for Gradio display
78
+ fd, tmp_path = tempfile.mkstemp(suffix=".jpg")
79
+ os.close(fd)
80
+ image_data.save(tmp_path, format="JPEG")
81
+
82
+ # Also encode as base64 for model processing (stored in metadata)
83
  buffered = io.BytesIO()
84
  image_data.save(buffered, format="JPEG")
85
  img_base64 = base64.b64encode(buffered.getvalue()).decode()
86
 
87
  text_content = stripped_message if stripped_message else "Describe this image"
88
 
89
+ # Store both text and image in a single message with base64 in metadata
90
  updated_history.append({
91
  "role": "user",
92
  "content": [
93
  {"type": "text", "text": text_content},
94
  {
95
+ "type": "image",
96
+ "path": tmp_path,
97
+ "alt_text": "User uploaded image",
98
+ "_base64": f"data:image/jpeg;base64,{img_base64}", # Store base64 for model
99
  },
100
  ],
101
  })
 
134
  return str(content)
135
 
136
 
137
+ def _clean_history_for_display(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
138
+ """Remove internal metadata fields like _base64 before displaying in Gradio."""
139
+ cleaned = []
140
+ for message in history:
141
+ cleaned_message = {"role": message.get("role", "user")}
142
+ content = message.get("content")
143
+
144
+ if isinstance(content, str):
145
+ cleaned_message["content"] = content
146
+ elif isinstance(content, list):
147
+ cleaned_content = []
148
+ for item in content:
149
+ if isinstance(item, dict):
150
+ # Remove _base64 metadata
151
+ cleaned_item = {k: v for k, v in item.items() if not k.startswith("_")}
152
+ cleaned_content.append(cleaned_item)
153
+ else:
154
+ cleaned_content.append(item)
155
+ cleaned_message["content"] = cleaned_content
156
+ else:
157
+ cleaned_message["content"] = content
158
+
159
+ cleaned.append(cleaned_message)
160
+
161
+ return cleaned
162
+
163
+
164
  @spaces.GPU
165
  def bot(
166
  history: list[dict[str, Any]]
 
182
  # Check if any message contains images
183
  has_images = any(
184
  isinstance(msg.get("content"), list) and
185
+ any(item.get("type") == "image" for item in msg.get("content") if isinstance(item, dict))
186
  for msg in history
187
  )
188
 
 
191
  # Use processor if images are present
192
  if processor is not None and has_images:
193
  try:
194
+ # Processor expects messages with PIL images
 
195
  from PIL import Image
196
  import base64
197
  import io
 
209
  if isinstance(item, dict):
210
  if item.get("type") == "text":
211
  formatted_content.append({"type": "text", "text": item.get("text", "")})
212
+ elif item.get("type") == "image":
213
+ # Use _base64 metadata if available, otherwise load from path
214
+ pil_image = None
215
+ if "_base64" in item:
216
+ img_url = item["_base64"]
217
+ if img_url.startswith("data:image"):
218
+ base64_data = img_url.split(",")[1]
219
+ img_data = base64.b64decode(base64_data)
220
+ pil_image = Image.open(io.BytesIO(img_data))
221
+ elif "path" in item:
222
+ pil_image = Image.open(item["path"])
223
+
224
+ if pil_image is not None:
225
  formatted_content.append({"type": "image", "image": pil_image})
226
  if formatted_content:
227
  processor_history.append({"role": role, "content": formatted_content})
 
281
  # Yield tokens as they come in
282
  for new_text in streamer:
283
  history[-1]["content"] += new_text
284
+ yield _clean_history_for_display(history)
285
 
286
  assistant_message = history[-1]["content"]
287
  logger.log_interaction(user=user_message_text, answer=assistant_message)