gbreadman13code commited on
Commit
4f2b4bb
·
1 Parent(s): 4518f25

Deploy SAM2 segmentation API

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +49 -0
  2. README.md +167 -4
  3. app.py +1109 -0
  4. download_model.py +80 -0
  5. requirements.txt +10 -0
  6. sam2_repo/README.md +224 -0
  7. sam2_repo/checkpoints/download_ckpts.sh +59 -0
  8. sam2_repo/pyproject.toml +6 -0
  9. sam2_repo/sam2/__init__.py +11 -0
  10. sam2_repo/sam2/__pycache__/__init__.cpython-313.pyc +0 -0
  11. sam2_repo/sam2/__pycache__/build_sam.cpython-313.pyc +0 -0
  12. sam2_repo/sam2/__pycache__/sam2_image_predictor.cpython-313.pyc +0 -0
  13. sam2_repo/sam2/automatic_mask_generator.py +454 -0
  14. sam2_repo/sam2/benchmark.py +92 -0
  15. sam2_repo/sam2/build_sam.py +174 -0
  16. sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  17. sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  18. sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  19. sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  20. sam2_repo/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  21. sam2_repo/sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  22. sam2_repo/sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  23. sam2_repo/sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  24. sam2_repo/sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  25. sam2_repo/sam2/csrc/connected_components.cu +289 -0
  26. sam2_repo/sam2/modeling/__init__.py +5 -0
  27. sam2_repo/sam2/modeling/__pycache__/__init__.cpython-313.pyc +0 -0
  28. sam2_repo/sam2/modeling/__pycache__/memory_attention.cpython-313.pyc +0 -0
  29. sam2_repo/sam2/modeling/__pycache__/memory_encoder.cpython-313.pyc +0 -0
  30. sam2_repo/sam2/modeling/__pycache__/position_encoding.cpython-313.pyc +0 -0
  31. sam2_repo/sam2/modeling/__pycache__/sam2_base.cpython-313.pyc +0 -0
  32. sam2_repo/sam2/modeling/__pycache__/sam2_utils.cpython-313.pyc +0 -0
  33. sam2_repo/sam2/modeling/backbones/__init__.py +5 -0
  34. sam2_repo/sam2/modeling/backbones/__pycache__/__init__.cpython-313.pyc +0 -0
  35. sam2_repo/sam2/modeling/backbones/__pycache__/hieradet.cpython-313.pyc +0 -0
  36. sam2_repo/sam2/modeling/backbones/__pycache__/image_encoder.cpython-313.pyc +0 -0
  37. sam2_repo/sam2/modeling/backbones/__pycache__/utils.cpython-313.pyc +0 -0
  38. sam2_repo/sam2/modeling/backbones/hieradet.py +317 -0
  39. sam2_repo/sam2/modeling/backbones/image_encoder.py +134 -0
  40. sam2_repo/sam2/modeling/backbones/utils.py +93 -0
  41. sam2_repo/sam2/modeling/memory_attention.py +169 -0
  42. sam2_repo/sam2/modeling/memory_encoder.py +181 -0
  43. sam2_repo/sam2/modeling/position_encoding.py +239 -0
  44. sam2_repo/sam2/modeling/sam/__init__.py +5 -0
  45. sam2_repo/sam2/modeling/sam/__pycache__/__init__.cpython-313.pyc +0 -0
  46. sam2_repo/sam2/modeling/sam/__pycache__/mask_decoder.cpython-313.pyc +0 -0
  47. sam2_repo/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-313.pyc +0 -0
  48. sam2_repo/sam2/modeling/sam/__pycache__/transformer.cpython-313.pyc +0 -0
  49. sam2_repo/sam2/modeling/sam/mask_decoder.py +295 -0
  50. sam2_repo/sam2/modeling/sam/prompt_encoder.py +202 -0
