AdithyaSK commited on
Commit
e03ab78
Β·
1 Parent(s): b3edf84

Update README.md and app.py: change SDK version to 6.0.2 and enhance error handling in document indexing

Browse files
Files changed (2) hide show
  1. README.md +1 -2
  2. app.py +157 -570
README.md CHANGED
@@ -4,12 +4,11 @@ emoji: πŸ‘οΈ
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.9.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  short_description: Universal Multilingual Multimodal Document Retrieval
12
- hardware: zero-gpu
13
  ---
14
 
15
  # NetraEmbed - Universal Multilingual Multimodal Document Retrieval
 
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 6.0.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  short_description: Universal Multilingual Multimodal Document Retrieval
 
12
  ---
13
 
14
  # NetraEmbed - Universal Multilingual Multimodal Document Retrieval
app.py CHANGED
@@ -10,15 +10,13 @@ Features:
10
  - Query input with top-k selection (default: 5)
11
  - Similarity score display
12
  - Side-by-side comparison when both models are selected
13
- - Progressive loading with real-time updates
14
- - Proper error handling
15
  - ZeroGPU integration for efficient GPU usage
16
  """
17
 
18
  import io
19
  import gc
20
  import math
21
- from typing import Iterator, List, Optional, Tuple
22
 
23
  import gradio as gr
24
  import torch
@@ -37,7 +35,11 @@ from colpali_engine.interpretability.similarity_map_utils import normalize_simil
37
 
38
  # Configuration
39
  MAX_BATCH_SIZE = 32 # Maximum pages to process at once
40
- DEFAULT_DURATION = 120 # Default GPU duration in seconds
 
 
 
 
41
 
42
  # Global state for models and indexed documents
43
  class DocumentIndex:
@@ -49,37 +51,24 @@ class DocumentIndex:
49
  self.bigemma_processor = None
50
  self.colgemma_model = None
51
  self.colgemma_processor = None
52
- self.models_loaded = {"bigemma": False, "colgemma": False}
53
 
54
  doc_index = DocumentIndex()
55
 
56
  # Helper functions
57
- def get_loaded_models() -> List[str]:
58
- """Get list of currently loaded models."""
59
- loaded = []
60
- if doc_index.bigemma_model is not None:
61
- loaded.append("BiGemma3")
62
- if doc_index.colgemma_model is not None:
63
- loaded.append("ColGemma3")
64
- return loaded
65
-
66
- def get_model_choice_from_loaded() -> str:
67
- """Determine model choice string based on what's loaded."""
68
- loaded = get_loaded_models()
69
- if "BiGemma3" in loaded and "ColGemma3" in loaded:
70
- return "Both"
71
- elif "BiGemma3" in loaded:
72
- return "NetraEmbed (BiGemma3)"
73
- elif "ColGemma3" in loaded:
74
- return "ColNetraEmbed (ColGemma3)"
75
- else:
76
- return ""
77
-
78
- @spaces.GPU(duration=DEFAULT_DURATION)
79
  def load_bigemma_model():
80
  """Load BiGemma3 model and processor."""
81
- device = "cuda" if torch.cuda.is_available() else "cpu"
82
-
83
  if doc_index.bigemma_model is None:
84
  print("Loading BiGemma3 (NetraEmbed)...")
85
  try:
@@ -93,18 +82,15 @@ def load_bigemma_model():
93
  device_map=device,
94
  )
95
  doc_index.bigemma_model.eval()
96
- doc_index.models_loaded["bigemma"] = True
97
  print("βœ“ BiGemma3 loaded successfully")
98
  except Exception as e:
99
  print(f"❌ Failed to load BiGemma3: {str(e)}")
100
- raise
101
- return doc_index.bigemma_model, doc_index.bigemma_processor
102
 
103
- @spaces.GPU(duration=DEFAULT_DURATION)
104
  def load_colgemma_model():
105
  """Load ColGemma3 model and processor."""
106
- device = "cuda" if torch.cuda.is_available() else "cpu"
107
-
108
  if doc_index.colgemma_model is None:
109
  print("Loading ColGemma3 (ColNetraEmbed)...")
110
  try:
@@ -118,12 +104,11 @@ def load_colgemma_model():
118
  "Cognitive-Lab/ColNetraEmbed",
119
  use_fast=True,
120
  )
121
- doc_index.models_loaded["colgemma"] = True
122
  print("βœ“ ColGemma3 loaded successfully")
123
  except Exception as e:
124
  print(f"❌ Failed to load ColGemma3: {str(e)}")
125
- raise
126
- return doc_index.colgemma_model, doc_index.colgemma_processor
127
 
128
  def unload_models():
129
  """Unload models and free GPU memory."""
@@ -133,14 +118,12 @@ def unload_models():
133
  del doc_index.bigemma_processor
134
  doc_index.bigemma_model = None
135
  doc_index.bigemma_processor = None
136
- doc_index.models_loaded["bigemma"] = False
137
 
138
  if doc_index.colgemma_model is not None:
139
  del doc_index.colgemma_model
140
  del doc_index.colgemma_processor
141
  doc_index.colgemma_model = None
142
  doc_index.colgemma_processor = None
143
- doc_index.models_loaded["colgemma"] = False
144
 
145
  # Clear embeddings and images
146
  doc_index.bigemma_embeddings = None
@@ -157,42 +140,74 @@ def unload_models():
157
  except Exception as e:
158
  return f"❌ Error unloading models: {str(e)}"
159
 
