Files changed (4) hide show
  1. README.md +1 -1
  2. app.py +402 -327
  3. pyproject.toml +60 -0
  4. uv.lock +0 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🐠
4
  colorFrom: yellow
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: yellow
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 5.47.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -1,37 +1,106 @@
1
  import colorsys
2
  import gc
3
- import os
4
- from typing import Optional
 
 
5
 
6
  import cv2
7
  import gradio as gr
8
  import numpy as np
 
9
  import torch
10
  from gradio.themes import Soft
11
  from PIL import Image, ImageDraw, ImageFont
12
-
13
  from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor, Sam3VideoModel, Sam3VideoProcessor
14
 
 
 
 
15
 
16
- def get_device_and_dtype() -> tuple[str, torch.dtype]:
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
- dtype = torch.bfloat16
19
- return device, dtype
20
 
 
 
 
21
 
22
- _GLOBAL_DEVICE, _GLOBAL_DTYPE = get_device_and_dtype()
23
- _GLOBAL_MODEL_REPO_ID = "facebook/sam3"
24
- _GLOBAL_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- _GLOBAL_TRACKER_MODEL = Sam3TrackerVideoModel.from_pretrained(
27
- _GLOBAL_MODEL_REPO_ID, torch_dtype=_GLOBAL_DTYPE, device_map=_GLOBAL_DEVICE
28
- ).eval()
29
- _GLOBAL_TRACKER_PROCESSOR = Sam3TrackerVideoProcessor.from_pretrained(_GLOBAL_MODEL_REPO_ID, token=_GLOBAL_TOKEN)
 
 
 
30
 
31
- _GLOBAL_TEXT_VIDEO_MODEL = Sam3VideoModel.from_pretrained(_GLOBAL_MODEL_REPO_ID, token=_GLOBAL_TOKEN)
32
- _GLOBAL_TEXT_VIDEO_MODEL = _GLOBAL_TEXT_VIDEO_MODEL.to(_GLOBAL_DEVICE, dtype=_GLOBAL_DTYPE).eval()
33
- _GLOBAL_TEXT_VIDEO_PROCESSOR = Sam3VideoProcessor.from_pretrained(_GLOBAL_MODEL_REPO_ID, token=_GLOBAL_TOKEN)
34
- print("Models loaded successfully!")
 
35
 
36
 
37
  def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]:
@@ -106,10 +175,10 @@ def pastel_color_for_prompt(prompt_text: str) -> tuple[int, int, int]:
106
 
107
 
108
  class AppState:
109
- def __init__(self):
110
  self.reset()
111
 
112
- def reset(self):
113
  self.video_frames: list[Image.Image] = []
114
  self.inference_session = None
115
  self.video_fps: float | None = None
@@ -130,7 +199,7 @@ class AppState:
130
  self.pending_box_start_obj_id: int | None = None
131
  self.active_tab: str = "point_box"
132
 
133
- def __repr__(self):
134
  return f"AppState(video_frames={len(self.video_frames)}, video_fps={self.video_fps}, masks_by_frame={len(self.masks_by_frame)}, color_by_obj={len(self.color_by_obj)})"
135
 
136
  @property
@@ -139,23 +208,20 @@ class AppState:
139
 
140
 
141
  def init_video_session(
142
- GLOBAL_STATE: gr.State, video: str | dict, active_tab: str = "point_box"
143
  ) -> tuple[AppState, int, int, Image.Image, str]:
144
- GLOBAL_STATE.video_frames = []
145
- GLOBAL_STATE.masks_by_frame = {}
146
- GLOBAL_STATE.color_by_obj = {}
147
- GLOBAL_STATE.color_by_prompt = {}
148
- GLOBAL_STATE.text_prompts_by_frame_obj = {}
149
- GLOBAL_STATE.clicks_by_frame_obj = {}
150
- GLOBAL_STATE.boxes_by_frame_obj = {}
151
- GLOBAL_STATE.composited_frames = {}
152
- GLOBAL_STATE.inference_session = None
153
- GLOBAL_STATE.active_tab = active_tab
154
-
155
- device = _GLOBAL_DEVICE
156
- dtype = _GLOBAL_DTYPE
157
-
158
- video_path: Optional[str] = None
159
  if isinstance(video, dict):
160
  video_path = video.get("name") or video.get("path") or video.get("data")
161
  elif isinstance(video, str):
