Vladyslav Humennyy Claude commited on
Commit
a113c8a
·
1 Parent(s): 9203469

Rewrite image handling to match app_chat_vllm.py format

Browse files

- User function now converts images to base64 with image_url format
- Removed complex unused helper functions for message processing
- Bot function properly handles base64 images with processor
- Converts base64 back to PIL images when using processor
- Falls back to tokenizer for text-only messages
- Simplified and cleaner implementation matching app_chat_vllm.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>

Files changed (1) hide show
  1. app.py +112 -258
app.py CHANGED
@@ -60,44 +60,40 @@ def load_model():
60
  model, tokenizer, processor, device = load_model()
61
 
62
 
63
- def _ensure_image_path(image_data: Any) -> str | None:
64
- """Return a valid file path for the provided image data."""
65
- if image_data is None:
66
- return None
67
-
68
- try:
69
- from PIL import Image
70
- except ImportError: # pragma: no cover - PIL is bundled with Gradio's image component
71
- return None
72
-
73
- # Already a path string
74
- if isinstance(image_data, str) and os.path.exists(image_data):
75
- return image_data
76
-
77
- # PIL Image object - save to temp file
78
- if isinstance(image_data, Image.Image):
79
- fd, tmp_path = tempfile.mkstemp(suffix=".png")
80
- os.close(fd)
81
- image_data.save(tmp_path, format="PNG")
82
- return tmp_path
83
-
84
- return None
85
-
86
-
87
  def user(user_message, image_data, history: list):
88
- user_message = user_message or ""
 
 
 
89
 
 
90
  updated_history = list(history)
91
  has_content = False
92
 
93
  stripped_message = user_message.strip()
94
- if stripped_message:
95
- updated_history.append({"role": "user", "content": stripped_message})
96
- has_content = True
97
 
98
- image_path = _ensure_image_path(image_data)
99
- if image_path is not None:
100
- updated_history.append({"role": "user", "content": {"path": image_path, "alt_text": "User uploaded image"}})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  has_content = True
102
 
103
  if not has_content:
@@ -117,257 +113,116 @@ def append_example_message(x: gr.SelectData, history):
117
  return history
118
 
119
 
120
- def _message_contains_image(message: dict[str, Any]) -> bool:
121
- content = message.get("content")
122
- if isinstance(content, dict):
123
- if "path" in content or "image" in content:
124
- return True
125
- if content.get("type") in {"image", "image_url"}:
126
- return True
127
  if isinstance(content, list):
 
128
  for item in content:
129
- if isinstance(item, dict) and item.get("type") in {"image", "image_url"}:
130
- return True
131
- return False
132
-
133
-
134
- def _content_to_text(content: Any) -> str:
135
- if isinstance(content, dict):
136
- if "text" in content:
137
- return content.get("text", "")
138
- if "path" in content:
139
- alt_text = content.get("alt_text")
140
- placeholder = alt_text or os.path.basename(content["path"]) or "image"
141
- return f"[image: {placeholder}]"
142
- if "image" in content:
143
- return "[image]"
144
- if content.get("type") == "image_url":
145
- image_url = content.get("image_url")
146
- if isinstance(image_url, dict):
147
- image_url = image_url.get("url", "")
148
- return f"[image: {image_url}]"
149
- if content.get("type") == "text":
150
- return content.get("text", "")
151
- return str(content)
152
- if isinstance(content, list):
153
- text_parts: list[str] = []
154
- for item in content:
155
- if isinstance(item, dict):
156
- item_type = item.get("type")
157
- if item_type == "text":
158
- text_parts.append(item.get("text", ""))
159
- elif item_type == "image":
160
- text_parts.append("[image]")
161
- elif item_type == "image_url":
162
- image_url = item.get("image_url")
163
- if isinstance(image_url, dict):
164
- image_url = image_url.get("url", "")
165
- text_parts.append(f"[image: {image_url}]")
166
- else:
167
- text_parts.append(str(item))
168
- else:
169
- text_parts.append(str(item))
170
- filtered = [part for part in text_parts if part]
171
- return "\n".join(filtered) if filtered else "[image]"
172
  return str(content)
173
 
174
 