160
- def clear_incompatible_embeddings(model_choice: str) -> str:
161
- """Clear embeddings that are incompatible with currently loading models."""
162
- cleared = []
163
-
164
- # If loading only BiGemma3, clear ColGemma3 embeddings
165
- if model_choice == "NetraEmbed (BiGemma3)":
166
- if doc_index.colgemma_embeddings is not None:
167
- doc_index.colgemma_embeddings = None
168
- doc_index.images = []
169
- cleared.append("ColGemma3")
170
- print("Cleared ColGemma3 embeddings")
171
-
172
- # If loading only ColGemma3, clear BiGemma3 embeddings
173
- elif model_choice == "ColNetraEmbed (ColGemma3)":
174
- if doc_index.bigemma_embeddings is not None:
175
- doc_index.bigemma_embeddings = None
176
- doc_index.images = []
177
- cleared.append("BiGemma3")
178
- print("Cleared BiGemma3 embeddings")
179
-
180
- if cleared:
181
- return f"Cleared {', '.join(cleared)} embeddings - please re-index"
182
- return ""
183
 
184
- def pdf_to_images(pdf_path: str) -> List[Image.Image]:
185
- """Convert PDF to list of PIL Images with error handling."""
186
  try:
187
- print(f"Converting PDF to images: {pdf_path}")
188
- images = convert_from_path(pdf_path, dpi=200)
189
- print(f"Converted {len(images)} pages")
190
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  except Exception as e:
192
- print(f"❌ PDF conversion error: {str(e)}")
193
- raise Exception(f"Failed to convert PDF: {str(e)}")
 
 
194
 