@@ -170,7 +236,6 @@ def init_video_session(
170
  if len(frames) == 0:
171
  raise gr.Error("No frames could be loaded from the video.")
172
 
173
- MAX_SECONDS = 8.0
174
  trimmed_note = ""
175
  fps_in = info.get("fps")
176
  max_frames_allowed = int(MAX_SECONDS * fps_in) if fps_in else len(frames)
@@ -179,44 +244,49 @@ def init_video_session(
179
  trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
180
  if isinstance(info, dict):
181
  info["num_frames"] = len(frames)
182
- GLOBAL_STATE.video_frames = frames
183
- GLOBAL_STATE.video_fps = float(fps_in) if fps_in else None
184
 
185
  raw_video = [np.array(frame) for frame in frames]
186
 
187
  if active_tab == "text":
188
- processor = _GLOBAL_TEXT_VIDEO_PROCESSOR
189
- GLOBAL_STATE.inference_session = processor.init_video_session(
190
  video=frames,
191
- inference_device=device,
 
192
  processing_device="cpu",
193
  video_storage_device="cpu",
194
- dtype=dtype,
195
  )
196
  else:
197
- processor = _GLOBAL_TRACKER_PROCESSOR
198
- GLOBAL_STATE.inference_session = processor.init_video_session(
199
  video=raw_video,
200
- inference_device=device,
201
- video_storage_device="cpu",
202
  processing_device="cpu",
203
- inference_state_device=device,
204
- dtype=dtype,
205
  )
206
 
 
 
 
 
207
  first_frame = frames[0]
208
  max_idx = len(frames) - 1
209
  if active_tab == "text":
210
  status = (
211
- f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
212
- f"Device: {device}, dtype: bfloat16. Ready for text prompting."
213
  )
214
  else:
215
  status = (
216
- f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
217
- f"Device: {device}, dtype: bfloat16. Video session initialized."
218
  )
219
- return GLOBAL_STATE, 0, max_idx, first_frame, status
220
 
221
 
222
  def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
@@ -288,7 +358,7 @@ def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
288
  try:
289
  font = ImageFont.truetype(font_path, font_size)
290
  break
291
- except (OSError, IOError):
292
  continue
293
  if font is None:
294
  # Fallback to default font
@@ -340,7 +410,7 @@ def update_frame_display(state: AppState, frame_idx: int) -> Image.Image:
340
  return compose_frame(state, frame_idx)
341
 
342
 
343
- def _get_prompt_for_obj(state: AppState, obj_id: int) -> Optional[str]:
344
  """Get the prompt text associated with an object ID."""
345
  # Priority 1: Check text_prompts_by_frame_obj (most reliable)
346
  for frame_texts in state.text_prompts_by_frame_obj.values():
@@ -348,19 +418,18 @@ def _get_prompt_for_obj(state: AppState, obj_id: int) -> Optional[str]:
348
  return frame_texts[obj_id].strip()
349
 
350
  # Priority 2: Check inference session mapping
351
- if state.inference_session is not None:
352
- if (
353
- hasattr(state.inference_session, "obj_id_to_prompt_id")
354
- and obj_id in state.inference_session.obj_id_to_prompt_id
355
- ):
356
- prompt_id = state.inference_session.obj_id_to_prompt_id[obj_id]
357
- if hasattr(state.inference_session, "prompts") and prompt_id in state.inference_session.prompts:
358
- return state.inference_session.prompts[prompt_id].strip()
359
 
360
  return None
361
 
362
 
363
- def _ensure_color_for_obj(state: AppState, obj_id: int):
364
  """Assign color to object based on its prompt if available, otherwise use object ID."""
365
  prompt_text = _get_prompt_for_obj(state, obj_id)
366
 
@@ -375,6 +444,7 @@ def _ensure_color_for_obj(state: AppState, obj_id: int):
375
  state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
376
 
377
 
 
378
  def on_image_click(
379
  img: Image.Image | np.ndarray,
380
  state: AppState,
@@ -383,12 +453,13 @@ def on_image_click(
383
  label: str,
384
  clear_old: bool,
385
  evt: gr.SelectData,
386
- ) -> Image.Image:
387
  if state is None or state.inference_session is None:
388
  return img
389
 
390
- model = _GLOBAL_TRACKER_MODEL
391
- processor = _GLOBAL_TRACKER_PROCESSOR
 
392
 
393
  x = y = None
394
  if evt is not None:
@@ -417,29 +488,28 @@ def on_image_click(
417
  state.pending_box_start_obj_id = ann_obj_id
418
  state.composited_frames.pop(ann_frame_idx, None)
419
  return update_frame_display(state, ann_frame_idx)
420
- else:
421
- x1, y1 = state.pending_box_start
422
- x2, y2 = int(x), int(y)
423
- state.pending_box_start = None
424
- state.pending_box_start_frame_idx = None
425
- state.pending_box_start_obj_id = None
426
- state.composited_frames.pop(ann_frame_idx, None)
427
- x_min, y_min = min(x1, x2), min(y1, y2)
428
- x_max, y_max = max(x1, x2), max(y1, y2)
429
 
430
- box = [[[x_min, y_min, x_max, y_max]]]
431
- processor.add_inputs_to_inference_session(
432
- inference_session=state.inference_session,
433
- frame_idx=ann_frame_idx,
434
- obj_ids=ann_obj_id,
435
- input_boxes=box,
436
- )
437
 
438
- frame_boxes = state.boxes_by_frame_obj.setdefault(ann_frame_idx, {})
439
- obj_boxes = frame_boxes.setdefault(ann_obj_id, [])
440
- obj_boxes.clear()
441
- obj_boxes.append((x_min, y_min, x_max, y_max))
442
- state.composited_frames.pop(ann_frame_idx, None)
443
  else:
444
  label_int = 1 if str(label).lower().startswith("pos") else 0
445
 
@@ -485,23 +555,26 @@ def on_image_click(
485
 
486
  state.composited_frames.pop(ann_frame_idx, None)
487
 
488
- return update_frame_display(state, ann_frame_idx)
489
 
 
490
 
 
 
491
  def on_text_prompt(
492
  state: AppState,
493
  frame_idx: int,
494
  text_prompt: str,
495
- ) -> tuple[Image.Image, str, str]:
496
  if state is None or state.inference_session is None:
497
  return None, "Upload a video and enter text prompt.", "**Active prompts:** None"
498
 
499
- model = _GLOBAL_TEXT_VIDEO_MODEL
500
- processor = _GLOBAL_TEXT_VIDEO_PROCESSOR
501
 
502
  if not text_prompt or not text_prompt.strip():
503
  active_prompts = _get_active_prompts_display(state)
504
- return update_frame_display(state, int(frame_idx)), "Please enter a text prompt.", active_prompts
505
 
506
  frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
507
 
@@ -509,7 +582,9 @@ def on_text_prompt(
509
  prompt_texts = [p.strip() for p in text_prompt.split(",") if p.strip()]
510
  if not prompt_texts:
511
  active_prompts = _get_active_prompts_display(state)
512
- return update_frame_display(state, int(frame_idx)), "Please enter a valid text prompt.", active_prompts
 
 
513
 
514
  # Add text prompt(s) - supports both single string and list of strings
515
  state.inference_session = processor.add_text_prompt(
@@ -593,7 +668,10 @@ def on_text_prompt(
593
  status = f"Processed text prompt(s) {prompts_str} on frame {frame_idx}. No objects detected."
594
 
595
  active_prompts = _get_active_prompts_display(state)
596
- return update_frame_display(state, int(frame_idx)), status, active_prompts
 
 
 
597
 
598
 
599
  def _get_active_prompts_display(state: AppState) -> str:
@@ -610,32 +688,35 @@ def _get_active_prompts_display(state: AppState) -> str:
610
  return "**Active prompts:** None"
611
 
612
 
613
- def propagate_masks(GLOBAL_STATE: gr.State):
614
- if GLOBAL_STATE is None:
615
- return GLOBAL_STATE, "Load a video first.", gr.update()
 
616
 
617
- if GLOBAL_STATE.active_tab != "text" and GLOBAL_STATE.inference_session is None:
618
- return GLOBAL_STATE, "Load a video first.", gr.update()
619
 
620
- total = max(1, GLOBAL_STATE.num_frames)
621
  processed = 0
622
 
623
- yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update()
624
 
625
  last_frame_idx = 0
626
 
627
  with torch.no_grad():
628
- if GLOBAL_STATE.active_tab == "text":
629
- if GLOBAL_STATE.inference_session is None:
630
- yield GLOBAL_STATE, "Text video model not loaded.", gr.update()
631
  return
632
 
633
- model = _GLOBAL_TEXT_VIDEO_MODEL
634
- processor = _GLOBAL_TEXT_VIDEO_PROCESSOR
 
 
635
 
636
  # Collect all unique prompts from existing frame annotations
637
  text_prompt_to_obj_ids = {}
638
- for frame_idx, frame_texts in GLOBAL_STATE.text_prompts_by_frame_obj.items():
639
  for obj_id, text_prompt in frame_texts.items():
640
  if text_prompt not in text_prompt_to_obj_ids:
641
  text_prompt_to_obj_ids[text_prompt] = []
@@ -643,8 +724,8 @@ def propagate_masks(GLOBAL_STATE: gr.State):
643
  text_prompt_to_obj_ids[text_prompt].append(obj_id)
644
 
645
  # Also check if there are prompts already in the inference session
646
- if hasattr(GLOBAL_STATE.inference_session, "prompts") and GLOBAL_STATE.inference_session.prompts:
647
- for prompt_text in GLOBAL_STATE.inference_session.prompts.values():
648
  if prompt_text not in text_prompt_to_obj_ids:
649
  text_prompt_to_obj_ids[prompt_text] = []
650
 
@@ -652,31 +733,30 @@ def propagate_masks(GLOBAL_STATE: gr.State):
652
  text_prompt_to_obj_ids[text_prompt].sort()
653
 
654
  if not text_prompt_to_obj_ids:
655
- yield GLOBAL_STATE, "No text prompts found. Please add a text prompt first.", gr.update()
 
656
  return
657
 
658
  # Add all prompts to the inference session (processor handles deduplication)
659
- for text_prompt in text_prompt_to_obj_ids.keys():
660
- GLOBAL_STATE.inference_session = processor.add_text_prompt(
661
- inference_session=GLOBAL_STATE.inference_session,
662
  text=text_prompt,
663
  )
664
 
665
- earliest_frame = (
666
- min(GLOBAL_STATE.text_prompts_by_frame_obj.keys()) if GLOBAL_STATE.text_prompts_by_frame_obj else 0
667
- )
668
 
669
- frames_to_track = GLOBAL_STATE.num_frames - earliest_frame
670
 
671
  outputs_per_frame = {}
672
 
673
  for model_outputs in model.propagate_in_video_iterator(
674
- inference_session=GLOBAL_STATE.inference_session,
675
  start_frame_idx=earliest_frame,
676
  max_frame_num_to_track=frames_to_track,
677
  ):
678
  processed_outputs = processor.postprocess_outputs(
679
- GLOBAL_STATE.inference_session,
680
  model_outputs,
681
  )
682
  frame_idx = model_outputs.frame_idx
@@ -687,8 +767,8 @@ def propagate_masks(GLOBAL_STATE: gr.State):
687
  scores = processed_outputs["scores"]
688
  prompt_to_obj_ids = processed_outputs.get("prompt_to_obj_ids", {})
689
 
690
- masks_for_frame = GLOBAL_STATE.masks_by_frame.setdefault(frame_idx, {})
691
- frame_texts = GLOBAL_STATE.text_prompts_by_frame_obj.setdefault(frame_idx, {})
692
 
693
  num_objects = len(object_ids)
694
  if num_objects > 0:
@@ -715,183 +795,185 @@ def propagate_masks(GLOBAL_STATE: gr.State):
715
  # Store prompt and assign color
716
  if found_prompt:
717
  frame_texts[current_obj_id] = found_prompt.strip()
718
- _ensure_color_for_obj(GLOBAL_STATE, current_obj_id)
719
 
720
- GLOBAL_STATE.composited_frames.pop(frame_idx, None)
721
  last_frame_idx = frame_idx
722
  processed += 1
723
  if processed % 30 == 0 or processed == total:
724
- yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
 
 
725
  else:
726
- if GLOBAL_STATE.inference_session is None:
727
- yield GLOBAL_STATE, "Tracker model not loaded.", gr.update()
728
  return
729
 
730
- model = _GLOBAL_TRACKER_MODEL
731
- processor = _GLOBAL_TRACKER_PROCESSOR
732
 
733
- for sam2_video_output in model.propagate_in_video_iterator(
734
- inference_session=GLOBAL_STATE.inference_session
735
- ):
736
  video_res_masks = processor.post_process_masks(
737
  [sam2_video_output.pred_masks],
738
- original_sizes=[
739
- [GLOBAL_STATE.inference_session.video_height, GLOBAL_STATE.inference_session.video_width]
740
- ],
741
  )[0]
742
 
743
  frame_idx = sam2_video_output.frame_idx
744
- for i, out_obj_id in enumerate(GLOBAL_STATE.inference_session.obj_ids):
745
- _ensure_color_for_obj(GLOBAL_STATE, int(out_obj_id))
746
  mask_2d = video_res_masks[i].cpu().numpy()
747
- masks_for_frame = GLOBAL_STATE.masks_by_frame.setdefault(frame_idx, {})
748
  masks_for_frame[int(out_obj_id)] = mask_2d
749
- GLOBAL_STATE.composited_frames.pop(frame_idx, None)
750
 
751
  last_frame_idx = frame_idx
752
  processed += 1
753
  if processed % 30 == 0 or processed == total:
754
- yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
 
 
755
 
756
  text = f"Propagated masks across {processed} frames."
757
- yield GLOBAL_STATE, text, gr.update(value=last_frame_idx)
 
758
 
759
 
760
- def reset_prompts(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, str, str]:
761
  """Reset prompts and all outputs, but keep processed frames and cached vision features."""
762
- if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
763
- active_prompts = _get_active_prompts_display(GLOBAL_STATE)
764
- return GLOBAL_STATE, None, "No active session to reset.", active_prompts
765
 
766
- if GLOBAL_STATE.active_tab != "text":
767
- active_prompts = _get_active_prompts_display(GLOBAL_STATE)
768
- return GLOBAL_STATE, None, "Reset prompts is only available for text prompting mode.", active_prompts
769
 
770
  # Reset inference session tracking data but keep cache and processed frames
771
- if hasattr(GLOBAL_STATE.inference_session, "reset_tracking_data"):
772
- GLOBAL_STATE.inference_session.reset_tracking_data()
773
 
774
  # Manually clear prompts (reset_tracking_data doesn't clear prompts themselves)
775
- if hasattr(GLOBAL_STATE.inference_session, "prompts"):
776
- GLOBAL_STATE.inference_session.prompts.clear()
777
- if hasattr(GLOBAL_STATE.inference_session, "prompt_input_ids"):
778
- GLOBAL_STATE.inference_session.prompt_input_ids.clear()
779
- if hasattr(GLOBAL_STATE.inference_session, "prompt_embeddings"):
780
- GLOBAL_STATE.inference_session.prompt_embeddings.clear()
781
- if hasattr(GLOBAL_STATE.inference_session, "prompt_attention_masks"):
782
- GLOBAL_STATE.inference_session.prompt_attention_masks.clear()
783
- if hasattr(GLOBAL_STATE.inference_session, "obj_id_to_prompt_id"):
784
- GLOBAL_STATE.inference_session.obj_id_to_prompt_id.clear()
785
 
786
  # Reset detection-tracking fusion state
787
- if hasattr(GLOBAL_STATE.inference_session, "obj_id_to_score"):
788
- GLOBAL_STATE.inference_session.obj_id_to_score.clear()
789
- if hasattr(GLOBAL_STATE.inference_session, "obj_id_to_tracker_score_frame_wise"):
790
- GLOBAL_STATE.inference_session.obj_id_to_tracker_score_frame_wise.clear()
791
- if hasattr(GLOBAL_STATE.inference_session, "obj_id_to_last_occluded"):
792
- GLOBAL_STATE.inference_session.obj_id_to_last_occluded.clear()
793
- if hasattr(GLOBAL_STATE.inference_session, "max_obj_id"):
794
- GLOBAL_STATE.inference_session.max_obj_id = -1
795
- if hasattr(GLOBAL_STATE.inference_session, "obj_first_frame_idx"):
796
- GLOBAL_STATE.inference_session.obj_first_frame_idx.clear()
797
- if hasattr(GLOBAL_STATE.inference_session, "unmatched_frame_inds"):
798
- GLOBAL_STATE.inference_session.unmatched_frame_inds.clear()
799
- if hasattr(GLOBAL_STATE.inference_session, "overlap_pair_to_frame_inds"):
800
- GLOBAL_STATE.inference_session.overlap_pair_to_frame_inds.clear()
801
- if hasattr(GLOBAL_STATE.inference_session, "trk_keep_alive"):
802
- GLOBAL_STATE.inference_session.trk_keep_alive.clear()
803
- if hasattr(GLOBAL_STATE.inference_session, "removed_obj_ids"):
804
- GLOBAL_STATE.inference_session.removed_obj_ids.clear()
805
- if hasattr(GLOBAL_STATE.inference_session, "suppressed_obj_ids"):
806
- GLOBAL_STATE.inference_session.suppressed_obj_ids.clear()
807
- if hasattr(GLOBAL_STATE.inference_session, "hotstart_removed_obj_ids"):
808
- GLOBAL_STATE.inference_session.hotstart_removed_obj_ids.clear()
809
 
810
  # Clear all app state outputs
811
- GLOBAL_STATE.masks_by_frame.clear()
812
- GLOBAL_STATE.text_prompts_by_frame_obj.clear()
813
- GLOBAL_STATE.composited_frames.clear()
814
- GLOBAL_STATE.color_by_obj.clear()
815
- GLOBAL_STATE.color_by_prompt.clear()
816
 
817
  # Update display
818
- current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0))
819
- current_idx = max(0, min(current_idx, GLOBAL_STATE.num_frames - 1))
820
- preview_img = update_frame_display(GLOBAL_STATE, current_idx)
821
- active_prompts = _get_active_prompts_display(GLOBAL_STATE)
822
  status = "Prompts and outputs reset. Processed frames and cached vision features preserved."
823
 
824
- return GLOBAL_STATE, preview_img, status, active_prompts
825
 
826
 
827
- def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str, str]:
828
- if not GLOBAL_STATE.video_frames:
829
- return GLOBAL_STATE, None, 0, 0, "Session reset. Load a new video.", "**Active prompts:** None"
830
 
831
- if GLOBAL_STATE.active_tab == "text":
832
- if GLOBAL_STATE.video_frames:
833
- processor = _GLOBAL_TEXT_VIDEO_PROCESSOR
834
- GLOBAL_STATE.inference_session = processor.init_video_session(
835
- video=GLOBAL_STATE.video_frames,
836
- inference_device=_GLOBAL_DEVICE,
837
  processing_device="cpu",
838
  video_storage_device="cpu",
839
- dtype=_GLOBAL_DTYPE,
840
- )
841
- elif GLOBAL_STATE.inference_session is not None and hasattr(
842
- GLOBAL_STATE.inference_session, "reset_inference_session"
843
- ):
844
- GLOBAL_STATE.inference_session.reset_inference_session()
845
- else:
846
- if GLOBAL_STATE.video_frames:
847
- processor = _GLOBAL_TRACKER_PROCESSOR
848
- raw_video = [np.array(frame) for frame in GLOBAL_STATE.video_frames]
849
- GLOBAL_STATE.inference_session = processor.init_video_session(
850
- video=raw_video,
851
- inference_device=_GLOBAL_DEVICE,
852
- video_storage_device="cpu",
853
- processing_device="cpu",
854
- dtype=_GLOBAL_DTYPE,
855
  )
 
 
 
 
 
 
 
 
 
 
 
 
856
 
857
- GLOBAL_STATE.masks_by_frame.clear()
858
- GLOBAL_STATE.clicks_by_frame_obj.clear()
859
- GLOBAL_STATE.boxes_by_frame_obj.clear()
860
- GLOBAL_STATE.text_prompts_by_frame_obj.clear()
861
- GLOBAL_STATE.composited_frames.clear()
862
- GLOBAL_STATE.color_by_obj.clear()
863
- GLOBAL_STATE.color_by_prompt.clear()
864
- GLOBAL_STATE.pending_box_start = None
865
- GLOBAL_STATE.pending_box_start_frame_idx = None
866
- GLOBAL_STATE.pending_box_start_obj_id = None
867
 
868
  gc.collect()
869
 
870
- current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0))
871
- current_idx = max(0, min(current_idx, GLOBAL_STATE.num_frames - 1))
872
- preview_img = update_frame_display(GLOBAL_STATE, current_idx)
873
- slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True)
874
  slider_value = gr.update(value=current_idx)
875
  status = "Session reset. Prompts cleared; video preserved."
876
- active_prompts = _get_active_prompts_display(GLOBAL_STATE)
877
- return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status, active_prompts
878
 
879
 
880
- def _on_video_change_pointbox(GLOBAL_STATE: gr.State, video):
881
- GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video, "point_box")
882
  return (
883
- GLOBAL_STATE,
884
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
885
  first_frame,
886
  status,
887
  )
888
 
889
 
890
- def _on_video_change_text(GLOBAL_STATE: gr.State, video):
891
- GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video, "text")
892
- active_prompts = _get_active_prompts_display(GLOBAL_STATE)
 
 
893
  return (
894
- GLOBAL_STATE,
895
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
896
  first_frame,
897
  status,
@@ -899,10 +981,8 @@ def _on_video_change_text(GLOBAL_STATE: gr.State, video):
899
  )
