Spaces:
Running
on
Zero
ZeroGPU (#4)
Browse files- Clean up (207380d371e725816ae5c3f53d531910ba67504d)
- Clean up (2f7dd3f5db9e9b315769dff39950a17399dbbaf5)
- ruff (6228bae0defd157d867ee7c11dbc471d1302990e)
- gradio==5.47.2 (2ca03b40be4b7e3e269f5fe923006ed6bd5bd037)
- Clean up (05517058ea0a3f027c86838ca74b4a813d5b36af)
- Clean up (b3503590167f6ef34dd1c47eb02e9a1f4109cc44)
- Clean up (899c724afb8a9537d8153c201537982d958706ff)
- Add missing type annotation (fd135add43ba4849b42d5a6a48a6145f7c4890af)
- Fix type annotation (9827da7e076f2799498620fd469d7e3c9480e9f7)
- Rename (9487a4249ff9574b8eb19f9a555b4da404db399b)
- Clean up (bec66180eb459c8dfca06ed7289d5690100a1915)
- Clean up (57b535c8bbd55e17113e91fb848e12474776fa77)
- Fix (c5ddb5172dcf01b92d44239d9600ece930b6b0b7)
- Add missing type annotation (dd289a80dcbb51991c3a7f262e3078ad7784ad96)
- Use tempfile (bb8d6819946800c97f7f0b34bb84195b582482f4)
- Clean up (5f16955badf395d39d104cd2a02c4ed0e37c9d9e)
- Add missing type annotation (a594ef6a99311692f11a0f39d1cee6b070a37b87)
- Clean up (8224854fa85dbb96e43f4955cbb8ceb49f94bc75)
- Fix (74ad55952b232b45cca92948c7665a60ad2d3eab)
- Clean up (4e68fb424493d6c682e55f08aa5ac524e1962173)
- ZeroGPU (ffce5fe119880498bc912e544612c92893db305b)
Co-authored-by: hysts <[email protected]>
|
@@ -4,7 +4,7 @@ emoji: 🐠
|
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
|
|
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.47.2
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
|
@@ -1,37 +1,106 @@
|
|
| 1 |
import colorsys
|
| 2 |
import gc
|
| 3 |
-
import
|
| 4 |
-
from
|
|
|
|
|
|
|
| 5 |
|
| 6 |
import cv2
|
| 7 |
import gradio as gr
|
| 8 |
import numpy as np
|
|
|
|
| 9 |
import torch
|
| 10 |
from gradio.themes import Soft
|
| 11 |
from PIL import Image, ImageDraw, ImageFont
|
| 12 |
-
|
| 13 |
from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor, Sam3VideoModel, Sam3VideoProcessor
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
dtype = torch.bfloat16
|
| 19 |
-
return device, dtype
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]:
|
|
@@ -106,10 +175,10 @@ def pastel_color_for_prompt(prompt_text: str) -> tuple[int, int, int]:
|
|
| 106 |
|
| 107 |
|
| 108 |
class AppState:
|
| 109 |
-
def __init__(self):
|
| 110 |
self.reset()
|
| 111 |
|
| 112 |
-
def reset(self):
|
| 113 |
self.video_frames: list[Image.Image] = []
|
| 114 |
self.inference_session = None
|
| 115 |
self.video_fps: float | None = None
|
|
@@ -130,7 +199,7 @@ class AppState:
|
|
| 130 |
self.pending_box_start_obj_id: int | None = None
|
| 131 |
self.active_tab: str = "point_box"
|
| 132 |
|
| 133 |
-
def __repr__(self):
|
| 134 |
return f"AppState(video_frames={len(self.video_frames)}, video_fps={self.video_fps}, masks_by_frame={len(self.masks_by_frame)}, color_by_obj={len(self.color_by_obj)})"
|
| 135 |
|
| 136 |
@property
|
|
@@ -139,23 +208,20 @@ class AppState:
|
|
| 139 |
|
| 140 |
|
| 141 |
def init_video_session(
|
| 142 |
-
|
| 143 |
) -> tuple[AppState, int, int, Image.Image, str]:
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
dtype = _GLOBAL_DTYPE
|
| 157 |
-
|
| 158 |
-
video_path: Optional[str] = None
|
| 159 |
if isinstance(video, dict):
|
| 160 |
video_path = video.get("name") or video.get("path") or video.get("data")
|
| 161 |
elif isinstance(video, str):
|
|
@@ -170,7 +236,6 @@ def init_video_session(
|
|
| 170 |
if len(frames) == 0:
|
| 171 |
raise gr.Error("No frames could be loaded from the video.")
|
| 172 |
|
| 173 |
-
MAX_SECONDS = 8.0
|
| 174 |
trimmed_note = ""
|
| 175 |
fps_in = info.get("fps")
|
| 176 |
max_frames_allowed = int(MAX_SECONDS * fps_in) if fps_in else len(frames)
|
|
@@ -179,44 +244,49 @@ def init_video_session(
|
|
| 179 |
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
|
| 180 |
if isinstance(info, dict):
|
| 181 |
info["num_frames"] = len(frames)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
|
| 185 |
raw_video = [np.array(frame) for frame in frames]
|
| 186 |
|
| 187 |
if active_tab == "text":
|
| 188 |
-
processor =
|
| 189 |
-
|
| 190 |
video=frames,
|
| 191 |
-
inference_device=
|
|
|
|
| 192 |
processing_device="cpu",
|
| 193 |
video_storage_device="cpu",
|
| 194 |
-
dtype=
|
| 195 |
)
|
| 196 |
else:
|
| 197 |
-
processor =
|
| 198 |
-
|
| 199 |
video=raw_video,
|
| 200 |
-
inference_device=
|
| 201 |
-
|
| 202 |
processing_device="cpu",
|
| 203 |
-
|
| 204 |
-
dtype=
|
| 205 |
)
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
first_frame = frames[0]
|
| 208 |
max_idx = len(frames) - 1
|
| 209 |
if active_tab == "text":
|
| 210 |
status = (
|
| 211 |
-
f"Loaded {len(frames)} frames @ {
|
| 212 |
-
f"Device: {
|
| 213 |
)
|
| 214 |
else:
|
| 215 |
status = (
|
| 216 |
-
f"Loaded {len(frames)} frames @ {
|
| 217 |
-
f"Device: {
|
| 218 |
)
|
| 219 |
-
return
|
| 220 |
|
| 221 |
|
| 222 |
def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
|
|
@@ -288,7 +358,7 @@ def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
|
|
| 288 |
try:
|
| 289 |
font = ImageFont.truetype(font_path, font_size)
|
| 290 |
break
|
| 291 |
-
except
|
| 292 |
continue
|
| 293 |
if font is None:
|
| 294 |
# Fallback to default font
|
|
@@ -340,7 +410,7 @@ def update_frame_display(state: AppState, frame_idx: int) -> Image.Image:
|
|
| 340 |
return compose_frame(state, frame_idx)
|
| 341 |
|
| 342 |
|
| 343 |
-
def _get_prompt_for_obj(state: AppState, obj_id: int) ->
|
| 344 |
"""Get the prompt text associated with an object ID."""
|
| 345 |
# Priority 1: Check text_prompts_by_frame_obj (most reliable)
|
| 346 |
for frame_texts in state.text_prompts_by_frame_obj.values():
|
|
@@ -348,19 +418,18 @@ def _get_prompt_for_obj(state: AppState, obj_id: int) -> Optional[str]:
|
|
| 348 |
return frame_texts[obj_id].strip()
|
| 349 |
|
| 350 |
# Priority 2: Check inference session mapping
|
| 351 |
-
if state.inference_session is not None
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
return state.inference_session.prompts[prompt_id].strip()
|
| 359 |
|
| 360 |
return None
|
| 361 |
|
| 362 |
|
| 363 |
-
def _ensure_color_for_obj(state: AppState, obj_id: int):
|
| 364 |
"""Assign color to object based on its prompt if available, otherwise use object ID."""
|
| 365 |
prompt_text = _get_prompt_for_obj(state, obj_id)
|
| 366 |
|
|
@@ -375,6 +444,7 @@ def _ensure_color_for_obj(state: AppState, obj_id: int):
|
|
| 375 |
state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
|
| 376 |
|
| 377 |
|
|
|
|
| 378 |
def on_image_click(
|
| 379 |
img: Image.Image | np.ndarray,
|
| 380 |
state: AppState,
|
|
@@ -383,12 +453,13 @@ def on_image_click(
|
|
| 383 |
label: str,
|
| 384 |
clear_old: bool,
|
| 385 |
evt: gr.SelectData,
|
| 386 |
-
) -> Image.Image:
|
| 387 |
if state is None or state.inference_session is None:
|
| 388 |
return img
|
| 389 |
|
| 390 |
-
model =
|
| 391 |
-
processor =
|
|
|
|
| 392 |
|
| 393 |
x = y = None
|
| 394 |
if evt is not None:
|
|
@@ -417,29 +488,28 @@ def on_image_click(
|
|
| 417 |
state.pending_box_start_obj_id = ann_obj_id
|
| 418 |
state.composited_frames.pop(ann_frame_idx, None)
|
| 419 |
return update_frame_display(state, ann_frame_idx)
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
x_max, y_max = max(x1, x2), max(y1, y2)
|
| 429 |
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
else:
|
| 444 |
label_int = 1 if str(label).lower().startswith("pos") else 0
|
| 445 |
|
|
@@ -485,23 +555,26 @@ def on_image_click(
|
|
| 485 |
|
| 486 |
state.composited_frames.pop(ann_frame_idx, None)
|
| 487 |
|
| 488 |
-
|
| 489 |
|
|
|
|
| 490 |
|
|
|
|
|
|
|
| 491 |
def on_text_prompt(
|
| 492 |
state: AppState,
|
| 493 |
frame_idx: int,
|
| 494 |
text_prompt: str,
|
| 495 |
-
) -> tuple[Image.Image, str, str]:
|
| 496 |
if state is None or state.inference_session is None:
|
| 497 |
return None, "Upload a video and enter text prompt.", "**Active prompts:** None"
|
| 498 |
|
| 499 |
-
model =
|
| 500 |
-
processor =
|
| 501 |
|
| 502 |
if not text_prompt or not text_prompt.strip():
|
| 503 |
active_prompts = _get_active_prompts_display(state)
|
| 504 |
-
return update_frame_display(state, int(frame_idx)), "Please enter a text prompt.", active_prompts
|
| 505 |
|
| 506 |
frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
|
| 507 |
|
|
@@ -509,7 +582,9 @@ def on_text_prompt(
|
|
| 509 |
prompt_texts = [p.strip() for p in text_prompt.split(",") if p.strip()]
|
| 510 |
if not prompt_texts:
|
| 511 |
active_prompts = _get_active_prompts_display(state)
|
| 512 |
-
return update_frame_display(state, int(frame_idx)), "Please enter a valid text prompt.", active_prompts
|
|
|
|
|
|
|
| 513 |
|
| 514 |
# Add text prompt(s) - supports both single string and list of strings
|
| 515 |
state.inference_session = processor.add_text_prompt(
|
|
@@ -593,7 +668,10 @@ def on_text_prompt(
|
|
| 593 |
status = f"Processed text prompt(s) {prompts_str} on frame {frame_idx}. No objects detected."
|
| 594 |
|
| 595 |
active_prompts = _get_active_prompts_display(state)
|
| 596 |
-
|
|
|
|
|
|
|
|
|
|
| 597 |
|
| 598 |
|
| 599 |
def _get_active_prompts_display(state: AppState) -> str:
|
|
@@ -610,32 +688,35 @@ def _get_active_prompts_display(state: AppState) -> str:
|
|
| 610 |
return "**Active prompts:** None"
|
| 611 |
|
| 612 |
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
|
|
|
| 616 |
|
| 617 |
-
if
|
| 618 |
-
return
|
| 619 |
|
| 620 |
-
total = max(1,
|
| 621 |
processed = 0
|
| 622 |
|
| 623 |
-
yield
|
| 624 |
|
| 625 |
last_frame_idx = 0
|
| 626 |
|
| 627 |
with torch.no_grad():
|
| 628 |
-
if
|
| 629 |
-
if
|
| 630 |
-
yield
|
| 631 |
return
|
| 632 |
|
| 633 |
-
model =
|
| 634 |
-
processor =
|
|
|
|
|
|
|
| 635 |
|
| 636 |
# Collect all unique prompts from existing frame annotations
|
| 637 |
text_prompt_to_obj_ids = {}
|
| 638 |
-
for frame_idx, frame_texts in
|
| 639 |
for obj_id, text_prompt in frame_texts.items():
|
| 640 |
if text_prompt not in text_prompt_to_obj_ids:
|
| 641 |
text_prompt_to_obj_ids[text_prompt] = []
|
|
@@ -643,8 +724,8 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 643 |
text_prompt_to_obj_ids[text_prompt].append(obj_id)
|
| 644 |
|
| 645 |
# Also check if there are prompts already in the inference session
|
| 646 |
-
if hasattr(
|
| 647 |
-
for prompt_text in
|
| 648 |
if prompt_text not in text_prompt_to_obj_ids:
|
| 649 |
text_prompt_to_obj_ids[prompt_text] = []
|
| 650 |
|
|
@@ -652,31 +733,30 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 652 |
text_prompt_to_obj_ids[text_prompt].sort()
|
| 653 |
|
| 654 |
if not text_prompt_to_obj_ids:
|
| 655 |
-
|
|
|
|
| 656 |
return
|
| 657 |
|
| 658 |
# Add all prompts to the inference session (processor handles deduplication)
|
| 659 |
-
for text_prompt in text_prompt_to_obj_ids
|
| 660 |
-
|
| 661 |
-
inference_session=
|
| 662 |
text=text_prompt,
|
| 663 |
)
|
| 664 |
|
| 665 |
-
earliest_frame = (
|
| 666 |
-
min(GLOBAL_STATE.text_prompts_by_frame_obj.keys()) if GLOBAL_STATE.text_prompts_by_frame_obj else 0
|
| 667 |
-
)
|
| 668 |
|
| 669 |
-
frames_to_track =
|
| 670 |
|
| 671 |
outputs_per_frame = {}
|
| 672 |
|
| 673 |
for model_outputs in model.propagate_in_video_iterator(
|
| 674 |
-
inference_session=
|
| 675 |
start_frame_idx=earliest_frame,
|
| 676 |
max_frame_num_to_track=frames_to_track,
|
| 677 |
):
|
| 678 |
processed_outputs = processor.postprocess_outputs(
|
| 679 |
-
|
| 680 |
model_outputs,
|
| 681 |
)
|
| 682 |
frame_idx = model_outputs.frame_idx
|
|
@@ -687,8 +767,8 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 687 |
scores = processed_outputs["scores"]
|
| 688 |
prompt_to_obj_ids = processed_outputs.get("prompt_to_obj_ids", {})
|
| 689 |
|
| 690 |
-
masks_for_frame =
|
| 691 |
-
frame_texts =
|
| 692 |
|
| 693 |
num_objects = len(object_ids)
|
| 694 |
if num_objects > 0:
|
|
@@ -715,183 +795,185 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 715 |
# Store prompt and assign color
|
| 716 |
if found_prompt:
|
| 717 |
frame_texts[current_obj_id] = found_prompt.strip()
|
| 718 |
-
_ensure_color_for_obj(
|
| 719 |
|
| 720 |
-
|
| 721 |
last_frame_idx = frame_idx
|
| 722 |
processed += 1
|
| 723 |
if processed % 30 == 0 or processed == total:
|
| 724 |
-
|
|
|
|
|
|
|
| 725 |
else:
|
| 726 |
-
if
|
| 727 |
-
yield
|
| 728 |
return
|
| 729 |
|
| 730 |
-
model =
|
| 731 |
-
processor =
|
| 732 |
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
):
|
| 736 |
video_res_masks = processor.post_process_masks(
|
| 737 |
[sam2_video_output.pred_masks],
|
| 738 |
-
original_sizes=[
|
| 739 |
-
[GLOBAL_STATE.inference_session.video_height, GLOBAL_STATE.inference_session.video_width]
|
| 740 |
-
],
|
| 741 |
)[0]
|
| 742 |
|
| 743 |
frame_idx = sam2_video_output.frame_idx
|
| 744 |
-
for i, out_obj_id in enumerate(
|
| 745 |
-
_ensure_color_for_obj(
|
| 746 |
mask_2d = video_res_masks[i].cpu().numpy()
|
| 747 |
-
masks_for_frame =
|
| 748 |
masks_for_frame[int(out_obj_id)] = mask_2d
|
| 749 |
-
|
| 750 |
|
| 751 |
last_frame_idx = frame_idx
|
| 752 |
processed += 1
|
| 753 |
if processed % 30 == 0 or processed == total:
|
| 754 |
-
|
|
|
|
|
|
|
| 755 |
|
| 756 |
text = f"Propagated masks across {processed} frames."
|
| 757 |
-
|
|
|
|
| 758 |
|
| 759 |
|
| 760 |
-
def reset_prompts(
|
| 761 |
"""Reset prompts and all outputs, but keep processed frames and cached vision features."""
|
| 762 |
-
if
|
| 763 |
-
active_prompts = _get_active_prompts_display(
|
| 764 |
-
return
|
| 765 |
|
| 766 |
-
if
|
| 767 |
-
active_prompts = _get_active_prompts_display(
|
| 768 |
-
return
|
| 769 |
|
| 770 |
# Reset inference session tracking data but keep cache and processed frames
|
| 771 |
-
if hasattr(
|
| 772 |
-
|
| 773 |
|
| 774 |
# Manually clear prompts (reset_tracking_data doesn't clear prompts themselves)
|
| 775 |
-
if hasattr(
|
| 776 |
-
|
| 777 |
-
if hasattr(
|
| 778 |
-
|
| 779 |
-
if hasattr(
|
| 780 |
-
|
| 781 |
-
if hasattr(
|
| 782 |
-
|
| 783 |
-
if hasattr(
|
| 784 |
-
|
| 785 |
|
| 786 |
# Reset detection-tracking fusion state
|
| 787 |
-
if hasattr(
|
| 788 |
-
|
| 789 |
-
if hasattr(
|
| 790 |
-
|
| 791 |
-
if hasattr(
|
| 792 |
-
|
| 793 |
-
if hasattr(
|
| 794 |
-
|
| 795 |
-
if hasattr(
|
| 796 |
-
|
| 797 |
-
if hasattr(
|
| 798 |
-
|
| 799 |
-
if hasattr(
|
| 800 |
-
|
| 801 |
-
if hasattr(
|
| 802 |
-
|
| 803 |
-
if hasattr(
|
| 804 |
-
|
| 805 |
-
if hasattr(
|
| 806 |
-
|
| 807 |
-
if hasattr(
|
| 808 |
-
|
| 809 |
|
| 810 |
# Clear all app state outputs
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
|
| 817 |
# Update display
|
| 818 |
-
current_idx = int(getattr(
|
| 819 |
-
current_idx = max(0, min(current_idx,
|
| 820 |
-
preview_img = update_frame_display(
|
| 821 |
-
active_prompts = _get_active_prompts_display(
|
| 822 |
status = "Prompts and outputs reset. Processed frames and cached vision features preserved."
|
| 823 |
|
| 824 |
-
return
|
| 825 |
|
| 826 |
|
| 827 |
-
def reset_session(
|
| 828 |
-
if not
|
| 829 |
-
return
|
| 830 |
|
| 831 |
-
if
|
| 832 |
-
if
|
| 833 |
-
processor =
|
| 834 |
-
|
| 835 |
-
video=
|
| 836 |
-
inference_device=
|
| 837 |
processing_device="cpu",
|
| 838 |
video_storage_device="cpu",
|
| 839 |
-
dtype=
|
| 840 |
-
)
|
| 841 |
-
elif GLOBAL_STATE.inference_session is not None and hasattr(
|
| 842 |
-
GLOBAL_STATE.inference_session, "reset_inference_session"
|
| 843 |
-
):
|
| 844 |
-
GLOBAL_STATE.inference_session.reset_inference_session()
|
| 845 |
-
else:
|
| 846 |
-
if GLOBAL_STATE.video_frames:
|
| 847 |
-
processor = _GLOBAL_TRACKER_PROCESSOR
|
| 848 |
-
raw_video = [np.array(frame) for frame in GLOBAL_STATE.video_frames]
|
| 849 |
-
GLOBAL_STATE.inference_session = processor.init_video_session(
|
| 850 |
-
video=raw_video,
|
| 851 |
-
inference_device=_GLOBAL_DEVICE,
|
| 852 |
-
video_storage_device="cpu",
|
| 853 |
-
processing_device="cpu",
|
| 854 |
-
dtype=_GLOBAL_DTYPE,
|
| 855 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 856 |
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
| 867 |
|
| 868 |
gc.collect()
|
| 869 |
|
| 870 |
-
current_idx = int(getattr(
|
| 871 |
-
current_idx = max(0, min(current_idx,
|
| 872 |
-
preview_img = update_frame_display(
|
| 873 |
-
slider_minmax = gr.update(minimum=0, maximum=max(
|
| 874 |
slider_value = gr.update(value=current_idx)
|
| 875 |
status = "Session reset. Prompts cleared; video preserved."
|
| 876 |
-
active_prompts = _get_active_prompts_display(
|
| 877 |
-
return
|
| 878 |
|
| 879 |
|
| 880 |
-
def _on_video_change_pointbox(
|
| 881 |
-
|
| 882 |
return (
|
| 883 |
-
|
| 884 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
| 885 |
first_frame,
|
| 886 |
status,
|
| 887 |
)
|
| 888 |
|
| 889 |
|
| 890 |
-
def _on_video_change_text(
|
| 891 |
-
|
| 892 |
-
|
|
|
|
|
|
|
| 893 |
return (
|
| 894 |
-
|
| 895 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
| 896 |
first_frame,
|
| 897 |
status,
|
|
@@ -899,10 +981,8 @@ def _on_video_change_text(GLOBAL_STATE: gr.State, video):
|
|
| 899 |
)
|
| 900 |
|
| 901 |
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
with gr.Blocks(title="SAM3", theme=theme) as demo:
|
| 905 |
-
GLOBAL_STATE = gr.State(AppState())
|
| 906 |
|
| 907 |
gr.Markdown(
|
| 908 |
"""
|
|
@@ -934,15 +1014,13 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
|
|
| 934 |
|
| 935 |
with gr.Row():
|
| 936 |
with gr.Column(scale=1):
|
| 937 |
-
video_in_text = gr.Video(label="Upload video", sources=["upload", "webcam"]
|
| 938 |
load_status_text = gr.Markdown(visible=True)
|
| 939 |
reset_btn_text = gr.Button("Reset Session", variant="secondary")
|
| 940 |
with gr.Column(scale=2):
|
| 941 |
-
preview_text = gr.Image(label="Preview"
|
| 942 |
with gr.Row():
|
| 943 |
-
frame_slider_text = gr.Slider(
|
| 944 |
-
label="Frame", minimum=0, maximum=0, step=1, value=0, interactive=True
|
| 945 |
-
)
|
| 946 |
with gr.Column(scale=0):
|
| 947 |
propagate_btn_text = gr.Button("Propagate across video", variant="primary")
|
| 948 |
propagate_status_text = gr.Markdown(visible=True)
|
|
@@ -969,12 +1047,9 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
|
|
| 969 |
]
|
| 970 |
with gr.Row():
|
| 971 |
gr.Examples(
|
| 972 |
-
examples=examples_list_text,
|
| 973 |
-
inputs=[GLOBAL_STATE, video_in_text],
|
| 974 |
-
fn=_on_video_change_text,
|
| 975 |
-
outputs=[GLOBAL_STATE, frame_slider_text, preview_text, load_status_text, active_prompts_display],
|
| 976 |
label="Examples",
|
| 977 |
-
|
|
|
|
| 978 |
examples_per_page=5,
|
| 979 |
)
|
| 980 |
|
|
@@ -1000,17 +1075,13 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
|
|
| 1000 |
|
| 1001 |
with gr.Row():
|
| 1002 |
with gr.Column(scale=1):
|
| 1003 |
-
video_in_pointbox = gr.Video(
|
| 1004 |
-
label="Upload video", sources=["upload", "webcam"], interactive=True, max_length=7
|
| 1005 |
-
)
|
| 1006 |
load_status_pointbox = gr.Markdown(visible=True)
|
| 1007 |
reset_btn_pointbox = gr.Button("Reset Session", variant="secondary")
|
| 1008 |
with gr.Column(scale=2):
|
| 1009 |
-
preview_pointbox = gr.Image(label="Preview"
|
| 1010 |
with gr.Row():
|
| 1011 |
-
frame_slider_pointbox = gr.Slider(
|
| 1012 |
-
label="Frame", minimum=0, maximum=0, step=1, value=0, interactive=True
|
| 1013 |
-
)
|
| 1014 |
with gr.Column(scale=0):
|
| 1015 |
propagate_btn_pointbox = gr.Button("Propagate across video", variant="primary")
|
| 1016 |
propagate_status_pointbox = gr.Markdown(visible=True)
|
|
@@ -1032,105 +1103,101 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
|
|
| 1032 |
]
|
| 1033 |
with gr.Row():
|
| 1034 |
gr.Examples(
|
| 1035 |
-
examples=examples_list_pointbox,
|
| 1036 |
-
inputs=[GLOBAL_STATE, video_in_pointbox],
|
| 1037 |
-
fn=_on_video_change_pointbox,
|
| 1038 |
-
outputs=[GLOBAL_STATE, frame_slider_pointbox, preview_pointbox, load_status_pointbox],
|
| 1039 |
label="Examples",
|
| 1040 |
-
|
|
|
|
| 1041 |
examples_per_page=5,
|
| 1042 |
)
|
| 1043 |
|
| 1044 |
video_in_pointbox.change(
|
| 1045 |
-
_on_video_change_pointbox,
|
| 1046 |
-
inputs=[
|
| 1047 |
-
outputs=[
|
| 1048 |
show_progress=True,
|
| 1049 |
)
|
| 1050 |
|
| 1051 |
-
def _sync_frame_idx_pointbox(state_in: AppState, idx: int):
|
| 1052 |
if state_in is not None:
|
| 1053 |
state_in.current_frame_idx = int(idx)
|
| 1054 |
return update_frame_display(state_in, int(idx))
|
| 1055 |
|
| 1056 |
frame_slider_pointbox.change(
|
| 1057 |
-
_sync_frame_idx_pointbox,
|
| 1058 |
-
inputs=[
|
| 1059 |
outputs=preview_pointbox,
|
| 1060 |
)
|
| 1061 |
|
| 1062 |
video_in_text.change(
|
| 1063 |
-
_on_video_change_text,
|
| 1064 |
-
inputs=[
|
| 1065 |
-
outputs=[
|
| 1066 |
show_progress=True,
|
| 1067 |
)
|
| 1068 |
|
| 1069 |
-
def _sync_frame_idx_text(state_in: AppState, idx: int):
|
| 1070 |
if state_in is not None:
|
| 1071 |
state_in.current_frame_idx = int(idx)
|
| 1072 |
return update_frame_display(state_in, int(idx))
|
| 1073 |
|
| 1074 |
frame_slider_text.change(
|
| 1075 |
-
_sync_frame_idx_text,
|
| 1076 |
-
inputs=[
|
| 1077 |
outputs=preview_text,
|
| 1078 |
)
|
| 1079 |
|
| 1080 |
-
def _sync_obj_id(s: AppState, oid):
|
| 1081 |
if s is not None and oid is not None:
|
| 1082 |
s.current_obj_id = int(oid)
|
| 1083 |
-
return gr.update()
|
| 1084 |
|
| 1085 |
-
obj_id_inp.change(
|
|
|
|
|
|
|
|
|
|
| 1086 |
|
| 1087 |
-
def _sync_label(s: AppState, lab: str):
|
| 1088 |
if s is not None and lab is not None:
|
| 1089 |
s.current_label = str(lab)
|
| 1090 |
-
return gr.update()
|
| 1091 |
|
| 1092 |
-
label_radio.change(
|
|
|
|
|
|
|
|
|
|
| 1093 |
|
| 1094 |
-
def _sync_prompt_type(s: AppState, val: str):
|
| 1095 |
if s is not None and val is not None:
|
| 1096 |
s.current_prompt_type = str(val)
|
| 1097 |
s.pending_box_start = None
|
| 1098 |
is_points = str(val).lower() == "points"
|
| 1099 |
-
|
| 1100 |
gr.update(visible=is_points),
|
| 1101 |
gr.update(interactive=is_points) if is_points else gr.update(value=True, interactive=False),
|
| 1102 |
-
|
| 1103 |
-
return updates
|
| 1104 |
|
| 1105 |
prompt_type.change(
|
| 1106 |
-
_sync_prompt_type,
|
| 1107 |
-
inputs=[
|
| 1108 |
outputs=[label_radio, clear_old_chk],
|
| 1109 |
)
|
| 1110 |
|
| 1111 |
preview_pointbox.select(
|
| 1112 |
-
on_image_click,
|
| 1113 |
-
[preview_pointbox,
|
| 1114 |
-
preview_pointbox,
|
| 1115 |
)
|
| 1116 |
|
| 1117 |
-
def _on_text_apply(state: AppState, frame_idx: int, text: str):
|
| 1118 |
-
img, status, active_prompts = on_text_prompt(state, frame_idx, text)
|
| 1119 |
-
return img, status, active_prompts
|
| 1120 |
-
|
| 1121 |
text_apply_btn.click(
|
| 1122 |
-
|
| 1123 |
-
inputs=[
|
| 1124 |
-
outputs=[preview_text, text_status, active_prompts_display],
|
| 1125 |
)
|
| 1126 |
|
| 1127 |
reset_prompts_btn.click(
|
| 1128 |
-
reset_prompts,
|
| 1129 |
-
inputs=
|
| 1130 |
-
outputs=[
|
| 1131 |
)
|
| 1132 |
|
| 1133 |
-
def _render_video(s: AppState):
|
| 1134 |
if s is None or s.num_frames == 0:
|
| 1135 |
raise gr.Error("Load a video first.")
|
| 1136 |
fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12
|
|
@@ -1144,44 +1211,52 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
|
|
| 1144 |
frames_np.append(np.array(img)[:, :, ::-1])
|
| 1145 |
if (idx + 1) % 60 == 0:
|
| 1146 |
gc.collect()
|
| 1147 |
-
out_path = "/tmp/sam3_playback.mp4"
|
| 1148 |
try:
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
|
| 1152 |
-
|
| 1153 |
-
|
| 1154 |
-
|
|
|
|
| 1155 |
except Exception as e:
|
| 1156 |
print(f"Failed to render video with cv2: {e}")
|
| 1157 |
raise gr.Error(f"Failed to render video: {e}")
|
| 1158 |
|
| 1159 |
-
render_btn_pointbox.click(
|
| 1160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1161 |
|
| 1162 |
propagate_btn_pointbox.click(
|
| 1163 |
-
propagate_masks,
|
| 1164 |
-
inputs=
|
| 1165 |
-
outputs=[
|
| 1166 |
)
|
| 1167 |
|
| 1168 |
propagate_btn_text.click(
|
| 1169 |
-
propagate_masks,
|
| 1170 |
-
inputs=
|
| 1171 |
-
outputs=[
|
| 1172 |
)
|
| 1173 |
|
| 1174 |
reset_btn_pointbox.click(
|
| 1175 |
-
reset_session,
|
| 1176 |
-
inputs=
|
| 1177 |
-
outputs=[
|
| 1178 |
)
|
| 1179 |
|
| 1180 |
reset_btn_text.click(
|
| 1181 |
-
reset_session,
|
| 1182 |
-
inputs=
|
| 1183 |
outputs=[
|
| 1184 |
-
|
| 1185 |
preview_text,
|
| 1186 |
frame_slider_text,
|
| 1187 |
frame_slider_text,
|
|
|
|
| 1 |
import colorsys
|
| 2 |
import gc
|
| 3 |
+
import tempfile
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from collections.abc import Iterator, Mapping, Sequence
|
| 6 |
+
from typing import Any
|
| 7 |
|
| 8 |
import cv2
|
| 9 |
import gradio as gr
|
| 10 |
import numpy as np
|
| 11 |
+
import spaces
|
| 12 |
import torch
|
| 13 |
from gradio.themes import Soft
|
| 14 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
| 15 |
from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor, Sam3VideoModel, Sam3VideoProcessor
|
| 16 |
|
| 17 |
+
MODEL_ID = "facebook/sam3"
|
| 18 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
+
DTYPE = torch.bfloat16
|
| 20 |
|
| 21 |
+
TRACKER_MODEL = Sam3TrackerVideoModel.from_pretrained(MODEL_ID, torch_dtype=DTYPE).to(DEVICE).eval()
|
| 22 |
+
TRACKER_PROCESSOR = Sam3TrackerVideoProcessor.from_pretrained(MODEL_ID)
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
TEXT_VIDEO_MODEL = Sam3VideoModel.from_pretrained(MODEL_ID).to(DEVICE, dtype=DTYPE).eval()
|
| 25 |
+
TEXT_VIDEO_PROCESSOR = Sam3VideoProcessor.from_pretrained(MODEL_ID)
|
| 26 |
+
print("Models loaded successfully!")
|
| 27 |
|
| 28 |
+
MAX_SECONDS = 8.0
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def to_device_recursive(obj: Any, device: str | torch.device) -> Any: # noqa: ANN401
|
| 32 |
+
"""Return a new object where all torch.Tensors reachable from `obj` are moved to the given device.
|
| 33 |
+
|
| 34 |
+
- Does NOT mutate the original object.
|
| 35 |
+
- Handles:
|
| 36 |
+
* torch.Tensor
|
| 37 |
+
* Mapping (e.g. dict, defaultdict, OrderedDict, etc.)
|
| 38 |
+
* Sequence (e.g. list, tuple) except str/bytes
|
| 39 |
+
* Custom classes with attributes (__dict__)
|
| 40 |
+
- Tries to preserve container types where reasonable.
|
| 41 |
+
"""
|
| 42 |
+
device = torch.device(device)
|
| 43 |
+
memo = {}
|
| 44 |
+
|
| 45 |
+
def _convert(x: Any) -> Any: # noqa: ANN401, C901
|
| 46 |
+
obj_id = id(x)
|
| 47 |
+
if obj_id in memo:
|
| 48 |
+
return memo[obj_id]
|
| 49 |
+
|
| 50 |
+
# 1. Tensor
|
| 51 |
+
if isinstance(x, torch.Tensor):
|
| 52 |
+
y = x.to(device)
|
| 53 |
+
memo[obj_id] = y
|
| 54 |
+
return y
|
| 55 |
+
|
| 56 |
+
# 2. Mapping (dict, defaultdict, etc.)
|
| 57 |
+
if isinstance(x, Mapping):
|
| 58 |
+
# Special case: defaultdict
|
| 59 |
+
if isinstance(x, defaultdict):
|
| 60 |
+
y = defaultdict(x.default_factory)
|
| 61 |
+
memo[obj_id] = y
|
| 62 |
+
for k, v in x.items():
|
| 63 |
+
y[k] = _convert(v)
|
| 64 |
+
return y
|
| 65 |
+
|
| 66 |
+
# Try to rebuild the same type using (key, value) pairs
|
| 67 |
+
try:
|
| 68 |
+
y = type(x)((k, _convert(v)) for k, v in x.items())
|
| 69 |
+
memo[obj_id] = y
|
| 70 |
+
return y
|
| 71 |
+
except TypeError:
|
| 72 |
+
# Fallback: plain dict
|
| 73 |
+
y = {k: _convert(v) for k, v in x.items()}
|
| 74 |
+
memo[obj_id] = y
|
| 75 |
+
return y
|
| 76 |
+
|
| 77 |
+
# 3. Sequence (list/tuple/etc.) but not str/bytes
|
| 78 |
+
if isinstance(x, Sequence) and not isinstance(x, (str, bytes, bytearray)):
|
| 79 |
+
if isinstance(x, list):
|
| 80 |
+
y = [_convert(v) for v in x]
|
| 81 |
+
elif isinstance(x, tuple):
|
| 82 |
+
y = type(x)(_convert(v) for v in x)
|
| 83 |
+
else:
|
| 84 |
+
try:
|
| 85 |
+
y = type(x)(_convert(v) for v in x)
|
| 86 |
+
except TypeError:
|
| 87 |
+
y = [_convert(v) for v in x]
|
| 88 |
+
memo[obj_id] = y
|
| 89 |
+
return y
|
| 90 |
|
| 91 |
+
# 4. Custom object with attributes (__dict__)
|
| 92 |
+
if hasattr(x, "__dict__") and not isinstance(x, type):
|
| 93 |
+
new_obj = x.__class__.__new__(x.__class__)
|
| 94 |
+
memo[obj_id] = new_obj
|
| 95 |
+
for name, value in vars(x).items():
|
| 96 |
+
setattr(new_obj, name, _convert(value))
|
| 97 |
+
return new_obj
|
| 98 |
|
| 99 |
+
# 5. Everything else → keep as-is
|
| 100 |
+
memo[obj_id] = x
|
| 101 |
+
return x
|
| 102 |
+
|
| 103 |
+
return _convert(obj)
|
| 104 |
|
| 105 |
|
| 106 |
def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], dict]:
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
class AppState:
|
| 178 |
+
def __init__(self) -> None:
|
| 179 |
self.reset()
|
| 180 |
|
| 181 |
+
def reset(self) -> None:
|
| 182 |
self.video_frames: list[Image.Image] = []
|
| 183 |
self.inference_session = None
|
| 184 |
self.video_fps: float | None = None
|
|
|
|
| 199 |
self.pending_box_start_obj_id: int | None = None
|
| 200 |
self.active_tab: str = "point_box"
|
| 201 |
|
| 202 |
+
def __repr__(self) -> str:
|
| 203 |
return f"AppState(video_frames={len(self.video_frames)}, video_fps={self.video_fps}, masks_by_frame={len(self.masks_by_frame)}, color_by_obj={len(self.color_by_obj)})"
|
| 204 |
|
| 205 |
@property
|
|
|
|
| 208 |
|
| 209 |
|
| 210 |
def init_video_session(
|
| 211 |
+
state: AppState, video: str | dict, active_tab: str = "point_box"
|
| 212 |
) -> tuple[AppState, int, int, Image.Image, str]:
|
| 213 |
+
state.video_frames = []
|
| 214 |
+
state.masks_by_frame = {}
|
| 215 |
+
state.color_by_obj = {}
|
| 216 |
+
state.color_by_prompt = {}
|
| 217 |
+
state.text_prompts_by_frame_obj = {}
|
| 218 |
+
state.clicks_by_frame_obj = {}
|
| 219 |
+
state.boxes_by_frame_obj = {}
|
| 220 |
+
state.composited_frames = {}
|
| 221 |
+
state.inference_session = None
|
| 222 |
+
state.active_tab = active_tab
|
| 223 |
+
|
| 224 |
+
video_path: str | None = None
|
|
|
|
|
|
|
|
|
|
| 225 |
if isinstance(video, dict):
|
| 226 |
video_path = video.get("name") or video.get("path") or video.get("data")
|
| 227 |
elif isinstance(video, str):
|
|
|
|
| 236 |
if len(frames) == 0:
|
| 237 |
raise gr.Error("No frames could be loaded from the video.")
|
| 238 |
|
|
|
|
| 239 |
trimmed_note = ""
|
| 240 |
fps_in = info.get("fps")
|
| 241 |
max_frames_allowed = int(MAX_SECONDS * fps_in) if fps_in else len(frames)
|
|
|
|
| 244 |
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
|
| 245 |
if isinstance(info, dict):
|
| 246 |
info["num_frames"] = len(frames)
|
| 247 |
+
state.video_frames = frames
|
| 248 |
+
state.video_fps = float(fps_in) if fps_in else None
|
| 249 |
|
| 250 |
raw_video = [np.array(frame) for frame in frames]
|
| 251 |
|
| 252 |
if active_tab == "text":
|
| 253 |
+
processor = TEXT_VIDEO_PROCESSOR
|
| 254 |
+
state.inference_session = processor.init_video_session(
|
| 255 |
video=frames,
|
| 256 |
+
inference_device="cpu",
|
| 257 |
+
inference_state_device="cpu",
|
| 258 |
processing_device="cpu",
|
| 259 |
video_storage_device="cpu",
|
| 260 |
+
dtype=DTYPE,
|
| 261 |
)
|
| 262 |
else:
|
| 263 |
+
processor = TRACKER_PROCESSOR
|
| 264 |
+
state.inference_session = processor.init_video_session(
|
| 265 |
video=raw_video,
|
| 266 |
+
inference_device="cpu",
|
| 267 |
+
inference_state_device="cpu",
|
| 268 |
processing_device="cpu",
|
| 269 |
+
video_storage_device="cpu",
|
| 270 |
+
dtype=DTYPE,
|
| 271 |
)
|
| 272 |
|
| 273 |
+
state.inference_session.inference_device = DEVICE
|
| 274 |
+
state.inference_session.processing_device = DEVICE
|
| 275 |
+
state.inference_session.cache.inference_device = DEVICE
|
| 276 |
+
|
| 277 |
first_frame = frames[0]
|
| 278 |
max_idx = len(frames) - 1
|
| 279 |
if active_tab == "text":
|
| 280 |
status = (
|
| 281 |
+
f"Loaded {len(frames)} frames @ {state.video_fps or 'unknown'} fps{trimmed_note}. "
|
| 282 |
+
f"Device: {DEVICE}, dtype: bfloat16. Ready for text prompting."
|
| 283 |
)
|
| 284 |
else:
|
| 285 |
status = (
|
| 286 |
+
f"Loaded {len(frames)} frames @ {state.video_fps or 'unknown'} fps{trimmed_note}. "
|
| 287 |
+
f"Device: {DEVICE}, dtype: bfloat16. Video session initialized."
|
| 288 |
)
|
| 289 |
+
return state, 0, max_idx, first_frame, status
|
| 290 |
|
| 291 |
|
| 292 |
def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
|
|
|
|
| 358 |
try:
|
| 359 |
font = ImageFont.truetype(font_path, font_size)
|
| 360 |
break
|
| 361 |
+
except OSError:
|
| 362 |
continue
|
| 363 |
if font is None:
|
| 364 |
# Fallback to default font
|
|
|
|
| 410 |
return compose_frame(state, frame_idx)
|
| 411 |
|
| 412 |
|
| 413 |
+
def _get_prompt_for_obj(state: AppState, obj_id: int) -> str | None:
|
| 414 |
"""Get the prompt text associated with an object ID."""
|
| 415 |
# Priority 1: Check text_prompts_by_frame_obj (most reliable)
|
| 416 |
for frame_texts in state.text_prompts_by_frame_obj.values():
|
|
|
|
| 418 |
return frame_texts[obj_id].strip()
|
| 419 |
|
| 420 |
# Priority 2: Check inference session mapping
|
| 421 |
+
if state.inference_session is not None and (
|
| 422 |
+
hasattr(state.inference_session, "obj_id_to_prompt_id")
|
| 423 |
+
and obj_id in state.inference_session.obj_id_to_prompt_id
|
| 424 |
+
):
|
| 425 |
+
prompt_id = state.inference_session.obj_id_to_prompt_id[obj_id]
|
| 426 |
+
if hasattr(state.inference_session, "prompts") and prompt_id in state.inference_session.prompts:
|
| 427 |
+
return state.inference_session.prompts[prompt_id].strip()
|
|
|
|
| 428 |
|
| 429 |
return None
|
| 430 |
|
| 431 |
|
| 432 |
+
def _ensure_color_for_obj(state: AppState, obj_id: int) -> None:
|
| 433 |
"""Assign color to object based on its prompt if available, otherwise use object ID."""
|
| 434 |
prompt_text = _get_prompt_for_obj(state, obj_id)
|
| 435 |
|
|
|
|
| 444 |
state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
|
| 445 |
|
| 446 |
|
| 447 |
+
@spaces.GPU
|
| 448 |
def on_image_click(
|
| 449 |
img: Image.Image | np.ndarray,
|
| 450 |
state: AppState,
|
|
|
|
| 453 |
label: str,
|
| 454 |
clear_old: bool,
|
| 455 |
evt: gr.SelectData,
|
| 456 |
+
) -> tuple[Image.Image, AppState]:
|
| 457 |
if state is None or state.inference_session is None:
|
| 458 |
return img
|
| 459 |
|
| 460 |
+
model = TRACKER_MODEL
|
| 461 |
+
processor = TRACKER_PROCESSOR
|
| 462 |
+
state.inference_session = to_device_recursive(state.inference_session, DEVICE)
|
| 463 |
|
| 464 |
x = y = None
|
| 465 |
if evt is not None:
|
|
|
|
| 488 |
state.pending_box_start_obj_id = ann_obj_id
|
| 489 |
state.composited_frames.pop(ann_frame_idx, None)
|
| 490 |
return update_frame_display(state, ann_frame_idx)
|
| 491 |
+
x1, y1 = state.pending_box_start
|
| 492 |
+
x2, y2 = int(x), int(y)
|
| 493 |
+
state.pending_box_start = None
|
| 494 |
+
state.pending_box_start_frame_idx = None
|
| 495 |
+
state.pending_box_start_obj_id = None
|
| 496 |
+
state.composited_frames.pop(ann_frame_idx, None)
|
| 497 |
+
x_min, y_min = min(x1, x2), min(y1, y2)
|
| 498 |
+
x_max, y_max = max(x1, x2), max(y1, y2)
|
|
|
|
| 499 |
|
| 500 |
+
box = [[[x_min, y_min, x_max, y_max]]]
|
| 501 |
+
processor.add_inputs_to_inference_session(
|
| 502 |
+
inference_session=state.inference_session,
|
| 503 |
+
frame_idx=ann_frame_idx,
|
| 504 |
+
obj_ids=ann_obj_id,
|
| 505 |
+
input_boxes=box,
|
| 506 |
+
)
|
| 507 |
|
| 508 |
+
frame_boxes = state.boxes_by_frame_obj.setdefault(ann_frame_idx, {})
|
| 509 |
+
obj_boxes = frame_boxes.setdefault(ann_obj_id, [])
|
| 510 |
+
obj_boxes.clear()
|
| 511 |
+
obj_boxes.append((x_min, y_min, x_max, y_max))
|
| 512 |
+
state.composited_frames.pop(ann_frame_idx, None)
|
| 513 |
else:
|
| 514 |
label_int = 1 if str(label).lower().startswith("pos") else 0
|
| 515 |
|
|
|
|
| 555 |
|
| 556 |
state.composited_frames.pop(ann_frame_idx, None)
|
| 557 |
|
| 558 |
+
state.inference_session = to_device_recursive(state.inference_session, "cpu")
|
| 559 |
|
| 560 |
+
return update_frame_display(state, ann_frame_idx), state
|
| 561 |
|
| 562 |
+
|
| 563 |
+
@spaces.GPU
|
| 564 |
def on_text_prompt(
|
| 565 |
state: AppState,
|
| 566 |
frame_idx: int,
|
| 567 |
text_prompt: str,
|
| 568 |
+
) -> tuple[Image.Image, str, str, AppState]:
|
| 569 |
if state is None or state.inference_session is None:
|
| 570 |
return None, "Upload a video and enter text prompt.", "**Active prompts:** None"
|
| 571 |
|
| 572 |
+
model = TEXT_VIDEO_MODEL
|
| 573 |
+
processor = TEXT_VIDEO_PROCESSOR
|
| 574 |
|
| 575 |
if not text_prompt or not text_prompt.strip():
|
| 576 |
active_prompts = _get_active_prompts_display(state)
|
| 577 |
+
return update_frame_display(state, int(frame_idx)), "Please enter a text prompt.", active_prompts, state
|
| 578 |
|
| 579 |
frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
|
| 580 |
|
|
|
|
| 582 |
prompt_texts = [p.strip() for p in text_prompt.split(",") if p.strip()]
|
| 583 |
if not prompt_texts:
|
| 584 |
active_prompts = _get_active_prompts_display(state)
|
| 585 |
+
return update_frame_display(state, int(frame_idx)), "Please enter a valid text prompt.", active_prompts, state
|
| 586 |
+
|
| 587 |
+
state.inference_session = to_device_recursive(state.inference_session, DEVICE)
|
| 588 |
|
| 589 |
# Add text prompt(s) - supports both single string and list of strings
|
| 590 |
state.inference_session = processor.add_text_prompt(
|
|
|
|
| 668 |
status = f"Processed text prompt(s) {prompts_str} on frame {frame_idx}. No objects detected."
|
| 669 |
|
| 670 |
active_prompts = _get_active_prompts_display(state)
|
| 671 |
+
|
| 672 |
+
state.inference_session = to_device_recursive(state.inference_session, "cpu")
|
| 673 |
+
|
| 674 |
+
return update_frame_display(state, int(frame_idx)), status, active_prompts, state
|
| 675 |
|
| 676 |
|
| 677 |
def _get_active_prompts_display(state: AppState) -> str:
|
|
|
|
| 688 |
return "**Active prompts:** None"
|
| 689 |
|
| 690 |
|
| 691 |
+
@spaces.GPU
|
| 692 |
+
def propagate_masks(state: AppState) -> Iterator[tuple[AppState, str, dict]]:
|
| 693 |
+
if state is None:
|
| 694 |
+
return state, "Load a video first.", gr.update()
|
| 695 |
|
| 696 |
+
if state.active_tab != "text" and state.inference_session is None:
|
| 697 |
+
return state, "Load a video first.", gr.update()
|
| 698 |
|
| 699 |
+
total = max(1, state.num_frames)
|
| 700 |
processed = 0
|
| 701 |
|
| 702 |
+
yield state, f"Propagating masks: {processed}/{total}", gr.update()
|
| 703 |
|
| 704 |
last_frame_idx = 0
|
| 705 |
|
| 706 |
with torch.no_grad():
|
| 707 |
+
if state.active_tab == "text":
|
| 708 |
+
if state.inference_session is None:
|
| 709 |
+
yield state, "Text video model not loaded.", gr.update()
|
| 710 |
return
|
| 711 |
|
| 712 |
+
model = TEXT_VIDEO_MODEL
|
| 713 |
+
processor = TEXT_VIDEO_PROCESSOR
|
| 714 |
+
|
| 715 |
+
state.inference_session = to_device_recursive(state.inference_session, DEVICE)
|
| 716 |
|
| 717 |
# Collect all unique prompts from existing frame annotations
|
| 718 |
text_prompt_to_obj_ids = {}
|
| 719 |
+
for frame_idx, frame_texts in state.text_prompts_by_frame_obj.items():
|
| 720 |
for obj_id, text_prompt in frame_texts.items():
|
| 721 |
if text_prompt not in text_prompt_to_obj_ids:
|
| 722 |
text_prompt_to_obj_ids[text_prompt] = []
|
|
|
|
| 724 |
text_prompt_to_obj_ids[text_prompt].append(obj_id)
|
| 725 |
|
| 726 |
# Also check if there are prompts already in the inference session
|
| 727 |
+
if hasattr(state.inference_session, "prompts") and state.inference_session.prompts:
|
| 728 |
+
for prompt_text in state.inference_session.prompts.values():
|
| 729 |
if prompt_text not in text_prompt_to_obj_ids:
|
| 730 |
text_prompt_to_obj_ids[prompt_text] = []
|
| 731 |
|
|
|
|
| 733 |
text_prompt_to_obj_ids[text_prompt].sort()
|
| 734 |
|
| 735 |
if not text_prompt_to_obj_ids:
|
| 736 |
+
state.inference_session = to_device_recursive(state.inference_session, "cpu")
|
| 737 |
+
yield state, "No text prompts found. Please add a text prompt first.", gr.update()
|
| 738 |
return
|
| 739 |
|
| 740 |
# Add all prompts to the inference session (processor handles deduplication)
|
| 741 |
+
for text_prompt in text_prompt_to_obj_ids:
|
| 742 |
+
state.inference_session = processor.add_text_prompt(
|
| 743 |
+
inference_session=state.inference_session,
|
| 744 |
text=text_prompt,
|
| 745 |
)
|
| 746 |
|
| 747 |
+
earliest_frame = min(state.text_prompts_by_frame_obj.keys()) if state.text_prompts_by_frame_obj else 0
|
|
|
|
|
|
|
| 748 |
|
| 749 |
+
frames_to_track = state.num_frames - earliest_frame
|
| 750 |
|
| 751 |
outputs_per_frame = {}
|
| 752 |
|
| 753 |
for model_outputs in model.propagate_in_video_iterator(
|
| 754 |
+
inference_session=state.inference_session,
|
| 755 |
start_frame_idx=earliest_frame,
|
| 756 |
max_frame_num_to_track=frames_to_track,
|
| 757 |
):
|
| 758 |
processed_outputs = processor.postprocess_outputs(
|
| 759 |
+
state.inference_session,
|
| 760 |
model_outputs,
|
| 761 |
)
|
| 762 |
frame_idx = model_outputs.frame_idx
|
|
|
|
| 767 |
scores = processed_outputs["scores"]
|
| 768 |
prompt_to_obj_ids = processed_outputs.get("prompt_to_obj_ids", {})
|
| 769 |
|
| 770 |
+
masks_for_frame = state.masks_by_frame.setdefault(frame_idx, {})
|
| 771 |
+
frame_texts = state.text_prompts_by_frame_obj.setdefault(frame_idx, {})
|
| 772 |
|
| 773 |
num_objects = len(object_ids)
|
| 774 |
if num_objects > 0:
|
|
|
|
| 795 |
# Store prompt and assign color
|
| 796 |
if found_prompt:
|
| 797 |
frame_texts[current_obj_id] = found_prompt.strip()
|
| 798 |
+
_ensure_color_for_obj(state, current_obj_id)
|
| 799 |
|
| 800 |
+
state.composited_frames.pop(frame_idx, None)
|
| 801 |
last_frame_idx = frame_idx
|
| 802 |
processed += 1
|
| 803 |
if processed % 30 == 0 or processed == total:
|
| 804 |
+
state.inference_session = to_device_recursive(state.inference_session, "cpu")
|
| 805 |
+
yield state, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
|
| 806 |
+
state.inference_session = to_device_recursive(state.inference_session, DEVICE)
|
| 807 |
else:
|
| 808 |
+
if state.inference_session is None:
|
| 809 |
+
yield state, "Tracker model not loaded.", gr.update()
|
| 810 |
return
|
| 811 |
|
| 812 |
+
model = TRACKER_MODEL
|
| 813 |
+
processor = TRACKER_PROCESSOR
|
| 814 |
|
| 815 |
+
state.inference_session = to_device_recursive(state.inference_session, DEVICE)
|
| 816 |
+
|
| 817 |
+
for sam2_video_output in model.propagate_in_video_iterator(inference_session=state.inference_session):
|
| 818 |
video_res_masks = processor.post_process_masks(
|
| 819 |
[sam2_video_output.pred_masks],
|
| 820 |
+
original_sizes=[[state.inference_session.video_height, state.inference_session.video_width]],
|
|
|
|
|
|
|
| 821 |
)[0]
|
| 822 |
|
| 823 |
frame_idx = sam2_video_output.frame_idx
|
| 824 |
+
for i, out_obj_id in enumerate(state.inference_session.obj_ids):
|
| 825 |
+
_ensure_color_for_obj(state, int(out_obj_id))
|
| 826 |
mask_2d = video_res_masks[i].cpu().numpy()
|
| 827 |
+
masks_for_frame = state.masks_by_frame.setdefault(frame_idx, {})
|
| 828 |
masks_for_frame[int(out_obj_id)] = mask_2d
|
| 829 |
+
state.composited_frames.pop(frame_idx, None)
|
| 830 |
|
| 831 |
last_frame_idx = frame_idx
|
| 832 |
processed += 1
|
| 833 |
if processed % 30 == 0 or processed == total:
|
| 834 |
+
state.inference_session = to_device_recursive(state.inference_session, "cpu")
|
| 835 |
+
yield state, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
|
| 836 |
+
state.inference_session = to_device_recursive(state.inference_session, DEVICE)
|
| 837 |
|
| 838 |
text = f"Propagated masks across {processed} frames."
|
| 839 |
+
state.inference_session = to_device_recursive(state.inference_session, "cpu")
|
| 840 |
+
yield state, text, gr.update(value=last_frame_idx)
|
| 841 |
|
| 842 |
|
| 843 |
+
def reset_prompts(state: AppState) -> tuple[AppState, Image.Image, str, str]:
|
| 844 |
"""Reset prompts and all outputs, but keep processed frames and cached vision features."""
|
| 845 |
+
if state is None or state.inference_session is None:
|
| 846 |
+
active_prompts = _get_active_prompts_display(state)
|
| 847 |
+
return state, None, "No active session to reset.", active_prompts
|
| 848 |
|
| 849 |
+
if state.active_tab != "text":
|
| 850 |
+
active_prompts = _get_active_prompts_display(state)
|
| 851 |
+
return state, None, "Reset prompts is only available for text prompting mode.", active_prompts
|
| 852 |
|
| 853 |
# Reset inference session tracking data but keep cache and processed frames
|
| 854 |
+
if hasattr(state.inference_session, "reset_tracking_data"):
|
| 855 |
+
state.inference_session.reset_tracking_data()
|
| 856 |
|
| 857 |
# Manually clear prompts (reset_tracking_data doesn't clear prompts themselves)
|
| 858 |
+
if hasattr(state.inference_session, "prompts"):
|
| 859 |
+
state.inference_session.prompts.clear()
|
| 860 |
+
if hasattr(state.inference_session, "prompt_input_ids"):
|
| 861 |
+
state.inference_session.prompt_input_ids.clear()
|
| 862 |
+
if hasattr(state.inference_session, "prompt_embeddings"):
|
| 863 |
+
state.inference_session.prompt_embeddings.clear()
|
| 864 |
+
if hasattr(state.inference_session, "prompt_attention_masks"):
|
| 865 |
+
state.inference_session.prompt_attention_masks.clear()
|
| 866 |
+
if hasattr(state.inference_session, "obj_id_to_prompt_id"):
|
| 867 |
+
state.inference_session.obj_id_to_prompt_id.clear()
|
| 868 |
|
| 869 |
# Reset detection-tracking fusion state
|
| 870 |
+
if hasattr(state.inference_session, "obj_id_to_score"):
|
| 871 |
+
state.inference_session.obj_id_to_score.clear()
|
| 872 |
+
if hasattr(state.inference_session, "obj_id_to_tracker_score_frame_wise"):
|
| 873 |
+
state.inference_session.obj_id_to_tracker_score_frame_wise.clear()
|
| 874 |
+
if hasattr(state.inference_session, "obj_id_to_last_occluded"):
|
| 875 |
+
state.inference_session.obj_id_to_last_occluded.clear()
|
| 876 |
+
if hasattr(state.inference_session, "max_obj_id"):
|
| 877 |
+
state.inference_session.max_obj_id = -1
|
| 878 |
+
if hasattr(state.inference_session, "obj_first_frame_idx"):
|
| 879 |
+
state.inference_session.obj_first_frame_idx.clear()
|
| 880 |
+
if hasattr(state.inference_session, "unmatched_frame_inds"):
|
| 881 |
+
state.inference_session.unmatched_frame_inds.clear()
|
| 882 |
+
if hasattr(state.inference_session, "overlap_pair_to_frame_inds"):
|
| 883 |
+
state.inference_session.overlap_pair_to_frame_inds.clear()
|
| 884 |
+
if hasattr(state.inference_session, "trk_keep_alive"):
|
| 885 |
+
state.inference_session.trk_keep_alive.clear()
|
| 886 |
+
if hasattr(state.inference_session, "removed_obj_ids"):
|
| 887 |
+
state.inference_session.removed_obj_ids.clear()
|
| 888 |
+
if hasattr(state.inference_session, "suppressed_obj_ids"):
|
| 889 |
+
state.inference_session.suppressed_obj_ids.clear()
|
| 890 |
+
if hasattr(state.inference_session, "hotstart_removed_obj_ids"):
|
| 891 |
+
state.inference_session.hotstart_removed_obj_ids.clear()
|
| 892 |
|
| 893 |
# Clear all app state outputs
|
| 894 |
+
state.masks_by_frame.clear()
|
| 895 |
+
state.text_prompts_by_frame_obj.clear()
|
| 896 |
+
state.composited_frames.clear()
|
| 897 |
+
state.color_by_obj.clear()
|
| 898 |
+
state.color_by_prompt.clear()
|
| 899 |
|
| 900 |
# Update display
|
| 901 |
+
current_idx = int(getattr(state, "current_frame_idx", 0))
|
| 902 |
+
current_idx = max(0, min(current_idx, state.num_frames - 1))
|
| 903 |
+
preview_img = update_frame_display(state, current_idx)
|
| 904 |
+
active_prompts = _get_active_prompts_display(state)
|
| 905 |
status = "Prompts and outputs reset. Processed frames and cached vision features preserved."
|
| 906 |
|
| 907 |
+
return state, preview_img, status, active_prompts
|
| 908 |
|
| 909 |
|
| 910 |
+
def reset_session(state: AppState) -> tuple[AppState, Image.Image, int, int, str, str]:
|
| 911 |
+
if not state.video_frames:
|
| 912 |
+
return state, None, 0, 0, "Session reset. Load a new video.", "**Active prompts:** None"
|
| 913 |
|
| 914 |
+
if state.active_tab == "text":
|
| 915 |
+
if state.video_frames:
|
| 916 |
+
processor = TEXT_VIDEO_PROCESSOR
|
| 917 |
+
state.inference_session = processor.init_video_session(
|
| 918 |
+
video=state.video_frames,
|
| 919 |
+
inference_device=DEVICE,
|
| 920 |
processing_device="cpu",
|
| 921 |
video_storage_device="cpu",
|
| 922 |
+
dtype=DTYPE,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 923 |
)
|
| 924 |
+
elif state.inference_session is not None and hasattr(state.inference_session, "reset_inference_session"):
|
| 925 |
+
state.inference_session.reset_inference_session()
|
| 926 |
+
elif state.video_frames:
|
| 927 |
+
processor = TRACKER_PROCESSOR
|
| 928 |
+
raw_video = [np.array(frame) for frame in state.video_frames]
|
| 929 |
+
state.inference_session = processor.init_video_session(
|
| 930 |
+
video=raw_video,
|
| 931 |
+
inference_device=DEVICE,
|
| 932 |
+
video_storage_device="cpu",
|
| 933 |
+
processing_device="cpu",
|
| 934 |
+
dtype=DTYPE,
|
| 935 |
+
)
|
| 936 |
|
| 937 |
+
state.masks_by_frame.clear()
|
| 938 |
+
state.clicks_by_frame_obj.clear()
|
| 939 |
+
state.boxes_by_frame_obj.clear()
|
| 940 |
+
state.text_prompts_by_frame_obj.clear()
|
| 941 |
+
state.composited_frames.clear()
|
| 942 |
+
state.color_by_obj.clear()
|
| 943 |
+
state.color_by_prompt.clear()
|
| 944 |
+
state.pending_box_start = None
|
| 945 |
+
state.pending_box_start_frame_idx = None
|
| 946 |
+
state.pending_box_start_obj_id = None
|
| 947 |
|
| 948 |
gc.collect()
|
| 949 |
|
| 950 |
+
current_idx = int(getattr(state, "current_frame_idx", 0))
|
| 951 |
+
current_idx = max(0, min(current_idx, state.num_frames - 1))
|
| 952 |
+
preview_img = update_frame_display(state, current_idx)
|
| 953 |
+
slider_minmax = gr.update(minimum=0, maximum=max(state.num_frames - 1, 0), interactive=True)
|
| 954 |
slider_value = gr.update(value=current_idx)
|
| 955 |
status = "Session reset. Prompts cleared; video preserved."
|
| 956 |
+
active_prompts = _get_active_prompts_display(state)
|
| 957 |
+
return state, preview_img, slider_minmax, slider_value, status, active_prompts
|
| 958 |
|
| 959 |
|
| 960 |
+
def _on_video_change_pointbox(state: AppState, video: str | dict) -> tuple[AppState, dict, Image.Image, str]:
|
| 961 |
+
state, min_idx, max_idx, first_frame, status = init_video_session(state, video, "point_box")
|
| 962 |
return (
|
| 963 |
+
state,
|
| 964 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
| 965 |
first_frame,
|
| 966 |
status,
|
| 967 |
)
|
| 968 |
|
| 969 |
|
| 970 |
+
def _on_video_change_text(state: AppState, video: str | dict) -> tuple[AppState, dict, Image.Image, str, str]:
|
| 971 |
+
if video is None:
|
| 972 |
+
return state, None, None, None, None
|
| 973 |
+
state, min_idx, max_idx, first_frame, status = init_video_session(state, video, "text")
|
| 974 |
+
active_prompts = _get_active_prompts_display(state)
|
| 975 |
return (
|
| 976 |
+
state,
|
| 977 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
| 978 |
first_frame,
|
| 979 |
status,
|
|
|
|
| 981 |
)
|
| 982 |
|
| 983 |
|
| 984 |
+
with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")) as demo:
|
| 985 |
+
app_state = gr.State(AppState())
|
|
|
|
|
|
|
| 986 |
|
| 987 |
gr.Markdown(
|
| 988 |
"""
|
|
|
|
| 1014 |
|
| 1015 |
with gr.Row():
|
| 1016 |
with gr.Column(scale=1):
|
| 1017 |
+
video_in_text = gr.Video(label="Upload video", sources=["upload", "webcam"])
|
| 1018 |
load_status_text = gr.Markdown(visible=True)
|
| 1019 |
reset_btn_text = gr.Button("Reset Session", variant="secondary")
|
| 1020 |
with gr.Column(scale=2):
|
| 1021 |
+
preview_text = gr.Image(label="Preview")
|
| 1022 |
with gr.Row():
|
| 1023 |
+
frame_slider_text = gr.Slider(label="Frame", minimum=0, maximum=0, step=1, value=0)
|
|
|
|
|
|
|
| 1024 |
with gr.Column(scale=0):
|
| 1025 |
propagate_btn_text = gr.Button("Propagate across video", variant="primary")
|
| 1026 |
propagate_status_text = gr.Markdown(visible=True)
|
|
|
|
| 1047 |
]
|
| 1048 |
with gr.Row():
|
| 1049 |
gr.Examples(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1050 |
label="Examples",
|
| 1051 |
+
examples=examples_list_text,
|
| 1052 |
+
inputs=[app_state, video_in_text],
|
| 1053 |
examples_per_page=5,
|
| 1054 |
)
|
| 1055 |
|
|
|
|
| 1075 |
|
| 1076 |
with gr.Row():
|
| 1077 |
with gr.Column(scale=1):
|
| 1078 |
+
video_in_pointbox = gr.Video(label="Upload video", sources=["upload", "webcam"], max_length=7)
|
|
|
|
|
|
|
| 1079 |
load_status_pointbox = gr.Markdown(visible=True)
|
| 1080 |
reset_btn_pointbox = gr.Button("Reset Session", variant="secondary")
|
| 1081 |
with gr.Column(scale=2):
|
| 1082 |
+
preview_pointbox = gr.Image(label="Preview")
|
| 1083 |
with gr.Row():
|
| 1084 |
+
frame_slider_pointbox = gr.Slider(label="Frame", minimum=0, maximum=0, step=1, value=0)
|
|
|
|
|
|
|
| 1085 |
with gr.Column(scale=0):
|
| 1086 |
propagate_btn_pointbox = gr.Button("Propagate across video", variant="primary")
|
| 1087 |
propagate_status_pointbox = gr.Markdown(visible=True)
|
|
|
|
| 1103 |
]
|
| 1104 |
with gr.Row():
|
| 1105 |
gr.Examples(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1106 |
label="Examples",
|
| 1107 |
+
examples=examples_list_pointbox,
|
| 1108 |
+
inputs=[app_state, video_in_pointbox],
|
| 1109 |
examples_per_page=5,
|
| 1110 |
)
|
| 1111 |
|
| 1112 |
video_in_pointbox.change(
|
| 1113 |
+
fn=_on_video_change_pointbox,
|
| 1114 |
+
inputs=[app_state, video_in_pointbox],
|
| 1115 |
+
outputs=[app_state, frame_slider_pointbox, preview_pointbox, load_status_pointbox],
|
| 1116 |
show_progress=True,
|
| 1117 |
)
|
| 1118 |
|
| 1119 |
+
def _sync_frame_idx_pointbox(state_in: AppState, idx: int) -> Image.Image:
|
| 1120 |
if state_in is not None:
|
| 1121 |
state_in.current_frame_idx = int(idx)
|
| 1122 |
return update_frame_display(state_in, int(idx))
|
| 1123 |
|
| 1124 |
frame_slider_pointbox.change(
|
| 1125 |
+
fn=_sync_frame_idx_pointbox,
|
| 1126 |
+
inputs=[app_state, frame_slider_pointbox],
|
| 1127 |
outputs=preview_pointbox,
|
| 1128 |
)
|
| 1129 |
|
| 1130 |
video_in_text.change(
|
| 1131 |
+
fn=_on_video_change_text,
|
| 1132 |
+
inputs=[app_state, video_in_text],
|
| 1133 |
+
outputs=[app_state, frame_slider_text, preview_text, load_status_text, active_prompts_display],
|
| 1134 |
show_progress=True,
|
| 1135 |
)
|
| 1136 |
|
| 1137 |
+
def _sync_frame_idx_text(state_in: AppState, idx: int) -> Image.Image:
|
| 1138 |
if state_in is not None:
|
| 1139 |
state_in.current_frame_idx = int(idx)
|
| 1140 |
return update_frame_display(state_in, int(idx))
|
| 1141 |
|
| 1142 |
frame_slider_text.change(
|
| 1143 |
+
fn=_sync_frame_idx_text,
|
| 1144 |
+
inputs=[app_state, frame_slider_text],
|
| 1145 |
outputs=preview_text,
|
| 1146 |
)
|
| 1147 |
|
| 1148 |
+
def _sync_obj_id(s: AppState, oid: int) -> None:
|
| 1149 |
if s is not None and oid is not None:
|
| 1150 |
s.current_obj_id = int(oid)
|
|
|
|
| 1151 |
|
| 1152 |
+
obj_id_inp.change(
|
| 1153 |
+
fn=_sync_obj_id,
|
| 1154 |
+
inputs=[app_state, obj_id_inp],
|
| 1155 |
+
)
|
| 1156 |
|
| 1157 |
+
def _sync_label(s: AppState, lab: str) -> None:
|
| 1158 |
if s is not None and lab is not None:
|
| 1159 |
s.current_label = str(lab)
|
|
|
|
| 1160 |
|
| 1161 |
+
label_radio.change(
|
| 1162 |
+
fn=_sync_label,
|
| 1163 |
+
inputs=[app_state, label_radio],
|
| 1164 |
+
)
|
| 1165 |
|
| 1166 |
+
def _sync_prompt_type(s: AppState, val: str) -> tuple[dict, dict]:
|
| 1167 |
if s is not None and val is not None:
|
| 1168 |
s.current_prompt_type = str(val)
|
| 1169 |
s.pending_box_start = None
|
| 1170 |
is_points = str(val).lower() == "points"
|
| 1171 |
+
return (
|
| 1172 |
gr.update(visible=is_points),
|
| 1173 |
gr.update(interactive=is_points) if is_points else gr.update(value=True, interactive=False),
|
| 1174 |
+
)
|
|
|
|
| 1175 |
|
| 1176 |
prompt_type.change(
|
| 1177 |
+
fn=_sync_prompt_type,
|
| 1178 |
+
inputs=[app_state, prompt_type],
|
| 1179 |
outputs=[label_radio, clear_old_chk],
|
| 1180 |
)
|
| 1181 |
|
| 1182 |
preview_pointbox.select(
|
| 1183 |
+
fn=on_image_click,
|
| 1184 |
+
inputs=[preview_pointbox, app_state, frame_slider_pointbox, obj_id_inp, label_radio, clear_old_chk],
|
| 1185 |
+
outputs=[preview_pointbox, app_state],
|
| 1186 |
)
|
| 1187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1188 |
text_apply_btn.click(
|
| 1189 |
+
fn=on_text_prompt,
|
| 1190 |
+
inputs=[app_state, frame_slider_text, text_prompt_input],
|
| 1191 |
+
outputs=[preview_text, text_status, active_prompts_display, app_state],
|
| 1192 |
)
|
| 1193 |
|
| 1194 |
reset_prompts_btn.click(
|
| 1195 |
+
fn=reset_prompts,
|
| 1196 |
+
inputs=app_state,
|
| 1197 |
+
outputs=[app_state, preview_text, text_status, active_prompts_display],
|
| 1198 |
)
|
| 1199 |
|
| 1200 |
+
def _render_video(s: AppState) -> str:
|
| 1201 |
if s is None or s.num_frames == 0:
|
| 1202 |
raise gr.Error("Load a video first.")
|
| 1203 |
fps = s.video_fps if s.video_fps and s.video_fps > 0 else 12
|
|
|
|
| 1211 |
frames_np.append(np.array(img)[:, :, ::-1])
|
| 1212 |
if (idx + 1) % 60 == 0:
|
| 1213 |
gc.collect()
|
|
|
|
| 1214 |
try:
|
| 1215 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as out_path:
|
| 1216 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 1217 |
+
writer = cv2.VideoWriter(out_path.name, fourcc, fps, (w, h))
|
| 1218 |
+
for fr_bgr in frames_np:
|
| 1219 |
+
writer.write(fr_bgr)
|
| 1220 |
+
writer.release()
|
| 1221 |
+
return out_path.name
|
| 1222 |
except Exception as e:
|
| 1223 |
print(f"Failed to render video with cv2: {e}")
|
| 1224 |
raise gr.Error(f"Failed to render video: {e}")
|
| 1225 |
|
| 1226 |
+
render_btn_pointbox.click(
|
| 1227 |
+
fn=_render_video,
|
| 1228 |
+
inputs=app_state,
|
| 1229 |
+
outputs=playback_video_pointbox,
|
| 1230 |
+
)
|
| 1231 |
+
render_btn_text.click(
|
| 1232 |
+
fn=_render_video,
|
| 1233 |
+
inputs=app_state,
|
| 1234 |
+
outputs=playback_video_text,
|
| 1235 |
+
)
|
| 1236 |
|
| 1237 |
propagate_btn_pointbox.click(
|
| 1238 |
+
fn=propagate_masks,
|
| 1239 |
+
inputs=app_state,
|
| 1240 |
+
outputs=[app_state, propagate_status_pointbox, frame_slider_pointbox],
|
| 1241 |
)
|
| 1242 |
|
| 1243 |
propagate_btn_text.click(
|
| 1244 |
+
fn=propagate_masks,
|
| 1245 |
+
inputs=app_state,
|
| 1246 |
+
outputs=[app_state, propagate_status_text, frame_slider_text],
|
| 1247 |
)
|
| 1248 |
|
| 1249 |
reset_btn_pointbox.click(
|
| 1250 |
+
fn=reset_session,
|
| 1251 |
+
inputs=app_state,
|
| 1252 |
+
outputs=[app_state, preview_pointbox, frame_slider_pointbox, frame_slider_pointbox, load_status_pointbox],
|
| 1253 |
)
|
| 1254 |
|
| 1255 |
reset_btn_text.click(
|
| 1256 |
+
fn=reset_session,
|
| 1257 |
+
inputs=app_state,
|
| 1258 |
outputs=[
|
| 1259 |
+
app_state,
|
| 1260 |
preview_text,
|
| 1261 |
frame_slider_text,
|
| 1262 |
frame_slider_text,
|
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "sam3-video-segmentation"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"accelerate>=1.11.0",
|
| 9 |
+
"gradio>=5.49.1",
|
| 10 |
+
"imageio[pyav]>=2.37.2",
|
| 11 |
+
"kernels>=0.11.0",
|
| 12 |
+
"opencv-python>=4.12.0.88",
|
| 13 |
+
"spaces>=0.42.1",
|
| 14 |
+
"torch==2.8.0",
|
| 15 |
+
"torchvision>=0.23.0",
|
| 16 |
+
"transformers",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[tool.ruff]
|
| 20 |
+
line-length = 119
|
| 21 |
+
|
| 22 |
+
[tool.ruff.lint]
|
| 23 |
+
select = ["ALL"]
|
| 24 |
+
ignore = [
|
| 25 |
+
"COM812", # missing-trailing-comma
|
| 26 |
+
"D203", # one-blank-line-before-class
|
| 27 |
+
"D213", # multi-line-summary-second-line
|
| 28 |
+
"E501", # line-too-long
|
| 29 |
+
"SIM117", # multiple-with-statements
|
| 30 |
+
#
|
| 31 |
+
"D100", # undocumented-public-module
|
| 32 |
+
"D101", # undocumented-public-class
|
| 33 |
+
"D102", # undocumented-public-method
|
| 34 |
+
"D103", # undocumented-public-function
|
| 35 |
+
"D104", # undocumented-public-package
|
| 36 |
+
"D105", # undocumented-magic-method
|
| 37 |
+
"D107", # undocumented-public-init
|
| 38 |
+
"EM101", # raw-string-in-exception
|
| 39 |
+
"FBT001", # boolean-type-hint-positional-argument
|
| 40 |
+
"FBT002", # boolean-default-value-positional-argument
|
| 41 |
+
"PGH003", # blanket-type-ignore
|
| 42 |
+
"PLR0913", # too-many-arguments
|
| 43 |
+
"PLR0915", # too-many-statements
|
| 44 |
+
"TRY003", # raise-vanilla-args
|
| 45 |
+
]
|
| 46 |
+
unfixable = [
|
| 47 |
+
"F401", # unused-import
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
[tool.ruff.lint.pydocstyle]
|
| 51 |
+
convention = "google"
|
| 52 |
+
|
| 53 |
+
[tool.ruff.lint.per-file-ignores]
|
| 54 |
+
"*.ipynb" = ["T201", "T203"]
|
| 55 |
+
|
| 56 |
+
[tool.ruff.format]
|
| 57 |
+
docstring-code-format = true
|
| 58 |
+
|
| 59 |
+
[tool.uv.sources]
|
| 60 |
+
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "69f003696b" }
|
|
The diff for this file is too large to render.
See raw diff
|
|
|