175
- def _collect_recent_user_contents(history: list[dict[str, Any]]) -> list[Any]:
176
- """Collect the trailing sequence of user messages prior to the assistant reply."""
177
- chunks: list[Any] = []
178
- for message in reversed(history):
179
- if message.get("role") != "user":
180
- break
181
- chunks.append(message.get("content"))
182
- chunks.reverse()
183
- return chunks
184
-
185
-
186
- def _prepare_text_history(history: list[dict[str, Any]]) -> list[dict[str, str]]:
187
- text_history: list[dict[str, str]] = []
188
- for message in history:
189
- role = message.get("role", "user")
190
- content_text = _content_to_text(message.get("content"))
191
- if not content_text:
192
- continue
193
- if text_history and text_history[-1]["role"] == role:
194
- text_history[-1]["content"] = text_history[-1]["content"] + "\n" + content_text
195
- else:
196
- text_history.append({"role": role, "content": content_text})
197
- return text_history
198
-
199
-
200
- def _prepare_processor_history(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
201
- """Prepare history for processor with proper image format."""
202
- processor_history = []
203
-
204
- for message in history:
205
- role = message.get("role", "user")
206
- content = message.get("content")
207
-
208
- # Handle different content formats
209
- if isinstance(content, str):
210
- # Simple text message
211
- processor_history.append({"role": role, "content": content})
212
- elif isinstance(content, list):
213
- # Multi-modal content (text + images)
214
- formatted_content = []
215
- for item in content:
216
- if isinstance(item, dict):
217
- item_type = item.get("type")
218
- if item_type == "text":
219
- formatted_content.append({"type": "text", "text": item.get("text", "")})
220
- elif item_type == "image":
221
- # Extract PIL Image from _pil_image field or load from path
222
- pil_image = item.get("_pil_image")
223
- if pil_image is None and "path" in item:
224
- from PIL import Image
225
- pil_image = Image.open(item["path"])
226
- if pil_image is not None:
227
- formatted_content.append({"type": "image", "image": pil_image})
228
- if formatted_content:
229
- processor_history.append({"role": role, "content": formatted_content})
230
- elif isinstance(content, dict):
231
- # Legacy format or single image
232
- if "image" in content or "_pil_image" in content:
233
- pil_image = content.get("_pil_image") or content.get("image")
234
- if pil_image is None and "path" in content:
235
- from PIL import Image
236
- pil_image = Image.open(content["path"])
237
- if pil_image is not None:
238
- processor_history.append({
239
- "role": role,
240
- "content": [{"type": "image", "image": pil_image}]
241
- })
242
- else:
243
- # Try to extract text
244
- text = _content_to_text(content)
245
- if text:
246
- processor_history.append({"role": role, "content": text})
247
-
248
- return processor_history
249
-
250
-
251
- def _clean_history_for_display(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
252
- """Remove internal metadata fields like _pil_image before displaying in Gradio."""
253
- cleaned = []
254
-
255
- for message in history:
256
- cleaned_message = {"role": message.get("role", "user")}
257
- content = message.get("content")
258
-
259
- if isinstance(content, str):
260
- cleaned_message["content"] = content
261
- elif isinstance(content, list):
262
- cleaned_content = []
263
- for item in content:
264
- if isinstance(item, dict):
265
- # Remove _pil_image and ensure alt_text is string or absent
266
- cleaned_item = {}
267
- for k, v in item.items():
268
- if k == "_pil_image":
269
- continue
270
- if k == "alt_text":
271
- # Ensure alt_text is a string
272
- if isinstance(v, str):
273
- cleaned_item[k] = v
274
- # Skip non-string alt_text values
275
- continue
276
- cleaned_item[k] = v
277
- # Ensure alt_text exists for image type
278
- if cleaned_item.get("type") == "image" and "alt_text" not in cleaned_item:
279
- cleaned_item["alt_text"] = "uploaded image"
280
- cleaned_content.append(cleaned_item)
281
- else:
282
- cleaned_content.append(item)
283
- cleaned_message["content"] = cleaned_content
284
- elif isinstance(content, dict):
285
- # Remove _pil_image and ensure alt_text is string or absent
286
- cleaned_item = {}
287
- for k, v in content.items():
288
- if k == "_pil_image":
289
- continue
290
- if k == "alt_text":
291
- # Ensure alt_text is a string
292
- if isinstance(v, str):
293
- cleaned_item[k] = v
294
- # Skip non-string alt_text values
295
- continue
296
- cleaned_item[k] = v
297
- # Ensure alt_text exists for image content
298
- if "path" in cleaned_item and "alt_text" not in cleaned_item:
299
- cleaned_item["alt_text"] = "uploaded image"
300
- cleaned_message["content"] = cleaned_item
301
- else:
302
- cleaned_message["content"] = content
303
-
304
- cleaned.append(cleaned_message)
305
-
306
- return cleaned
307
-
308
-
309
  @spaces.GPU
310
  def bot(
311
  history: list[dict[str, Any]]
312
- # max_tokens,
313
- # temperature,
314
- # top_p,
315
  ):
316
- user_chunks = _collect_recent_user_contents(history)
317
- if not user_chunks:
318
- user_message_text = ""
319
- else:
320
- user_message_text = "\n".join(filter(None, (_content_to_text(chunk) for chunk in user_chunks)))
321
- print('User message:', user_message_text)
322
- # [{"role": "system", "content": system_message}] +
323
- # Build conversation
324
  max_tokens = 4096
325
  temperature = 0.7
326
  top_p = 0.95
327
 
328
- text_history = _prepare_text_history(history)
329
-
330
- # Handle empty history case
331
- if not text_history:
332
- input_text = ""
333
- else:
334
- input_text: str = tokenizer.apply_chat_template(
335
- text_history,
336
- tokenize=False,
337
- add_generation_prompt=True,
338
- # enable_thinking=True,
339
- )
340
-
341
- if input_text and tokenizer.bos_token:
342
- input_text = input_text.replace(tokenizer.bos_token, "", 1)
343
- print(input_text)
344
- model_inputs = None
345
-
346
  # Early return if no input
347
- if not input_text and not any(_message_contains_image(msg) for msg in history):
348
  return
349
 
350
- if processor is not None and any(_message_contains_image(msg) for msg in history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  try:
352
- processor_history = _prepare_processor_history(history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  model_inputs = processor(
354
  messages=processor_history,
355
  return_tensors="pt",
356
  add_generation_prompt=True,
357
  ).to(model.device)
358
- except Exception as exc: # pragma: no cover - diagnostic logging
359
- print(f"Processor failed, using tokenizer pipeline instead: {exc}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
  if model_inputs is None:
362
- model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device) # .to(device)
363
 
364
- decoded_input = tokenizer.decode(model_inputs["input_ids"][0])
365
- print("Decoded input:", decoded_input)
366
- print([{int(token_id.item()): tokenizer.decode([int(token_id.item())])} for token_id in model_inputs["input_ids"][0]])
367
  # Streamer setup
368
- streamer = TextIteratorStreamer(
369
- tokenizer, skip_prompt=True # skip_special_tokens=True # ,
370
- )
371
 
372
  # Run model.generate in background thread
373
  generation_kwargs = dict(
@@ -377,7 +232,6 @@ def bot(
377
  top_p=top_p,
378
  top_k=64,
379
  do_sample=True,
380
- # eos_token_id=tokenizer.eos_token_id,
381
  streamer=streamer,
382
  )
383
  thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
@@ -387,7 +241,7 @@ def bot(
387
  # Yield tokens as they come in
388
  for new_text in streamer:
389
  history[-1]["content"] += new_text
390
- yield _clean_history_for_display(history)
391
 
392
  assistant_message = history[-1]["content"]
393
  logger.log_interaction(user=user_message_text, answer=assistant_message)
 
60
  model, tokenizer, processor, device = load_model()
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def user(user_message, image_data, history: list):
64
+ """Format user message with optional image (like app_chat_vllm.py)."""
65
+ import base64
66
+ import io
67
+ from PIL import Image
68
 
69
+ user_message = user_message or ""
70
  updated_history = list(history)
71
  has_content = False
72
 
73
  stripped_message = user_message.strip()
 
 
 
74
 
75
+ # Format message with image in base64 format (matching app_chat_vllm.py)
76
+ if image_data is not None:
77
+ # Convert PIL image to base64
78
+ buffered = io.BytesIO()
79
+ image_data.save(buffered, format="JPEG")
80
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
81
+
82
+ text_content = stripped_message if stripped_message else "Describe this image"
83
+
84
+ updated_history.append({
85
+ "role": "user",
86
+ "content": [
87
+ {"type": "text", "text": text_content},
88
+ {
89
+ "type": "image_url",
90
+ "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
91
+ },
92
+ ],
93
+ })
94
+ has_content = True
95
+ elif stripped_message:
96
+ updated_history.append({"role": "user", "content": stripped_message})
97
  has_content = True
98
 
99
  if not has_content:
 
113
  return history
114
 
115
 
116
+ def _extract_text_from_content(content: Any) -> str:
117
+ """Extract text from message content for logging."""
118
+ if isinstance(content, str):
119
+ return content
 
 
 
120
  if isinstance(content, list):
121
+ text_parts = []
122
  for item in content:
123
+ if isinstance(item, dict) and item.get("type") == "text":
124
+ text_parts.append(item.get("text", ""))
125
+ return " ".join(text_parts) if text_parts else "[Image]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  return str(content)
127
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @spaces.GPU
130
  def bot(
131
  history: list[dict[str, Any]]
 
 
 
132
  ):
133
+ """Generate bot response with support for text and images."""
 
 
 
 
 
 
 
134
  max_tokens = 4096
135
  temperature = 0.7
136
  top_p = 0.95
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  # Early return if no input
139
+ if not history:
140
  return
141
 
142
+ # Extract last user message for logging
143
+ last_user_msg = next((msg for msg in reversed(history) if msg.get("role") == "user"), None)
144
+ user_message_text = _extract_text_from_content(last_user_msg.get("content")) if last_user_msg else ""
145
+ print('User message:', user_message_text)
146
+
147
+ # Check if any message contains images
148
+ has_images = any(
149
+ isinstance(msg.get("content"), list) and
150
+ any(item.get("type") == "image_url" for item in msg.get("content") if isinstance(item, dict))
151
+ for msg in history
152
+ )
153
+
154
+ model_inputs = None
155
+
156
+ # Use processor if images are present
157
+ if processor is not None and has_images:
158
  try:
159
+ # Processor expects messages with PIL images, not base64
160
+ # We need to convert base64 back to PIL for the processor
161
+ from PIL import Image
162
+ import base64
163
+ import io
164
+
165
+ processor_history = []
166
+ for msg in history:
167
+ role = msg.get("role", "user")
168
+ content = msg.get("content")
169
+
170
+ if isinstance(content, str):
171
+ processor_history.append({"role": role, "content": content})
172
+ elif isinstance(content, list):
173
+ formatted_content = []
174
+ for item in content:
175
+ if isinstance(item, dict):
176
+ if item.get("type") == "text":
177
+ formatted_content.append({"type": "text", "text": item.get("text", "")})
178
+ elif item.get("type") == "image_url":
179
+ # Extract base64 and convert to PIL
180
+ img_url = item.get("image_url", {}).get("url", "")
181
+ if img_url.startswith("data:image"):
182
+ base64_data = img_url.split(",")[1]
183
+ img_data = base64.b64decode(base64_data)
184
+ pil_image = Image.open(io.BytesIO(img_data))
185
+ formatted_content.append({"type": "image", "image": pil_image})
186
+ if formatted_content:
187
+ processor_history.append({"role": role, "content": formatted_content})
188
+
189
  model_inputs = processor(
190
  messages=processor_history,
191
  return_tensors="pt",
192
  add_generation_prompt=True,
193
  ).to(model.device)
194
+ print("Using processor for vision input")
195
+ except Exception as exc:
196
+ print(f"Processor failed: {exc}")
197
+ model_inputs = None
198
+
199
+ # Fallback to tokenizer for text-only
200
+ if model_inputs is None:
201
+ # Convert to text-only format for tokenizer
202
+ text_history = []
203
+ for msg in history:
204
+ role = msg.get("role", "user")
205
+ content = msg.get("content")
206
+ text_content = _extract_text_from_content(content)
207
+ if text_content:
208
+ text_history.append({"role": role, "content": text_content})
209
+
210
+ if text_history:
211
+ input_text = tokenizer.apply_chat_template(
212
+ text_history,
213
+ tokenize=False,
214
+ add_generation_prompt=True,
215
+ )
216
+ if input_text and tokenizer.bos_token:
217
+ input_text = input_text.replace(tokenizer.bos_token, "", 1)
218
+ model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
219
+ print("Using tokenizer for text-only input")
220
 
221
  if model_inputs is None:
222
+ return
223
 
 
 
 
224
  # Streamer setup
225
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
 
 
226
 
227
  # Run model.generate in background thread
228
  generation_kwargs = dict(
 
232
  top_p=top_p,
233
  top_k=64,
234
  do_sample=True,
 
235
  streamer=streamer,
236
  )
237
  thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
 
241
  # Yield tokens as they come in
242
  for new_text in streamer:
243
  history[-1]["content"] += new_text
244
+ yield history
245
 
246
  assistant_message = history[-1]["content"]
247
  logger.log_interaction(user=user_message_text, answer=assistant_message)