900
 
901
 
902
- theme = Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")
903
-
904
- with gr.Blocks(title="SAM3", theme=theme) as demo:
905
- GLOBAL_STATE = gr.State(AppState())
906
 
907
  gr.Markdown(
908
  """
@@ -934,15 +1014,13 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
934
 
935
  with gr.Row():
936
  with gr.Column(scale=1):
937
- video_in_text = gr.Video(label="Upload video", sources=["upload", "webcam"], interactive=True)
938
  load_status_text = gr.Markdown(visible=True)
939
  reset_btn_text = gr.Button("Reset Session", variant="secondary")
940
  with gr.Column(scale=2):
941
- preview_text = gr.Image(label="Preview", interactive=True)
942
  with gr.Row():
943
- frame_slider_text = gr.Slider(
944
- label="Frame", minimum=0, maximum=0, step=1, value=0, interactive=True
945
- )
946
  with gr.Column(scale=0):
947
  propagate_btn_text = gr.Button("Propagate across video", variant="primary")
948
  propagate_status_text = gr.Markdown(visible=True)
@@ -969,12 +1047,9 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
969
  ]
970
  with gr.Row():
971
  gr.Examples(
972
- examples=examples_list_text,
973
- inputs=[GLOBAL_STATE, video_in_text],
974
- fn=_on_video_change_text,
975
- outputs=[GLOBAL_STATE, frame_slider_text, preview_text, load_status_text, active_prompts_display],
976
  label="Examples",
977
- cache_examples=False,
 
978
  examples_per_page=5,
979
  )
980
 
@@ -1000,17 +1075,13 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
1000
 
1001
  with gr.Row():
1002
  with gr.Column(scale=1):
1003
- video_in_pointbox = gr.Video(
1004
- label="Upload video", sources=["upload", "webcam"], interactive=True, max_length=7
1005
- )
1006
  load_status_pointbox = gr.Markdown(visible=True)
1007
  reset_btn_pointbox = gr.Button("Reset Session", variant="secondary")
1008
  with gr.Column(scale=2):
1009
- preview_pointbox = gr.Image(label="Preview", interactive=True)
1010
  with gr.Row():
1011
- frame_slider_pointbox = gr.Slider(
1012
- label="Frame", minimum=0, maximum=0, step=1, value=0, interactive=True
1013
- )
1014
  with gr.Column(scale=0):
1015
  propagate_btn_pointbox = gr.Button("Propagate across video", variant="primary")
1016
  propagate_status_pointbox = gr.Markdown(visible=True)
@@ -1032,105 +1103,101 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
1032
  ]
1033
  with gr.Row():
1034
  gr.Examples(
1035
- examples=examples_list_pointbox,
1036
- inputs=[GLOBAL_STATE, video_in_pointbox],
1037
- fn=_on_video_change_pointbox,
1038
- outputs=[GLOBAL_STATE, frame_slider_pointbox, preview_pointbox, load_status_pointbox],
1039
  label="Examples",
1040
- cache_examples=False,
 
1041
  examples_per_page=5,
1042
  )
1043
 
1044
  video_in_pointbox.change(
1045
- _on_video_change_pointbox,
1046
- inputs=[GLOBAL_STATE, video_in_pointbox],
1047
- outputs=[GLOBAL_STATE, frame_slider_pointbox, preview_pointbox, load_status_pointbox],
1048
  show_progress=True,
1049
  )
1050
 
1051
- def _sync_frame_idx_pointbox(state_in: AppState, idx: int):
1052
  if state_in is not None:
1053
  state_in.current_frame_idx = int(idx)
1054
  return update_frame_display(state_in, int(idx))
1055
 
1056
  frame_slider_pointbox.change(
1057
- _sync_frame_idx_pointbox,
1058
- inputs=[GLOBAL_STATE, frame_slider_pointbox],
1059
  outputs=preview_pointbox,
1060
  )
1061
 
1062
  video_in_text.change(
1063
- _on_video_change_text,
1064
- inputs=[GLOBAL_STATE, video_in_text],
1065
- outputs=[GLOBAL_STATE, frame_slider_text, preview_text, load_status_text, active_prompts_display],
1066
  show_progress=True,
1067
  )
1068
 
1069
- def _sync_frame_idx_text(state_in: AppState, idx: int):
1070
  if state_in is not None:
1071
  state_in.current_frame_idx = int(idx)
1072
  return update_frame_display(state_in, int(idx))
1073
 
1074
  frame_slider_text.change(
1075
- _sync_frame_idx_text,
1076
- inputs=[GLOBAL_STATE, frame_slider_text],
1077
  outputs=preview_text,
1078
  )
1079
 
1080
- def _sync_obj_id(s: AppState, oid):
1081
  if s is not None and oid is not None:
1082
  s.current_obj_id = int(oid)
1083
- return gr.update()
1084
 
1085
- obj_id_inp.change(_sync_obj_id, inputs=[GLOBAL_STATE, obj_id_inp], outputs=[])
 
 
 
1086
 
1087
- def _sync_label(s: AppState, lab: str):
1088
  if s is not None and lab is not None:
1089
  s.current_label = str(lab)
1090
- return gr.update()
1091
 
1092
- label_radio.change(_sync_label, inputs=[GLOBAL_STATE, label_radio], outputs=[])
 
 
 
1093
 
1094
- def _sync_prompt_type(s: AppState, val: str):
1095
  if s is not None and val is not None:
1096
  s.current_prompt_type = str(val)
1097
  s.pending_box_start = None
1098
  is_points = str(val).lower() == "points"
1099
- updates = [
1100
  gr.update(visible=is_points),
1101
  gr.update(interactive=is_points) if is_points else gr.update(value=True, interactive=False),
1102
- ]
1103
- return updates
1104
 
1105
  prompt_type.change(
1106
- _sync_prompt_type,
1107
- inputs=[GLOBAL_STATE, prompt_type],
1108
  outputs=[label_radio, clear_old_chk],
1109
  )
1110
 
1111
  preview_pointbox.select(
1112
- on_image_click,
1113
- [preview_pointbox, GLOBAL_STATE, frame_slider_pointbox, obj_id_inp, label_radio, clear_old_chk],
1114
- preview_pointbox,
1115
  )
1116
 
1117
- def _on_text_apply(state: AppState, frame_idx: int, text: str):
1118
- img, status, active_prompts = on_text_prompt(state, frame_idx, text)
1119
- return img, status, active_prompts
1120
-
1121
  text_apply_btn.click(
1122
- _on_text_apply,
1123
- inputs=[GLOBAL_STATE, frame_slider_text, text_prompt_input],
1124
- outputs=[preview_text, text_status, active_prompts_display],
1125
  )
1126
 
1127
  reset_prompts_btn.click(
1128
- reset_prompts,
1129
- inputs=[GLOBAL_STATE],
1130
- outputs=[GLOBAL_STATE, preview_text, text_status, active_prompts_display],
1131
  )
1132
 
1133
- def _render_video(s: AppState):
1134
  if s is None or s.num_frames == 0:
1135
  raise gr.Error("Load a video first.")
1136
  fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12
@@ -1144,44 +1211,52 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
1144
  frames_np.append(np.array(img)[:, :, ::-1])
1145
  if (idx + 1) % 60 == 0:
1146
  gc.collect()
1147
- out_path = "/tmp/sam3_playback.mp4"
1148
  try:
1149
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
1150
- writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
1151
- for fr_bgr in frames_np:
1152
- writer.write(fr_bgr)
1153
- writer.release()
1154
- return out_path
 
1155
  except Exception as e:
1156
  print(f"Failed to render video with cv2: {e}")
1157
  raise gr.Error(f"Failed to render video: {e}")
1158
 
1159
- render_btn_pointbox.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video_pointbox])
1160
- render_btn_text.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video_text])
 
 
 
 
 
 
 
 
1161
 
1162
  propagate_btn_pointbox.click(
1163
- propagate_masks,
1164
- inputs=[GLOBAL_STATE],
1165
- outputs=[GLOBAL_STATE, propagate_status_pointbox, frame_slider_pointbox],
1166
  )
1167
 
1168
  propagate_btn_text.click(
1169
- propagate_masks,
1170
- inputs=[GLOBAL_STATE],
1171
- outputs=[GLOBAL_STATE, propagate_status_text, frame_slider_text],
1172
  )
1173
 
1174
  reset_btn_pointbox.click(
1175
- reset_session,
1176
- inputs=GLOBAL_STATE,
1177
- outputs=[GLOBAL_STATE, preview_pointbox, frame_slider_pointbox, frame_slider_pointbox, load_status_pointbox],
1178
  )
1179
 
1180
  reset_btn_text.click(
1181
- reset_session,
1182
- inputs=GLOBAL_STATE,
1183
  outputs=[
1184
- GLOBAL_STATE,
1185
  preview_text,
1186
  frame_slider_text,
1187
  frame_slider_text,
 
1
  import colorsys
2
  import gc
3
+ import tempfile
4
+ from collections import defaultdict
5
+ from collections.abc import Iterator, Mapping, Sequence
6
+ from typing import Any
7
 
8
  import cv2
9
  import gradio as gr
10
  import numpy as np
11
+ import spaces
12
  import torch
13
  from gradio.themes import Soft
14
  from PIL import Image, ImageDraw, ImageFont
 
15
  from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor, Sam3VideoModel, Sam3VideoProcessor
16
 
17
+ MODEL_ID = "facebook/sam3"
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ DTYPE = torch.bfloat16
20
 
21
+ TRACKER_MODEL = Sam3TrackerVideoModel.from_pretrained(MODEL_ID, torch_dtype=DTYPE).to(DEVICE).eval()
22
+ TRACKER_PROCESSOR = Sam3TrackerVideoProcessor.from_pretrained(MODEL_ID)
 
 
23
 
24
+ TEXT_VIDEO_MODEL = Sam3VideoModel.from_pretrained(MODEL_ID).to(DEVICE, dtype=DTYPE).eval()
25
+ TEXT_VIDEO_PROCESSOR = Sam3VideoProcessor.from_pretrained(MODEL_ID)
26
+ print("Models loaded successfully!")
27
 
28
+ MAX_SECONDS = 8.0
29
+
30
+
31
+ def to_device_recursive(obj: Any, device: str | torch.device) -> Any: # noqa: ANN401
32
+ """Return a new object where all torch.Tensors reachable from `obj` are moved to the given device.
33
+
34
+ - Does NOT mutate the original object.
35
+ - Handles:
36
+ * torch.Tensor
37
+ * Mapping (e.g. dict, defaultdict, OrderedDict, etc.)
38
+ * Sequence (e.g. list, tuple) except str/bytes
39
+ * Custom classes with attributes (__dict__)
40
+ - Tries to preserve container types where reasonable.
41
+ """
42
+ device = torch.device(device)
43
+ memo = {}
44
+
45
+ def _convert(x: Any) -> Any: # noqa: ANN401, C901
46
+ obj_id = id(x)
47
+ if obj_id in memo:
48
+ return memo[obj_id]
49
+
50
+ # 1. Tensor
51
+ if isinstance(x, torch.Tensor):
52
+ y = x.to(device)
53
+ memo[obj_id] = y
54
+ return y
55
+
56
+ # 2. Mapping (dict, defaultdict, etc.)
57
+ if isinstance(x, Mapping):
58
+ # Special case: defaultdict
59
+ if isinstance(x, defaultdict):
60
+ y = defaultdict(x.default_factory)
61
+ memo[obj_id] = y
62
+ for k, v in x.items():
63
+ y[k] = _convert(v)
64
+ return y
65
+
66
+ # Try to rebuild the same type using (key, value) pairs
67
+ try:
68
+ y = type(x)((k, _convert(v)) for k, v in x.items())
69
+ memo[obj_id] = y
70
+ return y
71
+ except TypeError:
72
+ # Fallback: plain dict
73
+ y = {k: _convert(v) for k, v in x.items()}
74
+ memo[obj_id] = y
75
+ return y
76
+
77
+ # 3. Sequence (list/tuple/etc.) but not str/bytes
78
+ if isinstance(x, Sequence) and not isinstance(x, (str, bytes, bytearray)):
79
+ if isinstance(x, list):
80
+ y = [_convert(v) for v in x]
81
+ elif isinstance(x, tuple):
82
+ y = type(x)(_convert(v) for v in x)
83
+ else:
84
+ try:
85
+ y = type(x)(_convert(v) for v in x)
86
+ except TypeError:
87
+ y = [_convert(v) for v in x]
88
+ memo[obj_id] = y
89
+ return y
90
 
91
+ # 4. Custom object with attributes (__dict__)
92
+ if hasattr(x, "__dict__") and not isinstance(x, type):
93
+ new_obj = x.__class__.__new__(x.__class__)
94
+ memo[obj_id] = new_obj
95
+ for name, value in vars(x).items():
96
+ setattr(new_obj, name, _convert(value))
97
+ return new_obj
98
 
99
+ # 5. Everything else → keep as-is
100
+ memo[obj_id] = x
101
+ return x
102
+
103
+ return _convert(obj)
104
 
105
 
106
  def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]:
 
175
 
176
 
177
  class AppState:
178
+ def __init__(self) -> None:
179
  self.reset()
180
 
181
+ def reset(self) -> None:
182
  self.video_frames: list[Image.Image] = []
183
  self.inference_session = None
184
  self.video_fps: float | None = None
 
199
  self.pending_box_start_obj_id: int | None = None
200
  self.active_tab: str = "point_box"
201
 
202
+ def __repr__(self) -> str:
203
  return f"AppState(video_frames={len(self.video_frames)}, video_fps={self.video_fps}, masks_by_frame={len(self.masks_by_frame)}, color_by_obj={len(self.color_by_obj)})"
204
 
205
  @property
 
208
 
209
 
210
  def init_video_session(
211
+ state: AppState, video: str | dict, active_tab: str = "point_box"
212
  ) -> tuple[AppState, int, int, Image.Image, str]:
213
+ state.video_frames = []
214
+ state.masks_by_frame = {}
215
+ state.color_by_obj = {}
216
+ state.color_by_prompt = {}
217
+ state.text_prompts_by_frame_obj = {}
218
+ state.clicks_by_frame_obj = {}
219
+ state.boxes_by_frame_obj = {}
220
+ state.composited_frames = {}
221
+ state.inference_session = None
222
+ state.active_tab = active_tab
223
+
224
+ video_path: str | None = None
 
 
 
225
  if isinstance(video, dict):
226
  video_path = video.get("name") or video.get("path") or video.get("data")
227
  elif isinstance(video, str):
 
236
  if len(frames) == 0:
237
  raise gr.Error("No frames could be loaded from the video.")
238
 
 
239
  trimmed_note = ""
240
  fps_in = info.get("fps")
241
  max_frames_allowed = int(MAX_SECONDS * fps_in) if fps_in else len(frames)
 
244
  trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
245
  if isinstance(info, dict):
246
  info["num_frames"] = len(frames)
247
+ state.video_frames = frames
248
+ state.video_fps = float(fps_in) if fps_in else None
249
 
250
  raw_video = [np.array(frame) for frame in frames]
251
 
252
  if active_tab == "text":
253
+ processor = TEXT_VIDEO_PROCESSOR
254
+ state.inference_session = processor.init_video_session(
255
  video=frames,
256
+ inference_device="cpu",
257
+ inference_state_device="cpu",
258
  processing_device="cpu",
259
  video_storage_device="cpu",
260
+ dtype=DTYPE,
261
  )
262
  else:
263
+ processor = TRACKER_PROCESSOR
264
+ state.inference_session = processor.init_video_session(
265
  video=raw_video,
266
+ inference_device="cpu",
267
+ inference_state_device="cpu",
268
  processing_device="cpu",
269
+ video_storage_device="cpu",
270
+ dtype=DTYPE,
271
  )
272
 
273
+ state.inference_session.inference_device = DEVICE
274
+ state.inference_session.processing_device = DEVICE
275
+ state.inference_session.cache.inference_device = DEVICE
276
+
277
  first_frame = frames[0]
278
  max_idx = len(frames) - 1
279
  if active_tab == "text":
280
  status = (
281
+ f"Loaded {len(frames)} frames @ {state.video_fps or 'unknown'} fps{trimmed_note}. "
282
+ f"Device: {DEVICE}, dtype: bfloat16. Ready for text prompting."
283
  )
284
  else:
285
  status = (
286
+ f"Loaded {len(frames)} frames @ {state.video_fps or 'unknown'} fps{trimmed_note}. "
287
+ f"Device: {DEVICE}, dtype: bfloat16. Video session initialized."
288
  )
289
+ return state, 0, max_idx, first_frame, status
290
 
291
 
292
  def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
 
358
  try:
359
  font = ImageFont.truetype(font_path, font_size)
360
  break
361
+ except OSError:
362
  continue
363
  if font is None:
364
  # Fallback to default font
 
410
  return compose_frame(state, frame_idx)
411
 
412
 
413
+ def _get_prompt_for_obj(state: AppState, obj_id: int) -> str | None:
414
  """Get the prompt text associated with an object ID."""
415
  # Priority 1: Check text_prompts_by_frame_obj (most reliable)
416
  for frame_texts in state.text_prompts_by_frame_obj.values():
 
418
  return frame_texts[obj_id].strip()
419
 
420
  # Priority 2: Check inference session mapping
421
+ if state.inference_session is not None and (
422
+ hasattr(state.inference_session, "obj_id_to_prompt_id")
423
+ and obj_id in state.inference_session.obj_id_to_prompt_id
424
+ ):
425
+ prompt_id = state.inference_session.obj_id_to_prompt_id[obj_id]
426
+ if hasattr(state.inference_session, "prompts") and prompt_id in state.inference_session.prompts:
427
+ return state.inference_session.prompts[prompt_id].strip()
 
428
 
429
  return None
430
 
431
 
432
+ def _ensure_color_for_obj(state: AppState, obj_id: int) -> None:
433
  """Assign color to object based on its prompt if available, otherwise use object ID."""
434
  prompt_text = _get_prompt_for_obj(state, obj_id)
435
 
 
444
  state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
445
 
446
 
447
+ @spaces.GPU
448
  def on_image_click(
449
  img: Image.Image | np.ndarray,
450
  state: AppState,
 
453
  label: str,
454
  clear_old: bool,
455
  evt: gr.SelectData,
456
+ ) -> tuple[Image.Image, AppState]:
457
  if state is None or state.inference_session is None:
458
  return img
459
 
460
+ model = TRACKER_MODEL
461
+ processor = TRACKER_PROCESSOR
462
+ state.inference_session = to_device_recursive(state.inference_session, DEVICE)
463
 
464
  x = y = None
465
  if evt is not None:
 
488
  state.pending_box_start_obj_id = ann_obj_id
489
  state.composited_frames.pop(ann_frame_idx, None)
490
  return update_frame_display(state, ann_frame_idx)
491
+ x1, y1 = state.pending_box_start
492
+ x2, y2 = int(x), int(y)
493
+ state.pending_box_start = None
494
+ state.pending_box_start_frame_idx = None
495
+ state.pending_box_start_obj_id = None
496
+ state.composited_frames.pop(ann_frame_idx, None)
497
+ x_min, y_min = min(x1, x2), min(y1, y2)
498
+ x_max, y_max = max(x1, x2), max(y1, y2)
 
499
 
500
+ box = [[[x_min, y_min, x_max, y_max]]]
501
+ processor.add_inputs_to_inference_session(
502
+ inference_session=state.inference_session,
503
+ frame_idx=ann_frame_idx,
504
+ obj_ids=ann_obj_id,
505
+ input_boxes=box,
506
+ )
507
 
508
+ frame_boxes = state.boxes_by_frame_obj.setdefault(ann_frame_idx, {})
509
+ obj_boxes = frame_boxes.setdefault(ann_obj_id, [])
510
+ obj_boxes.clear()
511
+ obj_boxes.append((x_min, y_min, x_max, y_max))
512
+ state.composited_frames.pop(ann_frame_idx, None)
513
  else:
514
  label_int = 1 if str(label).lower().startswith("pos") else 0
515
 
 
555
 
556
  state.composited_frames.pop(ann_frame_idx, None)
557
 
558
+ state.inference_session = to_device_recursive(state.inference_session, "cpu")
559
 
560
+ return update_frame_display(state, ann_frame_idx), state
561
 
562
+
563
+ @spaces.GPU
564
  def on_text_prompt(
565
  state: AppState,
566
  frame_idx: int,
567
  text_prompt: str,
568
+ ) -> tuple[Image.Image, str, str, AppState]:
569
  if state is None or state.inference_session is None:
570
  return None, "Upload a video and enter text prompt.", "**Active prompts:** None"
571
 
572
+ model = TEXT_VIDEO_MODEL
573
+ processor = TEXT_VIDEO_PROCESSOR
574
 
575
  if not text_prompt or not text_prompt.strip():
576
  active_prompts = _get_active_prompts_display(state)
577
+ return update_frame_display(state, int(frame_idx)), "Please enter a text prompt.", active_prompts, state
578
 
579
  frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
580
 
 
582
  prompt_texts = [p.strip() for p in text_prompt.split(",") if p.strip()]
583
  if not prompt_texts:
584
  active_prompts = _get_active_prompts_display(state)
585
+ return update_frame_display(state, int(frame_idx)), "Please enter a valid text prompt.", active_prompts, state
586
+
587
+ state.inference_session = to_device_recursive(state.inference_session, DEVICE)
588
 
589
  # Add text prompt(s) - supports both single string and list of strings
590
  state.inference_session = processor.add_text_prompt(
 
668
  status = f"Processed text prompt(s) {prompts_str} on frame {frame_idx}. No objects detected."
669
 
670
  active_prompts = _get_active_prompts_display(state)
671
+
672
+ state.inference_session = to_device_recursive(state.inference_session, "cpu")
673
+
674
+ return update_frame_display(state, int(frame_idx)), status, active_prompts, state
675
 
676
 
677
  def _get_active_prompts_display(state: AppState) -> str:
 
688
  return "**Active prompts:** None"
689
 
690
 
691
+ @spaces.GPU
692
+ def propagate_masks(state: AppState) -> Iterator[tuple[AppState, str, dict]]:
693
+ if state is None:
694
+ return state, "Load a video first.", gr.update()
695
 
696
+ if state.active_tab != "text" and state.inference_session is None:
697
+ return state, "Load a video first.", gr.update()
698
 
699
+ total = max(1, state.num_frames)
700
  processed = 0
701
 
702
+ yield state, f"Propagating masks: {processed}/{total}", gr.update()
703
 
704
  last_frame_idx = 0
705
 
706
  with torch.no_grad():
707
+ if state.active_tab == "text":
708
+ if state.inference_session is None:
709
+ yield state, "Text video model not loaded.", gr.update()
710
  return
711
 
712
+ model = TEXT_VIDEO_MODEL
713
+ processor = TEXT_VIDEO_PROCESSOR
714
+
715
+ state.inference_session = to_device_recursive(state.inference_session, DEVICE)
716
 
717
  # Collect all unique prompts from existing frame annotations
718
  text_prompt_to_obj_ids = {}
719
+ for frame_idx, frame_texts in state.text_prompts_by_frame_obj.items():
720
  for obj_id, text_prompt in frame_texts.items():
721
  if text_prompt not in text_prompt_to_obj_ids:
722
  text_prompt_to_obj_ids[text_prompt] = []
 
724
  text_prompt_to_obj_ids[text_prompt].append(obj_id)
725
 
726
  # Also check if there are prompts already in the inference session
727
+ if hasattr(state.inference_session, "prompts") and state.inference_session.prompts:
728
+ for prompt_text in state.inference_session.prompts.values():
729
  if prompt_text not in text_prompt_to_obj_ids:
730
  text_prompt_to_obj_ids[prompt_text] = []
731
 
 
733
  text_prompt_to_obj_ids[text_prompt].sort()
734
 
735
  if not text_prompt_to_obj_ids:
736
+ state.inference_session = to_device_recursive(state.inference_session, "cpu")
737
+ yield state, "No text prompts found. Please add a text prompt first.", gr.update()
738
  return
739
 
740
  # Add all prompts to the inference session (processor handles deduplication)
741
+ for text_prompt in text_prompt_to_obj_ids:
742
+ state.inference_session = processor.add_text_prompt(
743
+ inference_session=state.inference_session,
744
  text=text_prompt,
745
  )
746
 
747
+ earliest_frame = min(state.text_prompts_by_frame_obj.keys()) if state.text_prompts_by_frame_obj else 0
 
 
748
 
749
+ frames_to_track = state.num_frames - earliest_frame
750
 
751
  outputs_per_frame = {}
752
 
753
  for model_outputs in model.propagate_in_video_iterator(
754
+ inference_session=state.inference_session,
755
  start_frame_idx=earliest_frame,
756
  max_frame_num_to_track=frames_to_track,
757
  ):
758
  processed_outputs = processor.postprocess_outputs(
759
+ state.inference_session,
760
  model_outputs,
761
  )
762
  frame_idx = model_outputs.frame_idx
 
767
  scores = processed_outputs["scores"]
768
  prompt_to_obj_ids = processed_outputs.get("prompt_to_obj_ids", {})
769
 
770
+ masks_for_frame = state.masks_by_frame.setdefault(frame_idx, {})
771
+ frame_texts = state.text_prompts_by_frame_obj.setdefault(frame_idx, {})
772
 
773
  num_objects = len(object_ids)
774
  if num_objects > 0:
 
795
  # Store prompt and assign color
796
  if found_prompt:
797
  frame_texts[current_obj_id] = found_prompt.strip()
798
+ _ensure_color_for_obj(state, current_obj_id)
799
 
800
+ state.composited_frames.pop(frame_idx, None)
801
  last_frame_idx = frame_idx
802
  processed += 1
803
  if processed % 30 == 0 or processed == total:
804
+ state.inference_session = to_device_recursive(state.inference_session, "cpu")
805
+ yield state, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
806
+ state.inference_session = to_device_recursive(state.inference_session, DEVICE)
807
  else:
808
+ if state.inference_session is None:
809
+ yield state, "Tracker model not loaded.", gr.update()
810
  return
811
 
812
+ model = TRACKER_MODEL
813
+ processor = TRACKER_PROCESSOR
814
 
815
+ state.inference_session = to_device_recursive(state.inference_session, DEVICE)
816
+
817
+ for sam2_video_output in model.propagate_in_video_iterator(inference_session=state.inference_session):
818
  video_res_masks = processor.post_process_masks(
819
  [sam2_video_output.pred_masks],
820
+ original_sizes=[[state.inference_session.video_height, state.inference_session.video_width]],
 
 
821
  )[0]
822
 
823
  frame_idx = sam2_video_output.frame_idx
824
+ for i, out_obj_id in enumerate(state.inference_session.obj_ids):
825
+ _ensure_color_for_obj(state, int(out_obj_id))
826
  mask_2d = video_res_masks[i].cpu().numpy()
827
+ masks_for_frame = state.masks_by_frame.setdefault(frame_idx, {})
828
  masks_for_frame[int(out_obj_id)] = mask_2d
829
+ state.composited_frames.pop(frame_idx, None)
830
 
831
  last_frame_idx = frame_idx
832
  processed += 1
833
  if processed % 30 == 0 or processed == total:
834
+ state.inference_session = to_device_recursive(state.inference_session, "cpu")
835
+ yield state, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
836
+ state.inference_session = to_device_recursive(state.inference_session, DEVICE)
837
 
838
  text = f"Propagated masks across {processed} frames."
839
+ state.inference_session = to_device_recursive(state.inference_session, "cpu")
840
+ yield state, text, gr.update(value=last_frame_idx)
841
 
842
 
843
+ def reset_prompts(state: AppState) -> tuple[AppState, Image.Image, str, str]:
844
  """Reset prompts and all outputs, but keep processed frames and cached vision features."""
845
+ if state is None or state.inference_session is None:
846
+ active_prompts = _get_active_prompts_display(state)
847
+ return state, None, "No active session to reset.", active_prompts
848
 
849
+ if state.active_tab != "text":
850
+ active_prompts = _get_active_prompts_display(state)
851
+ return state, None, "Reset prompts is only available for text prompting mode.", active_prompts
852
 
853
  # Reset inference session tracking data but keep cache and processed frames
854
+ if hasattr(state.inference_session, "reset_tracking_data"):
855
+ state.inference_session.reset_tracking_data()
856
 
857
  # Manually clear prompts (reset_tracking_data doesn't clear prompts themselves)
858
+ if hasattr(state.inference_session, "prompts"):
859
+ state.inference_session.prompts.clear()
860
+ if hasattr(state.inference_session, "prompt_input_ids"):
861
+ state.inference_session.prompt_input_ids.clear()
862
+ if hasattr(state.inference_session, "prompt_embeddings"):
863
+ state.inference_session.prompt_embeddings.clear()
864
+ if hasattr(state.inference_session, "prompt_attention_masks"):
865
+ state.inference_session.prompt_attention_masks.clear()
866
+ if hasattr(state.inference_session, "obj_id_to_prompt_id"):
867
+ state.inference_session.obj_id_to_prompt_id.clear()
868
 
869
  # Reset detection-tracking fusion state
870
+ if hasattr(state.inference_session, "obj_id_to_score"):
871
+ state.inference_session.obj_id_to_score.clear()
872
+ if hasattr(state.inference_session, "obj_id_to_tracker_score_frame_wise"):
873
+ state.inference_session.obj_id_to_tracker_score_frame_wise.clear()
874
+ if hasattr(state.inference_session, "obj_id_to_last_occluded"):
875
+ state.inference_session.obj_id_to_last_occluded.clear()
876
+ if hasattr(state.inference_session, "max_obj_id"):
877
+ state.inference_session.max_obj_id = -1
878
+ if hasattr(state.inference_session, "obj_first_frame_idx"):
879
+ state.inference_session.obj_first_frame_idx.clear()
880
+ if hasattr(state.inference_session, "unmatched_frame_inds"):
881
+ state.inference_session.unmatched_frame_inds.clear()
882
+ if hasattr(state.inference_session, "overlap_pair_to_frame_inds"):
883
+ state.inference_session.overlap_pair_to_frame_inds.clear()
884
+ if hasattr(state.inference_session, "trk_keep_alive"):
885
+ state.inference_session.trk_keep_alive.clear()
886
+ if hasattr(state.inference_session, "removed_obj_ids"):
887
+ state.inference_session.removed_obj_ids.clear()
888
+ if hasattr(state.inference_session, "suppressed_obj_ids"):
889
+ state.inference_session.suppressed_obj_ids.clear()
890
+ if hasattr(state.inference_session, "hotstart_removed_obj_ids"):
891
+ state.inference_session.hotstart_removed_obj_ids.clear()
892
 
893
  # Clear all app state outputs
894
+ state.masks_by_frame.clear()
895
+ state.text_prompts_by_frame_obj.clear()
896
+ state.composited_frames.clear()
897
+ state.color_by_obj.clear()
898
+ state.color_by_prompt.clear()
899
 
900
  # Update display
901
+ current_idx = int(getattr(state, "current_frame_idx", 0))
902
+ current_idx = max(0, min(current_idx, state.num_frames - 1))
903
+ preview_img = update_frame_display(state, current_idx)
904
+ active_prompts = _get_active_prompts_display(state)
905
  status = "Prompts and outputs reset. Processed frames and cached vision features preserved."
906
 
907
+ return state, preview_img, status, active_prompts
908
 
909
 
910
+ def reset_session(state: AppState) -> tuple[AppState, Image.Image, int, int, str, str]:
911
+ if not state.video_frames:
912
+ return state, None, 0, 0, "Session reset. Load a new video.", "**Active prompts:** None"
913
 
914
+ if state.active_tab == "text":
915
+ if state.video_frames:
916
+ processor = TEXT_VIDEO_PROCESSOR
917
+ state.inference_session = processor.init_video_session(
918
+ video=state.video_frames,
919
+ inference_device=DEVICE,
920
  processing_device="cpu",
921
  video_storage_device="cpu",
922
+ dtype=DTYPE,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
923
  )
924
+ elif state.inference_session is not None and hasattr(state.inference_session, "reset_inference_session"):
925
+ state.inference_session.reset_inference_session()
926
+ elif state.video_frames:
927
+ processor = TRACKER_PROCESSOR
928
+ raw_video = [np.array(frame) for frame in state.video_frames]
929
+ state.inference_session = processor.init_video_session(
930
+ video=raw_video,
931
+ inference_device=DEVICE,
932
+ video_storage_device="cpu",
933
+ processing_device="cpu",
934
+ dtype=DTYPE,
935
+ )
936
 
937
+ state.masks_by_frame.clear()
938
+ state.clicks_by_frame_obj.clear()
939
+ state.boxes_by_frame_obj.clear()
940
+ state.text_prompts_by_frame_obj.clear()
941
+ state.composited_frames.clear()
942
+ state.color_by_obj.clear()
943
+ state.color_by_prompt.clear()
944
+ state.pending_box_start = None
945
+ state.pending_box_start_frame_idx = None
946
+ state.pending_box_start_obj_id = None
947
 
948
  gc.collect()
949
 
950
+ current_idx = int(getattr(state, "current_frame_idx", 0))
951
+ current_idx = max(0, min(current_idx, state.num_frames - 1))
952
+ preview_img = update_frame_display(state, current_idx)
953
+ slider_minmax = gr.update(minimum=0, maximum=max(state.num_frames - 1, 0), interactive=True)
954
  slider_value = gr.update(value=current_idx)
955
  status = "Session reset. Prompts cleared; video preserved."
956
+ active_prompts = _get_active_prompts_display(state)
957
+ return state, preview_img, slider_minmax, slider_value, status, active_prompts
958
 
959
 
960
+ def _on_video_change_pointbox(state: AppState, video: str | dict) -> tuple[AppState, dict, Image.Image, str]:
961
+ state, min_idx, max_idx, first_frame, status = init_video_session(state, video, "point_box")
962
  return (
963
+ state,
964
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
965
  first_frame,
966
  status,
967
  )
968
 
969
 
970
+ def _on_video_change_text(state: AppState, video: str | dict) -> tuple[AppState, dict, Image.Image, str, str]:
971
+ if video is None:
972
+ return state, None, None, None, None
973
+ state, min_idx, max_idx, first_frame, status = init_video_session(state, video, "text")
974
+ active_prompts = _get_active_prompts_display(state)
975
  return (
976
+ state,
977
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
978
  first_frame,
979
  status,
 
981
  )
982
 
983
 
984
+ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")) as demo:
985
+ app_state = gr.State(AppState())
 
 
986
 
987
  gr.Markdown(
988
  """
 
1014
 
1015
  with gr.Row():
1016
  with gr.Column(scale=1):
1017
+ video_in_text = gr.Video(label="Upload video", sources=["upload", "webcam"])
1018
  load_status_text = gr.Markdown(visible=True)
1019
  reset_btn_text = gr.Button("Reset Session", variant="secondary")
1020
  with gr.Column(scale=2):
1021
+ preview_text = gr.Image(label="Preview")
1022
  with gr.Row():
1023
+ frame_slider_text = gr.Slider(label="Frame", minimum=0, maximum=0, step=1, value=0)
 
 
1024
  with gr.Column(scale=0):
1025
  propagate_btn_text = gr.Button("Propagate across video", variant="primary")
1026
  propagate_status_text = gr.Markdown(visible=True)
 
1047
  ]