195
- @spaces.GPU(duration=DEFAULT_DURATION)
196
  def generate_colgemma_heatmap(
197
  image: Image.Image,
198
  query: str,
@@ -203,17 +218,14 @@ def generate_colgemma_heatmap(
203
  ) -> Image.Image:
204
  """Generate heatmap overlay for ColGemma3 results."""
205
  try:
206
- device = "cuda" if torch.cuda.is_available() else "cpu"
207
-
208
  # Re-process the single image to get the proper batch_images dict for image mask
209
  batch_images = processor.process_images([image]).to(device)
210
 
211
- # Create image mask manually (ColGemmaProcessor3 doesn't have get_image_mask)
212
  if "input_ids" in batch_images and hasattr(model.config, "image_token_id"):
213
  image_token_id = model.config.image_token_id
214
  image_mask = batch_images["input_ids"] == image_token_id
215
  else:
216
- # Fallback: all tokens are image tokens
217
  image_mask = torch.ones(
218
  image_embedding.shape[0], image_embedding.shape[1], dtype=torch.bool, device=device
219
  )
@@ -225,10 +237,9 @@ def generate_colgemma_heatmap(
225
  if n_side * n_side == num_image_tokens:
226
  n_patches = (n_side, n_side)
227
  else:
228
- # Fallback: use default calculation
229
  n_patches = (16, 16)
230
 
231
- # Generate similarity maps (returns a list of tensors)
232
  similarity_maps_list = get_similarity_maps_from_embeddings(
233
  image_embeddings=image_embedding,
234
  query_embeddings=query_embedding,
@@ -236,10 +247,9 @@ def generate_colgemma_heatmap(
236
  image_mask=image_mask,
237
  )
238
 
239
- # Get the similarity map for our image (returns a list, get first element)
240
- similarity_map = similarity_maps_list[0] # (query_length, n_patches_x, n_patches_y)
241
 
242
- # Aggregate across all query tokens (mean)
243
  if similarity_map.dtype == torch.bfloat16:
244
  similarity_map = similarity_map.float()
245
  aggregated_map = torch.mean(similarity_map, dim=0)
@@ -247,10 +257,8 @@ def generate_colgemma_heatmap(
247
  # Convert the image to an array
248
  img_array = np.array(image.convert("RGBA"))
249
 
250
- # Normalize the similarity map and convert to numpy
251
  similarity_map_array = normalize_similarity_map(aggregated_map).to(torch.float32).cpu().numpy()
252
-
253
- # Reshape to match PIL convention
254
  similarity_map_array = rearrange(similarity_map_array, "h w -> w h")
255
 
256
  # Create PIL image from similarity map
@@ -280,121 +288,19 @@ def generate_colgemma_heatmap(
280
 
281
  except Exception as e:
282
  print(f"❌ Heatmap generation error: {str(e)}")
283
- # Return original image if heatmap generation fails
284
  return image
285
 
286
- @spaces.GPU(duration=DEFAULT_DURATION)
287
- def index_bigemma_images(images: List[Image.Image]) -> torch.Tensor:
288
- """Index images with BiGemma3 model."""
289
- device = "cuda" if torch.cuda.is_available() else "cpu"
290
- model, processor = doc_index.bigemma_model, doc_index.bigemma_processor
291
-
292
- batch_images = processor.process_images(images).to(device)
293
- embeddings = model(**batch_images, embedding_dim=768)
294
-
295
- return embeddings
296
-
297
- @spaces.GPU(duration=DEFAULT_DURATION)
298
- def index_colgemma_images(images: List[Image.Image]) -> torch.Tensor:
299
- """Index images with ColGemma3 model."""
300
- device = "cuda" if torch.cuda.is_available() else "cpu"
301
- model, processor = doc_index.colgemma_model, doc_index.colgemma_processor
302
-
303
- batch_images = processor.process_images(images).to(device)
304
- embeddings = model(**batch_images)
305
-
306
- return embeddings
307
-
308
- def index_document(pdf_file, model_choice: str) -> Iterator[str]:
309
- """Upload and index a PDF document with progress updates."""
310
- if pdf_file is None:
311
- yield "⚠️ Please upload a PDF document first."
312
- return
313
-
314
- try:
315
- status_messages = []
316
-
317
- # Convert PDF to images
318
- status_messages.append("⏳ Converting PDF to images...")
319
- yield "\n".join(status_messages)
320
-
321
- doc_index.images = pdf_to_images(pdf_file.name)
322
- num_pages = len(doc_index.images)
323
-
324
- status_messages.append(f"βœ“ Converted PDF to {num_pages} images")
325
-
326
- # Check if we need to batch process
327
- if num_pages > MAX_BATCH_SIZE:
328
- status_messages.append(f"⚠️ Large PDF ({num_pages} pages). Processing in batches of {MAX_BATCH_SIZE}...")
329
- yield "\n".join(status_messages)
330
-
331
- # Index with BiGemma3
332
- if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
333
- if doc_index.bigemma_model is None:
334
- status_messages.append("⏳ Loading BiGemma3 model...")
335
- yield "\n".join(status_messages)
336
- load_bigemma_model()
337
- status_messages.append("βœ“ BiGemma3 loaded")
338
- else:
339
- status_messages.append("βœ“ Using cached BiGemma3 model")
340
-
341
- yield "\n".join(status_messages)
342
-
343
- status_messages.append("⏳ Encoding images with BiGemma3...")
344
- yield "\n".join(status_messages)
345
-
346
- doc_index.bigemma_embeddings = index_bigemma_images(doc_index.images)
347
-
348
- status_messages.append("βœ“ Indexed with BiGemma3 (shape: {})".format(doc_index.bigemma_embeddings.shape))
349
- yield "\n".join(status_messages)
350
-
351
- # Index with ColGemma3
352
- if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
353
- if doc_index.colgemma_model is None:
354
- status_messages.append("⏳ Loading ColGemma3 model...")
355
- yield "\n".join(status_messages)
356
- load_colgemma_model()
357
- status_messages.append("βœ“ ColGemma3 loaded")
358
- else:
359
- status_messages.append("βœ“ Using cached ColGemma3 model")
360
-
361
- yield "\n".join(status_messages)
362
-
363
- status_messages.append("⏳ Encoding images with ColGemma3...")
364
- yield "\n".join(status_messages)
365
-
366
- doc_index.colgemma_embeddings = index_colgemma_images(doc_index.images)
367
-
368
- status_messages.append(
369
- "βœ“ Indexed with ColGemma3 (shape: {})".format(doc_index.colgemma_embeddings.shape)
370
- )
371
- yield "\n".join(status_messages)
372
-
373
- final_status = "\n".join(status_messages) + "\n\nβœ… Document ready for querying!"
374
- yield final_status
375
-
376
- except Exception as e:
377
- import traceback
378
-
379
- error_details = traceback.format_exc()
380
- print(f"Indexing error: {error_details}")
381
- yield f"❌ Error indexing document: {str(e)}"
382
-
383
- @spaces.GPU(duration=DEFAULT_DURATION)
384
  def query_bigemma(query: str, top_k: int) -> Tuple[str, List]:
385
  """Query indexed documents with BiGemma3."""
386
- device = "cuda" if torch.cuda.is_available() else "cpu"
387
  model, processor = doc_index.bigemma_model, doc_index.bigemma_processor
388
 
389
  # Encode query
390
  batch_query = processor.process_texts([query]).to(device)
391
  query_embedding = model(**batch_query, embedding_dim=768)
392
 
393
- # Compute scores (cosine similarity)
394
- scores = processor.score(
395
- qs=query_embedding,
396
- ps=doc_index.bigemma_embeddings,
397
- )
398
 
399
  # Get top-k results
400
  top_k_actual = min(top_k, len(doc_index.images))
@@ -413,21 +319,17 @@ def query_bigemma(query: str, top_k: int) -> Tuple[str, List]:
413
 
414
  return results_text, gallery_images
415
 
416
- @spaces.GPU(duration=DEFAULT_DURATION)
417
  def query_colgemma(query: str, top_k: int, show_heatmap: bool = False) -> Tuple[str, List]:
418
  """Query indexed documents with ColGemma3."""
419
- device = "cuda" if torch.cuda.is_available() else "cpu"
420
  model, processor = doc_index.colgemma_model, doc_index.colgemma_processor
421
 
422
  # Encode query
423
  batch_query = processor.process_queries([query]).to(device)
424
  query_embedding = model(**batch_query)
425
 
426
- # Compute scores (MaxSim)
427
- scores = processor.score_multi_vector(
428
- qs=query_embedding,
429
- ps=doc_index.colgemma_embeddings,
430
- )
431
 
432
  # Get top-k results
433
  top_k_actual = min(top_k, len(doc_index.images))
@@ -456,10 +358,7 @@ def query_colgemma(query: str, top_k: int, show_heatmap: bool = False) -> Tuple[
456
  )
457
  else:
458
  gallery_images.append(
459
- (
460
- doc_index.images[idx.item()],
461
- f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})",
462
- )
463
  )
464
 
465
  return results_text, gallery_images
@@ -484,14 +383,12 @@ def query_documents(
484
  if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
485
  if doc_index.bigemma_embeddings is None:
486
  return "⚠️ Please index the document with BiGemma3 first.", None, None, None
487
-
488
  results_bi, gallery_images_bi = query_bigemma(query, top_k)
489
 
490
  # Query with ColGemma3
491
  if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
492
  if doc_index.colgemma_embeddings is None:
493
  return "⚠️ Please index the document with ColGemma3 first.", None, None, None
494
-
495
  results_col, gallery_images_col = query_colgemma(query, top_k, show_heatmap)
496
 
497
  # Return results based on model choice
@@ -504,266 +401,57 @@ def query_documents(
504
 
505
  except Exception as e:
506
  import traceback
507
-
508
  error_details = traceback.format_exc()
509
  print(f"Query error: {error_details}")
510
  return f"❌ Error during query: {str(e)}", None, None, None
511
 
512
- def load_models_with_progress(model_choice: str) -> Iterator[Tuple]:
513
- """Load models with progress updates."""
514
- if not model_choice:
515
- yield (
516
- "❌ Please select a model first.",
517
- gr.update(visible=True),
518
- gr.update(visible=False),
519
- gr.update(visible=False),
520
- gr.update(visible=False),
521
- gr.update(visible=False),
522
- gr.update(interactive=False),
523
- gr.update(interactive=False),
524
- gr.update(interactive=False),
525
- gr.update(interactive=False),
526
- gr.update(interactive=False),
527
- gr.update(value="Load model first"),
528
- )
529
- return
530
-
531
- try:
532
- status_messages = []
533
-
534
- # Clear incompatible embeddings
535
- clear_msg = clear_incompatible_embeddings(model_choice)
536
- if clear_msg:
537
- status_messages.append(f"⚠️ {clear_msg}")
538
-
539
- # Load BiGemma3
540
- if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
541
- status_messages.append("⏳ Loading BiGemma3 (NetraEmbed)...")
542
- yield (
543
- "\n".join(status_messages),
544
- gr.update(visible=True),
545
- gr.update(visible=False),
546
- gr.update(visible=False),
547
- gr.update(visible=False),
548
- gr.update(visible=False),
549
- gr.update(interactive=False),
550
- gr.update(interactive=False),
551
- gr.update(interactive=False),
552
- gr.update(interactive=False),
553
- gr.update(interactive=False),
554
- gr.update(value="Loading models..."),
555
- )
556
-
557
- load_bigemma_model()
558
- status_messages[-1] = "βœ… BiGemma3 loaded successfully"
559
- yield (
560
- "\n".join(status_messages),
561
- gr.update(visible=True),
562
- gr.update(visible=False),
563
- gr.update(visible=False),
564
- gr.update(visible=False),
565
- gr.update(visible=False),
566
- gr.update(interactive=False),
567
- gr.update(interactive=False),
568
- gr.update(interactive=False),
569
- gr.update(interactive=False),
570
- gr.update(interactive=False),
571
- gr.update(value="Loading models..."),
572
- )
573
-
574
- # Load ColGemma3
575
- if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
576
- status_messages.append("⏳ Loading ColGemma3 (ColNetraEmbed)...")
577
- yield (
578
- "\n".join(status_messages),
579
- gr.update(visible=True),
580
- gr.update(visible=False),
581
- gr.update(visible=False),
582
- gr.update(visible=False),
583
- gr.update(visible=False),
584
- gr.update(interactive=False),
585
- gr.update(interactive=False),
586
- gr.update(interactive=False),
587
- gr.update(interactive=False),
588
- gr.update(interactive=False),
589
- gr.update(value="Loading models..."),
590
- )
591
-
592
- load_colgemma_model()
593
- status_messages[-1] = "βœ… ColGemma3 loaded successfully"
594
- yield (
595
- "\n".join(status_messages),
596
- gr.update(visible=True),
597
- gr.update(visible=False),
598
- gr.update(visible=False),
599
- gr.update(visible=False),
600
- gr.update(visible=False),
601
- gr.update(interactive=False),
602
- gr.update(interactive=False),
603
- gr.update(interactive=False),
604
- gr.update(interactive=False),
605
- gr.update(interactive=False),
606
- gr.update(value="Loading models..."),
607
- )
608
-
609
- # Determine column visibility based on loaded models
610
- show_bigemma = model_choice in ["NetraEmbed (BiGemma3)", "Both"]
611
- show_colgemma = model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]
612
- show_heatmap_checkbox = model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]
613
-
614
- final_status = "\n".join(status_messages) + "\n\nβœ… Ready!"
615
- yield (
616
- final_status,
617
- gr.update(visible=False),
618
- gr.update(visible=True),
619
- gr.update(visible=show_bigemma),
620
- gr.update(visible=show_colgemma),
621
- gr.update(visible=show_heatmap_checkbox),
622
- gr.update(interactive=True),
623
- gr.update(interactive=True),
624
- gr.update(interactive=True),
625
- gr.update(interactive=True),
626
- gr.update(interactive=True),
627
- gr.update(value="Ready to index"),
628
- )
629
-
630
- except Exception as e:
631
- import traceback
632
-
633
- error_details = traceback.format_exc()
634
- print(f"Model loading error: {error_details}")
635
- yield (
636
- f"❌ Failed to load models: {str(e)}",
637
- gr.update(visible=True),
638
- gr.update(visible=False),
639
- gr.update(visible=False),
640
- gr.update(visible=False),
641
- gr.update(visible=False),
642
- gr.update(interactive=False),
643
- gr.update(interactive=False),
644
- gr.update(interactive=False),
645
- gr.update(interactive=False),
646
- gr.update(interactive=False),
647
- gr.update(value="Load model first"),
648
- )
649
-
650
- def unload_models_and_hide_ui():
651
- """Unload models and hide main UI."""
652
- status = unload_models()
653
- return (
654
- status,
655
- gr.update(visible=True),
656
- gr.update(visible=False),
657
- gr.update(visible=False),
658
- gr.update(visible=False),
659
- gr.update(visible=False),
660
- gr.update(interactive=False),
661
- gr.update(interactive=False),
662
- gr.update(interactive=False),
663
- gr.update(interactive=False),
664
- gr.update(interactive=False),
665
- gr.update(value="Load model first"),
666
- )
667
-
668
  # Create Gradio interface
669
- with gr.Blocks(
670
- title="NetraEmbed Demo",
671
- ) as demo:
672
- # Header section with model info and banner
673
- with gr.Row():
674
- with gr.Column(scale=1):
675
- gr.Markdown("# NetraEmbed")
676
- gr.HTML(
677
- """
678
- <div style="display: flex; gap: 8px; flex-wrap: wrap; margin-bottom: 15px;">
679
- <a href="https://arxiv.org/abs/2512.03514" target="_blank">
680
- <img src="https://img.shields.io/badge/arXiv-2512.03514-b31b1b.svg" alt="Paper">
681
- </a>
682
- <a href="https://github.com/adithya-s-k/colpali" target="_blank">
683
- <img src="https://img.shields.io/badge/GitHub-colpali-181717?logo=github" alt="GitHub">
684
- </a>
685
- <a href="https://huggingface.co/Cognitive-Lab/ColNetraEmbed" target="_blank">
686
- <img src="https://img.shields.io/badge/πŸ€—%20HuggingFace-Model-yellow" alt="Model">
687
- </a>
688
- <a href="https://www.cognitivelab.in/blog/introducing-netraembed" target="_blank">
689
- <img src="https://img.shields.io/badge/Blog-CognitiveLab-blue" alt="Blog">
690
- </a>
691
- <a href="https://cloud.cognitivelab.in" target="_blank">
692
- <img src="https://img.shields.io/badge/Demo-Try%20it%20out-green" alt="Demo">
693
- </a>
694
- </div>
695
- """
696
- )
697
- gr.Markdown(
698
- """
699
-
700
- **πŸš€ Universal Multilingual Multimodal Document Retrieval**
701
-
702
- Upload a PDF document, select your model(s), and query using semantic search.
703
-
704
- **Available Models:**
705
- - **NetraEmbed (BiGemma3)**: Single-vector embedding with Matryoshka representation
706
- Fast retrieval with cosine similarity
707
- - **ColNetraEmbed (ColGemma3)**: Multi-vector embedding with late interaction
708
- High-quality retrieval with MaxSim scoring and attention heatmaps
709
-
710
- """
711
- )
712
 
713
- with gr.Column(scale=1):
714
- gr.HTML(
715
- """
716
- <div style="text-align: center;">
717
- <img src="https://cdn-uploads.huggingface.co/production/uploads/6442d975ad54813badc1ddf7/-fYMikXhSuqRqm-UIdulK.png"
718
- alt="NetraEmbed Banner"
719
- style="width: 100%; height: auto; border-radius: 8px;">
720
- </div>
721
- """
722
- )
723
 
724
- gr.Markdown("---")
 
 
 
 
725
 
726
- # Compact 3-column layout
727
  with gr.Row():
728
- # Column 1: Model Management
729
  with gr.Column(scale=1):
730
- gr.Markdown("### πŸ€– Model Management")
731
  model_select = gr.Radio(
732
  choices=["NetraEmbed (BiGemma3)", "ColNetraEmbed (ColGemma3)", "Both"],
733
  value="Both",
734
  label="Select Model(s)",
735
  )
736
 
737
- load_model_btn = gr.Button("πŸ”„ Load Model", variant="primary", size="sm")
738
- unload_model_btn = gr.Button("πŸ—‘οΈ Unload", variant="secondary", size="sm")
739
-
740
- model_status = gr.Textbox(
741
- label="Status",
742
- lines=6,
743
- interactive=False,
744
- value="Select and load a model",
745
- )
746
-
747
- loading_info = gr.Markdown(
748
- """
749
- **First load:** 2-3 min
750
- **Cached:** ~30 sec
751
- """,
752
- visible=True,
753
- )
754
-
755
- # Column 2: Document Upload & Indexing
756
  with gr.Column(scale=1):
757
  gr.Markdown("### πŸ“„ Upload & Index")
758
- pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"], interactive=False)
759
- index_btn = gr.Button("πŸ“₯ Index Document", variant="primary", size="sm", interactive=False)
760
-
761
- index_status = gr.Textbox(
762
- label="Indexing Status",
763
- lines=6,
764
- interactive=False,
765
- value="Load model first",
766
- )
767
 
768
  # Column 3: Query
769
  with gr.Column(scale=1):
@@ -772,145 +460,44 @@ with gr.Blocks(
772
  label="Enter Query",
773
  placeholder="e.g., financial report, organizational structure...",
774
  lines=2,
775
- interactive=False,
776
  )
777
-
778
  with gr.Row():
779
- top_k_slider = gr.Slider(
780
- minimum=1,
781
- maximum=10,
782
- value=5,
783
- step=1,
784
- label="Top K",
785
- scale=2,
786
- interactive=False,
787
- )
788
- heatmap_checkbox = gr.Checkbox(
789
- label="Heatmaps",
790
- value=False,
791
- visible=False,
792
- scale=1,
793
- )
794
-
795
- query_btn = gr.Button("πŸ” Search", variant="primary", size="sm", interactive=False)
796
 
797
  gr.Markdown("---")
798
 
799
- # Results section (always visible after model load)
800
- with gr.Column(visible=False) as main_interface:
801
- gr.Markdown("### πŸ“Š Results")
802
-
803
- with gr.Row(equal_height=True):
804
- with gr.Column(scale=1, visible=False) as bigemma_column:
805
- bigemma_results = gr.Markdown(
806
- value="*BiGemma3 results will appear here...*",
807
- )
808
- bigemma_gallery = gr.Gallery(
809
- label="BiGemma3 - Top Retrieved Pages",
810
- show_label=True,
811
- columns=2,
812
- height="auto",
813
- object_fit="contain",
814
- )
815
- with gr.Column(scale=1, visible=False) as colgemma_column:
816
- colgemma_results = gr.Markdown(
817
- value="*ColGemma3 results will appear here...*",
818
- )
819
- colgemma_gallery = gr.Gallery(
820
- label="ColGemma3 - Top Retrieved Pages",
821
- show_label=True,
822
- columns=2,
823
- height="auto",
824
- object_fit="contain",
825
- )
826
-
827
- # Tips
828
- with gr.Accordion("πŸ’‘ Tips", open=False):
829
- gr.Markdown(
830
- """
831
- - **Both models**: Compare results side-by-side
832
- - **Scores**: BiGemma3 uses cosine similarity (-1 to 1), ColGemma3 uses MaxSim (higher is better)
833
- - **Heatmaps**: Enable to visualize ColGemma3 attention patterns (brighter = higher attention)
834
- """
835
  )
836
 
837
- # Event handlers - Model Management
838
- load_model_btn.click(
839
- fn=load_models_with_progress,
840
- inputs=[model_select],
841
- outputs=[
842
- model_status,
843
- loading_info,
844
- main_interface,
845
- bigemma_column,
846
- colgemma_column,
847
- heatmap_checkbox,
848
- pdf_upload,
849
- index_btn,
850
- query_input,
851
- top_k_slider,
852
- query_btn,
853
- index_status,
854
- ],
855
- )
856
-
857
- unload_model_btn.click(
858
- fn=unload_models_and_hide_ui,
859
- outputs=[
860
- model_status,
861
- loading_info,
862
- main_interface,
863
- bigemma_column,
864
- colgemma_column,
865
- heatmap_checkbox,
866
- pdf_upload,
867
- index_btn,
868
- query_input,
869
- top_k_slider,
870
- query_btn,
871
- index_status,
872
- ],
873
- )
874
-
875
- # Event handlers - Main Interface
876
- def index_with_current_models(pdf_file):
877
- """Index document with currently loaded models."""
878
- if pdf_file is None:
879
- yield "⚠️ Please upload a PDF document first."
880
- return
881
-
882
- model_choice = get_model_choice_from_loaded()
883
- if not model_choice:
884
- yield "⚠️ No models loaded. Please load a model first."
885
- return
886
-
887
- # Use generator from index_document
888
- for status in index_document(pdf_file, model_choice):
889
- yield status
890
-
891
- def query_with_current_models(query, top_k, show_heatmap):
892
- """Query with currently loaded models."""
893
- model_choice = get_model_choice_from_loaded()
894
- if not model_choice:
895
- return "⚠️ No models loaded. Please load a model first.", None, None, None
896
-
897
- return query_documents(query, model_choice, top_k, show_heatmap)
898
-
899
  index_btn.click(
900
- fn=index_with_current_models,
901
- inputs=[pdf_upload],
902
  outputs=[index_status],
903
  )
904
 
905
  query_btn.click(
906
- fn=query_with_current_models,
907
- inputs=[query_input, top_k_slider, heatmap_checkbox],
908
  outputs=[bigemma_results, colgemma_results, bigemma_gallery, colgemma_gallery],
909
  )
910
 
911
- # Enable queue for handling multiple requests
912
- demo.queue(max_size=20)
913
-
914
  # Launch the app
915
- if __name__ == "__main__":
916
- demo.launch()
 
10
  - Query input with top-k selection (default: 5)
11
  - Similarity score display
12
  - Side-by-side comparison when both models are selected
 
 
13
  - ZeroGPU integration for efficient GPU usage
14
  """
15
 
16
  import io
17
  import gc
18
  import math
19
+ from typing import List, Optional, Tuple
20
 
21
  import gradio as gr
22
  import torch
 
35
 
36
  # Configuration
37
  MAX_BATCH_SIZE = 32 # Maximum pages to process at once
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+
40
+ print(f"Device: {device}")
41
+ if torch.cuda.is_available():
42
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
43
 
44
  # Global state for models and indexed documents
45
  class DocumentIndex:
 
51
  self.bigemma_processor = None
52
  self.colgemma_model = None
53
  self.colgemma_processor = None
 
54
 
55
  doc_index = DocumentIndex()
56
 
57
  # Helper functions
58
+ def pdf_to_images(pdf_path: str) -> List[Image.Image]:
59
+ """Convert PDF to list of PIL Images with error handling."""
60
+ try:
61
+ print(f"Converting PDF to images: {pdf_path}")
62
+ images = convert_from_path(pdf_path, dpi=200)
63
+ print(f"Converted {len(images)} pages")
64
+ return images
65
+ except Exception as e:
66
+ print(f"❌ PDF conversion error: {str(e)}")
67
+ raise gr.Error(f"Failed to convert PDF: {str(e)}")
68
+
69
+ @spaces.GPU
 
 
 
 
 
 
 
 
 
 
70
  def load_bigemma_model():
71
  """Load BiGemma3 model and processor."""
 
 
72
  if doc_index.bigemma_model is None:
73
  print("Loading BiGemma3 (NetraEmbed)...")
74
  try:
 
82
  device_map=device,
83
  )
84
  doc_index.bigemma_model.eval()
 
85
  print("βœ“ BiGemma3 loaded successfully")
86
  except Exception as e:
87
  print(f"❌ Failed to load BiGemma3: {str(e)}")
88
+ raise gr.Error(f"Failed to load BiGemma3: {str(e)}")
89
+ return "βœ… BiGemma3 loaded"
90
 
91
+ @spaces.GPU
92
  def load_colgemma_model():
93
  """Load ColGemma3 model and processor."""
 
 
94
  if doc_index.colgemma_model is None:
95
  print("Loading ColGemma3 (ColNetraEmbed)...")
96
  try:
 
104
  "Cognitive-Lab/ColNetraEmbed",
105
  use_fast=True,
106
  )
 
107
  print("βœ“ ColGemma3 loaded successfully")
108
  except Exception as e:
109
  print(f"❌ Failed to load ColGemma3: {str(e)}")
110
+ raise gr.Error(f"Failed to load ColGemma3: {str(e)}")
111
+ return "βœ… ColGemma3 loaded"
112
 
113
  def unload_models():
114
  """Unload models and free GPU memory."""
 
118
  del doc_index.bigemma_processor
119
  doc_index.bigemma_model = None
120
  doc_index.bigemma_processor = None
 
121
 
122
  if doc_index.colgemma_model is not None:
123
  del doc_index.colgemma_model
124
  del doc_index.colgemma_processor
125
  doc_index.colgemma_model = None
126
  doc_index.colgemma_processor = None
 
127
 
128
  # Clear embeddings and images
129
  doc_index.bigemma_embeddings = None
 
140
  except Exception as e:
141
  return f"❌ Error unloading models: {str(e)}"
142
 
143
+ @spaces.GPU
144
+ def index_bigemma_images(images: List[Image.Image]) -> torch.Tensor:
145
+ """Index images with BiGemma3 model."""
146
+ model, processor = doc_index.bigemma_model, doc_index.bigemma_processor
147
+ batch_images = processor.process_images(images).to(device)
148
+ embeddings = model(**batch_images, embedding_dim=768)
149
+ return embeddings
150
+
151
+ @spaces.GPU
152
+ def index_colgemma_images(images: List[Image.Image]) -> torch.Tensor:
153
+ """Index images with ColGemma3 model."""
154
+ model, processor = doc_index.colgemma_model, doc_index.colgemma_processor
155
+ batch_images = processor.process_images(images).to(device)
156
+ embeddings = model(**batch_images)
157
+ return embeddings
158
+
159
+ def index_document(pdf_file, model_choice: str):
160
+ """Upload and index a PDF document."""
161
+ if pdf_file is None:
162
+ return "⚠️ Please upload a PDF document first."
 
 
 
163
 
 
 
164
  try:
165
+ status = []
166
+
167
+ # Convert PDF to images
168
+ status.append("⏳ Converting PDF to images...")
169
+ doc_index.images = pdf_to_images(pdf_file.name)
170
+ num_pages = len(doc_index.images)
171
+ status.append(f"βœ“ Converted PDF to {num_pages} images")
172
+
173
+ if num_pages > MAX_BATCH_SIZE:
174
+ status.append(f"⚠️ Large PDF ({num_pages} pages). Processing in batches...")
175
+
176
+ # Index with BiGemma3
177
+ if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
178
+ if doc_index.bigemma_model is None:
179
+ status.append("⏳ Loading BiGemma3 model...")
180
+ load_bigemma_model()
181
+ status.append("βœ“ BiGemma3 loaded")
182
+ else:
183
+ status.append("βœ“ Using cached BiGemma3 model")
184
+
185
+ status.append("⏳ Encoding images with BiGemma3...")
186
+ doc_index.bigemma_embeddings = index_bigemma_images(doc_index.images)
187
+ status.append(f"βœ“ Indexed with BiGemma3 (shape: {doc_index.bigemma_embeddings.shape})")
188
+
189
+ # Index with ColGemma3
190
+ if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
191
+ if doc_index.colgemma_model is None:
192
+ status.append("⏳ Loading ColGemma3 model...")
193
+ load_colgemma_model()
194
+ status.append("βœ“ ColGemma3 loaded")
195
+ else:
196
+ status.append("βœ“ Using cached ColGemma3 model")
197
+
198
+ status.append("⏳ Encoding images with ColGemma3...")
199
+ doc_index.colgemma_embeddings = index_colgemma_images(doc_index.images)
200
+ status.append(f"βœ“ Indexed with ColGemma3 (shape: {doc_index.colgemma_embeddings.shape})")
201
+
202
+ return "\n".join(status) + "\n\nβœ… Document ready for querying!"
203
+
204
  except Exception as e:
205
+ import traceback
206
+ error_details = traceback.format_exc()
207
+ print(f"Indexing error: {error_details}")
208
+ return f"❌ Error indexing document: {str(e)}"
209
 
210
+ @spaces.GPU
211
  def generate_colgemma_heatmap(
212
  image: Image.Image,
213
  query: str,
 
218
  ) -> Image.Image:
219
  """Generate heatmap overlay for ColGemma3 results."""
220
  try:
 
 
221
  # Re-process the single image to get the proper batch_images dict for image mask
222
  batch_images = processor.process_images([image]).to(device)
223
 
224
+ # Create image mask manually
225
  if "input_ids" in batch_images and hasattr(model.config, "image_token_id"):
226
  image_token_id = model.config.image_token_id
227
  image_mask = batch_images["input_ids"] == image_token_id
228
  else:
 
229
  image_mask = torch.ones(
230
  image_embedding.shape[0], image_embedding.shape[1], dtype=torch.bool, device=device
231
  )
 
237
  if n_side * n_side == num_image_tokens:
238
  n_patches = (n_side, n_side)
239
  else:
 
240
  n_patches = (16, 16)
241
 
242
+ # Generate similarity maps
243
  similarity_maps_list = get_similarity_maps_from_embeddings(
244
  image_embeddings=image_embedding,
245
  query_embeddings=query_embedding,
 
247
  image_mask=image_mask,
248
  )
249
 
250
+ similarity_map = similarity_maps_list[0]
 
251
 
252
+ # Aggregate across all query tokens
253
  if similarity_map.dtype == torch.bfloat16:
254
  similarity_map = similarity_map.float()
255
  aggregated_map = torch.mean(similarity_map, dim=0)
 
257
  # Convert the image to an array
258
  img_array = np.array(image.convert("RGBA"))
259
 
260
+ # Normalize the similarity map
261
  similarity_map_array = normalize_similarity_map(aggregated_map).to(torch.float32).cpu().numpy()
 
 
262
  similarity_map_array = rearrange(similarity_map_array, "h w -> w h")
263
 
264
  # Create PIL image from similarity map
 
288
 
289
  except Exception as e:
290
  print(f"❌ Heatmap generation error: {str(e)}")
 
291
  return image
292
 
293
+ @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  def query_bigemma(query: str, top_k: int) -> Tuple[str, List]:
295
  """Query indexed documents with BiGemma3."""
 
296
  model, processor = doc_index.bigemma_model, doc_index.bigemma_processor
297
 
298
  # Encode query
299
  batch_query = processor.process_texts([query]).to(device)
300
  query_embedding = model(**batch_query, embedding_dim=768)
301
 
302
+ # Compute scores
303
+ scores = processor.score(qs=query_embedding, ps=doc_index.bigemma_embeddings)
 
 
 
304
 
305
  # Get top-k results
306
  top_k_actual = min(top_k, len(doc_index.images))
 
319
 
320
  return results_text, gallery_images
321
 
322
+ @spaces.GPU
323
  def query_colgemma(query: str, top_k: int, show_heatmap: bool = False) -> Tuple[str, List]:
324
  """Query indexed documents with ColGemma3."""
 
325
  model, processor = doc_index.colgemma_model, doc_index.colgemma_processor
326
 
327
  # Encode query
328
  batch_query = processor.process_queries([query]).to(device)
329
  query_embedding = model(**batch_query)
330
 
331
+ # Compute scores
332
+ scores = processor.score_multi_vector(qs=query_embedding, ps=doc_index.colgemma_embeddings)
 
 
 
333
 
334
  # Get top-k results
335
  top_k_actual = min(top_k, len(doc_index.images))
 
358
  )
359
  else:
360
  gallery_images.append(
361
+ (doc_index.images[idx.item()], f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})")
 
 
 
362
  )
363
 
364
  return results_text, gallery_images
 
383
  if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
384
  if doc_index.bigemma_embeddings is None:
385
  return "⚠️ Please index the document with BiGemma3 first.", None, None, None
 
386
  results_bi, gallery_images_bi = query_bigemma(query, top_k)
387
 
388
  # Query with ColGemma3
389
  if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
390
  if doc_index.colgemma_embeddings is None:
391
  return "⚠️ Please index the document with ColGemma3 first.", None, None, None
 
392
  results_col, gallery_images_col = query_colgemma(query, top_k, show_heatmap)
393
 
394
  # Return results based on model choice
 
401
 
402
  except Exception as e:
403
  import traceback
 
404
  error_details = traceback.format_exc()
405
  print(f"Query error: {error_details}")
406
  return f"❌ Error during query: {str(e)}", None, None, None
407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  # Create Gradio interface
409
+ with gr.Blocks(title="NetraEmbed Demo") as demo:
410
+ # Header section
411
+ gr.Markdown("# NetraEmbed")
412
+ gr.HTML(
413
+ """
414
+ <div style="display: flex; gap: 8px; flex-wrap: wrap; margin-bottom: 15px;">
415
+ <a href="https://arxiv.org/abs/2512.03514" target="_blank">
416
+ <img src="https://img.shields.io/badge/arXiv-2512.03514-b31b1b.svg" alt="Paper">
417
+ </a>
418
+ <a href="https://github.com/adithya-s-k/colpali" target="_blank">
419
+ <img src="https://img.shields.io/badge/GitHub-colpali-181717?logo=github" alt="GitHub">
420
+ </a>
421
+ <a href="https://huggingface.co/Cognitive-Lab/ColNetraEmbed" target="_blank">
422
+ <img src="https://img.shields.io/badge/πŸ€—%20HuggingFace-Model-yellow" alt="Model">
423
+ </a>
424
+ </div>
425
+ """
426
+ )
427
+ gr.Markdown(
428
+ """
429
+ **πŸš€ Universal Multilingual Multimodal Document Retrieval**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
+ Upload a PDF document, select your model(s), and query using semantic search.
 
 
 
 
 
 
 
 
 
432
 
433
+ **Available Models:**
434
+ - **NetraEmbed (BiGemma3)**: Single-vector embedding - Fast retrieval with cosine similarity
435
+ - **ColNetraEmbed (ColGemma3)**: Multi-vector embedding - High-quality retrieval with MaxSim scoring and heatmaps
436
+ """
437
+ )
438
 
 
439
  with gr.Row():
440
+ # Column 1: Model Selection
441
  with gr.Column(scale=1):
442
+ gr.Markdown("### πŸ€– Model Selection")
443
  model_select = gr.Radio(
444
  choices=["NetraEmbed (BiGemma3)", "ColNetraEmbed (ColGemma3)", "Both"],
445
  value="Both",
446
  label="Select Model(s)",
447
  )
448
 
449
+ # Column 2: Document Upload
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  with gr.Column(scale=1):
451
  gr.Markdown("### πŸ“„ Upload & Index")
452
+ pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
453
+ index_btn = gr.Button("πŸ“₯ Index Document", variant="primary")
454
+ index_status = gr.Textbox(label="Status", lines=6, interactive=False)
 
 
 
 
 
 
455
 
456
  # Column 3: Query
457
  with gr.Column(scale=1):
 
460
  label="Enter Query",
461
  placeholder="e.g., financial report, organizational structure...",
462
  lines=2,
 
463
  )
 
464
  with gr.Row():
465
+ top_k_slider = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Top K", scale=2)
466
+ heatmap_checkbox = gr.Checkbox(label="Heatmaps", value=False, scale=1)
467
+ query_btn = gr.Button("πŸ” Search", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
  gr.Markdown("---")
470
 
471
+ # Results section
472
+ gr.Markdown("### πŸ“Š Results")
473
+ with gr.Row():
474
+ with gr.Column(scale=1):
475
+ bigemma_results = gr.Markdown(value="*BiGemma3 results will appear here...*")
476
+ bigemma_gallery = gr.Gallery(
477
+ label="BiGemma3 - Top Retrieved Pages",
478
+ columns=2,
479
+ height="auto",
480
+ )
481
+ with gr.Column(scale=1):
482
+ colgemma_results = gr.Markdown(value="*ColGemma3 results will appear here...*")
483
+ colgemma_gallery = gr.Gallery(
484
+ label="ColGemma3 - Top Retrieved Pages",
485
+ columns=2,
486
+ height="auto",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  )
488
 
489
+ # Event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  index_btn.click(
491
+ fn=index_document,
492
+ inputs=[pdf_upload, model_select],
493
  outputs=[index_status],
494
  )
495
 
496
  query_btn.click(
497
+ fn=query_documents,
498
+ inputs=[query_input, model_select, top_k_slider, heatmap_checkbox],
499
  outputs=[bigemma_results, colgemma_results, bigemma_gallery, colgemma_gallery],
500
  )
501
 
 
 
 
502
  # Launch the app
503
+ demo.launch()