Dockerfile ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile для Hugging Face Spaces
2
+ # Оптимизирован для CPU
3
+
4
+ FROM python:3.10-slim
5
+
6
+ WORKDIR /app
7
+
8
+ # Системные зависимости для OpenCV и SAM2
9
+ RUN apt-get update && apt-get install -y \
10
+ git \
11
+ wget \
12
+ build-essential \
13
+ libglib2.0-0 \
14
+ libsm6 \
15
+ libxext6 \
16
+ libxrender-dev \
17
+ libgomp1 \
18
+ libgl1-mesa-glx \
19
+ && rm -rf /var/lib/apt/lists/*
20
+
21
+ # Копируем requirements
22
+ COPY requirements.txt .
23
+
24
+ # Устанавливаем Python зависимости
25
+ RUN pip install --no-cache-dir -r requirements.txt
26
+
27
+ # Копируем код приложения
28
+ COPY app.py .
29
+ COPY download_model.py .
30
+ COPY web_demo.html .
31
+ COPY web_demo_advanced.html .
32
+
33
+ # Копируем и устанавливаем SAM2
34
+ COPY sam2_repo sam2_repo
35
+ RUN cd sam2_repo && pip install --no-cache-dir -e .
36
+
37
+ # Создаем папку для моделей
38
+ RUN mkdir -p checkpoints
39
+
40
+ # Скачиваем tiny модель (самая легкая для CPU)
41
+ RUN python download_model.py tiny
42
+
43
+ # Hugging Face Spaces использует порт 7860
44
+ ENV PORT=7860
45
+ EXPOSE 7860
46
+
47
+ # Запуск с указанием хоста и порта
48
+ CMD ["sh", "-c", "python -c 'import uvicorn; uvicorn.run(\"app:app\", host=\"0.0.0.0\", port=${PORT})'"]
49
+
README.md CHANGED
@@ -1,11 +1,174 @@
1
  ---
2
- title: Sam2 Api
3
- emoji: 📈
4
  colorFrom: purple
5
- colorTo: pink
6
  sdk: docker
 
7
  pinned: false
8
  license: apache-2.0
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SAM2 Segmentation API
3
+ emoji: 🎯
4
  colorFrom: purple
5
+ colorTo: blue
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  license: apache-2.0
10
  ---
11
 
12
+ # 🎯 SAM2 Segmentation API
13
+
14
+ Мощный REST API для сегментации объектов на изображениях с использованием Meta SAM2 (Segment Anything Model 2).
15
+
16
+ ## ✨ Возможности
17
+
18
+ - **🎯 Box Prompts** - выделение прямоугольником
19
+ - **🖌️ Brush Prompts** - рисование кистью (зеленый = объект, красный = фон, белый = объект)
20
+ - **📍 Point Prompts** - клики по объектам
21
+ - **🔥 Batch API** - обработка множественных объектов за один запрос
22
+ - **🖼️ Extract Objects** - автоматическое извлечение объектов с прозрачностью
23
+ - **⚡ REST API** - полная документация в Swagger UI
24
+
25
+ ## 🚀 Быстрый старт
26
+
27
+ ### Web интерфейс
28
+
29
+ После запуска Space откройте:
30
+
31
+ - **Простой интерфейс**: `/web` - Box промпты
32
+ - **Продвинутый**: `/web/advanced` - Box + Brush промпты
33
+ - **API документация**: `/docs` - Swagger UI
34
+
35
+ ### API Endpoints
36
+
37
+ #### POST `/segment/batch` - Батчинг API (рекомендуется)
38
+
39
+ Обрабатывает множественные объекты за один запрос.
40
+
41
+ **Пример запроса:**
42
+ ```json
43
+ {
44
+ "image": "data:image/jpeg;base64,...",
45
+ "prompts": [
46
+ {
47
+ "id": 0,
48
+ "type": "mask",
49
+ "data": "data:image/png;base64,...",
50
+ "label": "person",
51
+ "selected": true
52
+ }
53
+ ],
54
+ "options": {
55
+ "extract_objects": true,
56
+ "include_masks": false,
57
+ "clean_masks": true
58
+ }
59
+ }
60
+ ```
61
+
62
+ #### POST `/segment` - Простая сегментация
63
+
64
+ С box промптом:
65
+ ```bash
66
+ curl -X POST "/segment?box_x1=50&box_y1=50&box_x2=300&box_y2=400&extract_objects=true" \
67
68
+ ```
69
+
70
+ ## 📊 Производительность
71
+
72
+ ⚠️ **CPU Version**: Работает на бесплатном CPU tier Hugging Face Spaces. Скорость обработки: ~5-10 секунд на изображение.
73
+
74
+ Для более быстрой обработки рекомендуется upgrade на GPU (Settings → Hardware).
75
+
76
+ ## 🎨 Форматы масок
77
+
78
+ API поддерживает несколько форматов масок:
79
+
80
+ - **🟢 Зеленый** (R<100, G>150, B<100) - foreground (объект)
81
+ - **⚪ Белый** (R>200, G>200, B>200) - foreground (объект)
82
+ - **🔴 Красный** (R>150, G<100, B<100) - background (исключить)
83
+
84
+ ## 🔧 Технологии
85
+
86
+ - Meta SAM2 2.1 (Segment Anything Model)
87
+ - FastAPI
88
+ - PyTorch
89
+ - OpenCV
90
+ - Pydantic
91
+
92
+ ## 📝 Примеры использования
93
+
94
+ ### Python
95
+
96
+ ```python
97
+ import requests
98
+ import base64
99
+
100
+ # Загрузить изображение
101
+ with open("image.jpg", "rb") as f:
102
+ image_b64 = base64.b64encode(f.read()).decode()
103
+
104
+ # Отправить запрос
105
+ response = requests.post(
106
+ "https://YOUR-SPACE.hf.space/segment/batch",
107
+ json={
108
+ "image": f"data:image/jpeg;base64,{image_b64}",
109
+ "prompts": [{
110
+ "id": 0,
111
+ "type": "box",
112
+ "data": "",
113
+ "bbox": {"x_min": 0.1, "y_min": 0.2, "x_max": 0.5, "y_max": 0.8},
114
+ "label": "person",
115
+ "selected": True
116
+ }],
117
+ "options": {"extract_objects": True}
118
+ }
119
+ )
120
+
121
+ result = response.json()
122
+ print(f"Обработано объектов: {len(result['results'])}")
123
+ ```
124
+
125
+ ### JavaScript
126
+
127
+ ```javascript
128
+ const response = await fetch('https://YOUR-SPACE.hf.space/segment/batch', {
129
+ method: 'POST',
130
+ headers: {'Content-Type': 'application/json'},
131
+ body: JSON.stringify({
132
+ image: imageBase64,
133
+ prompts: [{
134
+ id: 0,
135
+ type: "box",
136
+ data: "",
137
+ bbox: {x_min: 0.1, y_min: 0.2, x_max: 0.5, y_max: 0.8},
138
+ label: "person",
139
+ selected: true
140
+ }],
141
+ options: {extract_objects: true}
142
+ })
143
+ });
144
+
145
+ const result = await response.json();
146
+ console.log(`Обработано: ${result.results.length} объектов`);
147
+ ```
148
+
149
+ ## 📚 Документация
150
+
151
+ Полная интерактивная документация доступна по адресу `/docs` после запуска Space.
152
+
153
+ ## 🤝 Поддержка
154
+
155
+ - Модель: SAM 2.1 Hiera Tiny (для CPU)
156
+ - Форматы изображений: JPG, PNG, WEBP, BMP
157
+ - Максимальный размер: рекомендуется до 2048x2048px для разумной скорости
158
+
159
+ ## ⚡ Оптимизация для мобильных приложений
160
+
161
+ 1. Уменьшайте размер изоб��ажения перед отправкой (1024x1024)
162
+ 2. Используйте `include_masks: false` если контуры не нужны
163
+ 3. Кэшируйте результаты на клиенте
164
+ 4. Используйте батчинг API для множественных объектов
165
+
166
+ ## 📄 Лицензия
167
+
168
+ Apache 2.0
169
+
170
+ ## 🔗 Ссылки
171
+
172
+ - [SAM2 GitHub](https://github.com/facebookresearch/sam2)
173
+ - [SAM2 Paper](https://arxiv.org/abs/2408.00714)
174
+
app.py ADDED
@@ -0,0 +1,1109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ REST API сервер для сегментации изображений через SAM2.
3
+ Уставший сеньор кодит это в 3 часа ночи, поэтому код местами будет грязный.
4
+ """
5
+
6
+ from contextlib import asynccontextmanager
7
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Query, Body
8
+ from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel, Field
11
+ from PIL import Image
12
+ import numpy as np
13
+ import torch
14
+ import io
15
+ import os
16
+ import base64
17
+ import cv2
18
+ from typing import List, Dict, Any, Optional, Literal
19
+ import logging
20
+ from datetime import datetime
21
+ import json
22
+
23
+ # Настройка логирования, потому что дебажить это говно иначе невозможно
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Глобальные переменные для модели (лень каждый раз загружать)
28
+ predictor = None
29
+ device = None
30
+
31
+ # ===== Pydantic модели для батчинг API =====
32
+
33
+ class BBoxModel(BaseModel):
34
+ """Bounding box в нормализованных координатах (0.0 - 1.0) или пиксельных"""
35
+ x_min: float = Field(..., description="X координата левого верхнего угла")
36
+ y_min: float = Field(..., description="Y координата левого верхнего угла")
37
+ x_max: float = Field(..., description="X координата правого нижнего угла")
38
+ y_max: float = Field(..., description="Y координата правого нижнего угла")
39
+
40
+ class PromptModel(BaseModel):
41
+ """Промпт для сегментации одного объекта"""
42
+ id: int = Field(..., description="Уникальный ID объекта")
43
+ type: Literal["mask", "box", "points"] = Field(..., description="Тип промпта")
44
+ data: str = Field(..., description="Данные промпта (base64 для mask, JSON для points)")
45
+ bbox: Optional[BBoxModel] = Field(None, description="Опциональный bounding box")
46
+ label: Optional[str] = Field(None, description="Метка объекта (person, car, etc)")
47
+ selected: bool = Field(True, description="Обрабатывать ли этот промпт")
48
+
49
+ class SegmentOptionsModel(BaseModel):
50
+ """Опции сегментации"""
51
+ extract_objects: bool = Field(True, description="Вернуть вырезанные объекты")
52
+ include_masks: bool = Field(False, description="Включить контуры масок")
53
+ clean_masks: bool = Field(True, description="Очистить маски от артефактов")
54
+
55
+ class BatchSegmentRequest(BaseModel):
56
+ """Запрос на батчинг сегментацию"""
57
+ image: str = Field(..., description="Изображение в base64 (с data URL или без)")
58
+ prompts: List[PromptModel] = Field(..., description="Массив промптов")
59
+ options: Optional[SegmentOptionsModel] = Field(default_factory=SegmentOptionsModel)
60
+
61
+ class SegmentResultModel(BaseModel):
62
+ """Результат сегментации одного объекта"""
63
+ id: int
64
+ label: Optional[str] = None
65
+ bbox: Dict[str, Any]
66
+ area: int
67
+ center: Dict[str, int]
68
+ confidence: float
69
+ extracted_image: Optional[str] = None
70
+ contours: Optional[List[Dict[str, Any]]] = None
71
+ mask_rle: Optional[Dict[str, Any]] = None
72
+
73
+ class BatchSegmentResponse(BaseModel):
74
+ """Ответ батчинг сегментации"""
75
+ success: bool
76
+ image_size: Dict[str, int]
77
+ results: List[SegmentResultModel]
78
+
79
+ def save_batch_request_log(request_data: dict, response_data: dict, image_width: int, image_height: int):
80
+ """
81
+ Сохраняет запрос батчинга для аудита и дебага.
82
+ Создает папку с timestamp и сохраняет только метаданные:
83
+ 1. Лог запроса (request.json) - параметры без base64
84
+ 2. Лог ответа (response.json) - результаты без base64
85
+ 3. Краткую сводку (summary.json)
86
+
87
+ ⚠️ Изображения и маски НЕ сохраняются для безопасности!
88
+ """
89
+ try:
90
+ # Создаем корневую папку для логов
91
+ logs_dir = "batch_logs"
92
+ os.makedirs(logs_dir, exist_ok=True)
93
+
94
+ # Создаем папку с timestamp
95
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Миллисекунды
96
+ request_dir = os.path.join(logs_dir, timestamp)
97
+ os.makedirs(request_dir, exist_ok=True)
98
+
99
+ logger.info(f"📁 Сохраняю лог запроса в: {request_dir}")
100
+
101
+ # Сохраняем запрос (без base64 для безопасности)
102
+ request_log = {
103
+ "timestamp": timestamp,
104
+ "image_size": {
105
+ "width": image_width,
106
+ "height": image_height
107
+ },
108
+ "prompts": [
109
+ {
110
+ "id": p.get("id"),
111
+ "type": p.get("type"),
112
+ "label": p.get("label"),
113
+ "bbox": p.get("bbox"),
114
+ "selected": p.get("selected"),
115
+ "data_length": len(p.get("data", "")) # Длина вместо самих данных
116
+ }
117
+ for p in request_data.get("prompts", [])
118
+ ],
119
+ "options": request_data.get("options", {})
120
+ }
121
+
122
+ request_path = os.path.join(request_dir, "request.json")
123
+ with open(request_path, "w", encoding="utf-8") as f:
124
+ json.dump(request_log, f, indent=2, ensure_ascii=False)
125
+ logger.info(f" ✓ Сохранен лог запроса: {request_path}")
126
+
127
+ # 4. Сохраняем ответ (без base64 объектов)
128
+ response_log = {
129
+ "timestamp": timestamp,
130
+ "success": response_data.get("success"),
131
+ "image_size": response_data.get("image_size"),
132
+ "results": [
133
+ {
134
+ "id": r.get("id"),
135
+ "label": r.get("label"),
136
+ "bbox": r.get("bbox"),
137
+ "area": r.get("area"),
138
+ "center": r.get("center"),
139
+ "confidence": r.get("confidence"),
140
+ "has_extracted_image": "extracted_image" in r,
141
+ "has_contours": "contours" in r
142
+ }
143
+ for r in response_data.get("results", [])
144
+ ]
145
+ }
146
+
147
+ response_path = os.path.join(request_dir, "response.json")
148
+ with open(response_path, "w", encoding="utf-8") as f:
149
+ json.dump(response_log, f, indent=2, ensure_ascii=False)
150
+ logger.info(f" ✓ Сохранен лог ответа: {response_path}")
151
+
152
+ # 3. Создаем summary файл
153
+ summary = {
154
+ "timestamp": timestamp,
155
+ "processed_prompts": len(response_data.get("results", [])),
156
+ "total_prompts": len(request_data.get("prompts", [])),
157
+ "selected_prompts": len([p for p in request_data.get("prompts", []) if p.get("selected", True)]),
158
+ "image_size": f"{image_width}x{image_height}",
159
+ "prompt_types": [p.get("type") for p in request_data.get("prompts", [])],
160
+ "files": {
161
+ "request": "request.json",
162
+ "response": "response.json"
163
+ }
164
+ }
165
+
166
+ summary_path = os.path.join(request_dir, "summary.json")
167
+ with open(summary_path, "w", encoding="utf-8") as f:
168
+ json.dump(summary, f, indent=2, ensure_ascii=False)
169
+
170
+ logger.info(f"✅ Лог запроса сохранен: {request_dir}")
171
+
172
+ except Exception as e:
173
+ logger.error(f"❌ Ошибка при сохранении лога: {e}")
174
+ # Не прерываем обработку запроса если не удалось сохранить лог
175
+
176
+ def load_model(checkpoint_path: str = "checkpoints/sam2.1_hiera_tiny.pt"):
177
+ """
178
+ Загружает модель SAM2.
179
+ Вызывается один раз при старте сервера.
180
+ """
181
+ global predictor, device
182
+
183
+ try:
184
+ from sam2.build_sam import build_sam2
185
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
186
+
187
+ # Проверяем CUDA
188
+ device = "cuda" if torch.cuda.is_available() else "cpu"
189
+ logger.info(f"Используем устройство: {device}")
190
+
191
+ if device == "cpu":
192
+ logger.warning("CUDA недоступна, работаем на CPU (будет медленно как черепаха)")
193
+
194
+ # Определяем конфиг по имени файла чекпоинта
195
+ # Указываем путь относительно configs/ директории в пакете sam2
196
+ checkpoint_name = os.path.basename(checkpoint_path)
197
+ if "tiny" in checkpoint_name:
198
+ config = "configs/sam2.1/sam2.1_hiera_t.yaml"
199
+ elif "small" in checkpoint_name:
200
+ config = "configs/sam2.1/sam2.1_hiera_s.yaml"
201
+ elif "base_plus" in checkpoint_name:
202
+ config = "configs/sam2.1/sam2.1_hiera_b+.yaml"
203
+ elif "large" in checkpoint_name:
204
+ config = "configs/sam2.1/sam2.1_hiera_l.yaml"
205
+ else:
206
+ logger.warning(f"Неизвестный тип модели, пробую tiny конфиг")
207
+ config = "configs/sam2.1/sam2.1_hiera_t.yaml"
208
+
209
+ logger.info(f"Загружаю модель из {checkpoint_path}")
210
+ logger.info(f"Конфиг: {config}")
211
+
212
+ sam2_model = build_sam2(config, checkpoint_path, device=device)
213
+ predictor = SAM2ImagePredictor(sam2_model)
214
+
215
+ logger.info("✓ Модель загружена успешно")
216
+
217
+ except Exception as e:
218
+ logger.error(f"Не удалось загрузить модель: {e}")
219
+ logger.error("Убедись что SAM2 установлен (./install_sam2.sh)")
220
+ raise
221
+
222
+ @asynccontextmanager
223
+ async def lifespan(app: FastAPI):
224
+ """Загружаем модель при старте, выгружаем при остановке"""
225
+ # Startup
226
+ checkpoint_dir = "checkpoints"
227
+ if os.path.exists(checkpoint_dir):
228
+ checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
229
+ if checkpoints:
230
+ checkpoint_path = os.path.join(checkpoint_dir, checkpoints[0])
231
+ load_model(checkpoint_path)
232
+ else:
233
+ logger.error("Нет чекпоинтов в директории checkpoints/")
234
+ logger.error("Запусти: python download_model.py")
235
+ else:
236
+ logger.error("Директория checkpoints/ не найдена")
237
+
238
+ yield # Сервер работает
239
+
240
+ # Shutdown (если нужна очистка)
241
+
242
+ # Создаем FastAPI приложение с lifespan
243
+ app = FastAPI(
244
+ title="SAM2 Segmentation API",
245
+ description="API для автоматической сегментации объектов на изображениях",
246
+ version="1.0.0",
247
+ lifespan=lifespan
248
+ )
249
+
250
+ # Добавляем CORS для работы с веб-интерфейсом
251
+ app.add_middleware(
252
+ CORSMiddleware,
253
+ allow_origins=["*"], # В продакшене указать конкретные домены
254
+ allow_credentials=True,
255
+ allow_methods=["*"],
256
+ allow_headers=["*"],
257
+ )
258
+
259
+ @app.get("/")
260
+ async def root():
261
+ """Главная страница - информация об API"""
262
+ return {
263
+ "message": "SAM2 Segmentation API работает",
264
+ "version": "2.0.0",
265
+ "web_ui": {
266
+ "simple": "/web - Box промпты",
267
+ "advanced": "/web/advanced - Box + Brush промпты (рисование)"
268
+ },
269
+ "docs": "/docs",
270
+ "endpoints": {
271
+ "POST /segment": "Сегментация изображения (поддерживает points, box, mask via query params)",
272
+ "POST /segment/batch": "🔥 Батчинг сегментация (JSON API для множественных объектов)",
273
+ "POST /segment/auto": "Автоматическая сегментация всех объектов",
274
+ "GET /health": "Проверка здоровья сервиса"
275
+ }
276
+ }
277
+
278
+ @app.get("/web", response_class=HTMLResponse)
279
+ async def web_interface():
280
+ """Веб-интерфейс для тестирования Box Prompts (простой)"""
281
+ web_demo_path = os.path.join(os.path.dirname(__file__), "web_demo.html")
282
+ if os.path.exists(web_demo_path):
283
+ with open(web_demo_path, "r", encoding="utf-8") as f:
284
+ return f.read()
285
+ else:
286
+ return "<h1>Веб-интерфейс не найден</h1><p>Файл web_demo.html отсутствует</p>"
287
+
288
+ @app.get("/web/advanced", response_class=HTMLResponse)
289
+ async def web_interface_advanced():
290
+ """Продвинутый веб-интерфейс с Box + Brush промптами"""
291
+ web_demo_path = os.path.join(os.path.dirname(__file__), "web_demo_advanced.html")
292
+ if os.path.exists(web_demo_path):
293
+ with open(web_demo_path, "r", encoding="utf-8") as f:
294
+ return f.read()
295
+ else:
296
+ return "<h1>Продвинутый интерфейс не найден</h1><p>Файл web_demo_advanced.html отсутствует</p>"
297
+
298
+ @app.get("/health")
299
+ async def health():
300
+ """Проверка что всё ок"""
301
+ return {
302
+ "status": "healthy" if predictor is not None else "model not loaded",
303
+ "device": str(device) if device else "unknown"
304
+ }
305
+
306
+ def process_image(image_bytes: bytes) -> np.ndarray:
307
+ """Конвертирует байты в numpy array"""
308
+ image = Image.open(io.BytesIO(image_bytes))
309
+ if image.mode != "RGB":
310
+ image = image.convert("RGB")
311
+ return np.array(image)
312
+
313
+ def masks_to_coords(masks: np.ndarray, include_contours: bool = False) -> List[Dict[str, Any]]:
314
+ """
315
+ Конвертирует маски в координаты bounding box и контуров.
316
+ masks: (N, H, W) - N масок
317
+ include_contours: если True, добавляет контуры масок
318
+ """
319
+ results = []
320
+
321
+ for i, mask in enumerate(masks):
322
+ # Находим координаты пикселей маски
323
+ y_coords, x_coords = np.where(mask > 0)
324
+
325
+ if len(x_coords) == 0:
326
+ continue
327
+
328
+ # Bounding box
329
+ x_min, x_max = int(x_coords.min()), int(x_coords.max())
330
+ y_min, y_max = int(y_coords.min()), int(y_coords.max())
331
+
332
+ # Площадь сегмента
333
+ area = int(mask.sum())
334
+
335
+ segment_data = {
336
+ "segment_id": i,
337
+ "bbox": {
338
+ "x_min": x_min,
339
+ "y_min": y_min,
340
+ "x_max": x_max,
341
+ "y_max": y_max,
342
+ "width": x_max - x_min,
343
+ "height": y_max - y_min
344
+ },
345
+ "area": area,
346
+ "center": {
347
+ "x": int(x_coords.mean()),
348
+ "y": int(y_coords.mean())
349
+ }
350
+ }
351
+
352
+ # Добавляем контуры если нужно
353
+ if include_contours:
354
+ try:
355
+ # Конвертируем маску в uint8 (защита от булевых масок)
356
+ if mask.dtype == bool:
357
+ mask_uint8 = mask.astype(np.uint8) * 255
358
+ else:
359
+ mask_uint8 = (mask * 255).astype(np.uint8)
360
+
361
+ # Находим контуры с иерархией для поддержки "дыр"
362
+ # RETR_CCOMP: находит внешние контуры И внутренние дыры (holes)
363
+ # CHAIN_APPROX_NONE: сохраняет ВСЕ точки для pixel-perfect результата
364
+ contours, hierarchy = cv2.findContours(mask_uint8, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
365
+ except Exception as e:
366
+ logger.warning(f"Ошибка при извлечении контуров: {e}, использую fallback")
367
+ # Fallback на простое извлечение без иерархии
368
+ contours, hierarchy = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
369
+ hierarchy = None
370
+
371
+ # Конвертируем контуры в список точек с учетом иерархии
372
+ contour_data = []
373
+
374
+ if hierarchy is not None and len(contours) > 0:
375
+ hierarchy = hierarchy[0] # OpenCV возвращает hierarchy в странном формате
376
+
377
+ for i, contour in enumerate(contours):
378
+ try:
379
+ # Небольшое упрощение только для очень больших контуров
380
+ if len(contour) > 1000:
381
+ arc_length = cv2.arcLength(contour, True)
382
+ if arc_length > 0: # Защита от деления на 0
383
+ epsilon = 0.0005 * arc_length
384
+ approx = cv2.approxPolyDP(contour, epsilon, True)
385
+ else:
386
+ approx = contour
387
+ else:
388
+ approx = contour
389
+
390
+ # Конвертируем в список [x, y]
391
+ points = [[int(point[0][0]), int(point[0][1])] for point in approx]
392
+
393
+ if len(points) > 2:
394
+ # hierarchy[i] = [Next, Previous, First_Child, Parent]
395
+ # Если Parent == -1, это внешний контур
396
+ # Если Parent >= 0, это дыра (hole) внутри родительского контура
397
+ is_hole = hierarchy[i][3] != -1
398
+
399
+ contour_data.append({
400
+ "points": points,
401
+ "is_hole": is_hole
402
+ })
403
+ except Exception as e:
404
+ logger.warning(f"Ошибка при обработке контура {i}: {e}")
405
+ continue
406
+ else:
407
+ # Fallback если hierarchy не вернулась
408
+ for contour in contours:
409
+ try:
410
+ if len(contour) > 1000:
411
+ arc_length = cv2.arcLength(contour, True)
412
+ if arc_length > 0:
413
+ epsilon = 0.0005 * arc_length
414
+ approx = cv2.approxPolyDP(contour, epsilon, True)
415
+ else:
416
+ approx = contour
417
+ else:
418
+ approx = contour
419
+
420
+ points = [[int(point[0][0]), int(point[0][1])] for point in approx]
421
+ if len(points) > 2:
422
+ contour_data.append({
423
+ "points": points,
424
+ "is_hole": False
425
+ })
426
+ except Exception as e:
427
+ logger.warning(f"Ошибка при обработке контура: {e}")
428
+ continue
429
+
430
+ segment_data["contours"] = contour_data if len(contour_data) > 0 else []
431
+
432
+ # Также добавляем RLE (Run-Length Encoding) для компактного представления
433
+ # Это полезно если нужно восстановить точную маску
434
+ segment_data["mask_rle"] = mask_to_rle(mask)
435
+
436
+ results.append(segment_data)
437
+
438
+ return results
439
+
440
+ def mask_to_rle(mask: np.ndarray) -> Dict[str, Any]:
441
+ """
442
+ Конвертирует бинарную маску в RLE (Run-Length Encoding)
443
+ Компактное представление маски
444
+ """
445
+ # Конвертируем в int если это bool
446
+ if mask.dtype == bool:
447
+ pixels = mask.astype(np.uint8).flatten()
448
+ else:
449
+ pixels = mask.flatten()
450
+
451
+ pixels = np.concatenate([[0], pixels, [0]])
452
+ runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
453
+ runs[1::2] -= runs[::2]
454
+
455
+ return {
456
+ "counts": [int(x) for x in runs], # Конвертируем numpy int в Python int
457
+ "size": [int(x) for x in mask.shape] # Конвертируем в Python int
458
+ }
459
+
460
+ def convert_to_native_types(obj):
461
+ """
462
+ Рекурсивно конвертирует numpy типы в нативные Python типы
463
+ Нужно для сериализации в JSON через FastAPI
464
+ """
465
+ if isinstance(obj, np.integer):
466
+ return int(obj)
467
+ elif isinstance(obj, np.floating):
468
+ return float(obj)
469
+ elif isinstance(obj, np.ndarray):
470
+ return obj.tolist()
471
+ elif isinstance(obj, np.bool_):
472
+ return bool(obj)
473
+ elif isinstance(obj, dict):
474
+ return {key: convert_to_native_types(value) for key, value in obj.items()}
475
+ elif isinstance(obj, list):
476
+ return [convert_to_native_types(item) for item in obj]
477
+ return obj
478
+
479
+ def clean_mask(mask: np.ndarray, min_area: int = 100) -> np.ndarray:
480
+ """
481
+ Очищает маску от мелких артефактов и дыр.
482
+
483
+ mask: бинарная маска (H, W)
484
+ min_area: минимальная площадь компонента в пикселях
485
+
486
+ Returns: очищенная маска
487
+ """
488
+ # Конвертируем в uint8 если нужно
489
+ if mask.dtype == bool:
490
+ mask_uint8 = mask.astype(np.uint8) * 255
491
+ else:
492
+ mask_uint8 = (mask * 255).astype(np.uint8)
493
+
494
+ # Морфологическое закрытие для удаления мелких дыр
495
+ kernel = np.ones((3, 3), np.uint8)
496
+ mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel, iterations=2)
497
+
498
+ # Морфологическое открытие для удаления мелких шумов
499
+ mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_OPEN, kernel, iterations=1)
500
+
501
+ # Находим все связанные компоненты
502
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_uint8, connectivity=8)
503
+
504
+ # Создаем чистую маску
505
+ clean_mask = np.zeros_like(mask_uint8)
506
+
507
+ # Оставляем только большие компоненты
508
+ for i in range(1, num_labels): # Пропускаем фон (0)
509
+ area = stats[i, cv2.CC_STAT_AREA]
510
+ if area >= min_area:
511
+ clean_mask[labels == i] = 255
512
+
513
+ # Если ничего не осталось, возвращаем самый большой компонент
514
+ if clean_mask.sum() == 0 and num_labels > 1:
515
+ # Находим самый большой компонент
516
+ largest_component = 1 + np.argmax([stats[i, cv2.CC_STAT_AREA] for i in range(1, num_labels)])
517
+ clean_mask[labels == largest_component] = 255
518
+
519
+ return (clean_mask > 127).astype(bool)
520
+
521
+ def extract_object_image(image: np.ndarray, mask: np.ndarray, clean: bool = True) -> str:
522
+ """
523
+ Вырезает объект из изображения по маске и возвращает base64 PNG с прозрачностью.
524
+
525
+ image: RGB изображение (H, W, 3)
526
+ mask: бинарная маска (H, W)
527
+ clean: применить постобработку для удаления артефактов
528
+
529
+ Returns: base64 строка PNG изображения с альфа-каналом
530
+ """
531
+ # Конвертируем маску в bool если нужно
532
+ if mask.dtype != bool:
533
+ mask = mask > 0.5
534
+
535
+ # Очищаем маску от артефактов
536
+ if clean:
537
+ mask = clean_mask(mask, min_area=100)
538
+
539
+ # Создаем RGBA изображение
540
+ h, w = image.shape[:2]
541
+ rgba = np.zeros((h, w, 4), dtype=np.uint8)
542
+ rgba[:, :, :3] = image # RGB каналы
543
+ rgba[:, :, 3] = (mask * 255).astype(np.uint8) # Alpha канал из маски
544
+
545
+ # Конвертиру��м в PIL Image
546
+ pil_image = Image.fromarray(rgba, 'RGBA')
547
+
548
+ # Конвертируем в base64
549
+ buffer = io.BytesIO()
550
+ pil_image.save(buffer, format='PNG')
551
+ buffer.seek(0)
552
+ img_base64 = base64.b64encode(buffer.read()).decode('utf-8')
553
+
554
+ return f"data:image/png;base64,{img_base64}"
555
+
556
+ @app.post("/segment")
557
+ async def segment_image(
558
+ file: UploadFile = File(...),
559
+ point_x: List[float] = Query(None, description="X координаты точек промпта"),
560
+ point_y: List[float] = Query(None, description="Y координаты точек промпта"),
561
+ point_labels: List[int] = Query(None, description="Лейблы точек (1=foreground, 0=background)"),
562
+ box_x1: float = Query(None, description="X координата левого верхнего угла бокса"),
563
+ box_y1: float = Query(None, description="Y координата левого верхнего угла бокса"),
564
+ box_x2: float = Query(None, description="X координата правого нижнего угла бокса"),
565
+ box_y2: float = Query(None, description="Y координата правого нижнего угла бокса"),
566
+ mask_data: str = Query(None, description="Base64 закодированная маска (PNG с альфа-каналом)"),
567
+ include_masks: bool = Query(True, description="Включить контуры масок в ответ"),
568
+ extract_objects: bool = Query(False, description="Вернуть вырезанные объекты как base64 PNG"),
569
+ ):
570
+ """
571
+ Сегментирует изображение по промпту (точкам, боксу, маске или их комбинации).
572
+
573
+ Поддерживаемые промпты:
574
+ - Точки (point_x, point_y, point_labels) - клики пользователя
575
+ - Бокс (box_x1, box_y1, box_x2, box_y2) - прямоугольное выделение
576
+ - Маска (mask_data) - нарисованная кистью маска (зеленый=foreground, красный=background)
577
+ - Комбинация промптов - для максимальной точности
578
+
579
+ Если промпты не указаны, сегментирует центральный объект.
580
+ Если include_masks=True, возвращает контуры масок для точной отрисовки.
581
+ Если extract_objects=True, возвращает готовые вырезанные объекты как base64 PNG.
582
+ """
583
+ if predictor is None:
584
+ raise HTTPException(status_code=503, detail="Модель не загружена, перезапусти сервер")
585
+
586
+ try:
587
+ # Читаем изображение
588
+ image_bytes = await file.read()
589
+ image = process_image(image_bytes)
590
+
591
+ logger.info(f"Обрабатываю изображение: {image.shape}")
592
+ logger.info(f"Параметры: include_masks={include_masks}, extract_objects={extract_objects}")
593
+
594
+ # Устанавливаем изображение в предиктор
595
+ predictor.set_image(image)
596
+
597
+ # Подготавливаем промпты
598
+ points = None
599
+ labels = None
600
+ box = None
601
+
602
+ # Проверяем наличие точек
603
+ if point_x and point_y:
604
+ if len(point_x) != len(point_y):
605
+ raise HTTPException(status_code=400, detail="Количество X и Y координат должно совпадать")
606
+ points = np.array([[x, y] for x, y in zip(point_x, point_y)])
607
+ labels = np.array(point_labels) if point_labels else np.ones(len(points))
608
+ logger.info(f"Промпт: {len(points)} точек")
609
+
610
+ # Проверяем наличие бокса
611
+ if all(v is not None for v in [box_x1, box_y1, box_x2, box_y2]):
612
+ box = np.array([box_x1, box_y1, box_x2, box_y2])
613
+ logger.info(f"Промпт: бокс [{box_x1:.1f}, {box_y1:.1f}, {box_x2:.1f}, {box_y2:.1f}]")
614
+
615
+ # Валидация бокса
616
+ if box_x2 <= box_x1 or box_y2 <= box_y1:
617
+ raise HTTPException(
618
+ status_code=400,
619
+ detail="Некорректный бокс: x2 должен быть больше x1, y2 больше y1"
620
+ )
621
+
622
+ # Проверяем наличие нарисованной маски
623
+ if mask_data:
624
+ logger.info("Обрабатываю нарисованную маску...")
625
+ try:
626
+ # Декодируем base64
627
+ if ',' in mask_data:
628
+ mask_data = mask_data.split(',')[1] # Убираем data:image/png;base64,
629
+
630
+ mask_bytes = base64.b64decode(mask_data)
631
+ mask_image = Image.open(io.BytesIO(mask_bytes)).convert('RGBA')
632
+ mask_array = np.array(mask_image)
633
+
634
+ # Извлекаем foreground и background пиксели
635
+ # Поддерживаем несколько форматов:
636
+ # 1. Зеленый (R<100, G>150, B<100) - классический foreground
637
+ # 2. Белый/светлый (R>200, G>200, B>200) - часто используется фронтами
638
+ # 3. Красный (R>150, G<100, B<100) - background
639
+
640
+ green_mask = (mask_array[:, :, 0] < 100) & (mask_array[:, :, 1] > 150) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0)
641
+ white_mask = (mask_array[:, :, 0] > 200) & (mask_array[:, :, 1] > 200) & (mask_array[:, :, 2] > 200) & (mask_array[:, :, 3] > 0)
642
+ red_mask = (mask_array[:, :, 0] > 150) & (mask_array[:, :, 1] < 100) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0)
643
+
644
+ # Объединяем зеленые и белые как foreground
645
+ foreground_mask = green_mask | white_mask
646
+
647
+ # Сэмплируем точки из закрашенных областей
648
+ mask_points = []
649
+ mask_labels = []
650
+
651
+ # Foreground точки (зеленые + белые)
652
+ foreground_coords = np.argwhere(foreground_mask)
653
+ if len(foreground_coords) > 0:
654
+ # Масштабируем к размеру исходного изображения
655
+ scale_y = image.shape[0] / mask_array.shape[0]
656
+ scale_x = image.shape[1] / mask_array.shape[1]
657
+
658
+ # Сэмплируем до 20 точек равномерно (меньше = стабильнее)
659
+ step = max(1, len(foreground_coords) // 20)
660
+ sampled = foreground_coords[::step][:20] # Максимум 20 точек
661
+
662
+ for y, x in sampled:
663
+ mask_points.append([x * scale_x, y * scale_y])
664
+ mask_labels.append(1) # foreground
665
+
666
+ # Background точки (красные)
667
+ red_coords = np.argwhere(red_mask)
668
+ if len(red_coords) > 0:
669
+ scale_y = image.shape[0] / mask_array.shape[0]
670
+ scale_x = image.shape[1] / mask_array.shape[1]
671
+
672
+ step = max(1, len(red_coords) // 20)
673
+ sampled = red_coords[::step][:20] # Максимум 20 точек
674
+
675
+ for y, x in sampled:
676
+ mask_points.append([x * scale_x, y * scale_y])
677
+ mask_labels.append(0) # background
678
+
679
+ if mask_points:
680
+ # Объединяем с существующими точками
681
+ if points is not None:
682
+ points = np.vstack([points, np.array(mask_points)])
683
+ labels = np.concatenate([labels, np.array(mask_labels)])
684
+ else:
685
+ points = np.array(mask_points)
686
+ labels = np.array(mask_labels)
687
+
688
+ logger.info(f"Промпт из маски: {len(mask_points)} точек ({np.sum(np.array(mask_labels) == 1)} foreground, {np.sum(np.array(mask_labels) == 0)} background)")
689
+ else:
690
+ logger.warning("Маска пустая или не содержит foreground (зеленых/белых) или background (красных) пикселей")
691
+
692
+ except Exception as e:
693
+ logger.error(f"Ошибка обработки маски: {e}")
694
+ raise HTTPException(status_code=400, detail=f"Некорректная маска: {str(e)}")
695
+
696
+ # Делаем предсказание с промптами
697
+ if points is not None or box is not None:
698
+ logger.info(f"Используем промпты: points={points is not None}, box={box is not None}")
699
+
700
+ # Если много точек (>10), используем single mask для стабильности
701
+ # Если мало точек или только box, используем multimask для вариативности
702
+ use_multimask = True
703
+ if points is not None and len(points) > 10:
704
+ use_multimask = False
705
+ logger.info("Много точек, используем single mask mode для стабильности")
706
+
707
+ masks, scores, logits = predictor.predict(
708
+ point_coords=points,
709
+ point_labels=labels,
710
+ box=box,
711
+ multimask_output=use_multimask,
712
+ )
713
+
714
+ # Если multimask, выбираем лучшую по score
715
+ if use_multimask and len(masks) > 1:
716
+ best_idx = np.argmax(scores)
717
+ masks = masks[best_idx:best_idx+1]
718
+ scores = scores[best_idx:best_idx+1]
719
+ logger.info(f"Выбрана маска {best_idx} с confidence {scores[0]:.3f}")
720
+ else:
721
+ # Автоматическая сегментация - берем центральную точку
722
+ logger.info("Промпты не указаны, сегментирую центральный объект")
723
+ h, w = image.shape[:2]
724
+ point = np.array([[w // 2, h // 2]])
725
+ label = np.array([1])
726
+
727
+ masks, scores, logits = predictor.predict(
728
+ point_coords=point,
729
+ point_labels=label,
730
+ multimask_output=True,
731
+ )
732
+
733
+ # Конвертируем маски в координаты (с контурами если нужно)
734
+ segments = masks_to_coords(masks, include_contours=include_masks)
735
+
736
+ logger.info(f"Найдено сегментов: {len(segments)}, масок: {len(masks)}")
737
+ logger.info(f"extract_objects = {extract_objects}")
738
+
739
+ # Добавляем confidence scores
740
+ for i, seg in enumerate(segments):
741
+ seg["confidence"] = float(scores[i]) if i < len(scores) else 0.0
742
+
743
+ # Если нужно - вырезаем объект и добавляем base64
744
+ logger.info(f"Обрабатываю сегмент {i}: extract_objects={extract_objects}, i < len(masks) = {i < len(masks)}")
745
+ if extract_objects and i < len(masks):
746
+ logger.info(f"Вырезаю объект {i}...")
747
+ seg["extracted_image"] = extract_object_image(image, masks[i])
748
+ logger.info(f"✓ Вырезан объект {i}, размер маски: {masks[i].sum()} пикселей")
749
+ else:
750
+ logger.warning(f"❌ Пропускаю объект {i}: extract_objects={extract_objects}")
751
+
752
+ result = {
753
+ "success": True,
754
+ "image_size": {
755
+ "width": int(image.shape[1]),
756
+ "height": int(image.shape[0])
757
+ },
758
+ "segments_count": len(segments),
759
+ "segments": segments
760
+ }
761
+
762
+ # Конвертируем все numpy типы в нативные Python типы
763
+ return convert_to_native_types(result)
764
+
765
+ except Exception as e:
766
+ logger.error(f"Ошибка при сегментации: {e}")
767
+ raise HTTPException(status_code=500, detail=f"Ошибка обработки: {str(e)}")
768
+
769
+ @app.post("/segment/auto")
770
+ async def segment_auto(
771
+ file: UploadFile = File(...),
772
+ points_per_side: int = Query(32, description="Количество точек на сторону для автосегментации"),
773
+ include_masks: bool = Query(True, description="Включить контуры масок в ответ"),
774
+ ):
775
+ """
776
+ Автоматическая сегментация всех объектов на изображении.
777
+ Использует grid of points для поиска всех возможных объектов.
778
+ Если include_masks=True, возвращает контуры масок для точной отрисовки.
779
+ """
780
+ if predictor is None:
781
+ raise HTTPException(status_code=503, detail="Модель не загружена")
782
+
783
+ try:
784
+ image_bytes = await file.read()
785
+ image = process_image(image_bytes)
786
+
787
+ logger.info(f"Автосегментация изображения: {image.shape}")
788
+
789
+ predictor.set_image(image)
790
+
791
+ # Создаем сетку точек
792
+ h, w = image.shape[:2]
793
+ x_coords = np.linspace(0, w, points_per_side)
794
+ y_coords = np.linspace(0, h, points_per_side)
795
+
796
+ all_segments = []
797
+ segment_id = 0
798
+
799
+ # Для каждой точки в сетке пытаемся найти объект
800
+ for y in y_coords:
801
+ for x in x_coords:
802
+ point = np.array([[x, y]])
803
+ label = np.array([1])
804
+
805
+ masks, scores, _ = predictor.predict(
806
+ point_coords=point,
807
+ point_labels=label,
808
+ multimask_output=False,
809
+ )
810
+
811
+ if masks.shape[0] > 0 and scores[0] > 0.5: # Порог confidence
812
+ segments = masks_to_coords(masks, include_contours=include_masks)
813
+ for seg in segments:
814
+ seg["segment_id"] = segment_id
815
+ seg["confidence"] = float(scores[0])
816
+ all_segments.append(seg)
817
+ segment_id += 1
818
+
819
+ # Убир��ем дубликаты (примерно)
820
+ # Два сегмента считаем дубликатами если их центры близко
821
+ unique_segments = []
822
+ for seg in all_segments:
823
+ is_duplicate = False
824
+ for unique_seg in unique_segments:
825
+ dx = seg["center"]["x"] - unique_seg["center"]["x"]
826
+ dy = seg["center"]["y"] - unique_seg["center"]["y"]
827
+ dist = (dx**2 + dy**2) ** 0.5
828
+
829
+ if dist < 50: # Порог расстояния между центрами
830
+ is_duplicate = True
831
+ break
832
+
833
+ if not is_duplicate:
834
+ unique_segments.append(seg)
835
+
836
+ result = {
837
+ "success": True,
838
+ "image_size": {
839
+ "width": int(image.shape[1]),
840
+ "height": int(image.shape[0])
841
+ },
842
+ "segments_count": len(unique_segments),
843
+ "segments": unique_segments
844
+ }
845
+
846
+ # Конвертируем все numpy типы в нативные Python типы
847
+ return convert_to_native_types(result)
848
+
849
+ except Exception as e:
850
+ logger.error(f"Ошибка при автосегментации: {e}")
851
+ raise HTTPException(status_code=500, detail=f"Ошибка обработки: {str(e)}")
852
+
853
+ @app.post("/segment/batch", response_model=BatchSegmentResponse)
854
+ async def segment_batch(request: BatchSegmentRequest = Body(...)):
855
+ """
856
+ Батчинг сегментация нескольких объектов.
857
+
858
+ Принимает изображение и массив промптов (mask/box/points).
859
+ Обрабатывает каждый selected промпт отдельно.
860
+ Возвращает массив результатов с метаданными.
861
+
862
+ Идеально для:
863
+ - Множественных объектов
864
+ - Мобильных приложений
865
+ - Когда фронт уже разделил объекты
866
+ """
867
+ if predictor is None:
868
+ raise HTTPException(status_code=503, detail="Модель не загружена, перезапусти сервер")
869
+
870
+ try:
871
+ # Декодируем изображение из base64
872
+ image_data = request.image
873
+ if ',' in image_data:
874
+ image_data = image_data.split(',')[1] # Убираем data:image/...;base64,
875
+
876
+ image_bytes = base64.b64decode(image_data)
877
+ image = process_image(image_bytes)
878
+
879
+ logger.info(f"Батчинг сегментация: {image.shape}, промптов: {len(request.prompts)}")
880
+
881
+ # Устанавливаем изображение один раз
882
+ predictor.set_image(image)
883
+
884
+ results = []
885
+
886
+ # Фильтруем только selected промпты
887
+ selected_prompts = [p for p in request.prompts if p.selected]
888
+ logger.info(f"Обрабатываем {len(selected_prompts)} из {len(request.prompts)} промптов")
889
+
890
+ # Обрабатываем каждый промпт отдельно
891
+ for prompt in selected_prompts:
892
+ logger.info(f"Обрабатываю промпт #{prompt.id}, тип: {prompt.type}, label: {prompt.label}")
893
+
894
+ try:
895
+ # Подготавливаем промпт в зависимости от типа
896
+ points = None
897
+ labels = None
898
+ box = None
899
+
900
+ if prompt.type == "mask":
901
+ # Декодируем маску и извлекаем точки
902
+ mask_data = prompt.data
903
+ if ',' in mask_data:
904
+ mask_data = mask_data.split(',')[1]
905
+
906
+ mask_bytes = base64.b64decode(mask_data)
907
+ mask_image = Image.open(io.BytesIO(mask_bytes)).convert('RGBA')
908
+ mask_array = np.array(mask_image)
909
+
910
+ # Извлекаем foreground и background пиксели
911
+ # Поддерживаем несколько форматов:
912
+ # 1. Зеленый (R<100, G>150, B<100) - классический foreground
913
+ # 2. Белый/светлый (R>200, G>200, B>200) - часто используется фронтами
914
+ # 3. Красный (R>150, G<100, B<100) - background
915
+
916
+ green_mask = (mask_array[:, :, 0] < 100) & (mask_array[:, :, 1] > 150) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0)
917
+ white_mask = (mask_array[:, :, 0] > 200) & (mask_array[:, :, 1] > 200) & (mask_array[:, :, 2] > 200) & (mask_array[:, :, 3] > 0)
918
+ red_mask = (mask_array[:, :, 0] > 150) & (mask_array[:, :, 1] < 100) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0)
919
+
920
+ # Объединяем зеленые и белые как foreground
921
+ foreground_mask = green_mask | white_mask
922
+
923
+ mask_points = []
924
+ mask_labels = []
925
+
926
+ # Foreground точки (зеленые + белые)
927
+ foreground_coords = np.argwhere(foreground_mask)
928
+ if len(foreground_coords) > 0:
929
+ scale_y = image.shape[0] / mask_array.shape[0]
930
+ scale_x = image.shape[1] / mask_array.shape[1]
931
+ step = max(1, len(foreground_coords) // 20)
932
+ sampled = foreground_coords[::step][:20]
933
+
934
+ for y, x in sampled:
935
+ mask_points.append([x * scale_x, y * scale_y])
936
+ mask_labels.append(1)
937
+
938
+ # Background точки
939
+ red_coords = np.argwhere(red_mask)
940
+ if len(red_coords) > 0:
941
+ scale_y = image.shape[0] / mask_array.shape[0]
942
+ scale_x = image.shape[1] / mask_array.shape[1]
943
+ step = max(1, len(red_coords) // 20)
944
+ sampled = red_coords[::step][:20]
945
+
946
+ for y, x in sampled:
947
+ mask_points.append([x * scale_x, y * scale_y])
948
+ mask_labels.append(0)
949
+
950
+ if mask_points:
951
+ points = np.array(mask_points)
952
+ labels = np.array(mask_labels)
953
+
954
+ elif prompt.type == "box":
955
+ # Парсим bbox - может быть нормализованный (0-1) или пиксельный
956
+ bbox_data = prompt.bbox if prompt.bbox else None
957
+
958
+ if bbox_data:
959
+ x1 = bbox_data.x_min
960
+ y1 = bbox_data.y_min
961
+ x2 = bbox_data.x_max
962
+ y2 = bbox_data.y_max
963
+
964
+ # Если нормализованные координаты (0-1), конвертируем в пиксели
965
+ if x2 <= 1.0 and y2 <= 1.0:
966
+ x1 *= image.shape[1]
967
+ x2 *= image.shape[1]
968
+ y1 *= image.shape[0]
969
+ y2 *= image.shape[0]
970
+
971
+ box = np.array([x1, y1, x2, y2])
972
+
973
+ elif prompt.type == "points":
974
+ # Ожидаем JSON в формате [[x, y, label], ...]
975
+ import json
976
+ points_data = json.loads(prompt.data)
977
+
978
+ points_list = []
979
+ labels_list = []
980
+
981
+ for point in points_data:
982
+ x, y = point[0], point[1]
983
+ label = point[2] if len(point) > 2 else 1
984
+
985
+ # Если нормализованные, конвертируем
986
+ if x <= 1.0 and y <= 1.0:
987
+ x *= image.shape[1]
988
+ y *= image.shape[0]
989
+
990
+ points_list.append([x, y])
991
+ labels_list.append(label)
992
+
993
+ points = np.array(points_list)
994
+ labels = np.array(labels_list)
995
+
996
+ # Делаем предсказание
997
+ if points is not None or box is not None:
998
+ # Решаем использовать ли multimask
999
+ use_multimask = True
1000
+ if points is not None and len(points) > 10:
1001
+ use_multimask = False
1002
+
1003
+ masks, scores, logits = predictor.predict(
1004
+ point_coords=points,
1005
+ point_labels=labels,
1006
+ box=box,
1007
+ multimask_output=use_multimask,
1008
+ )
1009
+
1010
+ # Если multimask, выбираем лучшую
1011
+ if use_multimask and len(masks) > 1:
1012
+ best_idx = np.argmax(scores)
1013
+ masks = masks[best_idx:best_idx+1]
1014
+ scores = scores[best_idx:best_idx+1]
1015
+
1016
+ # Берем первую маску
1017
+ mask = masks[0]
1018
+ score = float(scores[0])
1019
+
1020
+ # Очищаем маску если нужно
1021
+ if request.options.clean_masks:
1022
+ mask = clean_mask(mask, min_area=100)
1023
+
1024
+ # Вычисляем метрики
1025
+ y_coords, x_coords = np.where(mask > 0)
1026
+
1027
+ if len(x_coords) > 0:
1028
+ x_min, x_max = int(x_coords.min()), int(x_coords.max())
1029
+ y_min, y_max = int(y_coords.min()), int(y_coords.max())
1030
+ area = int(mask.sum())
1031
+ center_x = int(x_coords.mean())
1032
+ center_y = int(y_coords.mean())
1033
+
1034
+ # Формируем результат
1035
+ result = {
1036
+ "id": prompt.id,
1037
+ "label": prompt.label,
1038
+ "bbox": {
1039
+ "x_min": x_min,
1040
+ "y_min": y_min,
1041
+ "x_max": x_max,
1042
+ "y_max": y_max,
1043
+ "width": x_max - x_min,
1044
+ "height": y_max - y_min
1045
+ },
1046
+ "area": area,
1047
+ "center": {
1048
+ "x": center_x,
1049
+ "y": center_y
1050
+ },
1051
+ "confidence": score
1052
+ }
1053
+
1054
+ # Добавляем вырезанный объект если нужно
1055
+ if request.options.extract_objects:
1056
+ result["extracted_image"] = extract_object_image(
1057
+ image, mask, clean=request.options.clean_masks
1058
+ )
1059
+
1060
+ # Добавляем контуры если нужно
1061
+ if request.options.include_masks:
1062
+ segments = masks_to_coords(masks, include_contours=True)
1063
+ if segments:
1064
+ result["contours"] = segments[0].get("contours", [])
1065
+ result["mask_rle"] = segments[0].get("mask_rle", {})
1066
+
1067
+ results.append(result)
1068
+ logger.info(f"✓ Промпт #{prompt.id} обработан, confidence: {score:.3f}")
1069
+ else:
1070
+ logger.warning(f"✗ Промпт #{prompt.id} не дал результата")
1071
+ else:
1072
+ logger.warning(f"✗ Промпт #{prompt.id}: нет данных для сегментации")
1073
+
1074
+ except Exception as e:
1075
+ logger.error(f"✗ Ошибка обработки промпта #{prompt.id}: {e}")
1076
+ # Продолжаем обработку остальных промптов
1077
+ continue
1078
+
1079
+ response = {
1080
+ "success": True,
1081
+ "image_size": {
1082
+ "width": int(image.shape[1]),
1083
+ "height": int(image.shape[0])
1084
+ },
1085
+ "results": results
1086
+ }
1087
+
1088
+ logger.info(f"Батчинг завершен: обработано {len(results)} объектов")
1089
+
1090
+ # Сохраняем лог запроса для аудита (только метаданные, без изображений)
1091
+ try:
1092
+ request_dict = request.dict()
1093
+ save_batch_request_log(request_dict, response, image.shape[1], image.shape[0])
1094
+ except Exception as e:
1095
+ logger.warning(f"Не удалось сохранить лог запроса: {e}")
1096
+
1097
+ return convert_to_native_types(response)
1098
+
1099
+ except Exception as e:
1100
+ logger.error(f"Ошибка при батчинг сегментации: {e}")
1101
+ raise HTTPException(status_code=500, detail=f"Ошибка обработки: {str(e)}")
1102
+
1103
+ if __name__ == "__main__":
1104
+ import uvicorn
1105
+ import os
1106
+
1107
+ # Порт из переменной окружения (для HF Spaces) или 8000 по умолчанию
1108
+ port = int(os.getenv("PORT", 8000))
1109
+ uvicorn.run(app, host="0.0.0.0", port=port)
download_model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Скрипт для скачивания модели SAM2.
4
+ Блин, Facebook не может нормально в pip packaging, поэтому качаем руками.
5
+ """
6
+
7
+ import os
8
+ import urllib.request
9
+ import sys
10
+
11
+ # Директория для чекпоинтов
12
+ CHECKPOINT_DIR = "checkpoints"
13
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
14
+
15
+ # Модели на выбор
16
+ MODELS = {
17
+ "tiny": {
18
+ "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
19
+ "filename": "sam2.1_hiera_tiny.pt",
20
+ "size": "~39MB"
21
+ },
22
+ "small": {
23
+ "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
24
+ "filename": "sam2.1_hiera_small.pt",
25
+ "size": "~46MB"
26
+ },
27
+ "base_plus": {
28
+ "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
29
+ "filename": "sam2.1_hiera_base_plus.pt",
30
+ "size": "~81MB"
31
+ },
32
+ "large": {
33
+ "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
34
+ "filename": "sam2.1_hiera_large.pt",
35
+ "size": "~224MB"
36
+ }
37
+ }
38
+
39
+ def download_model(model_name="tiny"):
40
+ """Качает модель, показывает прогресс"""
41
+ if model_name not in MODELS:
42
+ print(f"Неизвестная модель: {model_name}")
43
+ print(f"Доступные: {', '.join(MODELS.keys())}")
44
+ sys.exit(1)
45
+
46
+ model_info = MODELS[model_name]
47
+ filepath = os.path.join(CHECKPOINT_DIR, model_info["filename"])
48
+
49
+ if os.path.exists(filepath):
50
+ print(f"Модель уже скачана: {filepath}")
51
+ return filepath
52
+
53
+ print(f"Качаю {model_name} модель ({model_info['size']})...")
54
+ print(f"URL: {model_info['url']}")
55
+
56
+ def progress_hook(block_num, block_size, total_size):
57
+ downloaded = block_num * block_size
58
+ if total_size > 0:
59
+ percent = min(100, downloaded * 100 / total_size)
60
+ sys.stdout.write(f"\rПрогресс: {percent:.1f}%")
61
+ sys.stdout.flush()
62
+
63
+ try:
64
+ urllib.request.urlretrieve(
65
+ model_info["url"],
66
+ filepath,
67
+ reporthook=progress_hook
68
+ )
69
+ print(f"\n✓ Модель скачана: {filepath}")
70
+ return filepath
71
+ except Exception as e:
72
+ print(f"\n✗ Ошибка при скачивании: {e}")
73
+ if os.path.exists(filepath):
74
+ os.remove(filepath)
75
+ sys.exit(1)
76
+
77
+ if __name__ == "__main__":
78
+ model_name = sys.argv[1] if len(sys.argv) > 1 else "tiny"
79
+ download_model(model_name)
80
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.0
2
+ uvicorn[standard]==0.32.0
3
+ python-multipart==0.0.12
4
+ Pillow==11.0.0
5
+ numpy==2.1.0
6
+ torch==2.6.0
7
+ torchvision==0.21.0
8
+ opencv-python==4.10.0.84
9
+ pydantic==2.9.0
10
+
sam2_repo/README.md ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAM 2: Segment Anything in Images and Videos
2
+
3
+ **[AI at Meta, FAIR](https://ai.meta.com/research/)**
4
+
5
+ [Nikhila Ravi](https://nikhilaravi.com/), [Valentin Gabeur](https://gabeur.github.io/), [Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en), [Ronghang Hu](https://ronghanghu.com/), [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en), [Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en), [Haitham Khedr](https://hkhedr.com/), [Roman Rädle](https://scholar.google.de/citations?user=Tpt57v0AAAAJ&hl=en), [Chloe Rolland](https://scholar.google.com/citations?hl=fr&user=n-SnMhoAAAAJ), [Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en), [Eric Mintun](https://ericmintun.github.io/), [Junting Pan](https://junting.github.io/), [Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en), [Nicolas Carion](https://www.nicolascarion.com/), [Chao-Yuan Wu](https://chaoyuan.org/), [Ross Girshick](https://www.rossgirshick.info/), [Piotr Dollár](https://pdollar.github.io/), [Christoph Feichtenhofer](https://feichtenhofer.github.io/)
6
+
7
+ [[`Paper`](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/)] [[`Project`](https://ai.meta.com/sam2)] [[`Demo`](https://sam2.metademolab.com/)] [[`Dataset`](https://ai.meta.com/datasets/segment-anything-video)] [[`Blog`](https://ai.meta.com/blog/segment-anything-2)] [[`BibTeX`](#citing-sam-2)]
8
+
9
+ ![SAM 2 architecture](assets/model_diagram.png?raw=true)
10
+
11
+ **Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model-in-the-loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains.
12
+
13
+ ![SA-V dataset](assets/sa_v_dataset.jpg?raw=true)
14
+
15
+ ## Latest updates
16
+
17
+ **12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking**
18
+
19
+ - We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor`, leading to a major speedup for VOS inference.
20
+ - We update the implementation of `SAM2VideoPredictor` to support independent per-object inference, allowing us to relax the assumption of prompting for multi-object tracking and adding new objects after tracking starts.
21
+ - See [`RELEASE_NOTES.md`](RELEASE_NOTES.md) for full details.
22
+
23
+ **09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released**
24
+
25
+ - A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details.
26
+ * To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below.
27
+ - The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started.
28
+ - The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details.
29
+
30
+ ## Installation
31
+
32
+ SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.5.1` and `torchvision>=0.20.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using:
33
+
34
+ ```bash
35
+ git clone https://github.com/facebookresearch/sam2.git && cd sam2
36
+
37
+ pip install -e .
38
+ ```
39
+ If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
40
+
41
+ To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by:
42
+
43
+ ```bash
44
+ pip install -e ".[notebooks]"
45
+ ```
46
+
47
+ Note:
48
+ 1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.5.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.5.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`.
49
+ 2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version.
50
+ 3. If you see a message like `Failed to build the SAM 2 CUDA extension` during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases).
51
+
52
+ Please see [`INSTALL.md`](./INSTALL.md) for FAQs on potential issues and solutions.
53
+
54
+ ## Getting Started
55
+
56
+ ### Download Checkpoints
57
+
58
+ First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running:
59
+
60
+ ```bash
61
+ cd checkpoints && \
62
+ ./download_ckpts.sh && \
63
+ cd ..
64
+ ```
65
+
66
+ or individually from:
67
+
68
+ - [sam2.1_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)
69
+ - [sam2.1_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)
70
+ - [sam2.1_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)
71
+ - [sam2.1_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)
72
+
73
+ (note that these are the improved checkpoints denoted as SAM 2.1; see [Model Description](#model-description) for details.)
74
+
75
+ Then SAM 2 can be used in a few lines as follows for image and video prediction.
76
+
77
+ ### Image prediction
78
+
79
+ SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segment-anything) on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The `SAM2ImagePredictor` class has an easy interface for image prompting.
80
+
81
+ ```python
82
+ import torch
83
+ from sam2.build_sam import build_sam2
84
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
85
+
86
+ checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
87
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
88
+ predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
89
+
90
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
91
+ predictor.set_image(<your_image>)
92
+ masks, _, _ = predictor.predict(<input_prompts>)
93
+ ```
94
+
95
+ Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/image_predictor_example.ipynb)) for static image use cases.
96
+
97
+ SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/automatic_mask_generator_example.ipynb)) for automatic mask generation in images.
98
+
99
+ ### Video prediction
100
+
101
+ For promptable segmentation and tracking in videos, we provide a video predictor with APIs for example to add prompts and propagate masklets throughout a video. SAM 2 supports video inference on multiple objects and uses an inference state to keep track of the interactions in each video.
102
+
103
+ ```python
104
+ import torch
105
+ from sam2.build_sam import build_sam2_video_predictor
106
+
107
+ checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
108
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
109
+ predictor = build_sam2_video_predictor(model_cfg, checkpoint)
110
+
111
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
112
+ state = predictor.init_state(<your_video>)
113
+
114
+ # add new prompts and instantly get the output on the same frame
115
+ frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
116
+
117
+ # propagate the prompts to get masklets throughout the video
118
+ for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
119
+ ...
120
+ ```
121
+
122
+ Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/video_predictor_example.ipynb)) for details on how to add click or box prompts, make refinements, and track multiple objects in videos.
123
+
124
+ ## Load from 🤗 Hugging Face
125
+
126
+ Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`).
127
+
128
+ For image prediction:
129
+
130
+ ```python
131
+ import torch
132
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
133
+
134
+ predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
135
+
136
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
137
+ predictor.set_image(<your_image>)
138
+ masks, _, _ = predictor.predict(<input_prompts>)
139
+ ```
140
+
141
+ For video prediction:
142
+
143
+ ```python
144
+ import torch
145
+ from sam2.sam2_video_predictor import SAM2VideoPredictor
146
+
147
+ predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
148
+
149
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
150
+ state = predictor.init_state(<your_video>)
151
+
152
+ # add new prompts and instantly get the output on the same frame
153
+ frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
154
+
155
+ # propagate the prompts to get masklets throughout the video
156
+ for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
157
+ ...
158
+ ```
159
+
160
+ ## Model Description
161
+
162
+ ### SAM 2.1 checkpoints
163
+
164
+ The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024.
165
+ | **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
166
+ | :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
167
+ | sam2.1_hiera_tiny <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 91.2 | 76.5 | 71.8 | 77.3 |
168
+ | sam2.1_hiera_small <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 84.8 | 76.6 | 73.5 | 78.3 |
169
+ | sam2.1_hiera_base_plus <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 64.1 | 78.2 | 73.7 | 78.2 |
170
+ | sam2.1_hiera_large <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 39.5 | 79.5 | 74.6 | 80.6 |
171
+
172
+ ### SAM 2 checkpoints
173
+
174
+ The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows:
175
+
176
+ | **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
177
+ | :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
178
+ | sam2_hiera_tiny <br /> ([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 91.5 | 75.0 | 70.9 | 75.3 |
179
+ | sam2_hiera_small <br /> ([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 85.6 | 74.9 | 71.5 | 76.4 |
180
+ | sam2_hiera_base_plus <br /> ([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 64.8 | 74.7 | 72.8 | 75.8 |
181
+ | sam2_hiera_large <br /> ([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 39.7 | 76.0 | 74.6 | 79.8 |
182
+
183
+ Speed measured on an A100 with `torch 2.5.1, cuda 12.4`. See `benchmark.py` for an example on benchmarking (compiling all the model components). Compiling only the image encoder can be more flexible and also provide (a smaller) speed-up (set `compile_image_encoder: True` in the config).
184
+ ## Segment Anything Video Dataset
185
+
186
+ See [sav_dataset/README.md](sav_dataset/README.md) for details.
187
+
188
+ ## Training SAM 2
189
+
190
+ You can train or fine-tune SAM 2 on custom datasets of images, videos, or both. Please check the training [README](training/README.md) on how to get started.
191
+
192
+ ## Web demo for SAM 2
193
+
194
+ We have released the frontend + backend code for the SAM 2 web demo (a locally deployable version similar to https://sam2.metademolab.com/demo). Please see the web demo [README](demo/README.md) for details.
195
+
196
+ ## License
197
+
198
+ The SAM 2 model checkpoints, SAM 2 demo code (front-end and back-end), and SAM 2 training code are licensed under [Apache 2.0](./LICENSE), however the [Inter Font](https://github.com/rsms/inter?tab=OFL-1.1-1-ov-file) and [Noto Color Emoji](https://github.com/googlefonts/noto-emoji) used in the SAM 2 demo code are made available under the [SIL Open Font License, version 1.1](https://openfontlicense.org/open-font-license-official-text/).
199
+
200
+ ## Contributing
201
+
202
+ See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
203
+
204
+ ## Contributors
205
+
206
+ The SAM 2 project was made possible with the help of many contributors (alphabetical):
207
+
208
+ Karen Bergan, Daniel Bolya, Alex Bosenberg, Kai Brown, Vispi Cassod, Christopher Chedeau, Ida Cheng, Luc Dahlin, Shoubhik Debnath, Rene Martinez Doehner, Grant Gardner, Sahir Gomez, Rishi Godugu, Baishan Guo, Caleb Ho, Andrew Huang, Somya Jain, Bob Kamma, Amanda Kallet, Jake Kinney, Alexander Kirillov, Shiva Koduvayur, Devansh Kukreja, Robert Kuo, Aohan Lin, Parth Malani, Jitendra Malik, Mallika Malhotra, Miguel Martin, Alexander Miller, Sasha Mitts, William Ngan, George Orlin, Joelle Pineau, Kate Saenko, Rodrick Shepard, Azita Shokrpour, David Soofian, Jonathan Torres, Jenny Truong, Sagar Vaze, Meng Wang, Claudette Ward, Pengchuan Zhang.
209
+
210
+ Third-party code: we use a GPU-based connected component algorithm adapted from [`cc_torch`](https://github.com/zsef123/Connected_components_PyTorch) (with its license in [`LICENSE_cctorch`](./LICENSE_cctorch)) as an optional post-processing step for the mask predictions.
211
+
212
+ ## Citing SAM 2
213
+
214
+ If you use SAM 2 or the SA-V dataset in your research, please use the following BibTeX entry.
215
+
216
+ ```bibtex
217
+ @article{ravi2024sam2,
218
+ title={SAM 2: Segment Anything in Images and Videos},
219
+ author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
220
+ journal={arXiv preprint arXiv:2408.00714},
221
+ url={https://arxiv.org/abs/2408.00714},
222
+ year={2024}
223
+ }
224
+ ```
sam2_repo/checkpoints/download_ckpts.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+
6
+ # This source code is licensed under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ # Use either wget or curl to download the checkpoints
10
+ if command -v wget &> /dev/null; then
11
+ CMD="wget"
12
+ elif command -v curl &> /dev/null; then
13
+ CMD="curl -L -O"
14
+ else
15
+ echo "Please install wget or curl to download the checkpoints."
16
+ exit 1
17
+ fi
18
+
19
+ # Define the URLs for SAM 2 checkpoints
20
+ # SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824"
21
+ # sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt"
22
+ # sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt"
23
+ # sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt"
24
+ # sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt"
25
+
26
+ # Download each of the four checkpoints using wget
27
+ # echo "Downloading sam2_hiera_tiny.pt checkpoint..."
28
+ # $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }
29
+
30
+ # echo "Downloading sam2_hiera_small.pt checkpoint..."
31
+ # $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }
32
+
33
+ # echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
34
+ # $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }
35
+
36
+ # echo "Downloading sam2_hiera_large.pt checkpoint..."
37
+ # $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }
38
+
39
+ # Define the URLs for SAM 2.1 checkpoints
40
+ SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824"
41
+ sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt"
42
+ sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt"
43
+ sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt"
44
+ sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt"
45
+
46
+ # SAM 2.1 checkpoints
47
+ echo "Downloading sam2.1_hiera_tiny.pt checkpoint..."
48
+ $CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; }
49
+
50
+ echo "Downloading sam2.1_hiera_small.pt checkpoint..."
51
+ $CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; }
52
+
53
+ echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..."
54
+ $CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; }
55
+
56
+ echo "Downloading sam2.1_hiera_large.pt checkpoint..."
57
+ $CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; }
58
+
59
+ echo "All checkpoints are downloaded successfully."
sam2_repo/pyproject.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = [
3
+ "setuptools>=61.0",
4
+ "torch>=2.5.1",
5
+ ]
6
+ build-backend = "setuptools.build_meta"
sam2_repo/sam2/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from hydra import initialize_config_module
8
+ from hydra.core.global_hydra import GlobalHydra
9
+
10
+ if not GlobalHydra.instance().is_initialized():
11
+ initialize_config_module("sam2", version_base="1.2")
sam2_repo/sam2/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (455 Bytes). View file
 
sam2_repo/sam2/__pycache__/build_sam.cpython-313.pyc ADDED
Binary file (5.39 kB). View file
 
sam2_repo/sam2/__pycache__/sam2_image_predictor.cpython-313.pyc ADDED
Binary file (21.9 kB). View file
 
sam2_repo/sam2/automatic_mask_generator.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
13
+
14
+ from sam2.modeling.sam2_base import SAM2Base
15
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
16
+ from sam2.utils.amg import (
17
+ area_from_rle,
18
+ batch_iterator,
19
+ batched_mask_to_box,
20
+ box_xyxy_to_xywh,
21
+ build_all_layer_point_grids,
22
+ calculate_stability_score,
23
+ coco_encode_rle,
24
+ generate_crop_boxes,
25
+ is_box_near_crop_edge,
26
+ mask_to_rle_pytorch,
27
+ MaskData,
28
+ remove_small_regions,
29
+ rle_to_mask,
30
+ uncrop_boxes_xyxy,
31
+ uncrop_masks,
32
+ uncrop_points,
33
+ )
34
+
35
+
36
+ class SAM2AutomaticMaskGenerator:
37
+ def __init__(
38
+ self,
39
+ model: SAM2Base,
40
+ points_per_side: Optional[int] = 32,
41
+ points_per_batch: int = 64,
42
+ pred_iou_thresh: float = 0.8,
43
+ stability_score_thresh: float = 0.95,
44
+ stability_score_offset: float = 1.0,
45
+ mask_threshold: float = 0.0,
46
+ box_nms_thresh: float = 0.7,
47
+ crop_n_layers: int = 0,
48
+ crop_nms_thresh: float = 0.7,
49
+ crop_overlap_ratio: float = 512 / 1500,
50
+ crop_n_points_downscale_factor: int = 1,
51
+ point_grids: Optional[List[np.ndarray]] = None,
52
+ min_mask_region_area: int = 0,
53
+ output_mode: str = "binary_mask",
54
+ use_m2m: bool = False,
55
+ multimask_output: bool = True,
56
+ **kwargs,
57
+ ) -> None:
58
+ """
59
+ Using a SAM 2 model, generates masks for the entire image.
60
+ Generates a grid of point prompts over the image, then filters
61
+ low quality and duplicate masks. The default settings are chosen
62
+ for SAM 2 with a HieraL backbone.
63
+
64
+ Arguments:
65
+ model (Sam): The SAM 2 model to use for mask prediction.
66
+ points_per_side (int or None): The number of points to be sampled
67
+ along one side of the image. The total number of points is
68
+ points_per_side**2. If None, 'point_grids' must provide explicit
69
+ point sampling.
70
+ points_per_batch (int): Sets the number of points run simultaneously
71
+ by the model. Higher numbers may be faster but use more GPU memory.
72
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
73
+ model's predicted mask quality.
74
+ stability_score_thresh (float): A filtering threshold in [0,1], using
75
+ the stability of the mask under changes to the cutoff used to binarize
76
+ the model's mask predictions.
77
+ stability_score_offset (float): The amount to shift the cutoff when
78
+ calculated the stability score.
79
+ mask_threshold (float): Threshold for binarizing the mask logits
80
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
81
+ suppression to filter duplicate masks.
82
+ crop_n_layers (int): If >0, mask prediction will be run again on
83
+ crops of the image. Sets the number of layers to run, where each
84
+ layer has 2**i_layer number of image crops.
85
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
86
+ suppression to filter duplicate masks between different crops.
87
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
88
+ In the first crop layer, crops will overlap by this fraction of
89
+ the image length. Later layers with more crops scale down this overlap.
90
+ crop_n_points_downscale_factor (int): The number of points-per-side
91
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
92
+ point_grids (list(np.ndarray) or None): A list over explicit grids
93
+ of points used for sampling, normalized to [0,1]. The nth grid in the
94
+ list is used in the nth crop layer. Exclusive with points_per_side.
95
+ min_mask_region_area (int): If >0, postprocessing will be applied
96
+ to remove disconnected regions and holes in masks with area smaller
97
+ than min_mask_region_area. Requires opencv.
98
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
99
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
100
+ For large resolutions, 'binary_mask' may consume large amounts of
101
+ memory.
102
+ use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
103
+ multimask_output (bool): Whether to output multimask at each point of the grid.
104
+ """
105
+
106
+ assert (points_per_side is None) != (
107
+ point_grids is None
108
+ ), "Exactly one of points_per_side or point_grid must be provided."
109
+ if points_per_side is not None:
110
+ self.point_grids = build_all_layer_point_grids(
111
+ points_per_side,
112
+ crop_n_layers,
113
+ crop_n_points_downscale_factor,
114
+ )
115
+ elif point_grids is not None:
116
+ self.point_grids = point_grids
117
+ else:
118
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
119
+
120
+ assert output_mode in [
121
+ "binary_mask",
122
+ "uncompressed_rle",
123
+ "coco_rle",
124
+ ], f"Unknown output_mode {output_mode}."
125
+ if output_mode == "coco_rle":
126
+ try:
127
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
128
+ except ImportError as e:
129
+ print("Please install pycocotools")
130
+ raise e
131
+
132
+ self.predictor = SAM2ImagePredictor(
133
+ model,
134
+ max_hole_area=min_mask_region_area,
135
+ max_sprinkle_area=min_mask_region_area,
136
+ )
137
+ self.points_per_batch = points_per_batch
138
+ self.pred_iou_thresh = pred_iou_thresh
139
+ self.stability_score_thresh = stability_score_thresh
140
+ self.stability_score_offset = stability_score_offset
141
+ self.mask_threshold = mask_threshold
142
+ self.box_nms_thresh = box_nms_thresh
143
+ self.crop_n_layers = crop_n_layers
144
+ self.crop_nms_thresh = crop_nms_thresh
145
+ self.crop_overlap_ratio = crop_overlap_ratio
146
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
147
+ self.min_mask_region_area = min_mask_region_area
148
+ self.output_mode = output_mode
149
+ self.use_m2m = use_m2m
150
+ self.multimask_output = multimask_output
151
+
152
+ @classmethod
153
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
154
+ """
155
+ Load a pretrained model from the Hugging Face hub.
156
+
157
+ Arguments:
158
+ model_id (str): The Hugging Face repository ID.
159
+ **kwargs: Additional arguments to pass to the model constructor.
160
+
161
+ Returns:
162
+ (SAM2AutomaticMaskGenerator): The loaded model.
163
+ """
164
+ from sam2.build_sam import build_sam2_hf
165
+
166
+ sam_model = build_sam2_hf(model_id, **kwargs)
167
+ return cls(sam_model, **kwargs)
168
+
169
+ @torch.no_grad()
170
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
171
+ """
172
+ Generates masks for the given image.
173
+
174
+ Arguments:
175
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
176
+
177
+ Returns:
178
+ list(dict(str, any)): A list over records for masks. Each record is
179
+ a dict containing the following keys:
180
+ segmentation (dict(str, any) or np.ndarray): The mask. If
181
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
182
+ is a dictionary containing the RLE.
183
+ bbox (list(float)): The box around the mask, in XYWH format.
184
+ area (int): The area in pixels of the mask.
185
+ predicted_iou (float): The model's own prediction of the mask's
186
+ quality. This is filtered by the pred_iou_thresh parameter.
187
+ point_coords (list(list(float))): The point coordinates input
188
+ to the model to generate this mask.
189
+ stability_score (float): A measure of the mask's quality. This
190
+ is filtered on using the stability_score_thresh parameter.
191
+ crop_box (list(float)): The crop of the image used to generate
192
+ the mask, given in XYWH format.
193
+ """
194
+
195
+ # Generate masks
196
+ mask_data = self._generate_masks(image)
197
+
198
+ # Encode masks
199
+ if self.output_mode == "coco_rle":
200
+ mask_data["segmentations"] = [
201
+ coco_encode_rle(rle) for rle in mask_data["rles"]
202
+ ]
203
+ elif self.output_mode == "binary_mask":
204
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
205
+ else:
206
+ mask_data["segmentations"] = mask_data["rles"]
207
+
208
+ # Write mask records
209
+ curr_anns = []
210
+ for idx in range(len(mask_data["segmentations"])):
211
+ ann = {
212
+ "segmentation": mask_data["segmentations"][idx],
213
+ "area": area_from_rle(mask_data["rles"][idx]),
214
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
215
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
216
+ "point_coords": [mask_data["points"][idx].tolist()],
217
+ "stability_score": mask_data["stability_score"][idx].item(),
218
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
219
+ }
220
+ curr_anns.append(ann)
221
+
222
+ return curr_anns
223
+
224
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
225
+ orig_size = image.shape[:2]
226
+ crop_boxes, layer_idxs = generate_crop_boxes(
227
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
228
+ )
229
+
230
+ # Iterate over image crops
231
+ data = MaskData()
232
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
233
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
234
+ data.cat(crop_data)
235
+
236
+ # Remove duplicate masks between crops
237
+ if len(crop_boxes) > 1:
238
+ # Prefer masks from smaller crops
239
+ scores = 1 / box_area(data["crop_boxes"])
240
+ scores = scores.to(data["boxes"].device)
241
+ keep_by_nms = batched_nms(
242
+ data["boxes"].float(),
243
+ scores,
244
+ torch.zeros_like(data["boxes"][:, 0]), # categories
245
+ iou_threshold=self.crop_nms_thresh,
246
+ )
247
+ data.filter(keep_by_nms)
248
+ data.to_numpy()
249
+ return data
250
+
251
+ def _process_crop(
252
+ self,
253
+ image: np.ndarray,
254
+ crop_box: List[int],
255
+ crop_layer_idx: int,
256
+ orig_size: Tuple[int, ...],
257
+ ) -> MaskData:
258
+ # Crop the image and calculate embeddings
259
+ x0, y0, x1, y1 = crop_box
260
+ cropped_im = image[y0:y1, x0:x1, :]
261
+ cropped_im_size = cropped_im.shape[:2]
262
+ self.predictor.set_image(cropped_im)
263
+
264
+ # Get points for this crop
265
+ points_scale = np.array(cropped_im_size)[None, ::-1]
266
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
267
+
268
+ # Generate masks for this crop in batches
269
+ data = MaskData()
270
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
271
+ batch_data = self._process_batch(
272
+ points, cropped_im_size, crop_box, orig_size, normalize=True
273
+ )
274
+ data.cat(batch_data)
275
+ del batch_data
276
+ self.predictor.reset_predictor()
277
+
278
+ # Remove duplicates within this crop.
279
+ keep_by_nms = batched_nms(
280
+ data["boxes"].float(),
281
+ data["iou_preds"],
282
+ torch.zeros_like(data["boxes"][:, 0]), # categories
283
+ iou_threshold=self.box_nms_thresh,
284
+ )
285
+ data.filter(keep_by_nms)
286
+
287
+ # Return to the original image frame
288
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
289
+ data["points"] = uncrop_points(data["points"], crop_box)
290
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
291
+
292
+ return data
293
+
294
+ def _process_batch(
295
+ self,
296
+ points: np.ndarray,
297
+ im_size: Tuple[int, ...],
298
+ crop_box: List[int],
299
+ orig_size: Tuple[int, ...],
300
+ normalize=False,
301
+ ) -> MaskData:
302
+ orig_h, orig_w = orig_size
303
+
304
+ # Run model on this batch
305
+ points = torch.as_tensor(
306
+ points, dtype=torch.float32, device=self.predictor.device
307
+ )
308
+ in_points = self.predictor._transforms.transform_coords(
309
+ points, normalize=normalize, orig_hw=im_size
310
+ )
311
+ in_labels = torch.ones(
312
+ in_points.shape[0], dtype=torch.int, device=in_points.device
313
+ )
314
+ masks, iou_preds, low_res_masks = self.predictor._predict(
315
+ in_points[:, None, :],
316
+ in_labels[:, None],
317
+ multimask_output=self.multimask_output,
318
+ return_logits=True,
319
+ )
320
+
321
+ # Serialize predictions and store in MaskData
322
+ data = MaskData(
323
+ masks=masks.flatten(0, 1),
324
+ iou_preds=iou_preds.flatten(0, 1),
325
+ points=points.repeat_interleave(masks.shape[1], dim=0),
326
+ low_res_masks=low_res_masks.flatten(0, 1),
327
+ )
328
+ del masks
329
+
330
+ if not self.use_m2m:
331
+ # Filter by predicted IoU
332
+ if self.pred_iou_thresh > 0.0:
333
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
334
+ data.filter(keep_mask)
335
+
336
+ # Calculate and filter by stability score
337
+ data["stability_score"] = calculate_stability_score(
338
+ data["masks"], self.mask_threshold, self.stability_score_offset
339
+ )
340
+ if self.stability_score_thresh > 0.0:
341
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
342
+ data.filter(keep_mask)
343
+ else:
344
+ # One step refinement using previous mask predictions
345
+ in_points = self.predictor._transforms.transform_coords(
346
+ data["points"], normalize=normalize, orig_hw=im_size
347
+ )
348
+ labels = torch.ones(
349
+ in_points.shape[0], dtype=torch.int, device=in_points.device
350
+ )
351
+ masks, ious = self.refine_with_m2m(
352
+ in_points, labels, data["low_res_masks"], self.points_per_batch
353
+ )
354
+ data["masks"] = masks.squeeze(1)
355
+ data["iou_preds"] = ious.squeeze(1)
356
+
357
+ if self.pred_iou_thresh > 0.0:
358
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
359
+ data.filter(keep_mask)
360
+
361
+ data["stability_score"] = calculate_stability_score(
362
+ data["masks"], self.mask_threshold, self.stability_score_offset
363
+ )
364
+ if self.stability_score_thresh > 0.0:
365
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
366
+ data.filter(keep_mask)
367
+
368
+ # Threshold masks and calculate boxes
369
+ data["masks"] = data["masks"] > self.mask_threshold
370
+ data["boxes"] = batched_mask_to_box(data["masks"])
371
+
372
+ # Filter boxes that touch crop boundaries
373
+ keep_mask = ~is_box_near_crop_edge(
374
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
375
+ )
376
+ if not torch.all(keep_mask):
377
+ data.filter(keep_mask)
378
+
379
+ # Compress to RLE
380
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
381
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
382
+ del data["masks"]
383
+
384
+ return data
385
+
386
+ @staticmethod
387
+ def postprocess_small_regions(
388
+ mask_data: MaskData, min_area: int, nms_thresh: float
389
+ ) -> MaskData:
390
+ """
391
+ Removes small disconnected regions and holes in masks, then reruns
392
+ box NMS to remove any new duplicates.
393
+
394
+ Edits mask_data in place.
395
+
396
+ Requires open-cv as a dependency.
397
+ """
398
+ if len(mask_data["rles"]) == 0:
399
+ return mask_data
400
+
401
+ # Filter small disconnected regions and holes
402
+ new_masks = []
403
+ scores = []
404
+ for rle in mask_data["rles"]:
405
+ mask = rle_to_mask(rle)
406
+
407
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
408
+ unchanged = not changed
409
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
410
+ unchanged = unchanged and not changed
411
+
412
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
413
+ # Give score=0 to changed masks and score=1 to unchanged masks
414
+ # so NMS will prefer ones that didn't need postprocessing
415
+ scores.append(float(unchanged))
416
+
417
+ # Recalculate boxes and remove any new duplicates
418
+ masks = torch.cat(new_masks, dim=0)
419
+ boxes = batched_mask_to_box(masks)
420
+ keep_by_nms = batched_nms(
421
+ boxes.float(),
422
+ torch.as_tensor(scores),
423
+ torch.zeros_like(boxes[:, 0]), # categories
424
+ iou_threshold=nms_thresh,
425
+ )
426
+
427
+ # Only recalculate RLEs for masks that have changed
428
+ for i_mask in keep_by_nms:
429
+ if scores[i_mask] == 0.0:
430
+ mask_torch = masks[i_mask].unsqueeze(0)
431
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
432
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
433
+ mask_data.filter(keep_by_nms)
434
+
435
+ return mask_data
436
+
437
+ def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
438
+ new_masks = []
439
+ new_iou_preds = []
440
+
441
+ for cur_points, cur_point_labels, low_res_mask in batch_iterator(
442
+ points_per_batch, points, point_labels, low_res_masks
443
+ ):
444
+ best_masks, best_iou_preds, _ = self.predictor._predict(
445
+ cur_points[:, None, :],
446
+ cur_point_labels[:, None],
447
+ mask_input=low_res_mask[:, None, :],
448
+ multimask_output=False,
449
+ return_logits=True,
450
+ )
451
+ new_masks.append(best_masks)
452
+ new_iou_preds.append(best_iou_preds)
453
+ masks = torch.cat(new_masks, dim=0)
454
+ return masks, torch.cat(new_iou_preds, dim=0)
sam2_repo/sam2/benchmark.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import time
9
+
10
+ import numpy as np
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+ from sam2.build_sam import build_sam2_video_predictor
15
+
16
+ # Only cuda supported
17
+ assert torch.cuda.is_available()
18
+ device = torch.device("cuda")
19
+
20
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
21
+ if torch.cuda.get_device_properties(0).major >= 8:
22
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ torch.backends.cudnn.allow_tf32 = True
25
+
26
+ # Config and checkpoint
27
+ sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
28
+ model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
29
+
30
+ # Build video predictor with vos_optimized=True setting
31
+ predictor = build_sam2_video_predictor(
32
+ model_cfg, sam2_checkpoint, device=device, vos_optimized=True
33
+ )
34
+
35
+
36
+ # Initialize with video
37
+ video_dir = "notebooks/videos/bedroom"
38
+ # scan all the JPEG frame names in this directory
39
+ frame_names = [
40
+ p
41
+ for p in os.listdir(video_dir)
42
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
43
+ ]
44
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
45
+ inference_state = predictor.init_state(video_path=video_dir)
46
+
47
+
48
+ # Number of runs, warmup etc
49
+ warm_up, runs = 5, 25
50
+ verbose = True
51
+ num_frames = len(frame_names)
52
+ total, count = 0, 0
53
+ torch.cuda.empty_cache()
54
+
55
+ # We will select an object with a click.
56
+ # See video_predictor_example.ipynb for more detailed explanation
57
+ ann_frame_idx, ann_obj_id = 0, 1
58
+ # Add a positive click at (x, y) = (210, 350)
59
+ # For labels, `1` means positive click
60
+ points = np.array([[210, 350]], dtype=np.float32)
61
+ labels = np.array([1], np.int32)
62
+
63
+ _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
64
+ inference_state=inference_state,
65
+ frame_idx=ann_frame_idx,
66
+ obj_id=ann_obj_id,
67
+ points=points,
68
+ labels=labels,
69
+ )
70
+
71
+ # Warmup and then average FPS over several runs
72
+ with torch.autocast("cuda", torch.bfloat16):
73
+ with torch.inference_mode():
74
+ for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
75
+ start = time.time()
76
+ # Start tracking
77
+ for (
78
+ out_frame_idx,
79
+ out_obj_ids,
80
+ out_mask_logits,
81
+ ) in predictor.propagate_in_video(inference_state):
82
+ pass
83
+
84
+ end = time.time()
85
+ total += end - start
86
+ count += 1
87
+ if i == warm_up - 1:
88
+ print("Warmup FPS: ", count * num_frames / total)
89
+ total = 0
90
+ count = 0
91
+
92
+ print("FPS: ", count * num_frames / total)
sam2_repo/sam2/build_sam.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+
10
+ import torch
11
+ from hydra import compose
12
+ from hydra.utils import instantiate
13
+ from omegaconf import OmegaConf
14
+
15
+ import sam2
16
+
17
+ # Check if the user is running Python from the parent directory of the sam2 repo
18
+ # (i.e. the directory where this repo is cloned into) -- this is not supported since
19
+ # it could shadow the sam2 package and cause issues.
20
+ if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
21
+ # If the user has "sam2/sam2" in their path, they are likey importing the repo itself
22
+ # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
23
+ # This typically happens because the user is running Python from the parent directory
24
+ # that contains the sam2 repo they cloned.
25
+ raise RuntimeError(
26
+ "You're likely running Python from the parent directory of the sam2 repository "
27
+ "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
28
+ "This is not supported since the `sam2` Python package could be shadowed by the "
29
+ "repository name (the repository is also named `sam2` and contains the Python package "
30
+ "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
31
+ "rather than its parent dir, or from your home directory) after installing SAM 2."
32
+ )
33
+
34
+
35
+ HF_MODEL_ID_TO_FILENAMES = {
36
+ "facebook/sam2-hiera-tiny": (
37
+ "configs/sam2/sam2_hiera_t.yaml",
38
+ "sam2_hiera_tiny.pt",
39
+ ),
40
+ "facebook/sam2-hiera-small": (
41
+ "configs/sam2/sam2_hiera_s.yaml",
42
+ "sam2_hiera_small.pt",
43
+ ),
44
+ "facebook/sam2-hiera-base-plus": (
45
+ "configs/sam2/sam2_hiera_b+.yaml",
46
+ "sam2_hiera_base_plus.pt",
47
+ ),
48
+ "facebook/sam2-hiera-large": (
49
+ "configs/sam2/sam2_hiera_l.yaml",
50
+ "sam2_hiera_large.pt",
51
+ ),
52
+ "facebook/sam2.1-hiera-tiny": (
53
+ "configs/sam2.1/sam2.1_hiera_t.yaml",
54
+ "sam2.1_hiera_tiny.pt",
55
+ ),
56
+ "facebook/sam2.1-hiera-small": (
57
+ "configs/sam2.1/sam2.1_hiera_s.yaml",
58
+ "sam2.1_hiera_small.pt",
59
+ ),
60
+ "facebook/sam2.1-hiera-base-plus": (
61
+ "configs/sam2.1/sam2.1_hiera_b+.yaml",
62
+ "sam2.1_hiera_base_plus.pt",
63
+ ),
64
+ "facebook/sam2.1-hiera-large": (
65
+ "configs/sam2.1/sam2.1_hiera_l.yaml",
66
+ "sam2.1_hiera_large.pt",
67
+ ),
68
+ }
69
+
70
+
71
+ def build_sam2(
72
+ config_file,
73
+ ckpt_path=None,
74
+ device="cuda",
75
+ mode="eval",
76
+ hydra_overrides_extra=[],
77
+ apply_postprocessing=True,
78
+ **kwargs,
79
+ ):
80
+
81
+ if apply_postprocessing:
82
+ hydra_overrides_extra = hydra_overrides_extra.copy()
83
+ hydra_overrides_extra += [
84
+ # dynamically fall back to multi-mask if the single mask is not stable
85
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
86
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
87
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
88
+ ]
89
+ # Read config and init model
90
+ cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
91
+ OmegaConf.resolve(cfg)
92
+ model = instantiate(cfg.model, _recursive_=True)
93
+ _load_checkpoint(model, ckpt_path)
94
+ model = model.to(device)
95
+ if mode == "eval":
96
+ model.eval()
97
+ return model
98
+
99
+
100
+ def build_sam2_video_predictor(
101
+ config_file,
102
+ ckpt_path=None,
103
+ device="cuda",
104
+ mode="eval",
105
+ hydra_overrides_extra=[],
106
+ apply_postprocessing=True,
107
+ vos_optimized=False,
108
+ **kwargs,
109
+ ):
110
+ hydra_overrides = [
111
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
112
+ ]
113
+ if vos_optimized:
114
+ hydra_overrides = [
115
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
116
+ "++model.compile_image_encoder=True", # Let sam2_base handle this
117
+ ]
118
+
119
+ if apply_postprocessing:
120
+ hydra_overrides_extra = hydra_overrides_extra.copy()
121
+ hydra_overrides_extra += [
122
+ # dynamically fall back to multi-mask if the single mask is not stable
123
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
124
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
125
+ "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
126
+ # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
127
+ "++model.binarize_mask_from_pts_for_mem_enc=true",
128
+ # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
129
+ "++model.fill_hole_area=8",
130
+ ]
131
+ hydra_overrides.extend(hydra_overrides_extra)
132
+
133
+ # Read config and init model
134
+ cfg = compose(config_name=config_file, overrides=hydra_overrides)
135
+ OmegaConf.resolve(cfg)
136
+ model = instantiate(cfg.model, _recursive_=True)
137
+ _load_checkpoint(model, ckpt_path)
138
+ model = model.to(device)
139
+ if mode == "eval":
140
+ model.eval()
141
+ return model
142
+
143
+
144
+ def _hf_download(model_id):
145
+ from huggingface_hub import hf_hub_download
146
+
147
+ config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
148
+ ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
149
+ return config_name, ckpt_path
150
+
151
+
152
+ def build_sam2_hf(model_id, **kwargs):
153
+ config_name, ckpt_path = _hf_download(model_id)
154
+ return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
155
+
156
+
157
+ def build_sam2_video_predictor_hf(model_id, **kwargs):
158
+ config_name, ckpt_path = _hf_download(model_id)
159
+ return build_sam2_video_predictor(
160
+ config_file=config_name, ckpt_path=ckpt_path, **kwargs
161
+ )
162
+
163
+
164
+ def _load_checkpoint(model, ckpt_path):
165
+ if ckpt_path is not None:
166
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
167
+ missing_keys, unexpected_keys = model.load_state_dict(sd)
168
+ if missing_keys:
169
+ logging.error(missing_keys)
170
+ raise RuntimeError()
171
+ if unexpected_keys:
172
+ logging.error(unexpected_keys)
173
+ raise RuntimeError()
174
+ logging.info("Loaded checkpoint sucessfully")
sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [64, 64]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [64, 64]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ no_obj_embed_spatial: true
93
+ # use high-resolution feature map in the SAM mask decoder
94
+ use_high_res_features_in_sam: true
95
+ # output 3 masks on the first click on initial conditioning frames
96
+ multimask_output_in_sam: true
97
+ # SAM heads
98
+ iou_prediction_use_sigmoid: True
99
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
100
+ use_obj_ptrs_in_encoder: true
101
+ add_tpos_enc_to_obj_ptrs: true
102
+ proj_tpos_enc_in_obj_ptrs: true
103
+ use_signed_tpos_enc_to_obj_ptrs: true
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_l.yaml ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [64, 64]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [64, 64]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ no_obj_embed_spatial: true
97
+ # use high-resolution feature map in the SAM mask decoder
98
+ use_high_res_features_in_sam: true
99
+ # output 3 masks on the first click on initial conditioning frames
100
+ multimask_output_in_sam: true
101
+ # SAM heads
102
+ iou_prediction_use_sigmoid: True
103
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104
+ use_obj_ptrs_in_encoder: true
105
+ add_tpos_enc_to_obj_ptrs: true
106
+ proj_tpos_enc_in_obj_ptrs: true
107
+ use_signed_tpos_enc_to_obj_ptrs: true
108
+ only_obj_ptrs_in_the_past_for_eval: true
109
+ # object occlusion prediction
110
+ pred_obj_scores: true
111
+ pred_obj_scores_mlp: true
112
+ fixed_no_obj_ptr: true
113
+ # multimask tracking settings
114
+ multimask_output_for_tracking: true
115
+ use_multimask_token_for_obj_ptr: true
116
+ multimask_min_pt_num: 0
117
+ multimask_max_pt_num: 1
118
+ use_mlp_for_obj_ptr_proj: true
119
+ # Compilation flag
120
+ compile_image_encoder: False
sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_s.yaml ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [64, 64]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [64, 64]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ no_obj_embed_spatial: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: true
105
+ proj_tpos_enc_in_obj_ptrs: true
106
+ use_signed_tpos_enc_to_obj_ptrs: true
107
+ only_obj_ptrs_in_the_past_for_eval: true
108
+ # object occlusion prediction
109
+ pred_obj_scores: true
110
+ pred_obj_scores_mlp: true
111
+ fixed_no_obj_ptr: true
112
+ # multimask tracking settings
113
+ multimask_output_for_tracking: true
114
+ use_multimask_token_for_obj_ptr: true
115
+ multimask_min_pt_num: 0
116
+ multimask_max_pt_num: 1
117
+ use_mlp_for_obj_ptr_proj: true
118
+ # Compilation flag
119
+ compile_image_encoder: False
sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_t.yaml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [64, 64]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [64, 64]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ no_obj_embed_spatial: true
97
+ # use high-resolution feature map in the SAM mask decoder
98
+ use_high_res_features_in_sam: true
99
+ # output 3 masks on the first click on initial conditioning frames
100
+ multimask_output_in_sam: true
101
+ # SAM heads
102
+ iou_prediction_use_sigmoid: True
103
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104
+ use_obj_ptrs_in_encoder: true
105
+ add_tpos_enc_to_obj_ptrs: true
106
+ proj_tpos_enc_in_obj_ptrs: true
107
+ use_signed_tpos_enc_to_obj_ptrs: true
108
+ only_obj_ptrs_in_the_past_for_eval: true
109
+ # object occlusion prediction
110
+ pred_obj_scores: true
111
+ pred_obj_scores_mlp: true
112
+ fixed_no_obj_ptr: true
113
+ # multimask tracking settings
114
+ multimask_output_for_tracking: true
115
+ use_multimask_token_for_obj_ptr: true
116
+ multimask_min_pt_num: 0
117
+ multimask_max_pt_num: 1
118
+ use_mlp_for_obj_ptr_proj: true
119
+ # Compilation flag
120
+ # HieraT does not currently support compilation, should always be set to False
121
+ compile_image_encoder: False
sam2_repo/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ scratch:
4
+ resolution: 1024
5
+ train_batch_size: 1
6
+ num_train_workers: 10
7
+ num_frames: 8
8
+ max_num_objects: 3
9
+ base_lr: 5.0e-6
10
+ vision_lr: 3.0e-06
11
+ phases_per_epoch: 1
12
+ num_epochs: 40
13
+
14
+ dataset:
15
+ # PATHS to Dataset
16
+ img_folder: null # PATH to MOSE JPEGImages folder
17
+ gt_folder: null # PATH to MOSE Annotations folder
18
+ file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
19
+ multiplier: 2
20
+
21
+ # Video transforms
22
+ vos:
23
+ train_transforms:
24
+ - _target_: training.dataset.transforms.ComposeAPI
25
+ transforms:
26
+ - _target_: training.dataset.transforms.RandomHorizontalFlip
27
+ consistent_transform: True
28
+ - _target_: training.dataset.transforms.RandomAffine
29
+ degrees: 25
30
+ shear: 20
31
+ image_interpolation: bilinear
32
+ consistent_transform: True
33
+ - _target_: training.dataset.transforms.RandomResizeAPI
34
+ sizes: ${scratch.resolution}
35
+ square: true
36
+ consistent_transform: True
37
+ - _target_: training.dataset.transforms.ColorJitter
38
+ consistent_transform: True
39
+ brightness: 0.1
40
+ contrast: 0.03
41
+ saturation: 0.03
42
+ hue: null
43
+ - _target_: training.dataset.transforms.RandomGrayscale
44
+ p: 0.05
45
+ consistent_transform: True
46
+ - _target_: training.dataset.transforms.ColorJitter
47
+ consistent_transform: False
48
+ brightness: 0.1
49
+ contrast: 0.05
50
+ saturation: 0.05
51
+ hue: null
52
+ - _target_: training.dataset.transforms.ToTensorAPI
53
+ - _target_: training.dataset.transforms.NormalizeAPI
54
+ mean: [0.485, 0.456, 0.406]
55
+ std: [0.229, 0.224, 0.225]
56
+
57
+ trainer:
58
+ _target_: training.trainer.Trainer
59
+ mode: train_only
60
+ max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
61
+ accelerator: cuda
62
+ seed_value: 123
63
+
64
+ model:
65
+ _target_: training.model.sam2.SAM2Train
66
+ image_encoder:
67
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
68
+ scalp: 1
69
+ trunk:
70
+ _target_: sam2.modeling.backbones.hieradet.Hiera
71
+ embed_dim: 112
72
+ num_heads: 2
73
+ drop_path_rate: 0.1
74
+ neck:
75
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
76
+ position_encoding:
77
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
78
+ num_pos_feats: 256
79
+ normalize: true
80
+ scale: null
81
+ temperature: 10000
82
+ d_model: 256
83
+ backbone_channel_list: [896, 448, 224, 112]
84
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
85
+ fpn_interp_model: nearest
86
+
87
+ memory_attention:
88
+ _target_: sam2.modeling.memory_attention.MemoryAttention
89
+ d_model: 256
90
+ pos_enc_at_input: true
91
+ layer:
92
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
93
+ activation: relu
94
+ dim_feedforward: 2048
95
+ dropout: 0.1
96
+ pos_enc_at_attn: false
97
+ self_attention:
98
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
99
+ rope_theta: 10000.0
100
+ feat_sizes: [64, 64]
101
+ embedding_dim: 256
102
+ num_heads: 1
103
+ downsample_rate: 1
104
+ dropout: 0.1
105
+ d_model: 256
106
+ pos_enc_at_cross_attn_keys: true
107
+ pos_enc_at_cross_attn_queries: false
108
+ cross_attention:
109
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
110
+ rope_theta: 10000.0
111
+ feat_sizes: [64, 64]
112
+ rope_k_repeat: True
113
+ embedding_dim: 256
114
+ num_heads: 1
115
+ downsample_rate: 1
116
+ dropout: 0.1
117
+ kv_in_dim: 64
118
+ num_layers: 4
119
+
120
+ memory_encoder:
121
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
122
+ out_dim: 64
123
+ position_encoding:
124
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
125
+ num_pos_feats: 64
126
+ normalize: true
127
+ scale: null
128
+ temperature: 10000
129
+ mask_downsampler:
130
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
131
+ kernel_size: 3
132
+ stride: 2
133
+ padding: 1
134
+ fuser:
135
+ _target_: sam2.modeling.memory_encoder.Fuser
136
+ layer:
137
+ _target_: sam2.modeling.memory_encoder.CXBlock
138
+ dim: 256
139
+ kernel_size: 7
140
+ padding: 3
141
+ layer_scale_init_value: 1e-6
142
+ use_dwconv: True # depth-wise convs
143
+ num_layers: 2
144
+
145
+ num_maskmem: 7
146
+ image_size: ${scratch.resolution}
147
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
148
+ sigmoid_scale_for_mem_enc: 20.0
149
+ sigmoid_bias_for_mem_enc: -10.0
150
+ use_mask_input_as_output_without_sam: true
151
+ # Memory
152
+ directly_add_no_mem_embed: true
153
+ no_obj_embed_spatial: true
154
+ # use high-resolution feature map in the SAM mask decoder
155
+ use_high_res_features_in_sam: true
156
+ # output 3 masks on the first click on initial conditioning frames
157
+ multimask_output_in_sam: true
158
+ # SAM heads
159
+ iou_prediction_use_sigmoid: True
160
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
161
+ use_obj_ptrs_in_encoder: true
162
+ add_tpos_enc_to_obj_ptrs: true
163
+ proj_tpos_enc_in_obj_ptrs: true
164
+ use_signed_tpos_enc_to_obj_ptrs: true
165
+ only_obj_ptrs_in_the_past_for_eval: true
166
+ # object occlusion prediction
167
+ pred_obj_scores: true
168
+ pred_obj_scores_mlp: true
169
+ fixed_no_obj_ptr: true
170
+ # multimask tracking settings
171
+ multimask_output_for_tracking: true
172
+ use_multimask_token_for_obj_ptr: true
173
+ multimask_min_pt_num: 0
174
+ multimask_max_pt_num: 1
175
+ use_mlp_for_obj_ptr_proj: true
176
+ # Compilation flag
177
+ # compile_image_encoder: False
178
+
179
+ ####### Training specific params #######
180
+ # box/point input and corrections
181
+ prob_to_use_pt_input_for_train: 0.5
182
+ prob_to_use_pt_input_for_eval: 0.0
183
+ prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
184
+ prob_to_use_box_input_for_eval: 0.0
185
+ prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
186
+ num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
187
+ num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
188
+ rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
189
+ add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
190
+ # maximum 2 initial conditioning frames
191
+ num_init_cond_frames_for_train: 2
192
+ rand_init_cond_frames_for_train: True # random 1~2
193
+ num_correction_pt_per_frame: 7
194
+ use_act_ckpt_iterative_pt_sampling: false
195
+
196
+
197
+
198
+ num_init_cond_frames_for_eval: 1 # only mask on the first frame
199
+ forward_backbone_per_frame_for_eval: True
200
+
201
+
202
+ data:
203
+ train:
204
+ _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
205
+ phases_per_epoch: ${scratch.phases_per_epoch}
206
+ batch_sizes:
207
+ - ${scratch.train_batch_size}
208
+
209
+ datasets:
210
+ - _target_: training.dataset.utils.RepeatFactorWrapper
211
+ dataset:
212
+ _target_: training.dataset.utils.ConcatDataset
213
+ datasets:
214
+ - _target_: training.dataset.vos_dataset.VOSDataset
215
+ transforms: ${vos.train_transforms}
216
+ training: true
217
+ video_dataset:
218
+ _target_: training.dataset.vos_raw_dataset.PNGRawDataset
219
+ img_folder: ${dataset.img_folder}
220
+ gt_folder: ${dataset.gt_folder}
221
+ file_list_txt: ${dataset.file_list_txt}
222
+ sampler:
223
+ _target_: training.dataset.vos_sampler.RandomUniformSampler
224
+ num_frames: ${scratch.num_frames}
225
+ max_num_objects: ${scratch.max_num_objects}
226
+ multiplier: ${dataset.multiplier}
227
+ shuffle: True
228
+ num_workers: ${scratch.num_train_workers}
229
+ pin_memory: True
230
+ drop_last: True
231
+ collate_fn:
232
+ _target_: training.utils.data_utils.collate_fn
233
+ _partial_: true
234
+ dict_key: all
235
+
236
+ optim:
237
+ amp:
238
+ enabled: True
239
+ amp_dtype: bfloat16
240
+
241
+ optimizer:
242
+ _target_: torch.optim.AdamW
243
+
244
+ gradient_clip:
245
+ _target_: training.optimizer.GradientClipper
246
+ max_norm: 0.1
247
+ norm_type: 2
248
+
249
+ param_group_modifiers:
250
+ - _target_: training.optimizer.layer_decay_param_modifier
251
+ _partial_: True
252
+ layer_decay_value: 0.9
253
+ apply_to: 'image_encoder.trunk'
254
+ overrides:
255
+ - pattern: '*pos_embed*'
256
+ value: 1.0
257
+
258
+ options:
259
+ lr:
260
+ - scheduler:
261
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
262
+ start_value: ${scratch.base_lr}
263
+ end_value: ${divide:${scratch.base_lr},10}
264
+ - scheduler:
265
+ _target_: fvcore.common.param_scheduler.CosineParamScheduler
266
+ start_value: ${scratch.vision_lr}
267
+ end_value: ${divide:${scratch.vision_lr},10}
268
+ param_names:
269
+ - 'image_encoder.*'
270
+ weight_decay:
271
+ - scheduler:
272
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
273
+ value: 0.1
274
+ - scheduler:
275
+ _target_: fvcore.common.param_scheduler.ConstantParamScheduler
276
+ value: 0.0
277
+ param_names:
278
+ - '*bias*'
279
+ module_cls_names: ['torch.nn.LayerNorm']
280
+
281
+ loss:
282
+ all:
283
+ _target_: training.loss_fns.MultiStepMultiMasksAndIous
284
+ weight_dict:
285
+ loss_mask: 20
286
+ loss_dice: 1
287
+ loss_iou: 1
288
+ loss_class: 1
289
+ supervise_all_iou: true
290
+ iou_use_l1_loss: true
291
+ pred_obj_scores: true
292
+ focal_gamma_obj_score: 0.0
293
+ focal_alpha_obj_score: -1.0
294
+
295
+ distributed:
296
+ backend: nccl
297
+ find_unused_parameters: True
298
+
299
+ logging:
300
+ tensorboard_writer:
301
+ _target_: training.utils.logger.make_tensorboard_logger
302
+ log_dir: ${launcher.experiment_log_dir}/tensorboard
303
+ flush_secs: 120
304
+ should_log: True
305
+ log_dir: ${launcher.experiment_log_dir}/logs
306
+ log_freq: 10
307
+
308
+ # initialize from a SAM 2 checkpoint
309
+ checkpoint:
310
+ save_dir: ${launcher.experiment_log_dir}/checkpoints
311
+ save_freq: 0 # 0 only last checkpoint is saved.
312
+ model_weight_initializer:
313
+ _partial_: True
314
+ _target_: training.utils.checkpoint_utils.load_state_dict_into_model
315
+ strict: True
316
+ ignore_unexpected_keys: null
317
+ ignore_missing_keys: null
318
+
319
+ state_dict:
320
+ _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
321
+ checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
322
+ ckpt_state_dict_keys: ['model']
323
+
324
+ launcher:
325
+ num_nodes: 1
326
+ gpus_per_node: 8
327
+ experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
328
+
329
+ # SLURM args if running on a cluster
330
+ submitit:
331
+ partition: null
332
+ account: null
333
+ qos: null
334
+ cpus_per_task: 10
335
+ use_cluster: false
336
+ timeout_hour: 24
337
+ name: null
338
+ port_range: [10000, 65000]
339
+
sam2_repo/sam2/configs/sam2/sam2_hiera_b+.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 112
12
+ num_heads: 2
13
+ neck:
14
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15
+ position_encoding:
16
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17
+ num_pos_feats: 256
18
+ normalize: true
19
+ scale: null
20
+ temperature: 10000
21
+ d_model: 256
22
+ backbone_channel_list: [896, 448, 224, 112]
23
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24
+ fpn_interp_model: nearest
25
+
26
+ memory_attention:
27
+ _target_: sam2.modeling.memory_attention.MemoryAttention
28
+ d_model: 256
29
+ pos_enc_at_input: true
30
+ layer:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32
+ activation: relu
33
+ dim_feedforward: 2048
34
+ dropout: 0.1
35
+ pos_enc_at_attn: false
36
+ self_attention:
37
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
38
+ rope_theta: 10000.0
39
+ feat_sizes: [64, 64]
40
+ embedding_dim: 256
41
+ num_heads: 1
42
+ downsample_rate: 1
43
+ dropout: 0.1
44
+ d_model: 256
45
+ pos_enc_at_cross_attn_keys: true
46
+ pos_enc_at_cross_attn_queries: false
47
+ cross_attention:
48
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
49
+ rope_theta: 10000.0
50
+ feat_sizes: [64, 64]
51
+ rope_k_repeat: True
52
+ embedding_dim: 256
53
+ num_heads: 1
54
+ downsample_rate: 1
55
+ dropout: 0.1
56
+ kv_in_dim: 64
57
+ num_layers: 4
58
+
59
+ memory_encoder:
60
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
61
+ out_dim: 64
62
+ position_encoding:
63
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64
+ num_pos_feats: 64
65
+ normalize: true
66
+ scale: null
67
+ temperature: 10000
68
+ mask_downsampler:
69
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
70
+ kernel_size: 3
71
+ stride: 2
72
+ padding: 1
73
+ fuser:
74
+ _target_: sam2.modeling.memory_encoder.Fuser
75
+ layer:
76
+ _target_: sam2.modeling.memory_encoder.CXBlock
77
+ dim: 256
78
+ kernel_size: 7
79
+ padding: 3
80
+ layer_scale_init_value: 1e-6
81
+ use_dwconv: True # depth-wise convs
82
+ num_layers: 2
83
+
84
+ num_maskmem: 7
85
+ image_size: 1024
86
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87
+ sigmoid_scale_for_mem_enc: 20.0
88
+ sigmoid_bias_for_mem_enc: -10.0
89
+ use_mask_input_as_output_without_sam: true
90
+ # Memory
91
+ directly_add_no_mem_embed: true
92
+ # use high-resolution feature map in the SAM mask decoder
93
+ use_high_res_features_in_sam: true
94
+ # output 3 masks on the first click on initial conditioning frames
95
+ multimask_output_in_sam: true
96
+ # SAM heads
97
+ iou_prediction_use_sigmoid: True
98
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99
+ use_obj_ptrs_in_encoder: true
100
+ add_tpos_enc_to_obj_ptrs: false
101
+ only_obj_ptrs_in_the_past_for_eval: true
102
+ # object occlusion prediction
103
+ pred_obj_scores: true
104
+ pred_obj_scores_mlp: true
105
+ fixed_no_obj_ptr: true
106
+ # multimask tracking settings
107
+ multimask_output_for_tracking: true
108
+ use_multimask_token_for_obj_ptr: true
109
+ multimask_min_pt_num: 0
110
+ multimask_max_pt_num: 1
111
+ use_mlp_for_obj_ptr_proj: true
112
+ # Compilation flag
113
+ compile_image_encoder: False
sam2_repo/sam2/configs/sam2/sam2_hiera_l.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [64, 64]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [64, 64]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True # depth-wise convs
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ compile_image_encoder: False
sam2_repo/sam2/configs/sam2/sam2_hiera_s.yaml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 11, 2]
14
+ global_att_blocks: [7, 10, 13]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [64, 64]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [64, 64]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ # Memory
94
+ directly_add_no_mem_embed: true
95
+ # use high-resolution feature map in the SAM mask decoder
96
+ use_high_res_features_in_sam: true
97
+ # output 3 masks on the first click on initial conditioning frames
98
+ multimask_output_in_sam: true
99
+ # SAM heads
100
+ iou_prediction_use_sigmoid: True
101
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102
+ use_obj_ptrs_in_encoder: true
103
+ add_tpos_enc_to_obj_ptrs: false
104
+ only_obj_ptrs_in_the_past_for_eval: true
105
+ # object occlusion prediction
106
+ pred_obj_scores: true
107
+ pred_obj_scores_mlp: true
108
+ fixed_no_obj_ptr: true
109
+ # multimask tracking settings
110
+ multimask_output_for_tracking: true
111
+ use_multimask_token_for_obj_ptr: true
112
+ multimask_min_pt_num: 0
113
+ multimask_max_pt_num: 1
114
+ use_mlp_for_obj_ptr_proj: true
115
+ # Compilation flag
116
+ compile_image_encoder: False
sam2_repo/sam2/configs/sam2/sam2_hiera_t.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Model
4
+ model:
5
+ _target_: sam2.modeling.sam2_base.SAM2Base
6
+ image_encoder:
7
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8
+ scalp: 1
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 96
12
+ num_heads: 1
13
+ stages: [1, 2, 7, 2]
14
+ global_att_blocks: [5, 7, 9]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ neck:
17
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18
+ position_encoding:
19
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20
+ num_pos_feats: 256
21
+ normalize: true
22
+ scale: null
23
+ temperature: 10000
24
+ d_model: 256
25
+ backbone_channel_list: [768, 384, 192, 96]
26
+ fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27
+ fpn_interp_model: nearest
28
+
29
+ memory_attention:
30
+ _target_: sam2.modeling.memory_attention.MemoryAttention
31
+ d_model: 256
32
+ pos_enc_at_input: true
33
+ layer:
34
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35
+ activation: relu
36
+ dim_feedforward: 2048
37
+ dropout: 0.1
38
+ pos_enc_at_attn: false
39
+ self_attention:
40
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
41
+ rope_theta: 10000.0
42
+ feat_sizes: [64, 64]
43
+ embedding_dim: 256
44
+ num_heads: 1
45
+ downsample_rate: 1
46
+ dropout: 0.1
47
+ d_model: 256
48
+ pos_enc_at_cross_attn_keys: true
49
+ pos_enc_at_cross_attn_queries: false
50
+ cross_attention:
51
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
52
+ rope_theta: 10000.0
53
+ feat_sizes: [64, 64]
54
+ rope_k_repeat: True
55
+ embedding_dim: 256
56
+ num_heads: 1
57
+ downsample_rate: 1
58
+ dropout: 0.1
59
+ kv_in_dim: 64
60
+ num_layers: 4
61
+
62
+ memory_encoder:
63
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
64
+ out_dim: 64
65
+ position_encoding:
66
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67
+ num_pos_feats: 64
68
+ normalize: true
69
+ scale: null
70
+ temperature: 10000
71
+ mask_downsampler:
72
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
73
+ kernel_size: 3
74
+ stride: 2
75
+ padding: 1
76
+ fuser:
77
+ _target_: sam2.modeling.memory_encoder.Fuser
78
+ layer:
79
+ _target_: sam2.modeling.memory_encoder.CXBlock
80
+ dim: 256
81
+ kernel_size: 7
82
+ padding: 3
83
+ layer_scale_init_value: 1e-6
84
+ use_dwconv: True # depth-wise convs
85
+ num_layers: 2
86
+
87
+ num_maskmem: 7
88
+ image_size: 1024
89
+ # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90
+ # SAM decoder
91
+ sigmoid_scale_for_mem_enc: 20.0
92
+ sigmoid_bias_for_mem_enc: -10.0
93
+ use_mask_input_as_output_without_sam: true
94
+ # Memory
95
+ directly_add_no_mem_embed: true
96
+ # use high-resolution feature map in the SAM mask decoder
97
+ use_high_res_features_in_sam: true
98
+ # output 3 masks on the first click on initial conditioning frames
99
+ multimask_output_in_sam: true
100
+ # SAM heads
101
+ iou_prediction_use_sigmoid: True
102
+ # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103
+ use_obj_ptrs_in_encoder: true
104
+ add_tpos_enc_to_obj_ptrs: false
105
+ only_obj_ptrs_in_the_past_for_eval: true
106
+ # object occlusion prediction
107
+ pred_obj_scores: true
108
+ pred_obj_scores_mlp: true
109
+ fixed_no_obj_ptr: true
110
+ # multimask tracking settings
111
+ multimask_output_for_tracking: true
112
+ use_multimask_token_for_obj_ptr: true
113
+ multimask_min_pt_num: 0
114
+ multimask_max_pt_num: 1
115
+ use_mlp_for_obj_ptr_proj: true
116
+ # Compilation flag
117
+ # HieraT does not currently support compilation, should always be set to False
118
+ compile_image_encoder: False
sam2_repo/sam2/csrc/connected_components.cu ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ // adapted from https://github.com/zsef123/Connected_components_PyTorch
8
+ // with license found in the LICENSE_cctorch file in the root directory.
9
+ #include <ATen/cuda/CUDAContext.h>
10
+ #include <cuda.h>
11
+ #include <cuda_runtime.h>
12
+ #include <torch/extension.h>
13
+ #include <torch/script.h>
14
+ #include <vector>
15
+
16
+ // 2d
17
+ #define BLOCK_ROWS 16
18
+ #define BLOCK_COLS 16
19
+
20
+ namespace cc2d {
21
+
22
+ template <typename T>
23
+ __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
24
+ return (bitmap >> pos) & 1;
25
+ }
26
+
27
+ __device__ int32_t find(const int32_t* s_buf, int32_t n) {
28
+ while (s_buf[n] != n)
29
+ n = s_buf[n];
30
+ return n;
31
+ }
32
+
33
+ __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
34
+ const int32_t id = n;
35
+ while (s_buf[n] != n) {
36
+ n = s_buf[n];
37
+ s_buf[id] = n;
38
+ }
39
+ return n;
40
+ }
41
+
42
+ __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
43
+ bool done;
44
+ do {
45
+ a = find(s_buf, a);
46
+ b = find(s_buf, b);
47
+
48
+ if (a < b) {
49
+ int32_t old = atomicMin(s_buf + b, a);
50
+ done = (old == b);
51
+ b = old;
52
+ } else if (b < a) {
53
+ int32_t old = atomicMin(s_buf + a, b);
54
+ done = (old == a);
55
+ a = old;
56
+ } else
57
+ done = true;
58
+
59
+ } while (!done);
60
+ }
61
+
62
+ __global__ void
63
+ init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
64
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
65
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
66
+ const uint32_t idx = row * W + col;
67
+
68
+ if (row < H && col < W)
69
+ label[idx] = idx;
70
+ }
71
+
72
+ __global__ void
73
+ merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
74
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
75
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
76
+ const uint32_t idx = row * W + col;
77
+
78
+ if (row >= H || col >= W)
79
+ return;
80
+
81
+ uint32_t P = 0;
82
+
83
+ if (img[idx])
84
+ P |= 0x777;
85
+ if (row + 1 < H && img[idx + W])
86
+ P |= 0x777 << 4;
87
+ if (col + 1 < W && img[idx + 1])
88
+ P |= 0x777 << 1;
89
+
90
+ if (col == 0)
91
+ P &= 0xEEEE;
92
+ if (col + 1 >= W)
93
+ P &= 0x3333;
94
+ else if (col + 2 >= W)
95
+ P &= 0x7777;
96
+
97
+ if (row == 0)
98
+ P &= 0xFFF0;
99
+ if (row + 1 >= H)
100
+ P &= 0xFF;
101
+
102
+ if (P > 0) {
103
+ // If need check about top-left pixel(if flag the first bit) and hit the
104
+ // top-left pixel
105
+ if (hasBit(P, 0) && img[idx - W - 1]) {
106
+ union_(label, idx, idx - 2 * W - 2); // top left block
107
+ }
108
+
109
+ if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
110
+ union_(label, idx, idx - 2 * W); // top bottom block
111
+
112
+ if (hasBit(P, 3) && img[idx + 2 - W])
113
+ union_(label, idx, idx - 2 * W + 2); // top right block
114
+
115
+ if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
116
+ union_(label, idx, idx - 2); // just left block
117
+ }
118
+ }
119
+
120
+ __global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
121
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
122
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
123
+ const uint32_t idx = row * W + col;
124
+
125
+ if (row < H && col < W)
126
+ find_n_compress(label, idx);
127
+ }
128
+
129
+ __global__ void final_labeling(
130
+ const uint8_t* img,
131
+ int32_t* label,
132
+ const int32_t W,
133
+ const int32_t H) {
134
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
135
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
136
+ const uint32_t idx = row * W + col;
137
+
138
+ if (row >= H || col >= W)
139
+ return;
140
+
141
+ int32_t y = label[idx] + 1;
142
+
143
+ if (img[idx])
144
+ label[idx] = y;
145
+ else
146
+ label[idx] = 0;
147
+
148
+ if (col + 1 < W) {
149
+ if (img[idx + 1])
150
+ label[idx + 1] = y;
151
+ else
152
+ label[idx + 1] = 0;
153
+
154
+ if (row + 1 < H) {
155
+ if (img[idx + W + 1])
156
+ label[idx + W + 1] = y;
157
+ else
158
+ label[idx + W + 1] = 0;
159
+ }
160
+ }
161
+
162
+ if (row + 1 < H) {
163
+ if (img[idx + W])
164
+ label[idx + W] = y;
165
+ else
166
+ label[idx + W] = 0;
167
+ }
168
+ }
169
+
170
+ __global__ void init_counting(
171
+ const int32_t* label,
172
+ int32_t* count_init,
173
+ const int32_t W,
174
+ const int32_t H) {
175
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
176
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
177
+ const uint32_t idx = row * W + col;
178
+
179
+ if (row >= H || col >= W)
180
+ return;
181
+
182
+ int32_t y = label[idx];
183
+ if (y > 0) {
184
+ int32_t count_idx = y - 1;
185
+ atomicAdd(count_init + count_idx, 1);
186
+ }
187
+ }
188
+
189
+ __global__ void final_counting(
190
+ const int32_t* label,
191
+ const int32_t* count_init,
192
+ int32_t* count_final,
193
+ const int32_t W,
194
+ const int32_t H) {
195
+ const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
196
+ const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
197
+ const uint32_t idx = row * W + col;
198
+
199
+ if (row >= H || col >= W)
200
+ return;
201
+
202
+ int32_t y = label[idx];
203
+ if (y > 0) {
204
+ int32_t count_idx = y - 1;
205
+ count_final[idx] = count_init[count_idx];
206
+ } else {
207
+ count_final[idx] = 0;
208
+ }
209
+ }
210
+
211
+ } // namespace cc2d
212
+
213
+ std::vector<torch::Tensor> get_connected_componnets(
214
+ const torch::Tensor& inputs) {
215
+ AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
216
+ AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
217
+ AT_ASSERTM(
218
+ inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
219
+
220
+ const uint32_t N = inputs.size(0);
221
+ const uint32_t C = inputs.size(1);
222
+ const uint32_t H = inputs.size(2);
223
+ const uint32_t W = inputs.size(3);
224
+
225
+ AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
226
+ AT_ASSERTM((H % 2) == 0, "height must be an even number");
227
+ AT_ASSERTM((W % 2) == 0, "width must be an even number");
228
+
229
+ // label must be uint32_t
230
+ auto label_options =
231
+ torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
232
+ torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
233
+ torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
234
+ torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
235
+
236
+ dim3 grid = dim3(
237
+ ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
238
+ ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
239
+ dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
240
+ dim3 grid_count =
241
+ dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
242
+ dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
243
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
244
+
245
+ for (int n = 0; n < N; n++) {
246
+ uint32_t offset = n * H * W;
247
+
248
+ cc2d::init_labeling<<<grid, block, 0, stream>>>(
249
+ labels.data_ptr<int32_t>() + offset, W, H);
250
+ cc2d::merge<<<grid, block, 0, stream>>>(
251
+ inputs.data_ptr<uint8_t>() + offset,
252
+ labels.data_ptr<int32_t>() + offset,
253
+ W,
254
+ H);
255
+ cc2d::compression<<<grid, block, 0, stream>>>(
256
+ labels.data_ptr<int32_t>() + offset, W, H);
257
+ cc2d::final_labeling<<<grid, block, 0, stream>>>(
258
+ inputs.data_ptr<uint8_t>() + offset,
259
+ labels.data_ptr<int32_t>() + offset,
260
+ W,
261
+ H);
262
+
263
+ // get the counting of each pixel
264
+ cc2d::init_counting<<<grid_count, block_count, 0, stream>>>(
265
+ labels.data_ptr<int32_t>() + offset,
266
+ counts_init.data_ptr<int32_t>() + offset,
267
+ W,
268
+ H);
269
+ cc2d::final_counting<<<grid_count, block_count, 0, stream>>>(
270
+ labels.data_ptr<int32_t>() + offset,
271
+ counts_init.data_ptr<int32_t>() + offset,
272
+ counts_final.data_ptr<int32_t>() + offset,
273
+ W,
274
+ H);
275
+ }
276
+
277
+ // returned values are [labels, counts]
278
+ std::vector<torch::Tensor> outputs;
279
+ outputs.push_back(labels);
280
+ outputs.push_back(counts_final);
281
+ return outputs;
282
+ }
283
+
284
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
285
+ m.def(
286
+ "get_connected_componnets",
287
+ &get_connected_componnets,
288
+ "get_connected_componnets");
289
+ }
sam2_repo/sam2/modeling/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sam2_repo/sam2/modeling/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (165 Bytes). View file
 
sam2_repo/sam2/modeling/__pycache__/memory_attention.cpython-313.pyc ADDED
Binary file (7.06 kB). View file
 
sam2_repo/sam2/modeling/__pycache__/memory_encoder.cpython-313.pyc ADDED
Binary file (7.99 kB). View file
 
sam2_repo/sam2/modeling/__pycache__/position_encoding.cpython-313.pyc ADDED
Binary file (15.1 kB). View file
 
sam2_repo/sam2/modeling/__pycache__/sam2_base.cpython-313.pyc ADDED
Binary file (31 kB). View file
 
sam2_repo/sam2/modeling/__pycache__/sam2_utils.cpython-313.pyc ADDED
Binary file (17.3 kB). View file
 
sam2_repo/sam2/modeling/backbones/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sam2_repo/sam2/modeling/backbones/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (175 Bytes). View file
 
sam2_repo/sam2/modeling/backbones/__pycache__/hieradet.cpython-313.pyc ADDED
Binary file (13.6 kB). View file
 
sam2_repo/sam2/modeling/backbones/__pycache__/image_encoder.cpython-313.pyc ADDED
Binary file (5.53 kB). View file
 
sam2_repo/sam2/modeling/backbones/__pycache__/utils.cpython-313.pyc ADDED
Binary file (4.06 kB). View file
 
sam2_repo/sam2/modeling/backbones/hieradet.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from functools import partial
9
+ from typing import List, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from iopath.common.file_io import g_pathmgr
15
+
16
+ from sam2.modeling.backbones.utils import (
17
+ PatchEmbed,
18
+ window_partition,
19
+ window_unpartition,
20
+ )
21
+
22
+ from sam2.modeling.sam2_utils import DropPath, MLP
23
+
24
+
25
+ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
26
+ if pool is None:
27
+ return x
28
+ # (B, H, W, C) -> (B, C, H, W)
29
+ x = x.permute(0, 3, 1, 2)
30
+ x = pool(x)
31
+ # (B, C, H', W') -> (B, H', W', C)
32
+ x = x.permute(0, 2, 3, 1)
33
+ if norm:
34
+ x = norm(x)
35
+
36
+ return x
37
+
38
+
39
+ class MultiScaleAttention(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ dim_out: int,
44
+ num_heads: int,
45
+ q_pool: nn.Module = None,
46
+ ):
47
+ super().__init__()
48
+
49
+ self.dim = dim
50
+ self.dim_out = dim_out
51
+ self.num_heads = num_heads
52
+ self.q_pool = q_pool
53
+ self.qkv = nn.Linear(dim, dim_out * 3)
54
+ self.proj = nn.Linear(dim_out, dim_out)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ B, H, W, _ = x.shape
58
+ # qkv with shape (B, H * W, 3, nHead, C)
59
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
60
+ # q, k, v with shape (B, H * W, nheads, C)
61
+ q, k, v = torch.unbind(qkv, 2)
62
+
63
+ # Q pooling (for downsample at stage changes)
64
+ if self.q_pool:
65
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
66
+ H, W = q.shape[1:3] # downsampled shape
67
+ q = q.reshape(B, H * W, self.num_heads, -1)
68
+
69
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
70
+ x = F.scaled_dot_product_attention(
71
+ q.transpose(1, 2),
72
+ k.transpose(1, 2),
73
+ v.transpose(1, 2),
74
+ )
75
+ # Transpose back
76
+ x = x.transpose(1, 2)
77
+ x = x.reshape(B, H, W, -1)
78
+
79
+ x = self.proj(x)
80
+
81
+ return x
82
+
83
+
84
+ class MultiScaleBlock(nn.Module):
85
+ def __init__(
86
+ self,
87
+ dim: int,
88
+ dim_out: int,
89
+ num_heads: int,
90
+ mlp_ratio: float = 4.0,
91
+ drop_path: float = 0.0,
92
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
93
+ q_stride: Tuple[int, int] = None,
94
+ act_layer: nn.Module = nn.GELU,
95
+ window_size: int = 0,
96
+ ):
97
+ super().__init__()
98
+
99
+ if isinstance(norm_layer, str):
100
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
101
+
102
+ self.dim = dim
103
+ self.dim_out = dim_out
104
+ self.norm1 = norm_layer(dim)
105
+
106
+ self.window_size = window_size
107
+
108
+ self.pool, self.q_stride = None, q_stride
109
+ if self.q_stride:
110
+ self.pool = nn.MaxPool2d(
111
+ kernel_size=q_stride, stride=q_stride, ceil_mode=False
112
+ )
113
+
114
+ self.attn = MultiScaleAttention(
115
+ dim,
116
+ dim_out,
117
+ num_heads=num_heads,
118
+ q_pool=self.pool,
119
+ )
120
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
121
+
122
+ self.norm2 = norm_layer(dim_out)
123
+ self.mlp = MLP(
124
+ dim_out,
125
+ int(dim_out * mlp_ratio),
126
+ dim_out,
127
+ num_layers=2,
128
+ activation=act_layer,
129
+ )
130
+
131
+ if dim != dim_out:
132
+ self.proj = nn.Linear(dim, dim_out)
133
+
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ shortcut = x # B, H, W, C
136
+ x = self.norm1(x)
137
+
138
+ # Skip connection
139
+ if self.dim != self.dim_out:
140
+ shortcut = do_pool(self.proj(x), self.pool)
141
+
142
+ # Window partition
143
+ window_size = self.window_size
144
+ if window_size > 0:
145
+ H, W = x.shape[1], x.shape[2]
146
+ x, pad_hw = window_partition(x, window_size)
147
+
148
+ # Window Attention + Q Pooling (if stage change)
149
+ x = self.attn(x)
150
+ if self.q_stride:
151
+ # Shapes have changed due to Q pooling
152
+ window_size = self.window_size // self.q_stride[0]
153
+ H, W = shortcut.shape[1:3]
154
+
155
+ pad_h = (window_size - H % window_size) % window_size
156
+ pad_w = (window_size - W % window_size) % window_size
157
+ pad_hw = (H + pad_h, W + pad_w)
158
+
159
+ # Reverse window partition
160
+ if self.window_size > 0:
161
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
162
+
163
+ x = shortcut + self.drop_path(x)
164
+ # MLP
165
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
166
+ return x
167
+
168
+
169
+ class Hiera(nn.Module):
170
+ """
171
+ Reference: https://arxiv.org/abs/2306.00989
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ embed_dim: int = 96, # initial embed dim
177
+ num_heads: int = 1, # initial number of heads
178
+ drop_path_rate: float = 0.0, # stochastic depth
179
+ q_pool: int = 3, # number of q_pool stages
180
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
181
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
182
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
183
+ head_mul: float = 2.0, # head_mul factor at stage shift
184
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
185
+ # window size per stage, when not using global att.
186
+ window_spec: Tuple[int, ...] = (
187
+ 8,
188
+ 4,
189
+ 14,
190
+ 7,
191
+ ),
192
+ # global attn in these blocks
193
+ global_att_blocks: Tuple[int, ...] = (
194
+ 12,
195
+ 16,
196
+ 20,
197
+ ),
198
+ weights_path=None,
199
+ return_interm_layers=True, # return feats from every stage
200
+ ):
201
+ super().__init__()
202
+
203
+ assert len(stages) == len(window_spec)
204
+ self.window_spec = window_spec
205
+
206
+ depth = sum(stages)
207
+ self.q_stride = q_stride
208
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
209
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
210
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
211
+ self.return_interm_layers = return_interm_layers
212
+
213
+ self.patch_embed = PatchEmbed(
214
+ embed_dim=embed_dim,
215
+ )
216
+ # Which blocks have global att?
217
+ self.global_att_blocks = global_att_blocks
218
+
219
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
220
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
221
+ self.pos_embed = nn.Parameter(
222
+ torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
223
+ )
224
+ self.pos_embed_window = nn.Parameter(
225
+ torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
226
+ )
227
+
228
+ dpr = [
229
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
230
+ ] # stochastic depth decay rule
231
+
232
+ cur_stage = 1
233
+ self.blocks = nn.ModuleList()
234
+
235
+ for i in range(depth):
236
+ dim_out = embed_dim
237
+ # lags by a block, so first block of
238
+ # next stage uses an initial window size
239
+ # of previous stage and final window size of current stage
240
+ window_size = self.window_spec[cur_stage - 1]
241
+
242
+ if self.global_att_blocks is not None:
243
+ window_size = 0 if i in self.global_att_blocks else window_size
244
+
245
+ if i - 1 in self.stage_ends:
246
+ dim_out = int(embed_dim * dim_mul)
247
+ num_heads = int(num_heads * head_mul)
248
+ cur_stage += 1
249
+
250
+ block = MultiScaleBlock(
251
+ dim=embed_dim,
252
+ dim_out=dim_out,
253
+ num_heads=num_heads,
254
+ drop_path=dpr[i],
255
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
256
+ window_size=window_size,
257
+ )
258
+
259
+ embed_dim = dim_out
260
+ self.blocks.append(block)
261
+
262
+ self.channel_list = (
263
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
264
+ if return_interm_layers
265
+ else [self.blocks[-1].dim_out]
266
+ )
267
+
268
+ if weights_path is not None:
269
+ with g_pathmgr.open(weights_path, "rb") as f:
270
+ chkpt = torch.load(f, map_location="cpu")
271
+ logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
272
+
273
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
274
+ h, w = hw
275
+ window_embed = self.pos_embed_window
276
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
277
+ pos_embed = pos_embed + window_embed.tile(
278
+ [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
279
+ )
280
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
281
+ return pos_embed
282
+
283
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
284
+ x = self.patch_embed(x)
285
+ # x: (B, H, W, C)
286
+
287
+ # Add pos embed
288
+ x = x + self._get_pos_embed(x.shape[1:3])
289
+
290
+ outputs = []
291
+ for i, blk in enumerate(self.blocks):
292
+ x = blk(x)
293
+ if (i == self.stage_ends[-1]) or (
294
+ i in self.stage_ends and self.return_interm_layers
295
+ ):
296
+ feats = x.permute(0, 3, 1, 2)
297
+ outputs.append(feats)
298
+
299
+ return outputs
300
+
301
+ def get_layer_id(self, layer_name):
302
+ # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
303
+ num_layers = self.get_num_layers()
304
+
305
+ if layer_name.find("rel_pos") != -1:
306
+ return num_layers + 1
307
+ elif layer_name.find("pos_embed") != -1:
308
+ return 0
309
+ elif layer_name.find("patch_embed") != -1:
310
+ return 0
311
+ elif layer_name.find("blocks") != -1:
312
+ return int(layer_name.split("blocks")[1].split(".")[1]) + 1
313
+ else:
314
+ return num_layers + 1
315
+
316
+ def get_num_layers(self) -> int:
317
+ return len(self.blocks)
sam2_repo/sam2/modeling/backbones/image_encoder.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class ImageEncoder(nn.Module):
15
+ def __init__(
16
+ self,
17
+ trunk: nn.Module,
18
+ neck: nn.Module,
19
+ scalp: int = 0,
20
+ ):
21
+ super().__init__()
22
+ self.trunk = trunk
23
+ self.neck = neck
24
+ self.scalp = scalp
25
+ assert (
26
+ self.trunk.channel_list == self.neck.backbone_channel_list
27
+ ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
28
+
29
+ def forward(self, sample: torch.Tensor):
30
+ # Forward through backbone
31
+ features, pos = self.neck(self.trunk(sample))
32
+ if self.scalp > 0:
33
+ # Discard the lowest resolution features
34
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
35
+
36
+ src = features[-1]
37
+ output = {
38
+ "vision_features": src,
39
+ "vision_pos_enc": pos,
40
+ "backbone_fpn": features,
41
+ }
42
+ return output
43
+
44
+
45
+ class FpnNeck(nn.Module):
46
+ """
47
+ A modified variant of Feature Pyramid Network (FPN) neck
48
+ (we remove output conv and also do bicubic interpolation similar to ViT
49
+ pos embed interpolation)
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ position_encoding: nn.Module,
55
+ d_model: int,
56
+ backbone_channel_list: List[int],
57
+ kernel_size: int = 1,
58
+ stride: int = 1,
59
+ padding: int = 0,
60
+ fpn_interp_model: str = "bilinear",
61
+ fuse_type: str = "sum",
62
+ fpn_top_down_levels: Optional[List[int]] = None,
63
+ ):
64
+ """Initialize the neck
65
+ :param trunk: the backbone
66
+ :param position_encoding: the positional encoding to use
67
+ :param d_model: the dimension of the model
68
+ :param neck_norm: the normalization to use
69
+ """
70
+ super().__init__()
71
+ self.position_encoding = position_encoding
72
+ self.convs = nn.ModuleList()
73
+ self.backbone_channel_list = backbone_channel_list
74
+ self.d_model = d_model
75
+ for dim in backbone_channel_list:
76
+ current = nn.Sequential()
77
+ current.add_module(
78
+ "conv",
79
+ nn.Conv2d(
80
+ in_channels=dim,
81
+ out_channels=d_model,
82
+ kernel_size=kernel_size,
83
+ stride=stride,
84
+ padding=padding,
85
+ ),
86
+ )
87
+
88
+ self.convs.append(current)
89
+ self.fpn_interp_model = fpn_interp_model
90
+ assert fuse_type in ["sum", "avg"]
91
+ self.fuse_type = fuse_type
92
+
93
+ # levels to have top-down features in its outputs
94
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
95
+ # have top-down propagation, while outputs of level 0 and level 1 have only
96
+ # lateral features from the same backbone level.
97
+ if fpn_top_down_levels is None:
98
+ # default is to have top-down features on all levels
99
+ fpn_top_down_levels = range(len(self.convs))
100
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
101
+
102
+ def forward(self, xs: List[torch.Tensor]):
103
+
104
+ out = [None] * len(self.convs)
105
+ pos = [None] * len(self.convs)
106
+ assert len(xs) == len(self.convs)
107
+ # fpn forward pass
108
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
109
+ prev_features = None
110
+ # forward in top-down order (from low to high resolution)
111
+ n = len(self.convs) - 1
112
+ for i in range(n, -1, -1):
113
+ x = xs[i]
114
+ lateral_features = self.convs[n - i](x)
115
+ if i in self.fpn_top_down_levels and prev_features is not None:
116
+ top_down_features = F.interpolate(
117
+ prev_features.to(dtype=torch.float32),
118
+ scale_factor=2.0,
119
+ mode=self.fpn_interp_model,
120
+ align_corners=(
121
+ None if self.fpn_interp_model == "nearest" else False
122
+ ),
123
+ antialias=False,
124
+ )
125
+ prev_features = lateral_features + top_down_features
126
+ if self.fuse_type == "avg":
127
+ prev_features /= 2
128
+ else:
129
+ prev_features = lateral_features
130
+ x_out = prev_features
131
+ out[i] = x_out
132
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
133
+
134
+ return out, pos
sam2_repo/sam2/modeling/backbones/utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Some utilities for backbones, in particular for windowing"""
8
+
9
+ from typing import Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ def window_partition(x, window_size):
17
+ """
18
+ Partition into non-overlapping windows with padding if needed.
19
+ Args:
20
+ x (tensor): input tokens with [B, H, W, C].
21
+ window_size (int): window size.
22
+ Returns:
23
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
24
+ (Hp, Wp): padded height and width before partition
25
+ """
26
+ B, H, W, C = x.shape
27
+
28
+ pad_h = (window_size - H % window_size) % window_size
29
+ pad_w = (window_size - W % window_size) % window_size
30
+ if pad_h > 0 or pad_w > 0:
31
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
32
+ Hp, Wp = H + pad_h, W + pad_w
33
+
34
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
35
+ windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
36
+ return windows, (Hp, Wp)
37
+
38
+
39
+ def window_unpartition(windows, window_size, pad_hw, hw):
40
+ """
41
+ Window unpartition into original sequences and removing padding.
42
+ Args:
43
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
44
+ window_size (int): window size.
45
+ pad_hw (Tuple): padded height and width (Hp, Wp).
46
+ hw (Tuple): original height and width (H, W) before padding.
47
+ Returns:
48
+ x: unpartitioned sequences with [B, H, W, C].
49
+ """
50
+ Hp, Wp = pad_hw
51
+ H, W = hw
52
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
53
+ x = windows.reshape(
54
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
55
+ )
56
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
57
+
58
+ if Hp > H or Wp > W:
59
+ x = x[:, :H, :W, :]
60
+ return x
61
+
62
+
63
+ class PatchEmbed(nn.Module):
64
+ """
65
+ Image to Patch Embedding.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ kernel_size: Tuple[int, ...] = (7, 7),
71
+ stride: Tuple[int, ...] = (4, 4),
72
+ padding: Tuple[int, ...] = (3, 3),
73
+ in_chans: int = 3,
74
+ embed_dim: int = 768,
75
+ ):
76
+ """
77
+ Args:
78
+ kernel_size (Tuple): kernel size of the projection layer.
79
+ stride (Tuple): stride of the projection layer.
80
+ padding (Tuple): padding size of the projection layer.
81
+ in_chans (int): Number of input image channels.
82
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
83
+ """
84
+ super().__init__()
85
+ self.proj = nn.Conv2d(
86
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
87
+ )
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ x = self.proj(x)
91
+ # B C H W -> B H W C
92
+ x = x.permute(0, 2, 3, 1)
93
+ return x
sam2_repo/sam2/modeling/memory_attention.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch import nn, Tensor
11
+
12
+ from sam2.modeling.sam.transformer import RoPEAttention
13
+
14
+ from sam2.modeling.sam2_utils import get_activation_fn, get_clones
15
+
16
+
17
+ class MemoryAttentionLayer(nn.Module):
18
+
19
+ def __init__(
20
+ self,
21
+ activation: str,
22
+ cross_attention: nn.Module,
23
+ d_model: int,
24
+ dim_feedforward: int,
25
+ dropout: float,
26
+ pos_enc_at_attn: bool,
27
+ pos_enc_at_cross_attn_keys: bool,
28
+ pos_enc_at_cross_attn_queries: bool,
29
+ self_attention: nn.Module,
30
+ ):
31
+ super().__init__()
32
+ self.d_model = d_model
33
+ self.dim_feedforward = dim_feedforward
34
+ self.dropout_value = dropout
35
+ self.self_attn = self_attention
36
+ self.cross_attn_image = cross_attention
37
+
38
+ # Implementation of Feedforward model
39
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
40
+ self.dropout = nn.Dropout(dropout)
41
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
42
+
43
+ self.norm1 = nn.LayerNorm(d_model)
44
+ self.norm2 = nn.LayerNorm(d_model)
45
+ self.norm3 = nn.LayerNorm(d_model)
46
+ self.dropout1 = nn.Dropout(dropout)
47
+ self.dropout2 = nn.Dropout(dropout)
48
+ self.dropout3 = nn.Dropout(dropout)
49
+
50
+ self.activation_str = activation
51
+ self.activation = get_activation_fn(activation)
52
+
53
+ # Where to add pos enc
54
+ self.pos_enc_at_attn = pos_enc_at_attn
55
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
56
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
57
+
58
+ def _forward_sa(self, tgt, query_pos):
59
+ # Self-Attention
60
+ tgt2 = self.norm1(tgt)
61
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
62
+ tgt2 = self.self_attn(q, k, v=tgt2)
63
+ tgt = tgt + self.dropout1(tgt2)
64
+ return tgt
65
+
66
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
67
+ kwds = {}
68
+ if num_k_exclude_rope > 0:
69
+ assert isinstance(self.cross_attn_image, RoPEAttention)
70
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
71
+
72
+ # Cross-Attention
73
+ tgt2 = self.norm2(tgt)
74
+ tgt2 = self.cross_attn_image(
75
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
76
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
77
+ v=memory,
78
+ **kwds,
79
+ )
80
+ tgt = tgt + self.dropout2(tgt2)
81
+ return tgt
82
+
83
+ def forward(
84
+ self,
85
+ tgt,
86
+ memory,
87
+ pos: Optional[Tensor] = None,
88
+ query_pos: Optional[Tensor] = None,
89
+ num_k_exclude_rope: int = 0,
90
+ ) -> torch.Tensor:
91
+
92
+ # Self-Attn, Cross-Attn
93
+ tgt = self._forward_sa(tgt, query_pos)
94
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
95
+ # MLP
96
+ tgt2 = self.norm3(tgt)
97
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
98
+ tgt = tgt + self.dropout3(tgt2)
99
+ return tgt
100
+
101
+
102
+ class MemoryAttention(nn.Module):
103
+ def __init__(
104
+ self,
105
+ d_model: int,
106
+ pos_enc_at_input: bool,
107
+ layer: nn.Module,
108
+ num_layers: int,
109
+ batch_first: bool = True, # Do layers expect batch first input?
110
+ ):
111
+ super().__init__()
112
+ self.d_model = d_model
113
+ self.layers = get_clones(layer, num_layers)
114
+ self.num_layers = num_layers
115
+ self.norm = nn.LayerNorm(d_model)
116
+ self.pos_enc_at_input = pos_enc_at_input
117
+ self.batch_first = batch_first
118
+
119
+ def forward(
120
+ self,
121
+ curr: torch.Tensor, # self-attention inputs
122
+ memory: torch.Tensor, # cross-attention inputs
123
+ curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
124
+ memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
125
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
126
+ ):
127
+ if isinstance(curr, list):
128
+ assert isinstance(curr_pos, list)
129
+ assert len(curr) == len(curr_pos) == 1
130
+ curr, curr_pos = (
131
+ curr[0],
132
+ curr_pos[0],
133
+ )
134
+
135
+ assert (
136
+ curr.shape[1] == memory.shape[1]
137
+ ), "Batch size must be the same for curr and memory"
138
+
139
+ output = curr
140
+ if self.pos_enc_at_input and curr_pos is not None:
141
+ output = output + 0.1 * curr_pos
142
+
143
+ if self.batch_first:
144
+ # Convert to batch first
145
+ output = output.transpose(0, 1)
146
+ curr_pos = curr_pos.transpose(0, 1)
147
+ memory = memory.transpose(0, 1)
148
+ memory_pos = memory_pos.transpose(0, 1)
149
+
150
+ for layer in self.layers:
151
+ kwds = {}
152
+ if isinstance(layer.cross_attn_image, RoPEAttention):
153
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
154
+
155
+ output = layer(
156
+ tgt=output,
157
+ memory=memory,
158
+ pos=memory_pos,
159
+ query_pos=curr_pos,
160
+ **kwds,
161
+ )
162
+ normed_output = self.norm(output)
163
+
164
+ if self.batch_first:
165
+ # Convert back to seq first
166
+ normed_output = normed_output.transpose(0, 1)
167
+ curr_pos = curr_pos.transpose(0, 1)
168
+
169
+ return normed_output
sam2_repo/sam2/modeling/memory_encoder.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
15
+
16
+
17
+ class MaskDownSampler(nn.Module):
18
+ """
19
+ Progressively downsample a mask by total_stride, each time by stride.
20
+ Note that LayerNorm is applied per *token*, like in ViT.
21
+
22
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
23
+ In the end, we linearly project to embed_dim channels.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ embed_dim=256,
29
+ kernel_size=4,
30
+ stride=4,
31
+ padding=0,
32
+ total_stride=16,
33
+ activation=nn.GELU,
34
+ ):
35
+ super().__init__()
36
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
37
+ assert stride**num_layers == total_stride
38
+ self.encoder = nn.Sequential()
39
+ mask_in_chans, mask_out_chans = 1, 1
40
+ for _ in range(num_layers):
41
+ mask_out_chans = mask_in_chans * (stride**2)
42
+ self.encoder.append(
43
+ nn.Conv2d(
44
+ mask_in_chans,
45
+ mask_out_chans,
46
+ kernel_size=kernel_size,
47
+ stride=stride,
48
+ padding=padding,
49
+ )
50
+ )
51
+ self.encoder.append(LayerNorm2d(mask_out_chans))
52
+ self.encoder.append(activation())
53
+ mask_in_chans = mask_out_chans
54
+
55
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
56
+
57
+ def forward(self, x):
58
+ return self.encoder(x)
59
+
60
+
61
+ # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
62
+ class CXBlock(nn.Module):
63
+ r"""ConvNeXt Block. There are two equivalent implementations:
64
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
65
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
66
+ We use (2) as we find it slightly faster in PyTorch
67
+
68
+ Args:
69
+ dim (int): Number of input channels.
70
+ drop_path (float): Stochastic depth rate. Default: 0.0
71
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ dim,
77
+ kernel_size=7,
78
+ padding=3,
79
+ drop_path=0.0,
80
+ layer_scale_init_value=1e-6,
81
+ use_dwconv=True,
82
+ ):
83
+ super().__init__()
84
+ self.dwconv = nn.Conv2d(
85
+ dim,
86
+ dim,
87
+ kernel_size=kernel_size,
88
+ padding=padding,
89
+ groups=dim if use_dwconv else 1,
90
+ ) # depthwise conv
91
+ self.norm = LayerNorm2d(dim, eps=1e-6)
92
+ self.pwconv1 = nn.Linear(
93
+ dim, 4 * dim
94
+ ) # pointwise/1x1 convs, implemented with linear layers
95
+ self.act = nn.GELU()
96
+ self.pwconv2 = nn.Linear(4 * dim, dim)
97
+ self.gamma = (
98
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
99
+ if layer_scale_init_value > 0
100
+ else None
101
+ )
102
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
103
+
104
+ def forward(self, x):
105
+ input = x
106
+ x = self.dwconv(x)
107
+ x = self.norm(x)
108
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
109
+ x = self.pwconv1(x)
110
+ x = self.act(x)
111
+ x = self.pwconv2(x)
112
+ if self.gamma is not None:
113
+ x = self.gamma * x
114
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
115
+
116
+ x = input + self.drop_path(x)
117
+ return x
118
+
119
+
120
+ class Fuser(nn.Module):
121
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
122
+ super().__init__()
123
+ self.proj = nn.Identity()
124
+ self.layers = get_clones(layer, num_layers)
125
+
126
+ if input_projection:
127
+ assert dim is not None
128
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
129
+
130
+ def forward(self, x):
131
+ # normally x: (N, C, H, W)
132
+ x = self.proj(x)
133
+ for layer in self.layers:
134
+ x = layer(x)
135
+ return x
136
+
137
+
138
+ class MemoryEncoder(nn.Module):
139
+ def __init__(
140
+ self,
141
+ out_dim,
142
+ mask_downsampler,
143
+ fuser,
144
+ position_encoding,
145
+ in_dim=256, # in_dim of pix_feats
146
+ ):
147
+ super().__init__()
148
+
149
+ self.mask_downsampler = mask_downsampler
150
+
151
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
152
+ self.fuser = fuser
153
+ self.position_encoding = position_encoding
154
+ self.out_proj = nn.Identity()
155
+ if out_dim != in_dim:
156
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
157
+
158
+ def forward(
159
+ self,
160
+ pix_feat: torch.Tensor,
161
+ masks: torch.Tensor,
162
+ skip_mask_sigmoid: bool = False,
163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ ## Process masks
165
+ # sigmoid, so that less domain shift from gt masks which are bool
166
+ if not skip_mask_sigmoid:
167
+ masks = F.sigmoid(masks)
168
+ masks = self.mask_downsampler(masks)
169
+
170
+ ## Fuse pix_feats and downsampled masks
171
+ # in case the visual features are on CPU, cast them to CUDA
172
+ pix_feat = pix_feat.to(masks.device)
173
+
174
+ x = self.pix_feat_proj(pix_feat)
175
+ x = x + masks
176
+ x = self.fuser(x)
177
+ x = self.out_proj(x)
178
+
179
+ pos = self.position_encoding(x).to(x.dtype)
180
+
181
+ return {"vision_features": x, "vision_pos_enc": [pos]}
sam2_repo/sam2/modeling/position_encoding.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, Optional, Tuple
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class PositionEmbeddingSine(nn.Module):
17
+ """
18
+ This is a more standard version of the position embedding, very similar to the one
19
+ used by the Attention Is All You Need paper, generalized to work on images.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ num_pos_feats,
25
+ temperature: int = 10000,
26
+ normalize: bool = True,
27
+ scale: Optional[float] = None,
28
+ # Following settings only relevant
29
+ # for warmping up cache for compilation
30
+ warmup_cache: bool = True,
31
+ image_size: int = 1024,
32
+ strides: Tuple[int] = (4, 8, 16, 32),
33
+ ):
34
+ super().__init__()
35
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
36
+ self.num_pos_feats = num_pos_feats // 2
37
+ self.temperature = temperature
38
+ self.normalize = normalize
39
+ if scale is not None and normalize is False:
40
+ raise ValueError("normalize should be True if scale is passed")
41
+ if scale is None:
42
+ scale = 2 * math.pi
43
+ self.scale = scale
44
+
45
+ self.cache = {}
46
+ if warmup_cache and torch.cuda.is_available():
47
+ # Warmup cache for cuda, to help with compilation
48
+ device = torch.device("cuda")
49
+ for stride in strides:
50
+ cache_key = (image_size // stride, image_size // stride)
51
+ self._pe(1, device, *cache_key)
52
+
53
+ def _encode_xy(self, x, y):
54
+ # The positions are expected to be normalized
55
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
56
+ x_embed = x * self.scale
57
+ y_embed = y * self.scale
58
+
59
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
60
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
61
+
62
+ pos_x = x_embed[:, None] / dim_t
63
+ pos_y = y_embed[:, None] / dim_t
64
+ pos_x = torch.stack(
65
+ (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
66
+ ).flatten(1)
67
+ pos_y = torch.stack(
68
+ (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
69
+ ).flatten(1)
70
+ return pos_x, pos_y
71
+
72
+ @torch.no_grad()
73
+ def encode_boxes(self, x, y, w, h):
74
+ pos_x, pos_y = self._encode_xy(x, y)
75
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
76
+ return pos
77
+
78
+ encode = encode_boxes # Backwards compatibility
79
+
80
+ @torch.no_grad()
81
+ def encode_points(self, x, y, labels):
82
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
83
+ assert bx == by and nx == ny and bx == bl and nx == nl
84
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
85
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
86
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
87
+ return pos
88
+
89
+ @torch.no_grad()
90
+ def _pe(self, B, device, *cache_key):
91
+ H, W = cache_key
92
+ if cache_key in self.cache:
93
+ return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
94
+
95
+ y_embed = (
96
+ torch.arange(1, H + 1, dtype=torch.float32, device=device)
97
+ .view(1, -1, 1)
98
+ .repeat(B, 1, W)
99
+ )
100
+ x_embed = (
101
+ torch.arange(1, W + 1, dtype=torch.float32, device=device)
102
+ .view(1, 1, -1)
103
+ .repeat(B, H, 1)
104
+ )
105
+
106
+ if self.normalize:
107
+ eps = 1e-6
108
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
109
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
110
+
111
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
112
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
113
+
114
+ pos_x = x_embed[:, :, :, None] / dim_t
115
+ pos_y = y_embed[:, :, :, None] / dim_t
116
+ pos_x = torch.stack(
117
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
118
+ ).flatten(3)
119
+ pos_y = torch.stack(
120
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
121
+ ).flatten(3)
122
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
123
+ self.cache[cache_key] = pos[0]
124
+ return pos
125
+
126
+ @torch.no_grad()
127
+ def forward(self, x: torch.Tensor):
128
+ B = x.shape[0]
129
+ cache_key = (x.shape[-2], x.shape[-1])
130
+ return self._pe(B, x.device, *cache_key)
131
+
132
+
133
+ class PositionEmbeddingRandom(nn.Module):
134
+ """
135
+ Positional encoding using random spatial frequencies.
136
+ """
137
+
138
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
139
+ super().__init__()
140
+ if scale is None or scale <= 0.0:
141
+ scale = 1.0
142
+ self.register_buffer(
143
+ "positional_encoding_gaussian_matrix",
144
+ scale * torch.randn((2, num_pos_feats)),
145
+ )
146
+
147
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
148
+ """Positionally encode points that are normalized to [0,1]."""
149
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
150
+ coords = 2 * coords - 1
151
+ coords = coords @ self.positional_encoding_gaussian_matrix
152
+ coords = 2 * np.pi * coords
153
+ # outputs d_1 x ... x d_n x C shape
154
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
155
+
156
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
157
+ """Generate positional encoding for a grid of the specified size."""
158
+ h, w = size
159
+ device: Any = self.positional_encoding_gaussian_matrix.device
160
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
161
+ y_embed = grid.cumsum(dim=0) - 0.5
162
+ x_embed = grid.cumsum(dim=1) - 0.5
163
+ y_embed = y_embed / h
164
+ x_embed = x_embed / w
165
+
166
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
167
+ return pe.permute(2, 0, 1) # C x H x W
168
+
169
+ def forward_with_coords(
170
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
171
+ ) -> torch.Tensor:
172
+ """Positionally encode points that are not normalized to [0,1]."""
173
+ coords = coords_input.clone()
174
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
175
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
176
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
177
+
178
+
179
+ # Rotary Positional Encoding, adapted from:
180
+ # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
181
+ # 2. https://github.com/naver-ai/rope-vit
182
+ # 3. https://github.com/lucidrains/rotary-embedding-torch
183
+
184
+
185
+ def init_t_xy(end_x: int, end_y: int):
186
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
187
+ t_x = (t % end_x).float()
188
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
189
+ return t_x, t_y
190
+
191
+
192
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
193
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
194
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
195
+
196
+ t_x, t_y = init_t_xy(end_x, end_y)
197
+ freqs_x = torch.outer(t_x, freqs_x)
198
+ freqs_y = torch.outer(t_y, freqs_y)
199
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
200
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
201
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
202
+
203
+
204
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
205
+ ndim = x.ndim
206
+ assert 0 <= 1 < ndim
207
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
208
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
209
+ return freqs_cis.view(*shape)
210
+
211
+
212
+ def apply_rotary_enc(
213
+ xq: torch.Tensor,
214
+ xk: torch.Tensor,
215
+ freqs_cis: torch.Tensor,
216
+ repeat_freqs_k: bool = False,
217
+ ):
218
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
219
+ xk_ = (
220
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
221
+ if xk.shape[-2] != 0
222
+ else None
223
+ )
224
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
225
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
226
+ if xk_ is None:
227
+ # no keys to rotate, due to dropout
228
+ return xq_out.type_as(xq).to(xq.device), xk
229
+ # repeat freqs along seq_len dim to match k seq_len
230
+ if repeat_freqs_k:
231
+ r = xk_.shape[-2] // xq_.shape[-2]
232
+ if freqs_cis.is_cuda:
233
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
234
+ else:
235
+ # torch.repeat on complex numbers may not be supported on non-CUDA devices
236
+ # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
237
+ freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
238
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
239
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
sam2_repo/sam2/modeling/sam/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sam2_repo/sam2/modeling/sam/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (169 Bytes). View file
 
sam2_repo/sam2/modeling/sam/__pycache__/mask_decoder.cpython-313.pyc ADDED
Binary file (12.4 kB). View file
 
sam2_repo/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-313.pyc ADDED
Binary file (9.69 kB). View file
 
sam2_repo/sam2/modeling/sam/__pycache__/transformer.cpython-313.pyc ADDED
Binary file (13.2 kB). View file
 
sam2_repo/sam2/modeling/sam/mask_decoder.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from sam2.modeling.sam2_utils import LayerNorm2d, MLP
13
+
14
+
15
+ class MaskDecoder(nn.Module):
16
+ def __init__(
17
+ self,
18
+ *,
19
+ transformer_dim: int,
20
+ transformer: nn.Module,
21
+ num_multimask_outputs: int = 3,
22
+ activation: Type[nn.Module] = nn.GELU,
23
+ iou_head_depth: int = 3,
24
+ iou_head_hidden_dim: int = 256,
25
+ use_high_res_features: bool = False,
26
+ iou_prediction_use_sigmoid=False,
27
+ dynamic_multimask_via_stability=False,
28
+ dynamic_multimask_stability_delta=0.05,
29
+ dynamic_multimask_stability_thresh=0.98,
30
+ pred_obj_scores: bool = False,
31
+ pred_obj_scores_mlp: bool = False,
32
+ use_multimask_token_for_obj_ptr: bool = False,
33
+ ) -> None:
34
+ """
35
+ Predicts masks given an image and prompt embeddings, using a
36
+ transformer architecture.
37
+
38
+ Arguments:
39
+ transformer_dim (int): the channel dimension of the transformer
40
+ transformer (nn.Module): the transformer used to predict masks
41
+ num_multimask_outputs (int): the number of masks to predict
42
+ when disambiguating masks
43
+ activation (nn.Module): the type of activation to use when
44
+ upscaling masks
45
+ iou_head_depth (int): the depth of the MLP used to predict
46
+ mask quality
47
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
48
+ used to predict mask quality
49
+ """
50
+ super().__init__()
51
+ self.transformer_dim = transformer_dim
52
+ self.transformer = transformer
53
+
54
+ self.num_multimask_outputs = num_multimask_outputs
55
+
56
+ self.iou_token = nn.Embedding(1, transformer_dim)
57
+ self.num_mask_tokens = num_multimask_outputs + 1
58
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
59
+
60
+ self.pred_obj_scores = pred_obj_scores
61
+ if self.pred_obj_scores:
62
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
63
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
64
+
65
+ self.output_upscaling = nn.Sequential(
66
+ nn.ConvTranspose2d(
67
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
68
+ ),
69
+ LayerNorm2d(transformer_dim // 4),
70
+ activation(),
71
+ nn.ConvTranspose2d(
72
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
73
+ ),
74
+ activation(),
75
+ )
76
+ self.use_high_res_features = use_high_res_features
77
+ if use_high_res_features:
78
+ self.conv_s0 = nn.Conv2d(
79
+ transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
80
+ )
81
+ self.conv_s1 = nn.Conv2d(
82
+ transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
83
+ )
84
+
85
+ self.output_hypernetworks_mlps = nn.ModuleList(
86
+ [
87
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
88
+ for i in range(self.num_mask_tokens)
89
+ ]
90
+ )
91
+
92
+ self.iou_prediction_head = MLP(
93
+ transformer_dim,
94
+ iou_head_hidden_dim,
95
+ self.num_mask_tokens,
96
+ iou_head_depth,
97
+ sigmoid_output=iou_prediction_use_sigmoid,
98
+ )
99
+ if self.pred_obj_scores:
100
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
101
+ if pred_obj_scores_mlp:
102
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
103
+
104
+ # When outputting a single mask, optionally we can dynamically fall back to the best
105
+ # multimask output token if the single mask output token gives low stability scores.
106
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
107
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
108
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
109
+
110
+ def forward(
111
+ self,
112
+ image_embeddings: torch.Tensor,
113
+ image_pe: torch.Tensor,
114
+ sparse_prompt_embeddings: torch.Tensor,
115
+ dense_prompt_embeddings: torch.Tensor,
116
+ multimask_output: bool,
117
+ repeat_image: bool,
118
+ high_res_features: Optional[List[torch.Tensor]] = None,
119
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ """
121
+ Predict masks given image and prompt embeddings.
122
+
123
+ Arguments:
124
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
125
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
126
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
127
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
128
+ multimask_output (bool): Whether to return multiple masks or a single
129
+ mask.
130
+
131
+ Returns:
132
+ torch.Tensor: batched predicted masks
133
+ torch.Tensor: batched predictions of mask quality
134
+ torch.Tensor: batched SAM token for mask output
135
+ """
136
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
137
+ image_embeddings=image_embeddings,
138
+ image_pe=image_pe,
139
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
140
+ dense_prompt_embeddings=dense_prompt_embeddings,
141
+ repeat_image=repeat_image,
142
+ high_res_features=high_res_features,
143
+ )
144
+
145
+ # Select the correct mask or masks for output
146
+ if multimask_output:
147
+ masks = masks[:, 1:, :, :]
148
+ iou_pred = iou_pred[:, 1:]
149
+ elif self.dynamic_multimask_via_stability and not self.training:
150
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
151
+ else:
152
+ masks = masks[:, 0:1, :, :]
153
+ iou_pred = iou_pred[:, 0:1]
154
+
155
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
156
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
157
+ else:
158
+ # Take the mask output token. Here we *always* use the token for single mask output.
159
+ # At test time, even if we track after 1-click (and using multimask_output=True),
160
+ # we still take the single mask token here. The rationale is that we always track
161
+ # after multiple clicks during training, so the past tokens seen during training
162
+ # are always the single mask token (and we'll let it be the object-memory token).
163
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
164
+
165
+ # Prepare output
166
+ return masks, iou_pred, sam_tokens_out, object_score_logits
167
+
168
+ def predict_masks(
169
+ self,
170
+ image_embeddings: torch.Tensor,
171
+ image_pe: torch.Tensor,
172
+ sparse_prompt_embeddings: torch.Tensor,
173
+ dense_prompt_embeddings: torch.Tensor,
174
+ repeat_image: bool,
175
+ high_res_features: Optional[List[torch.Tensor]] = None,
176
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
177
+ """Predicts masks. See 'forward' for more details."""
178
+ # Concatenate output tokens
179
+ s = 0
180
+ if self.pred_obj_scores:
181
+ output_tokens = torch.cat(
182
+ [
183
+ self.obj_score_token.weight,
184
+ self.iou_token.weight,
185
+ self.mask_tokens.weight,
186
+ ],
187
+ dim=0,
188
+ )
189
+ s = 1
190
+ else:
191
+ output_tokens = torch.cat(
192
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
193
+ )
194
+ output_tokens = output_tokens.unsqueeze(0).expand(
195
+ sparse_prompt_embeddings.size(0), -1, -1
196
+ )
197
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
198
+
199
+ # Expand per-image data in batch direction to be per-mask
200
+ if repeat_image:
201
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
202
+ else:
203
+ assert image_embeddings.shape[0] == tokens.shape[0]
204
+ src = image_embeddings
205
+ src = src + dense_prompt_embeddings
206
+ assert (
207
+ image_pe.size(0) == 1
208
+ ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
209
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
210
+ b, c, h, w = src.shape
211
+
212
+ # Run the transformer
213
+ hs, src = self.transformer(src, pos_src, tokens)
214
+ iou_token_out = hs[:, s, :]
215
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
216
+
217
+ # Upscale mask embeddings and predict masks using the mask tokens
218
+ src = src.transpose(1, 2).view(b, c, h, w)
219
+ if not self.use_high_res_features:
220
+ upscaled_embedding = self.output_upscaling(src)
221
+ else:
222
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
223
+ feat_s0, feat_s1 = high_res_features
224
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
225
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
226
+
227
+ hyper_in_list: List[torch.Tensor] = []
228
+ for i in range(self.num_mask_tokens):
229
+ hyper_in_list.append(
230
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
231
+ )
232
+ hyper_in = torch.stack(hyper_in_list, dim=1)
233
+ b, c, h, w = upscaled_embedding.shape
234
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
235
+
236
+ # Generate mask quality predictions
237
+ iou_pred = self.iou_prediction_head(iou_token_out)
238
+ if self.pred_obj_scores:
239
+ assert s == 1
240
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
241
+ else:
242
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
243
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
244
+
245
+ return masks, iou_pred, mask_tokens_out, object_score_logits
246
+
247
+ def _get_stability_scores(self, mask_logits):
248
+ """
249
+ Compute stability scores of the mask logits based on the IoU between upper and
250
+ lower thresholds.
251
+ """
252
+ mask_logits = mask_logits.flatten(-2)
253
+ stability_delta = self.dynamic_multimask_stability_delta
254
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
255
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
256
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
257
+ return stability_scores
258
+
259
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
260
+ """
261
+ When outputting a single mask, if the stability score from the current single-mask
262
+ output (based on output token 0) falls below a threshold, we instead select from
263
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
264
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
265
+ """
266
+ # The best mask from multimask output tokens (1~3)
267
+ multimask_logits = all_mask_logits[:, 1:, :, :]
268
+ multimask_iou_scores = all_iou_scores[:, 1:]
269
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
270
+ batch_inds = torch.arange(
271
+ multimask_iou_scores.size(0), device=all_iou_scores.device
272
+ )
273
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
274
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
275
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
276
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
277
+
278
+ # The mask from singlemask output token 0 and its stability score
279
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
280
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
281
+ stability_scores = self._get_stability_scores(singlemask_logits)
282
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
283
+
284
+ # Dynamically fall back to best multimask output upon low stability scores.
285
+ mask_logits_out = torch.where(
286
+ is_stable[..., None, None].expand_as(singlemask_logits),
287
+ singlemask_logits,
288
+ best_multimask_logits,
289
+ )
290
+ iou_scores_out = torch.where(
291
+ is_stable.expand_as(singlemask_iou_scores),
292
+ singlemask_iou_scores,
293
+ best_multimask_iou_scores,
294
+ )
295
+ return mask_logits_out, iou_scores_out
sam2_repo/sam2/modeling/sam/prompt_encoder.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from sam2.modeling.position_encoding import PositionEmbeddingRandom
13
+
14
+ from sam2.modeling.sam2_utils import LayerNorm2d
15
+
16
+
17
+ class PromptEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ embed_dim: int,
21
+ image_embedding_size: Tuple[int, int],
22
+ input_image_size: Tuple[int, int],
23
+ mask_in_chans: int,
24
+ activation: Type[nn.Module] = nn.GELU,
25
+ ) -> None:
26
+ """
27
+ Encodes prompts for input to SAM's mask decoder.
28
+
29
+ Arguments:
30
+ embed_dim (int): The prompts' embedding dimension
31
+ image_embedding_size (tuple(int, int)): The spatial size of the
32
+ image embedding, as (H, W).
33
+ input_image_size (int): The padded size of the image as input
34
+ to the image encoder, as (H, W).
35
+ mask_in_chans (int): The number of hidden channels used for
36
+ encoding input masks.
37
+ activation (nn.Module): The activation to use when encoding
38
+ input masks.
39
+ """
40
+ super().__init__()
41
+ self.embed_dim = embed_dim
42
+ self.input_image_size = input_image_size
43
+ self.image_embedding_size = image_embedding_size
44
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
45
+
46
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
47
+ point_embeddings = [
48
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
49
+ ]
50
+ self.point_embeddings = nn.ModuleList(point_embeddings)
51
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
52
+
53
+ self.mask_input_size = (
54
+ 4 * image_embedding_size[0],
55
+ 4 * image_embedding_size[1],
56
+ )
57
+ self.mask_downscaling = nn.Sequential(
58
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
59
+ LayerNorm2d(mask_in_chans // 4),
60
+ activation(),
61
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
62
+ LayerNorm2d(mask_in_chans),
63
+ activation(),
64
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
65
+ )
66
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
67
+
68
+ def get_dense_pe(self) -> torch.Tensor:
69
+ """
70
+ Returns the positional encoding used to encode point prompts,
71
+ applied to a dense set of points the shape of the image encoding.
72
+
73
+ Returns:
74
+ torch.Tensor: Positional encoding with shape
75
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
76
+ """
77
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
78
+
79
+ def _embed_points(
80
+ self,
81
+ points: torch.Tensor,
82
+ labels: torch.Tensor,
83
+ pad: bool,
84
+ ) -> torch.Tensor:
85
+ """Embeds point prompts."""
86
+ points = points + 0.5 # Shift to center of pixel
87
+ if pad:
88
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
89
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
90
+ points = torch.cat([points, padding_point], dim=1)
91
+ labels = torch.cat([labels, padding_label], dim=1)
92
+ point_embedding = self.pe_layer.forward_with_coords(
93
+ points, self.input_image_size
94
+ )
95
+
96
+ point_embedding = torch.where(
97
+ (labels == -1).unsqueeze(-1),
98
+ torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
99
+ point_embedding,
100
+ )
101
+ point_embedding = torch.where(
102
+ (labels == 0).unsqueeze(-1),
103
+ point_embedding + self.point_embeddings[0].weight,
104
+ point_embedding,
105
+ )
106
+ point_embedding = torch.where(
107
+ (labels == 1).unsqueeze(-1),
108
+ point_embedding + self.point_embeddings[1].weight,
109
+ point_embedding,
110
+ )
111
+ point_embedding = torch.where(
112
+ (labels == 2).unsqueeze(-1),
113
+ point_embedding + self.point_embeddings[2].weight,
114
+ point_embedding,
115
+ )
116
+ point_embedding = torch.where(
117
+ (labels == 3).unsqueeze(-1),
118
+ point_embedding + self.point_embeddings[3].weight,
119
+ point_embedding,
120
+ )
121
+ return point_embedding
122
+
123
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
124
+ """Embeds box prompts."""
125
+ boxes = boxes + 0.5 # Shift to center of pixel
126
+ coords = boxes.reshape(-1, 2, 2)
127
+ corner_embedding = self.pe_layer.forward_with_coords(
128
+ coords, self.input_image_size
129
+ )
130
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
131
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
132
+ return corner_embedding
133
+
134
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
135
+ """Embeds mask inputs."""
136
+ mask_embedding = self.mask_downscaling(masks)
137
+ return mask_embedding
138
+
139
+ def _get_batch_size(
140
+ self,
141
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
142
+ boxes: Optional[torch.Tensor],
143
+ masks: Optional[torch.Tensor],
144
+ ) -> int:
145
+ """
146
+ Gets the batch size of the output given the batch size of the input prompts.
147
+ """
148
+ if points is not None:
149
+ return points[0].shape[0]
150
+ elif boxes is not None:
151
+ return boxes.shape[0]
152
+ elif masks is not None:
153
+ return masks.shape[0]
154
+ else:
155
+ return 1
156
+
157
+ def _get_device(self) -> torch.device:
158
+ return self.point_embeddings[0].weight.device
159
+
160
+ def forward(
161
+ self,
162
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
163
+ boxes: Optional[torch.Tensor],
164
+ masks: Optional[torch.Tensor],
165
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
166
+ """
167
+ Embeds different types of prompts, returning both sparse and dense
168
+ embeddings.
169
+
170
+ Arguments:
171
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
172
+ and labels to embed.
173
+ boxes (torch.Tensor or none): boxes to embed
174
+ masks (torch.Tensor or none): masks to embed
175
+
176
+ Returns:
177
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
178
+ BxNx(embed_dim), where N is determined by the number of input points
179
+ and boxes.
180
+ torch.Tensor: dense embeddings for the masks, in the shape
181
+ Bx(embed_dim)x(embed_H)x(embed_W)
182
+ """
183
+ bs = self._get_batch_size(points, boxes, masks)
184
+ sparse_embeddings = torch.empty(
185
+ (bs, 0, self.embed_dim), device=self._get_device()
186
+ )
187
+ if points is not None:
188
+ coords, labels = points
189
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
190
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
191
+ if boxes is not None:
192
+ box_embeddings = self._embed_boxes(boxes)
193
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
194
+
195
+ if masks is not None:
196
+ dense_embeddings = self._embed_masks(masks)
197
+ else:
198
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
199
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
200
+ )
201
+
202
+ return sparse_embeddings, dense_embeddings