1048
  with gr.Row():
1049
  gr.Examples(
 
 
 
 
1050
  label="Examples",
1051
+ examples=examples_list_text,
1052
+ inputs=[app_state, video_in_text],
1053
  examples_per_page=5,
1054
  )
1055
 
 
1075
 
1076
  with gr.Row():
1077
  with gr.Column(scale=1):
1078
+ video_in_pointbox = gr.Video(label="Upload video", sources=["upload", "webcam"], max_length=7)
 
 
1079
  load_status_pointbox = gr.Markdown(visible=True)
1080
  reset_btn_pointbox = gr.Button("Reset Session", variant="secondary")
1081
  with gr.Column(scale=2):
1082
+ preview_pointbox = gr.Image(label="Preview")
1083
  with gr.Row():
1084
+ frame_slider_pointbox = gr.Slider(label="Frame", minimum=0, maximum=0, step=1, value=0)
 
 
1085
  with gr.Column(scale=0):
1086
  propagate_btn_pointbox = gr.Button("Propagate across video", variant="primary")
1087
  propagate_status_pointbox = gr.Markdown(visible=True)
 
1103
  ]
1104
  with gr.Row():
1105
  gr.Examples(
 
 
 
 
1106
  label="Examples",
1107
+ examples=examples_list_pointbox,
1108
+ inputs=[app_state, video_in_pointbox],
1109
  examples_per_page=5,
1110
  )
1111
 
1112
  video_in_pointbox.change(
1113
+ fn=_on_video_change_pointbox,
1114
+ inputs=[app_state, video_in_pointbox],
1115
+ outputs=[app_state, frame_slider_pointbox, preview_pointbox, load_status_pointbox],
1116
  show_progress=True,
1117
  )
1118
 
1119
+ def _sync_frame_idx_pointbox(state_in: AppState, idx: int) -> Image.Image:
1120
  if state_in is not None:
1121
  state_in.current_frame_idx = int(idx)
1122
  return update_frame_display(state_in, int(idx))
1123
 
1124
  frame_slider_pointbox.change(
1125
+ fn=_sync_frame_idx_pointbox,
1126
+ inputs=[app_state, frame_slider_pointbox],
1127
  outputs=preview_pointbox,
1128
  )
1129
 
1130
  video_in_text.change(
1131
+ fn=_on_video_change_text,
1132
+ inputs=[app_state, video_in_text],
1133
+ outputs=[app_state, frame_slider_text, preview_text, load_status_text, active_prompts_display],
1134
  show_progress=True,
1135
  )
1136
 
1137
+ def _sync_frame_idx_text(state_in: AppState, idx: int) -> Image.Image:
1138
  if state_in is not None:
1139
  state_in.current_frame_idx = int(idx)
1140
  return update_frame_display(state_in, int(idx))
1141
 
1142
  frame_slider_text.change(
1143
+ fn=_sync_frame_idx_text,
1144
+ inputs=[app_state, frame_slider_text],
1145
  outputs=preview_text,
1146
  )
1147
 
1148
+ def _sync_obj_id(s: AppState, oid: int) -> None:
1149
  if s is not None and oid is not None:
1150
  s.current_obj_id = int(oid)
 
1151
 
1152
+ obj_id_inp.change(
1153
+ fn=_sync_obj_id,
1154
+ inputs=[app_state, obj_id_inp],
1155
+ )
1156
 
1157
+ def _sync_label(s: AppState, lab: str) -> None:
1158
  if s is not None and lab is not None:
1159
  s.current_label = str(lab)
 
1160
 
1161
+ label_radio.change(
1162
+ fn=_sync_label,
1163
+ inputs=[app_state, label_radio],
1164
+ )
1165
 
1166
+ def _sync_prompt_type(s: AppState, val: str) -> tuple[dict, dict]:
1167
  if s is not None and val is not None:
1168
  s.current_prompt_type = str(val)
1169
  s.pending_box_start = None
1170
  is_points = str(val).lower() == "points"
1171
+ return (
1172
  gr.update(visible=is_points),
1173
  gr.update(interactive=is_points) if is_points else gr.update(value=True, interactive=False),
1174
+ )
 
1175
 
1176
  prompt_type.change(
1177
+ fn=_sync_prompt_type,
1178
+ inputs=[app_state, prompt_type],
1179
  outputs=[label_radio, clear_old_chk],
1180
  )
1181
 
1182
  preview_pointbox.select(
1183
+ fn=on_image_click,
1184
+ inputs=[preview_pointbox, app_state, frame_slider_pointbox, obj_id_inp, label_radio, clear_old_chk],
1185
+ outputs=[preview_pointbox, app_state],
1186
  )
1187
 
 
 
 
 
1188
  text_apply_btn.click(
1189
+ fn=on_text_prompt,
1190
+ inputs=[app_state, frame_slider_text, text_prompt_input],
1191
+ outputs=[preview_text, text_status, active_prompts_display, app_state],
1192
  )
1193
 
1194
  reset_prompts_btn.click(
1195
+ fn=reset_prompts,
1196
+ inputs=app_state,
1197
+ outputs=[app_state, preview_text, text_status, active_prompts_display],
1198
  )
1199
 
1200
+ def _render_video(s: AppState) -> str:
1201
  if s is None or s.num_frames == 0:
1202
  raise gr.Error("Load a video first.")
1203
  fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12
 
1211
  frames_np.append(np.array(img)[:, :, ::-1])
1212
  if (idx + 1) % 60 == 0:
1213
  gc.collect()
 
1214
  try:
1215
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as out_path:
1216
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
1217
+ writer = cv2.VideoWriter(out_path.name, fourcc, fps, (w, h))
1218
+ for fr_bgr in frames_np:
1219
+ writer.write(fr_bgr)
1220
+ writer.release()
1221
+ return out_path.name
1222
  except Exception as e:
1223
  print(f"Failed to render video with cv2: {e}")
1224
  raise gr.Error(f"Failed to render video: {e}")
1225
 
1226
+ render_btn_pointbox.click(
1227
+ fn=_render_video,
1228
+ inputs=app_state,
1229
+ outputs=playback_video_pointbox,
1230
+ )
1231
+ render_btn_text.click(
1232
+ fn=_render_video,
1233
+ inputs=app_state,
1234
+ outputs=playback_video_text,
1235
+ )
1236
 
1237
  propagate_btn_pointbox.click(
1238
+ fn=propagate_masks,
1239
+ inputs=app_state,
1240
+ outputs=[app_state, propagate_status_pointbox, frame_slider_pointbox],
1241
  )
1242
 
1243
  propagate_btn_text.click(
1244
+ fn=propagate_masks,
1245
+ inputs=app_state,
1246
+ outputs=[app_state, propagate_status_text, frame_slider_text],
1247
  )
1248
 
1249
  reset_btn_pointbox.click(
1250
+ fn=reset_session,
1251
+ inputs=app_state,
1252
+ outputs=[app_state, preview_pointbox, frame_slider_pointbox, frame_slider_pointbox, load_status_pointbox],
1253
  )
1254
 
1255
  reset_btn_text.click(
1256
+ fn=reset_session,
1257
+ inputs=app_state,
1258
  outputs=[
1259
+ app_state,
1260
  preview_text,
1261
  frame_slider_text,
1262
  frame_slider_text,
pyproject.toml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "sam3-video-segmentation"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "accelerate>=1.11.0",
9
+ "gradio>=5.49.1",
10
+ "imageio[pyav]>=2.37.2",
11
+ "kernels>=0.11.0",
12
+ "opencv-python>=4.12.0.88",
13
+ "spaces>=0.42.1",
14
+ "torch==2.8.0",
15
+ "torchvision>=0.23.0",
16
+ "transformers",
17
+ ]
18
+
19
+ [tool.ruff]
20
+ line-length = 119
21
+
22
+ [tool.ruff.lint]
23
+ select = ["ALL"]
24
+ ignore = [
25
+ "COM812", # missing-trailing-comma
26
+ "D203", # one-blank-line-before-class
27
+ "D213", # multi-line-summary-second-line
28
+ "E501", # line-too-long
29
+ "SIM117", # multiple-with-statements
30
+ #
31
+ "D100", # undocumented-public-module
32
+ "D101", # undocumented-public-class
33
+ "D102", # undocumented-public-method
34
+ "D103", # undocumented-public-function
35
+ "D104", # undocumented-public-package
36
+ "D105", # undocumented-magic-method
37
+ "D107", # undocumented-public-init
38
+ "EM101", # raw-string-in-exception
39
+ "FBT001", # boolean-type-hint-positional-argument
40
+ "FBT002", # boolean-default-value-positional-argument
41
+ "PGH003", # blanket-type-ignore
42
+ "PLR0913", # too-many-arguments
43
+ "PLR0915", # too-many-statements
44
+ "TRY003", # raise-vanilla-args
45
+ ]
46
+ unfixable = [
47
+ "F401", # unused-import
48
+ ]
49
+
50
+ [tool.ruff.lint.pydocstyle]
51
+ convention = "google"
52
+
53
+ [tool.ruff.lint.per-file-ignores]
54
+ "*.ipynb" = ["T201", "T203"]
55
+
56
+ [tool.ruff.format]
57
+ docstring-code-format = true
58
+
59
+ [tool.uv.sources]
60
+ transformers = { git = "https://github.com/huggingface/transformers.git", rev = "69f003696b" }
uv.lock ADDED
The diff for this file is too large to render. See raw diff