Spaces:
Sleeping
Sleeping
gbreadman13code
commited on
Commit
·
4f2b4bb
1
Parent(s):
4518f25
Deploy SAM2 segmentation API
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +49 -0
- README.md +167 -4
- app.py +1109 -0
- download_model.py +80 -0
- requirements.txt +10 -0
- sam2_repo/README.md +224 -0
- sam2_repo/checkpoints/download_ckpts.sh +59 -0
- sam2_repo/pyproject.toml +6 -0
- sam2_repo/sam2/__init__.py +11 -0
- sam2_repo/sam2/__pycache__/__init__.cpython-313.pyc +0 -0
- sam2_repo/sam2/__pycache__/build_sam.cpython-313.pyc +0 -0
- sam2_repo/sam2/__pycache__/sam2_image_predictor.cpython-313.pyc +0 -0
- sam2_repo/sam2/automatic_mask_generator.py +454 -0
- sam2_repo/sam2/benchmark.py +92 -0
- sam2_repo/sam2/build_sam.py +174 -0
- sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2_repo/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2_repo/sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2_repo/sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2_repo/sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2_repo/sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2_repo/sam2/csrc/connected_components.cu +289 -0
- sam2_repo/sam2/modeling/__init__.py +5 -0
- sam2_repo/sam2/modeling/__pycache__/__init__.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/__pycache__/memory_attention.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/__pycache__/memory_encoder.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/__pycache__/position_encoding.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/__pycache__/sam2_base.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/__pycache__/sam2_utils.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/backbones/__init__.py +5 -0
- sam2_repo/sam2/modeling/backbones/__pycache__/__init__.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/backbones/__pycache__/hieradet.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/backbones/__pycache__/image_encoder.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/backbones/__pycache__/utils.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/backbones/hieradet.py +317 -0
- sam2_repo/sam2/modeling/backbones/image_encoder.py +134 -0
- sam2_repo/sam2/modeling/backbones/utils.py +93 -0
- sam2_repo/sam2/modeling/memory_attention.py +169 -0
- sam2_repo/sam2/modeling/memory_encoder.py +181 -0
- sam2_repo/sam2/modeling/position_encoding.py +239 -0
- sam2_repo/sam2/modeling/sam/__init__.py +5 -0
- sam2_repo/sam2/modeling/sam/__pycache__/__init__.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/sam/__pycache__/mask_decoder.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/sam/__pycache__/transformer.cpython-313.pyc +0 -0
- sam2_repo/sam2/modeling/sam/mask_decoder.py +295 -0
- sam2_repo/sam2/modeling/sam/prompt_encoder.py +202 -0
Dockerfile
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile для Hugging Face Spaces
|
| 2 |
+
# Оптимизирован для CPU
|
| 3 |
+
|
| 4 |
+
FROM python:3.10-slim
|
| 5 |
+
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
# Системные зависимости для OpenCV и SAM2
|
| 9 |
+
RUN apt-get update && apt-get install -y \
|
| 10 |
+
git \
|
| 11 |
+
wget \
|
| 12 |
+
build-essential \
|
| 13 |
+
libglib2.0-0 \
|
| 14 |
+
libsm6 \
|
| 15 |
+
libxext6 \
|
| 16 |
+
libxrender-dev \
|
| 17 |
+
libgomp1 \
|
| 18 |
+
libgl1-mesa-glx \
|
| 19 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 20 |
+
|
| 21 |
+
# Копируем requirements
|
| 22 |
+
COPY requirements.txt .
|
| 23 |
+
|
| 24 |
+
# Устанавливаем Python зависимости
|
| 25 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 26 |
+
|
| 27 |
+
# Копируем код приложения
|
| 28 |
+
COPY app.py .
|
| 29 |
+
COPY download_model.py .
|
| 30 |
+
COPY web_demo.html .
|
| 31 |
+
COPY web_demo_advanced.html .
|
| 32 |
+
|
| 33 |
+
# Копируем и устанавливаем SAM2
|
| 34 |
+
COPY sam2_repo sam2_repo
|
| 35 |
+
RUN cd sam2_repo && pip install --no-cache-dir -e .
|
| 36 |
+
|
| 37 |
+
# Создаем папку для моделей
|
| 38 |
+
RUN mkdir -p checkpoints
|
| 39 |
+
|
| 40 |
+
# Скачиваем tiny модель (самая легкая для CPU)
|
| 41 |
+
RUN python download_model.py tiny
|
| 42 |
+
|
| 43 |
+
# Hugging Face Spaces использует порт 7860
|
| 44 |
+
ENV PORT=7860
|
| 45 |
+
EXPOSE 7860
|
| 46 |
+
|
| 47 |
+
# Запуск с указанием хоста и порта
|
| 48 |
+
CMD ["sh", "-c", "python -c 'import uvicorn; uvicorn.run(\"app:app\", host=\"0.0.0.0\", port=${PORT})'"]
|
| 49 |
+
|
README.md
CHANGED
|
@@ -1,11 +1,174 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: purple
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
| 8 |
license: apache-2.0
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SAM2 Segmentation API
|
| 3 |
+
emoji: 🎯
|
| 4 |
colorFrom: purple
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
license: apache-2.0
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# 🎯 SAM2 Segmentation API
|
| 13 |
+
|
| 14 |
+
Мощный REST API для сегментации объектов на изображениях с использованием Meta SAM2 (Segment Anything Model 2).
|
| 15 |
+
|
| 16 |
+
## ✨ Возможности
|
| 17 |
+
|
| 18 |
+
- **🎯 Box Prompts** - выделение прямоугольником
|
| 19 |
+
- **🖌️ Brush Prompts** - рисование кистью (зеленый = объект, красный = фон, белый = объект)
|
| 20 |
+
- **📍 Point Prompts** - клики по объектам
|
| 21 |
+
- **🔥 Batch API** - обработка множественных объектов за один запрос
|
| 22 |
+
- **🖼️ Extract Objects** - автоматическое извлечение объектов с прозрачностью
|
| 23 |
+
- **⚡ REST API** - полная документация в Swagger UI
|
| 24 |
+
|
| 25 |
+
## 🚀 Быстрый старт
|
| 26 |
+
|
| 27 |
+
### Web интерфейс
|
| 28 |
+
|
| 29 |
+
После запуска Space откройте:
|
| 30 |
+
|
| 31 |
+
- **Простой интерфейс**: `/web` - Box промпты
|
| 32 |
+
- **Продвинутый**: `/web/advanced` - Box + Brush промпты
|
| 33 |
+
- **API документация**: `/docs` - Swagger UI
|
| 34 |
+
|
| 35 |
+
### API Endpoints
|
| 36 |
+
|
| 37 |
+
#### POST `/segment/batch` - Батчинг API (рекомендуется)
|
| 38 |
+
|
| 39 |
+
Обрабатывает множественные объекты за один запрос.
|
| 40 |
+
|
| 41 |
+
**Пример запроса:**
|
| 42 |
+
```json
|
| 43 |
+
{
|
| 44 |
+
"image": "data:image/jpeg;base64,...",
|
| 45 |
+
"prompts": [
|
| 46 |
+
{
|
| 47 |
+
"id": 0,
|
| 48 |
+
"type": "mask",
|
| 49 |
+
"data": "data:image/png;base64,...",
|
| 50 |
+
"label": "person",
|
| 51 |
+
"selected": true
|
| 52 |
+
}
|
| 53 |
+
],
|
| 54 |
+
"options": {
|
| 55 |
+
"extract_objects": true,
|
| 56 |
+
"include_masks": false,
|
| 57 |
+
"clean_masks": true
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
#### POST `/segment` - Простая сегментация
|
| 63 |
+
|
| 64 |
+
С box промптом:
|
| 65 |
+
```bash
|
| 66 |
+
curl -X POST "/segment?box_x1=50&box_y1=50&box_x2=300&box_y2=400&extract_objects=true" \
|
| 67 |
+
-F "[email protected]"
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## 📊 Производительность
|
| 71 |
+
|
| 72 |
+
⚠️ **CPU Version**: Работает на бесплатном CPU tier Hugging Face Spaces. Скорость обработки: ~5-10 секунд на изображение.
|
| 73 |
+
|
| 74 |
+
Для более быстрой обработки рекомендуется upgrade на GPU (Settings → Hardware).
|
| 75 |
+
|
| 76 |
+
## 🎨 Форматы масок
|
| 77 |
+
|
| 78 |
+
API поддерживает несколько форматов масок:
|
| 79 |
+
|
| 80 |
+
- **🟢 Зеленый** (R<100, G>150, B<100) - foreground (объект)
|
| 81 |
+
- **⚪ Белый** (R>200, G>200, B>200) - foreground (объект)
|
| 82 |
+
- **🔴 Красный** (R>150, G<100, B<100) - background (исключить)
|
| 83 |
+
|
| 84 |
+
## 🔧 Технологии
|
| 85 |
+
|
| 86 |
+
- Meta SAM2 2.1 (Segment Anything Model)
|
| 87 |
+
- FastAPI
|
| 88 |
+
- PyTorch
|
| 89 |
+
- OpenCV
|
| 90 |
+
- Pydantic
|
| 91 |
+
|
| 92 |
+
## 📝 Примеры использования
|
| 93 |
+
|
| 94 |
+
### Python
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
import requests
|
| 98 |
+
import base64
|
| 99 |
+
|
| 100 |
+
# Загрузить изображение
|
| 101 |
+
with open("image.jpg", "rb") as f:
|
| 102 |
+
image_b64 = base64.b64encode(f.read()).decode()
|
| 103 |
+
|
| 104 |
+
# Отправить запрос
|
| 105 |
+
response = requests.post(
|
| 106 |
+
"https://YOUR-SPACE.hf.space/segment/batch",
|
| 107 |
+
json={
|
| 108 |
+
"image": f"data:image/jpeg;base64,{image_b64}",
|
| 109 |
+
"prompts": [{
|
| 110 |
+
"id": 0,
|
| 111 |
+
"type": "box",
|
| 112 |
+
"data": "",
|
| 113 |
+
"bbox": {"x_min": 0.1, "y_min": 0.2, "x_max": 0.5, "y_max": 0.8},
|
| 114 |
+
"label": "person",
|
| 115 |
+
"selected": True
|
| 116 |
+
}],
|
| 117 |
+
"options": {"extract_objects": True}
|
| 118 |
+
}
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
result = response.json()
|
| 122 |
+
print(f"Обработано объектов: {len(result['results'])}")
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
### JavaScript
|
| 126 |
+
|
| 127 |
+
```javascript
|
| 128 |
+
const response = await fetch('https://YOUR-SPACE.hf.space/segment/batch', {
|
| 129 |
+
method: 'POST',
|
| 130 |
+
headers: {'Content-Type': 'application/json'},
|
| 131 |
+
body: JSON.stringify({
|
| 132 |
+
image: imageBase64,
|
| 133 |
+
prompts: [{
|
| 134 |
+
id: 0,
|
| 135 |
+
type: "box",
|
| 136 |
+
data: "",
|
| 137 |
+
bbox: {x_min: 0.1, y_min: 0.2, x_max: 0.5, y_max: 0.8},
|
| 138 |
+
label: "person",
|
| 139 |
+
selected: true
|
| 140 |
+
}],
|
| 141 |
+
options: {extract_objects: true}
|
| 142 |
+
})
|
| 143 |
+
});
|
| 144 |
+
|
| 145 |
+
const result = await response.json();
|
| 146 |
+
console.log(`Обработано: ${result.results.length} объектов`);
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
## 📚 Документация
|
| 150 |
+
|
| 151 |
+
Полная интерактивная документация доступна по адресу `/docs` после запуска Space.
|
| 152 |
+
|
| 153 |
+
## 🤝 Поддержка
|
| 154 |
+
|
| 155 |
+
- Модель: SAM 2.1 Hiera Tiny (для CPU)
|
| 156 |
+
- Форматы изображений: JPG, PNG, WEBP, BMP
|
| 157 |
+
- Максимальный размер: рекомендуется до 2048x2048px для разумной скорости
|
| 158 |
+
|
| 159 |
+
## ⚡ Оптимизация для мобильных приложений
|
| 160 |
+
|
| 161 |
+
1. Уменьшайте размер изоб��ажения перед отправкой (1024x1024)
|
| 162 |
+
2. Используйте `include_masks: false` если контуры не нужны
|
| 163 |
+
3. Кэшируйте результаты на клиенте
|
| 164 |
+
4. Используйте батчинг API для множественных объектов
|
| 165 |
+
|
| 166 |
+
## 📄 Лицензия
|
| 167 |
+
|
| 168 |
+
Apache 2.0
|
| 169 |
+
|
| 170 |
+
## 🔗 Ссылки
|
| 171 |
+
|
| 172 |
+
- [SAM2 GitHub](https://github.com/facebookresearch/sam2)
|
| 173 |
+
- [SAM2 Paper](https://arxiv.org/abs/2408.00714)
|
| 174 |
+
|
app.py
ADDED
|
@@ -0,0 +1,1109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
REST API сервер для сегментации изображений через SAM2.
|
| 3 |
+
Уставший сеньор кодит это в 3 часа ночи, поэтому код местами будет грязный.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from contextlib import asynccontextmanager
|
| 7 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException, Query, Body
|
| 8 |
+
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import io
|
| 15 |
+
import os
|
| 16 |
+
import base64
|
| 17 |
+
import cv2
|
| 18 |
+
from typing import List, Dict, Any, Optional, Literal
|
| 19 |
+
import logging
|
| 20 |
+
from datetime import datetime
|
| 21 |
+
import json
|
| 22 |
+
|
| 23 |
+
# Настройка логирования, потому что дебажить это говно иначе невозможно
|
| 24 |
+
logging.basicConfig(level=logging.INFO)
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
# Глобальные переменные для модели (лень каждый раз загружать)
|
| 28 |
+
predictor = None
|
| 29 |
+
device = None
|
| 30 |
+
|
| 31 |
+
# ===== Pydantic модели для батчинг API =====
|
| 32 |
+
|
| 33 |
+
class BBoxModel(BaseModel):
|
| 34 |
+
"""Bounding box в нормализованных координатах (0.0 - 1.0) или пиксельных"""
|
| 35 |
+
x_min: float = Field(..., description="X координата левого верхнего угла")
|
| 36 |
+
y_min: float = Field(..., description="Y координата левого верхнего угла")
|
| 37 |
+
x_max: float = Field(..., description="X координата правого нижнего угла")
|
| 38 |
+
y_max: float = Field(..., description="Y координата правого нижнего угла")
|
| 39 |
+
|
| 40 |
+
class PromptModel(BaseModel):
|
| 41 |
+
"""Промпт для сегментации одного объекта"""
|
| 42 |
+
id: int = Field(..., description="Уникальный ID объекта")
|
| 43 |
+
type: Literal["mask", "box", "points"] = Field(..., description="Тип промпта")
|
| 44 |
+
data: str = Field(..., description="Данные промпта (base64 для mask, JSON для points)")
|
| 45 |
+
bbox: Optional[BBoxModel] = Field(None, description="Опциональный bounding box")
|
| 46 |
+
label: Optional[str] = Field(None, description="Метка объекта (person, car, etc)")
|
| 47 |
+
selected: bool = Field(True, description="Обрабатывать ли этот промпт")
|
| 48 |
+
|
| 49 |
+
class SegmentOptionsModel(BaseModel):
|
| 50 |
+
"""Опции сегментации"""
|
| 51 |
+
extract_objects: bool = Field(True, description="Вернуть вырезанные объекты")
|
| 52 |
+
include_masks: bool = Field(False, description="Включить контуры масок")
|
| 53 |
+
clean_masks: bool = Field(True, description="Очистить маски от артефактов")
|
| 54 |
+
|
| 55 |
+
class BatchSegmentRequest(BaseModel):
|
| 56 |
+
"""Запрос на батчинг сегментацию"""
|
| 57 |
+
image: str = Field(..., description="Изображение в base64 (с data URL или без)")
|
| 58 |
+
prompts: List[PromptModel] = Field(..., description="Массив промптов")
|
| 59 |
+
options: Optional[SegmentOptionsModel] = Field(default_factory=SegmentOptionsModel)
|
| 60 |
+
|
| 61 |
+
class SegmentResultModel(BaseModel):
|
| 62 |
+
"""Результат сегментации одного объекта"""
|
| 63 |
+
id: int
|
| 64 |
+
label: Optional[str] = None
|
| 65 |
+
bbox: Dict[str, Any]
|
| 66 |
+
area: int
|
| 67 |
+
center: Dict[str, int]
|
| 68 |
+
confidence: float
|
| 69 |
+
extracted_image: Optional[str] = None
|
| 70 |
+
contours: Optional[List[Dict[str, Any]]] = None
|
| 71 |
+
mask_rle: Optional[Dict[str, Any]] = None
|
| 72 |
+
|
| 73 |
+
class BatchSegmentResponse(BaseModel):
|
| 74 |
+
"""Ответ батчинг сегментации"""
|
| 75 |
+
success: bool
|
| 76 |
+
image_size: Dict[str, int]
|
| 77 |
+
results: List[SegmentResultModel]
|
| 78 |
+
|
| 79 |
+
def save_batch_request_log(request_data: dict, response_data: dict, image_width: int, image_height: int):
|
| 80 |
+
"""
|
| 81 |
+
Сохраняет запрос батчинга для аудита и дебага.
|
| 82 |
+
Создает папку с timestamp и сохраняет только метаданные:
|
| 83 |
+
1. Лог запроса (request.json) - параметры без base64
|
| 84 |
+
2. Лог ответа (response.json) - результаты без base64
|
| 85 |
+
3. Краткую сводку (summary.json)
|
| 86 |
+
|
| 87 |
+
⚠️ Изображения и маски НЕ сохраняются для безопасности!
|
| 88 |
+
"""
|
| 89 |
+
try:
|
| 90 |
+
# Создаем корневую папку для логов
|
| 91 |
+
logs_dir = "batch_logs"
|
| 92 |
+
os.makedirs(logs_dir, exist_ok=True)
|
| 93 |
+
|
| 94 |
+
# Создаем папку с timestamp
|
| 95 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Миллисекунды
|
| 96 |
+
request_dir = os.path.join(logs_dir, timestamp)
|
| 97 |
+
os.makedirs(request_dir, exist_ok=True)
|
| 98 |
+
|
| 99 |
+
logger.info(f"📁 Сохраняю лог запроса в: {request_dir}")
|
| 100 |
+
|
| 101 |
+
# Сохраняем запрос (без base64 для безопасности)
|
| 102 |
+
request_log = {
|
| 103 |
+
"timestamp": timestamp,
|
| 104 |
+
"image_size": {
|
| 105 |
+
"width": image_width,
|
| 106 |
+
"height": image_height
|
| 107 |
+
},
|
| 108 |
+
"prompts": [
|
| 109 |
+
{
|
| 110 |
+
"id": p.get("id"),
|
| 111 |
+
"type": p.get("type"),
|
| 112 |
+
"label": p.get("label"),
|
| 113 |
+
"bbox": p.get("bbox"),
|
| 114 |
+
"selected": p.get("selected"),
|
| 115 |
+
"data_length": len(p.get("data", "")) # Длина вместо самих данных
|
| 116 |
+
}
|
| 117 |
+
for p in request_data.get("prompts", [])
|
| 118 |
+
],
|
| 119 |
+
"options": request_data.get("options", {})
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
request_path = os.path.join(request_dir, "request.json")
|
| 123 |
+
with open(request_path, "w", encoding="utf-8") as f:
|
| 124 |
+
json.dump(request_log, f, indent=2, ensure_ascii=False)
|
| 125 |
+
logger.info(f" ✓ Сохранен лог запроса: {request_path}")
|
| 126 |
+
|
| 127 |
+
# 4. Сохраняем ответ (без base64 объектов)
|
| 128 |
+
response_log = {
|
| 129 |
+
"timestamp": timestamp,
|
| 130 |
+
"success": response_data.get("success"),
|
| 131 |
+
"image_size": response_data.get("image_size"),
|
| 132 |
+
"results": [
|
| 133 |
+
{
|
| 134 |
+
"id": r.get("id"),
|
| 135 |
+
"label": r.get("label"),
|
| 136 |
+
"bbox": r.get("bbox"),
|
| 137 |
+
"area": r.get("area"),
|
| 138 |
+
"center": r.get("center"),
|
| 139 |
+
"confidence": r.get("confidence"),
|
| 140 |
+
"has_extracted_image": "extracted_image" in r,
|
| 141 |
+
"has_contours": "contours" in r
|
| 142 |
+
}
|
| 143 |
+
for r in response_data.get("results", [])
|
| 144 |
+
]
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
response_path = os.path.join(request_dir, "response.json")
|
| 148 |
+
with open(response_path, "w", encoding="utf-8") as f:
|
| 149 |
+
json.dump(response_log, f, indent=2, ensure_ascii=False)
|
| 150 |
+
logger.info(f" ✓ Сохранен лог ответа: {response_path}")
|
| 151 |
+
|
| 152 |
+
# 3. Создаем summary файл
|
| 153 |
+
summary = {
|
| 154 |
+
"timestamp": timestamp,
|
| 155 |
+
"processed_prompts": len(response_data.get("results", [])),
|
| 156 |
+
"total_prompts": len(request_data.get("prompts", [])),
|
| 157 |
+
"selected_prompts": len([p for p in request_data.get("prompts", []) if p.get("selected", True)]),
|
| 158 |
+
"image_size": f"{image_width}x{image_height}",
|
| 159 |
+
"prompt_types": [p.get("type") for p in request_data.get("prompts", [])],
|
| 160 |
+
"files": {
|
| 161 |
+
"request": "request.json",
|
| 162 |
+
"response": "response.json"
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
summary_path = os.path.join(request_dir, "summary.json")
|
| 167 |
+
with open(summary_path, "w", encoding="utf-8") as f:
|
| 168 |
+
json.dump(summary, f, indent=2, ensure_ascii=False)
|
| 169 |
+
|
| 170 |
+
logger.info(f"✅ Лог запроса сохранен: {request_dir}")
|
| 171 |
+
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.error(f"❌ Ошибка при сохранении лога: {e}")
|
| 174 |
+
# Не прерываем обработку запроса если не удалось сохранить лог
|
| 175 |
+
|
| 176 |
+
def load_model(checkpoint_path: str = "checkpoints/sam2.1_hiera_tiny.pt"):
|
| 177 |
+
"""
|
| 178 |
+
Загружает модель SAM2.
|
| 179 |
+
Вызывается один раз при старте сервера.
|
| 180 |
+
"""
|
| 181 |
+
global predictor, device
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
from sam2.build_sam import build_sam2
|
| 185 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 186 |
+
|
| 187 |
+
# Проверяем CUDA
|
| 188 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 189 |
+
logger.info(f"Используем устройство: {device}")
|
| 190 |
+
|
| 191 |
+
if device == "cpu":
|
| 192 |
+
logger.warning("CUDA недоступна, работаем на CPU (будет медленно как черепаха)")
|
| 193 |
+
|
| 194 |
+
# Определяем конфиг по имени файла чекпоинта
|
| 195 |
+
# Указываем путь относительно configs/ директории в пакете sam2
|
| 196 |
+
checkpoint_name = os.path.basename(checkpoint_path)
|
| 197 |
+
if "tiny" in checkpoint_name:
|
| 198 |
+
config = "configs/sam2.1/sam2.1_hiera_t.yaml"
|
| 199 |
+
elif "small" in checkpoint_name:
|
| 200 |
+
config = "configs/sam2.1/sam2.1_hiera_s.yaml"
|
| 201 |
+
elif "base_plus" in checkpoint_name:
|
| 202 |
+
config = "configs/sam2.1/sam2.1_hiera_b+.yaml"
|
| 203 |
+
elif "large" in checkpoint_name:
|
| 204 |
+
config = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 205 |
+
else:
|
| 206 |
+
logger.warning(f"Неизвестный тип модели, пробую tiny конфиг")
|
| 207 |
+
config = "configs/sam2.1/sam2.1_hiera_t.yaml"
|
| 208 |
+
|
| 209 |
+
logger.info(f"Загружаю модель из {checkpoint_path}")
|
| 210 |
+
logger.info(f"Конфиг: {config}")
|
| 211 |
+
|
| 212 |
+
sam2_model = build_sam2(config, checkpoint_path, device=device)
|
| 213 |
+
predictor = SAM2ImagePredictor(sam2_model)
|
| 214 |
+
|
| 215 |
+
logger.info("✓ Модель загружена успешно")
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
logger.error(f"Не удалось загрузить модель: {e}")
|
| 219 |
+
logger.error("Убедись что SAM2 установлен (./install_sam2.sh)")
|
| 220 |
+
raise
|
| 221 |
+
|
| 222 |
+
@asynccontextmanager
|
| 223 |
+
async def lifespan(app: FastAPI):
|
| 224 |
+
"""Загружаем модель при старте, выгружаем при остановке"""
|
| 225 |
+
# Startup
|
| 226 |
+
checkpoint_dir = "checkpoints"
|
| 227 |
+
if os.path.exists(checkpoint_dir):
|
| 228 |
+
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
|
| 229 |
+
if checkpoints:
|
| 230 |
+
checkpoint_path = os.path.join(checkpoint_dir, checkpoints[0])
|
| 231 |
+
load_model(checkpoint_path)
|
| 232 |
+
else:
|
| 233 |
+
logger.error("Нет чекпоинтов в директории checkpoints/")
|
| 234 |
+
logger.error("Запусти: python download_model.py")
|
| 235 |
+
else:
|
| 236 |
+
logger.error("Директория checkpoints/ не найдена")
|
| 237 |
+
|
| 238 |
+
yield # Сервер работает
|
| 239 |
+
|
| 240 |
+
# Shutdown (если нужна очистка)
|
| 241 |
+
|
| 242 |
+
# Создаем FastAPI приложение с lifespan
|
| 243 |
+
app = FastAPI(
|
| 244 |
+
title="SAM2 Segmentation API",
|
| 245 |
+
description="API для автоматической сегментации объектов на изображениях",
|
| 246 |
+
version="1.0.0",
|
| 247 |
+
lifespan=lifespan
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Добавляем CORS для работы с веб-интерфейсом
|
| 251 |
+
app.add_middleware(
|
| 252 |
+
CORSMiddleware,
|
| 253 |
+
allow_origins=["*"], # В продакшене указать конкретные домены
|
| 254 |
+
allow_credentials=True,
|
| 255 |
+
allow_methods=["*"],
|
| 256 |
+
allow_headers=["*"],
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
@app.get("/")
|
| 260 |
+
async def root():
|
| 261 |
+
"""Главная страница - информация об API"""
|
| 262 |
+
return {
|
| 263 |
+
"message": "SAM2 Segmentation API работает",
|
| 264 |
+
"version": "2.0.0",
|
| 265 |
+
"web_ui": {
|
| 266 |
+
"simple": "/web - Box промпты",
|
| 267 |
+
"advanced": "/web/advanced - Box + Brush промпты (рисование)"
|
| 268 |
+
},
|
| 269 |
+
"docs": "/docs",
|
| 270 |
+
"endpoints": {
|
| 271 |
+
"POST /segment": "Сегментация изображения (поддерживает points, box, mask via query params)",
|
| 272 |
+
"POST /segment/batch": "🔥 Батчинг сегментация (JSON API для множественных объектов)",
|
| 273 |
+
"POST /segment/auto": "Автоматическая сегментация всех объектов",
|
| 274 |
+
"GET /health": "Проверка здоровья сервиса"
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
@app.get("/web", response_class=HTMLResponse)
|
| 279 |
+
async def web_interface():
|
| 280 |
+
"""Веб-интерфейс для тестирования Box Prompts (простой)"""
|
| 281 |
+
web_demo_path = os.path.join(os.path.dirname(__file__), "web_demo.html")
|
| 282 |
+
if os.path.exists(web_demo_path):
|
| 283 |
+
with open(web_demo_path, "r", encoding="utf-8") as f:
|
| 284 |
+
return f.read()
|
| 285 |
+
else:
|
| 286 |
+
return "<h1>Веб-интерфейс не найден</h1><p>Файл web_demo.html отсутствует</p>"
|
| 287 |
+
|
| 288 |
+
@app.get("/web/advanced", response_class=HTMLResponse)
|
| 289 |
+
async def web_interface_advanced():
|
| 290 |
+
"""Продвинутый веб-интерфейс с Box + Brush промптами"""
|
| 291 |
+
web_demo_path = os.path.join(os.path.dirname(__file__), "web_demo_advanced.html")
|
| 292 |
+
if os.path.exists(web_demo_path):
|
| 293 |
+
with open(web_demo_path, "r", encoding="utf-8") as f:
|
| 294 |
+
return f.read()
|
| 295 |
+
else:
|
| 296 |
+
return "<h1>Продвинутый интерфейс не найден</h1><p>Файл web_demo_advanced.html отсутствует</p>"
|
| 297 |
+
|
| 298 |
+
@app.get("/health")
|
| 299 |
+
async def health():
|
| 300 |
+
"""Проверка что всё ок"""
|
| 301 |
+
return {
|
| 302 |
+
"status": "healthy" if predictor is not None else "model not loaded",
|
| 303 |
+
"device": str(device) if device else "unknown"
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
def process_image(image_bytes: bytes) -> np.ndarray:
|
| 307 |
+
"""Конвертирует байты в numpy array"""
|
| 308 |
+
image = Image.open(io.BytesIO(image_bytes))
|
| 309 |
+
if image.mode != "RGB":
|
| 310 |
+
image = image.convert("RGB")
|
| 311 |
+
return np.array(image)
|
| 312 |
+
|
| 313 |
+
def masks_to_coords(masks: np.ndarray, include_contours: bool = False) -> List[Dict[str, Any]]:
|
| 314 |
+
"""
|
| 315 |
+
Конвертирует маски в координаты bounding box и контуров.
|
| 316 |
+
masks: (N, H, W) - N масок
|
| 317 |
+
include_contours: если True, добавляет контуры масок
|
| 318 |
+
"""
|
| 319 |
+
results = []
|
| 320 |
+
|
| 321 |
+
for i, mask in enumerate(masks):
|
| 322 |
+
# Находим координаты пикселей маски
|
| 323 |
+
y_coords, x_coords = np.where(mask > 0)
|
| 324 |
+
|
| 325 |
+
if len(x_coords) == 0:
|
| 326 |
+
continue
|
| 327 |
+
|
| 328 |
+
# Bounding box
|
| 329 |
+
x_min, x_max = int(x_coords.min()), int(x_coords.max())
|
| 330 |
+
y_min, y_max = int(y_coords.min()), int(y_coords.max())
|
| 331 |
+
|
| 332 |
+
# Площадь сегмента
|
| 333 |
+
area = int(mask.sum())
|
| 334 |
+
|
| 335 |
+
segment_data = {
|
| 336 |
+
"segment_id": i,
|
| 337 |
+
"bbox": {
|
| 338 |
+
"x_min": x_min,
|
| 339 |
+
"y_min": y_min,
|
| 340 |
+
"x_max": x_max,
|
| 341 |
+
"y_max": y_max,
|
| 342 |
+
"width": x_max - x_min,
|
| 343 |
+
"height": y_max - y_min
|
| 344 |
+
},
|
| 345 |
+
"area": area,
|
| 346 |
+
"center": {
|
| 347 |
+
"x": int(x_coords.mean()),
|
| 348 |
+
"y": int(y_coords.mean())
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
# Добавляем контуры если нужно
|
| 353 |
+
if include_contours:
|
| 354 |
+
try:
|
| 355 |
+
# Конвертируем маску в uint8 (защита от булевых масок)
|
| 356 |
+
if mask.dtype == bool:
|
| 357 |
+
mask_uint8 = mask.astype(np.uint8) * 255
|
| 358 |
+
else:
|
| 359 |
+
mask_uint8 = (mask * 255).astype(np.uint8)
|
| 360 |
+
|
| 361 |
+
# Находим контуры с иерархией для поддержки "дыр"
|
| 362 |
+
# RETR_CCOMP: находит внешние контуры И внутренние дыры (holes)
|
| 363 |
+
# CHAIN_APPROX_NONE: сохраняет ВСЕ точки для pixel-perfect результата
|
| 364 |
+
contours, hierarchy = cv2.findContours(mask_uint8, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
| 365 |
+
except Exception as e:
|
| 366 |
+
logger.warning(f"Ошибка при извлечении контуров: {e}, использую fallback")
|
| 367 |
+
# Fallback на простое извлечение без иерархии
|
| 368 |
+
contours, hierarchy = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
| 369 |
+
hierarchy = None
|
| 370 |
+
|
| 371 |
+
# Конвертируем контуры в список точек с учетом иерархии
|
| 372 |
+
contour_data = []
|
| 373 |
+
|
| 374 |
+
if hierarchy is not None and len(contours) > 0:
|
| 375 |
+
hierarchy = hierarchy[0] # OpenCV возвращает hierarchy в странном формате
|
| 376 |
+
|
| 377 |
+
for i, contour in enumerate(contours):
|
| 378 |
+
try:
|
| 379 |
+
# Небольшое упрощение только для очень больших контуров
|
| 380 |
+
if len(contour) > 1000:
|
| 381 |
+
arc_length = cv2.arcLength(contour, True)
|
| 382 |
+
if arc_length > 0: # Защита от деления на 0
|
| 383 |
+
epsilon = 0.0005 * arc_length
|
| 384 |
+
approx = cv2.approxPolyDP(contour, epsilon, True)
|
| 385 |
+
else:
|
| 386 |
+
approx = contour
|
| 387 |
+
else:
|
| 388 |
+
approx = contour
|
| 389 |
+
|
| 390 |
+
# Конвертируем в список [x, y]
|
| 391 |
+
points = [[int(point[0][0]), int(point[0][1])] for point in approx]
|
| 392 |
+
|
| 393 |
+
if len(points) > 2:
|
| 394 |
+
# hierarchy[i] = [Next, Previous, First_Child, Parent]
|
| 395 |
+
# Если Parent == -1, это внешний контур
|
| 396 |
+
# Если Parent >= 0, это дыра (hole) внутри родительского контура
|
| 397 |
+
is_hole = hierarchy[i][3] != -1
|
| 398 |
+
|
| 399 |
+
contour_data.append({
|
| 400 |
+
"points": points,
|
| 401 |
+
"is_hole": is_hole
|
| 402 |
+
})
|
| 403 |
+
except Exception as e:
|
| 404 |
+
logger.warning(f"Ошибка при обработке контура {i}: {e}")
|
| 405 |
+
continue
|
| 406 |
+
else:
|
| 407 |
+
# Fallback если hierarchy не вернулась
|
| 408 |
+
for contour in contours:
|
| 409 |
+
try:
|
| 410 |
+
if len(contour) > 1000:
|
| 411 |
+
arc_length = cv2.arcLength(contour, True)
|
| 412 |
+
if arc_length > 0:
|
| 413 |
+
epsilon = 0.0005 * arc_length
|
| 414 |
+
approx = cv2.approxPolyDP(contour, epsilon, True)
|
| 415 |
+
else:
|
| 416 |
+
approx = contour
|
| 417 |
+
else:
|
| 418 |
+
approx = contour
|
| 419 |
+
|
| 420 |
+
points = [[int(point[0][0]), int(point[0][1])] for point in approx]
|
| 421 |
+
if len(points) > 2:
|
| 422 |
+
contour_data.append({
|
| 423 |
+
"points": points,
|
| 424 |
+
"is_hole": False
|
| 425 |
+
})
|
| 426 |
+
except Exception as e:
|
| 427 |
+
logger.warning(f"Ошибка при обработке контура: {e}")
|
| 428 |
+
continue
|
| 429 |
+
|
| 430 |
+
segment_data["contours"] = contour_data if len(contour_data) > 0 else []
|
| 431 |
+
|
| 432 |
+
# Также добавляем RLE (Run-Length Encoding) для компактного представления
|
| 433 |
+
# Это полезно если нужно восстановить точную маску
|
| 434 |
+
segment_data["mask_rle"] = mask_to_rle(mask)
|
| 435 |
+
|
| 436 |
+
results.append(segment_data)
|
| 437 |
+
|
| 438 |
+
return results
|
| 439 |
+
|
| 440 |
+
def mask_to_rle(mask: np.ndarray) -> Dict[str, Any]:
|
| 441 |
+
"""
|
| 442 |
+
Конвертирует бинарную маску в RLE (Run-Length Encoding)
|
| 443 |
+
Компактное представление маски
|
| 444 |
+
"""
|
| 445 |
+
# Конвертируем в int если это bool
|
| 446 |
+
if mask.dtype == bool:
|
| 447 |
+
pixels = mask.astype(np.uint8).flatten()
|
| 448 |
+
else:
|
| 449 |
+
pixels = mask.flatten()
|
| 450 |
+
|
| 451 |
+
pixels = np.concatenate([[0], pixels, [0]])
|
| 452 |
+
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
|
| 453 |
+
runs[1::2] -= runs[::2]
|
| 454 |
+
|
| 455 |
+
return {
|
| 456 |
+
"counts": [int(x) for x in runs], # Конвертируем numpy int в Python int
|
| 457 |
+
"size": [int(x) for x in mask.shape] # Конвертируем в Python int
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
def convert_to_native_types(obj):
|
| 461 |
+
"""
|
| 462 |
+
Рекурсивно конвертирует numpy типы в нативные Python типы
|
| 463 |
+
Нужно для сериализации в JSON через FastAPI
|
| 464 |
+
"""
|
| 465 |
+
if isinstance(obj, np.integer):
|
| 466 |
+
return int(obj)
|
| 467 |
+
elif isinstance(obj, np.floating):
|
| 468 |
+
return float(obj)
|
| 469 |
+
elif isinstance(obj, np.ndarray):
|
| 470 |
+
return obj.tolist()
|
| 471 |
+
elif isinstance(obj, np.bool_):
|
| 472 |
+
return bool(obj)
|
| 473 |
+
elif isinstance(obj, dict):
|
| 474 |
+
return {key: convert_to_native_types(value) for key, value in obj.items()}
|
| 475 |
+
elif isinstance(obj, list):
|
| 476 |
+
return [convert_to_native_types(item) for item in obj]
|
| 477 |
+
return obj
|
| 478 |
+
|
| 479 |
+
def clean_mask(mask: np.ndarray, min_area: int = 100) -> np.ndarray:
|
| 480 |
+
"""
|
| 481 |
+
Очищает маску от мелких артефактов и дыр.
|
| 482 |
+
|
| 483 |
+
mask: бинарная маска (H, W)
|
| 484 |
+
min_area: минимальная площадь компонента в пикселях
|
| 485 |
+
|
| 486 |
+
Returns: очищенная маска
|
| 487 |
+
"""
|
| 488 |
+
# Конвертируем в uint8 если нужно
|
| 489 |
+
if mask.dtype == bool:
|
| 490 |
+
mask_uint8 = mask.astype(np.uint8) * 255
|
| 491 |
+
else:
|
| 492 |
+
mask_uint8 = (mask * 255).astype(np.uint8)
|
| 493 |
+
|
| 494 |
+
# Морфологическое закрытие для удаления мелких дыр
|
| 495 |
+
kernel = np.ones((3, 3), np.uint8)
|
| 496 |
+
mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel, iterations=2)
|
| 497 |
+
|
| 498 |
+
# Морфологическое открытие для удаления мелких шумов
|
| 499 |
+
mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_OPEN, kernel, iterations=1)
|
| 500 |
+
|
| 501 |
+
# Находим все связанные компоненты
|
| 502 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_uint8, connectivity=8)
|
| 503 |
+
|
| 504 |
+
# Создаем чистую маску
|
| 505 |
+
clean_mask = np.zeros_like(mask_uint8)
|
| 506 |
+
|
| 507 |
+
# Оставляем только большие компоненты
|
| 508 |
+
for i in range(1, num_labels): # Пропускаем фон (0)
|
| 509 |
+
area = stats[i, cv2.CC_STAT_AREA]
|
| 510 |
+
if area >= min_area:
|
| 511 |
+
clean_mask[labels == i] = 255
|
| 512 |
+
|
| 513 |
+
# Если ничего не осталось, возвращаем самый большой компонент
|
| 514 |
+
if clean_mask.sum() == 0 and num_labels > 1:
|
| 515 |
+
# Находим самый большой компонент
|
| 516 |
+
largest_component = 1 + np.argmax([stats[i, cv2.CC_STAT_AREA] for i in range(1, num_labels)])
|
| 517 |
+
clean_mask[labels == largest_component] = 255
|
| 518 |
+
|
| 519 |
+
return (clean_mask > 127).astype(bool)
|
| 520 |
+
|
| 521 |
+
def extract_object_image(image: np.ndarray, mask: np.ndarray, clean: bool = True) -> str:
|
| 522 |
+
"""
|
| 523 |
+
Вырезает объект из изображения по маске и возвращает base64 PNG с прозрачностью.
|
| 524 |
+
|
| 525 |
+
image: RGB изображение (H, W, 3)
|
| 526 |
+
mask: бинарная маска (H, W)
|
| 527 |
+
clean: применить постобработку для удаления артефактов
|
| 528 |
+
|
| 529 |
+
Returns: base64 строка PNG изображения с альфа-каналом
|
| 530 |
+
"""
|
| 531 |
+
# Конвертируем маску в bool если нужно
|
| 532 |
+
if mask.dtype != bool:
|
| 533 |
+
mask = mask > 0.5
|
| 534 |
+
|
| 535 |
+
# Очищаем маску от артефактов
|
| 536 |
+
if clean:
|
| 537 |
+
mask = clean_mask(mask, min_area=100)
|
| 538 |
+
|
| 539 |
+
# Создаем RGBA изображение
|
| 540 |
+
h, w = image.shape[:2]
|
| 541 |
+
rgba = np.zeros((h, w, 4), dtype=np.uint8)
|
| 542 |
+
rgba[:, :, :3] = image # RGB каналы
|
| 543 |
+
rgba[:, :, 3] = (mask * 255).astype(np.uint8) # Alpha канал из маски
|
| 544 |
+
|
| 545 |
+
# Конвертиру��м в PIL Image
|
| 546 |
+
pil_image = Image.fromarray(rgba, 'RGBA')
|
| 547 |
+
|
| 548 |
+
# Конвертируем в base64
|
| 549 |
+
buffer = io.BytesIO()
|
| 550 |
+
pil_image.save(buffer, format='PNG')
|
| 551 |
+
buffer.seek(0)
|
| 552 |
+
img_base64 = base64.b64encode(buffer.read()).decode('utf-8')
|
| 553 |
+
|
| 554 |
+
return f"data:image/png;base64,{img_base64}"
|
| 555 |
+
|
| 556 |
+
@app.post("/segment")
|
| 557 |
+
async def segment_image(
|
| 558 |
+
file: UploadFile = File(...),
|
| 559 |
+
point_x: List[float] = Query(None, description="X координаты точек промпта"),
|
| 560 |
+
point_y: List[float] = Query(None, description="Y координаты точек промпта"),
|
| 561 |
+
point_labels: List[int] = Query(None, description="Лейблы точек (1=foreground, 0=background)"),
|
| 562 |
+
box_x1: float = Query(None, description="X координата левого верхнего угла бокса"),
|
| 563 |
+
box_y1: float = Query(None, description="Y координата левого верхнего угла бокса"),
|
| 564 |
+
box_x2: float = Query(None, description="X координата правого нижнего угла бокса"),
|
| 565 |
+
box_y2: float = Query(None, description="Y координата правого нижнего угла бокса"),
|
| 566 |
+
mask_data: str = Query(None, description="Base64 закодированная маска (PNG с альфа-каналом)"),
|
| 567 |
+
include_masks: bool = Query(True, description="Включить контуры масок в ответ"),
|
| 568 |
+
extract_objects: bool = Query(False, description="Вернуть вырезанные объекты как base64 PNG"),
|
| 569 |
+
):
|
| 570 |
+
"""
|
| 571 |
+
Сегментирует изображение по промпту (точкам, боксу, маске или их комбинации).
|
| 572 |
+
|
| 573 |
+
Поддерживаемые промпты:
|
| 574 |
+
- Точки (point_x, point_y, point_labels) - клики пользователя
|
| 575 |
+
- Бокс (box_x1, box_y1, box_x2, box_y2) - прямоугольное выделение
|
| 576 |
+
- Маска (mask_data) - нарисованная кистью маска (зеленый=foreground, красный=background)
|
| 577 |
+
- Комбинация промптов - для максимальной точности
|
| 578 |
+
|
| 579 |
+
Если промпты не указаны, сегментирует центральный объект.
|
| 580 |
+
Если include_masks=True, возвращает контуры масок для точной отрисовки.
|
| 581 |
+
Если extract_objects=True, возвращает готовые вырезанные объекты как base64 PNG.
|
| 582 |
+
"""
|
| 583 |
+
if predictor is None:
|
| 584 |
+
raise HTTPException(status_code=503, detail="Модель не загружена, перезапусти сервер")
|
| 585 |
+
|
| 586 |
+
try:
|
| 587 |
+
# Читаем изображение
|
| 588 |
+
image_bytes = await file.read()
|
| 589 |
+
image = process_image(image_bytes)
|
| 590 |
+
|
| 591 |
+
logger.info(f"Обрабатываю изображение: {image.shape}")
|
| 592 |
+
logger.info(f"Параметры: include_masks={include_masks}, extract_objects={extract_objects}")
|
| 593 |
+
|
| 594 |
+
# Устанавливаем изображение в предиктор
|
| 595 |
+
predictor.set_image(image)
|
| 596 |
+
|
| 597 |
+
# Подготавливаем промпты
|
| 598 |
+
points = None
|
| 599 |
+
labels = None
|
| 600 |
+
box = None
|
| 601 |
+
|
| 602 |
+
# Проверяем наличие точек
|
| 603 |
+
if point_x and point_y:
|
| 604 |
+
if len(point_x) != len(point_y):
|
| 605 |
+
raise HTTPException(status_code=400, detail="Количество X и Y координат должно совпадать")
|
| 606 |
+
points = np.array([[x, y] for x, y in zip(point_x, point_y)])
|
| 607 |
+
labels = np.array(point_labels) if point_labels else np.ones(len(points))
|
| 608 |
+
logger.info(f"Промпт: {len(points)} точек")
|
| 609 |
+
|
| 610 |
+
# Проверяем наличие бокса
|
| 611 |
+
if all(v is not None for v in [box_x1, box_y1, box_x2, box_y2]):
|
| 612 |
+
box = np.array([box_x1, box_y1, box_x2, box_y2])
|
| 613 |
+
logger.info(f"Промпт: бокс [{box_x1:.1f}, {box_y1:.1f}, {box_x2:.1f}, {box_y2:.1f}]")
|
| 614 |
+
|
| 615 |
+
# Валидация бокса
|
| 616 |
+
if box_x2 <= box_x1 or box_y2 <= box_y1:
|
| 617 |
+
raise HTTPException(
|
| 618 |
+
status_code=400,
|
| 619 |
+
detail="Некорректный бокс: x2 должен быть больше x1, y2 больше y1"
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# Проверяем наличие нарисованной маски
|
| 623 |
+
if mask_data:
|
| 624 |
+
logger.info("Обрабатываю нарисованную маску...")
|
| 625 |
+
try:
|
| 626 |
+
# Декодируем base64
|
| 627 |
+
if ',' in mask_data:
|
| 628 |
+
mask_data = mask_data.split(',')[1] # Убираем data:image/png;base64,
|
| 629 |
+
|
| 630 |
+
mask_bytes = base64.b64decode(mask_data)
|
| 631 |
+
mask_image = Image.open(io.BytesIO(mask_bytes)).convert('RGBA')
|
| 632 |
+
mask_array = np.array(mask_image)
|
| 633 |
+
|
| 634 |
+
# Извлекаем foreground и background пиксели
|
| 635 |
+
# Поддерживаем несколько форматов:
|
| 636 |
+
# 1. Зеленый (R<100, G>150, B<100) - классический foreground
|
| 637 |
+
# 2. Белый/светлый (R>200, G>200, B>200) - часто используется фронтами
|
| 638 |
+
# 3. Красный (R>150, G<100, B<100) - background
|
| 639 |
+
|
| 640 |
+
green_mask = (mask_array[:, :, 0] < 100) & (mask_array[:, :, 1] > 150) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0)
|
| 641 |
+
white_mask = (mask_array[:, :, 0] > 200) & (mask_array[:, :, 1] > 200) & (mask_array[:, :, 2] > 200) & (mask_array[:, :, 3] > 0)
|
| 642 |
+
red_mask = (mask_array[:, :, 0] > 150) & (mask_array[:, :, 1] < 100) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0)
|
| 643 |
+
|
| 644 |
+
# Объединяем зеленые и белые как foreground
|
| 645 |
+
foreground_mask = green_mask | white_mask
|
| 646 |
+
|
| 647 |
+
# Сэмплируем точки из закрашенных областей
|
| 648 |
+
mask_points = []
|
| 649 |
+
mask_labels = []
|
| 650 |
+
|
| 651 |
+
# Foreground точки (зеленые + белые)
|
| 652 |
+
foreground_coords = np.argwhere(foreground_mask)
|
| 653 |
+
if len(foreground_coords) > 0:
|
| 654 |
+
# Масштабируем к размеру исходного изображения
|
| 655 |
+
scale_y = image.shape[0] / mask_array.shape[0]
|
| 656 |
+
scale_x = image.shape[1] / mask_array.shape[1]
|
| 657 |
+
|
| 658 |
+
# Сэмплируем до 20 точек равномерно (меньше = стабильнее)
|
| 659 |
+
step = max(1, len(foreground_coords) // 20)
|
| 660 |
+
sampled = foreground_coords[::step][:20] # Максимум 20 точек
|
| 661 |
+
|
| 662 |
+
for y, x in sampled:
|
| 663 |
+
mask_points.append([x * scale_x, y * scale_y])
|
| 664 |
+
mask_labels.append(1) # foreground
|
| 665 |
+
|
| 666 |
+
# Background точки (красные)
|
| 667 |
+
red_coords = np.argwhere(red_mask)
|
| 668 |
+
if len(red_coords) > 0:
|
| 669 |
+
scale_y = image.shape[0] / mask_array.shape[0]
|
| 670 |
+
scale_x = image.shape[1] / mask_array.shape[1]
|
| 671 |
+
|
| 672 |
+
step = max(1, len(red_coords) // 20)
|
| 673 |
+
sampled = red_coords[::step][:20] # Максимум 20 точек
|
| 674 |
+
|
| 675 |
+
for y, x in sampled:
|
| 676 |
+
mask_points.append([x * scale_x, y * scale_y])
|
| 677 |
+
mask_labels.append(0) # background
|
| 678 |
+
|
| 679 |
+
if mask_points:
|
| 680 |
+
# Объединяем с существующими точками
|
| 681 |
+
if points is not None:
|
| 682 |
+
points = np.vstack([points, np.array(mask_points)])
|
| 683 |
+
labels = np.concatenate([labels, np.array(mask_labels)])
|
| 684 |
+
else:
|
| 685 |
+
points = np.array(mask_points)
|
| 686 |
+
labels = np.array(mask_labels)
|
| 687 |
+
|
| 688 |
+
logger.info(f"Промпт из маски: {len(mask_points)} точек ({np.sum(np.array(mask_labels) == 1)} foreground, {np.sum(np.array(mask_labels) == 0)} background)")
|
| 689 |
+
else:
|
| 690 |
+
logger.warning("Маска пустая или не содержит foreground (зеленых/белых) или background (красных) пикселей")
|
| 691 |
+
|
| 692 |
+
except Exception as e:
|
| 693 |
+
logger.error(f"Ошибка обработки маски: {e}")
|
| 694 |
+
raise HTTPException(status_code=400, detail=f"Некорректная маска: {str(e)}")
|
| 695 |
+
|
| 696 |
+
# Делаем предсказание с промптами
|
| 697 |
+
if points is not None or box is not None:
|
| 698 |
+
logger.info(f"Используем промпты: points={points is not None}, box={box is not None}")
|
| 699 |
+
|
| 700 |
+
# Если много точек (>10), используем single mask для стабильности
|
| 701 |
+
# Если мало точек или только box, используем multimask для вариативности
|
| 702 |
+
use_multimask = True
|
| 703 |
+
if points is not None and len(points) > 10:
|
| 704 |
+
use_multimask = False
|
| 705 |
+
logger.info("Много точек, используем single mask mode для стабильности")
|
| 706 |
+
|
| 707 |
+
masks, scores, logits = predictor.predict(
|
| 708 |
+
point_coords=points,
|
| 709 |
+
point_labels=labels,
|
| 710 |
+
box=box,
|
| 711 |
+
multimask_output=use_multimask,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
# Если multimask, выбираем лучшую по score
|
| 715 |
+
if use_multimask and len(masks) > 1:
|
| 716 |
+
best_idx = np.argmax(scores)
|
| 717 |
+
masks = masks[best_idx:best_idx+1]
|
| 718 |
+
scores = scores[best_idx:best_idx+1]
|
| 719 |
+
logger.info(f"Выбрана маска {best_idx} с confidence {scores[0]:.3f}")
|
| 720 |
+
else:
|
| 721 |
+
# Автоматическая сегментация - берем центральную точку
|
| 722 |
+
logger.info("Промпты не указаны, сегментирую центральный объект")
|
| 723 |
+
h, w = image.shape[:2]
|
| 724 |
+
point = np.array([[w // 2, h // 2]])
|
| 725 |
+
label = np.array([1])
|
| 726 |
+
|
| 727 |
+
masks, scores, logits = predictor.predict(
|
| 728 |
+
point_coords=point,
|
| 729 |
+
point_labels=label,
|
| 730 |
+
multimask_output=True,
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# Конвертируем маски в координаты (с контурами если нужно)
|
| 734 |
+
segments = masks_to_coords(masks, include_contours=include_masks)
|
| 735 |
+
|
| 736 |
+
logger.info(f"Найдено сегментов: {len(segments)}, масок: {len(masks)}")
|
| 737 |
+
logger.info(f"extract_objects = {extract_objects}")
|
| 738 |
+
|
| 739 |
+
# Добавляем confidence scores
|
| 740 |
+
for i, seg in enumerate(segments):
|
| 741 |
+
seg["confidence"] = float(scores[i]) if i < len(scores) else 0.0
|
| 742 |
+
|
| 743 |
+
# Если нужно - вырезаем объект и добавляем base64
|
| 744 |
+
logger.info(f"Обрабатываю сегмент {i}: extract_objects={extract_objects}, i < len(masks) = {i < len(masks)}")
|
| 745 |
+
if extract_objects and i < len(masks):
|
| 746 |
+
logger.info(f"Вырезаю объект {i}...")
|
| 747 |
+
seg["extracted_image"] = extract_object_image(image, masks[i])
|
| 748 |
+
logger.info(f"✓ Вырезан объект {i}, размер маски: {masks[i].sum()} пикселей")
|
| 749 |
+
else:
|
| 750 |
+
logger.warning(f"❌ Пропускаю объект {i}: extract_objects={extract_objects}")
|
| 751 |
+
|
| 752 |
+
result = {
|
| 753 |
+
"success": True,
|
| 754 |
+
"image_size": {
|
| 755 |
+
"width": int(image.shape[1]),
|
| 756 |
+
"height": int(image.shape[0])
|
| 757 |
+
},
|
| 758 |
+
"segments_count": len(segments),
|
| 759 |
+
"segments": segments
|
| 760 |
+
}
|
| 761 |
+
|
| 762 |
+
# Конвертируем все numpy типы в нативные Python типы
|
| 763 |
+
return convert_to_native_types(result)
|
| 764 |
+
|
| 765 |
+
except Exception as e:
|
| 766 |
+
logger.error(f"Ошибка при сегментации: {e}")
|
| 767 |
+
raise HTTPException(status_code=500, detail=f"Ошибка обработки: {str(e)}")
|
| 768 |
+
|
| 769 |
+
@app.post("/segment/auto")
|
| 770 |
+
async def segment_auto(
|
| 771 |
+
file: UploadFile = File(...),
|
| 772 |
+
points_per_side: int = Query(32, description="Количество точек на сторону для автосегментации"),
|
| 773 |
+
include_masks: bool = Query(True, description="Включить контуры масок в ответ"),
|
| 774 |
+
):
|
| 775 |
+
"""
|
| 776 |
+
Автоматическая сегментация всех объектов на изображении.
|
| 777 |
+
Использует grid of points для поиска всех возможных объектов.
|
| 778 |
+
Если include_masks=True, возвращает контуры масок для точной отрисовки.
|
| 779 |
+
"""
|
| 780 |
+
if predictor is None:
|
| 781 |
+
raise HTTPException(status_code=503, detail="Модель не загружена")
|
| 782 |
+
|
| 783 |
+
try:
|
| 784 |
+
image_bytes = await file.read()
|
| 785 |
+
image = process_image(image_bytes)
|
| 786 |
+
|
| 787 |
+
logger.info(f"Автосегментация изображения: {image.shape}")
|
| 788 |
+
|
| 789 |
+
predictor.set_image(image)
|
| 790 |
+
|
| 791 |
+
# Создаем сетку точек
|
| 792 |
+
h, w = image.shape[:2]
|
| 793 |
+
x_coords = np.linspace(0, w, points_per_side)
|
| 794 |
+
y_coords = np.linspace(0, h, points_per_side)
|
| 795 |
+
|
| 796 |
+
all_segments = []
|
| 797 |
+
segment_id = 0
|
| 798 |
+
|
| 799 |
+
# Для каждой точки в сетке пытаемся найти объект
|
| 800 |
+
for y in y_coords:
|
| 801 |
+
for x in x_coords:
|
| 802 |
+
point = np.array([[x, y]])
|
| 803 |
+
label = np.array([1])
|
| 804 |
+
|
| 805 |
+
masks, scores, _ = predictor.predict(
|
| 806 |
+
point_coords=point,
|
| 807 |
+
point_labels=label,
|
| 808 |
+
multimask_output=False,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
if masks.shape[0] > 0 and scores[0] > 0.5: # Порог confidence
|
| 812 |
+
segments = masks_to_coords(masks, include_contours=include_masks)
|
| 813 |
+
for seg in segments:
|
| 814 |
+
seg["segment_id"] = segment_id
|
| 815 |
+
seg["confidence"] = float(scores[0])
|
| 816 |
+
all_segments.append(seg)
|
| 817 |
+
segment_id += 1
|
| 818 |
+
|
| 819 |
+
# Убир��ем дубликаты (примерно)
|
| 820 |
+
# Два сегмента считаем дубликатами если их центры близко
|
| 821 |
+
unique_segments = []
|
| 822 |
+
for seg in all_segments:
|
| 823 |
+
is_duplicate = False
|
| 824 |
+
for unique_seg in unique_segments:
|
| 825 |
+
dx = seg["center"]["x"] - unique_seg["center"]["x"]
|
| 826 |
+
dy = seg["center"]["y"] - unique_seg["center"]["y"]
|
| 827 |
+
dist = (dx**2 + dy**2) ** 0.5
|
| 828 |
+
|
| 829 |
+
if dist < 50: # Порог расстояния между центрами
|
| 830 |
+
is_duplicate = True
|
| 831 |
+
break
|
| 832 |
+
|
| 833 |
+
if not is_duplicate:
|
| 834 |
+
unique_segments.append(seg)
|
| 835 |
+
|
| 836 |
+
result = {
|
| 837 |
+
"success": True,
|
| 838 |
+
"image_size": {
|
| 839 |
+
"width": int(image.shape[1]),
|
| 840 |
+
"height": int(image.shape[0])
|
| 841 |
+
},
|
| 842 |
+
"segments_count": len(unique_segments),
|
| 843 |
+
"segments": unique_segments
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
# Конвертируем все numpy типы в нативные Python типы
|
| 847 |
+
return convert_to_native_types(result)
|
| 848 |
+
|
| 849 |
+
except Exception as e:
|
| 850 |
+
logger.error(f"Ошибка при автосегментации: {e}")
|
| 851 |
+
raise HTTPException(status_code=500, detail=f"Ошибка обработки: {str(e)}")
|
| 852 |
+
|
| 853 |
+
@app.post("/segment/batch", response_model=BatchSegmentResponse)
|
| 854 |
+
async def segment_batch(request: BatchSegmentRequest = Body(...)):
|
| 855 |
+
"""
|
| 856 |
+
Батчинг сегментация нескольких объектов.
|
| 857 |
+
|
| 858 |
+
Принимает изображение и массив промптов (mask/box/points).
|
| 859 |
+
Обрабатывает каждый selected промпт отдельно.
|
| 860 |
+
Возвращает массив результатов с метаданными.
|
| 861 |
+
|
| 862 |
+
Идеально для:
|
| 863 |
+
- Множественных объектов
|
| 864 |
+
- Мобильных приложений
|
| 865 |
+
- Когда фронт уже разделил объекты
|
| 866 |
+
"""
|
| 867 |
+
if predictor is None:
|
| 868 |
+
raise HTTPException(status_code=503, detail="Модель не загружена, перезапусти сервер")
|
| 869 |
+
|
| 870 |
+
try:
|
| 871 |
+
# Декодируем изображение из base64
|
| 872 |
+
image_data = request.image
|
| 873 |
+
if ',' in image_data:
|
| 874 |
+
image_data = image_data.split(',')[1] # Убираем data:image/...;base64,
|
| 875 |
+
|
| 876 |
+
image_bytes = base64.b64decode(image_data)
|
| 877 |
+
image = process_image(image_bytes)
|
| 878 |
+
|
| 879 |
+
logger.info(f"Батчинг сегментация: {image.shape}, промптов: {len(request.prompts)}")
|
| 880 |
+
|
| 881 |
+
# Устанавливаем изображение один раз
|
| 882 |
+
predictor.set_image(image)
|
| 883 |
+
|
| 884 |
+
results = []
|
| 885 |
+
|
| 886 |
+
# Фильтруем только selected промпты
|
| 887 |
+
selected_prompts = [p for p in request.prompts if p.selected]
|
| 888 |
+
logger.info(f"Обрабатываем {len(selected_prompts)} из {len(request.prompts)} промптов")
|
| 889 |
+
|
| 890 |
+
# Обрабатываем каждый промпт отдельно
|
| 891 |
+
for prompt in selected_prompts:
|
| 892 |
+
logger.info(f"Обрабатываю промпт #{prompt.id}, тип: {prompt.type}, label: {prompt.label}")
|
| 893 |
+
|
| 894 |
+
try:
|
| 895 |
+
# Подготавливаем промпт в зависимости от типа
|
| 896 |
+
points = None
|
| 897 |
+
labels = None
|
| 898 |
+
box = None
|
| 899 |
+
|
| 900 |
+
if prompt.type == "mask":
|
| 901 |
+
# Декодируем маску и извлекаем точки
|
| 902 |
+
mask_data = prompt.data
|
| 903 |
+
if ',' in mask_data:
|
| 904 |
+
mask_data = mask_data.split(',')[1]
|
| 905 |
+
|
| 906 |
+
mask_bytes = base64.b64decode(mask_data)
|
| 907 |
+
mask_image = Image.open(io.BytesIO(mask_bytes)).convert('RGBA')
|
| 908 |
+
mask_array = np.array(mask_image)
|
| 909 |
+
|
| 910 |
+
# Извлекаем foreground и background пиксели
|
| 911 |
+
# Поддерживаем несколько форматов:
|
| 912 |
+
# 1. Зеленый (R<100, G>150, B<100) - классический foreground
|
| 913 |
+
# 2. Белый/светлый (R>200, G>200, B>200) - часто используется фронтами
|
| 914 |
+
# 3. Красный (R>150, G<100, B<100) - background
|
| 915 |
+
|
| 916 |
+
green_mask = (mask_array[:, :, 0] < 100) & (mask_array[:, :, 1] > 150) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0)
|
| 917 |
+
white_mask = (mask_array[:, :, 0] > 200) & (mask_array[:, :, 1] > 200) & (mask_array[:, :, 2] > 200) & (mask_array[:, :, 3] > 0)
|
| 918 |
+
red_mask = (mask_array[:, :, 0] > 150) & (mask_array[:, :, 1] < 100) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0)
|
| 919 |
+
|
| 920 |
+
# Объединяем зеленые и белые как foreground
|
| 921 |
+
foreground_mask = green_mask | white_mask
|
| 922 |
+
|
| 923 |
+
mask_points = []
|
| 924 |
+
mask_labels = []
|
| 925 |
+
|
| 926 |
+
# Foreground точки (зеленые + белые)
|
| 927 |
+
foreground_coords = np.argwhere(foreground_mask)
|
| 928 |
+
if len(foreground_coords) > 0:
|
| 929 |
+
scale_y = image.shape[0] / mask_array.shape[0]
|
| 930 |
+
scale_x = image.shape[1] / mask_array.shape[1]
|
| 931 |
+
step = max(1, len(foreground_coords) // 20)
|
| 932 |
+
sampled = foreground_coords[::step][:20]
|
| 933 |
+
|
| 934 |
+
for y, x in sampled:
|
| 935 |
+
mask_points.append([x * scale_x, y * scale_y])
|
| 936 |
+
mask_labels.append(1)
|
| 937 |
+
|
| 938 |
+
# Background точки
|
| 939 |
+
red_coords = np.argwhere(red_mask)
|
| 940 |
+
if len(red_coords) > 0:
|
| 941 |
+
scale_y = image.shape[0] / mask_array.shape[0]
|
| 942 |
+
scale_x = image.shape[1] / mask_array.shape[1]
|
| 943 |
+
step = max(1, len(red_coords) // 20)
|
| 944 |
+
sampled = red_coords[::step][:20]
|
| 945 |
+
|
| 946 |
+
for y, x in sampled:
|
| 947 |
+
mask_points.append([x * scale_x, y * scale_y])
|
| 948 |
+
mask_labels.append(0)
|
| 949 |
+
|
| 950 |
+
if mask_points:
|
| 951 |
+
points = np.array(mask_points)
|
| 952 |
+
labels = np.array(mask_labels)
|
| 953 |
+
|
| 954 |
+
elif prompt.type == "box":
|
| 955 |
+
# Парсим bbox - может быть нормализованный (0-1) или пиксельный
|
| 956 |
+
bbox_data = prompt.bbox if prompt.bbox else None
|
| 957 |
+
|
| 958 |
+
if bbox_data:
|
| 959 |
+
x1 = bbox_data.x_min
|
| 960 |
+
y1 = bbox_data.y_min
|
| 961 |
+
x2 = bbox_data.x_max
|
| 962 |
+
y2 = bbox_data.y_max
|
| 963 |
+
|
| 964 |
+
# Если нормализованные координаты (0-1), конвертируем в пиксели
|
| 965 |
+
if x2 <= 1.0 and y2 <= 1.0:
|
| 966 |
+
x1 *= image.shape[1]
|
| 967 |
+
x2 *= image.shape[1]
|
| 968 |
+
y1 *= image.shape[0]
|
| 969 |
+
y2 *= image.shape[0]
|
| 970 |
+
|
| 971 |
+
box = np.array([x1, y1, x2, y2])
|
| 972 |
+
|
| 973 |
+
elif prompt.type == "points":
|
| 974 |
+
# Ожидаем JSON в формате [[x, y, label], ...]
|
| 975 |
+
import json
|
| 976 |
+
points_data = json.loads(prompt.data)
|
| 977 |
+
|
| 978 |
+
points_list = []
|
| 979 |
+
labels_list = []
|
| 980 |
+
|
| 981 |
+
for point in points_data:
|
| 982 |
+
x, y = point[0], point[1]
|
| 983 |
+
label = point[2] if len(point) > 2 else 1
|
| 984 |
+
|
| 985 |
+
# Если нормализованные, конвертируем
|
| 986 |
+
if x <= 1.0 and y <= 1.0:
|
| 987 |
+
x *= image.shape[1]
|
| 988 |
+
y *= image.shape[0]
|
| 989 |
+
|
| 990 |
+
points_list.append([x, y])
|
| 991 |
+
labels_list.append(label)
|
| 992 |
+
|
| 993 |
+
points = np.array(points_list)
|
| 994 |
+
labels = np.array(labels_list)
|
| 995 |
+
|
| 996 |
+
# Делаем предсказание
|
| 997 |
+
if points is not None or box is not None:
|
| 998 |
+
# Решаем использовать ли multimask
|
| 999 |
+
use_multimask = True
|
| 1000 |
+
if points is not None and len(points) > 10:
|
| 1001 |
+
use_multimask = False
|
| 1002 |
+
|
| 1003 |
+
masks, scores, logits = predictor.predict(
|
| 1004 |
+
point_coords=points,
|
| 1005 |
+
point_labels=labels,
|
| 1006 |
+
box=box,
|
| 1007 |
+
multimask_output=use_multimask,
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
# Если multimask, выбираем лучшую
|
| 1011 |
+
if use_multimask and len(masks) > 1:
|
| 1012 |
+
best_idx = np.argmax(scores)
|
| 1013 |
+
masks = masks[best_idx:best_idx+1]
|
| 1014 |
+
scores = scores[best_idx:best_idx+1]
|
| 1015 |
+
|
| 1016 |
+
# Берем первую маску
|
| 1017 |
+
mask = masks[0]
|
| 1018 |
+
score = float(scores[0])
|
| 1019 |
+
|
| 1020 |
+
# Очищаем маску если нужно
|
| 1021 |
+
if request.options.clean_masks:
|
| 1022 |
+
mask = clean_mask(mask, min_area=100)
|
| 1023 |
+
|
| 1024 |
+
# Вычисляем метрики
|
| 1025 |
+
y_coords, x_coords = np.where(mask > 0)
|
| 1026 |
+
|
| 1027 |
+
if len(x_coords) > 0:
|
| 1028 |
+
x_min, x_max = int(x_coords.min()), int(x_coords.max())
|
| 1029 |
+
y_min, y_max = int(y_coords.min()), int(y_coords.max())
|
| 1030 |
+
area = int(mask.sum())
|
| 1031 |
+
center_x = int(x_coords.mean())
|
| 1032 |
+
center_y = int(y_coords.mean())
|
| 1033 |
+
|
| 1034 |
+
# Формируем результат
|
| 1035 |
+
result = {
|
| 1036 |
+
"id": prompt.id,
|
| 1037 |
+
"label": prompt.label,
|
| 1038 |
+
"bbox": {
|
| 1039 |
+
"x_min": x_min,
|
| 1040 |
+
"y_min": y_min,
|
| 1041 |
+
"x_max": x_max,
|
| 1042 |
+
"y_max": y_max,
|
| 1043 |
+
"width": x_max - x_min,
|
| 1044 |
+
"height": y_max - y_min
|
| 1045 |
+
},
|
| 1046 |
+
"area": area,
|
| 1047 |
+
"center": {
|
| 1048 |
+
"x": center_x,
|
| 1049 |
+
"y": center_y
|
| 1050 |
+
},
|
| 1051 |
+
"confidence": score
|
| 1052 |
+
}
|
| 1053 |
+
|
| 1054 |
+
# Добавляем вырезанный объект если нужно
|
| 1055 |
+
if request.options.extract_objects:
|
| 1056 |
+
result["extracted_image"] = extract_object_image(
|
| 1057 |
+
image, mask, clean=request.options.clean_masks
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
# Добавляем контуры если нужно
|
| 1061 |
+
if request.options.include_masks:
|
| 1062 |
+
segments = masks_to_coords(masks, include_contours=True)
|
| 1063 |
+
if segments:
|
| 1064 |
+
result["contours"] = segments[0].get("contours", [])
|
| 1065 |
+
result["mask_rle"] = segments[0].get("mask_rle", {})
|
| 1066 |
+
|
| 1067 |
+
results.append(result)
|
| 1068 |
+
logger.info(f"✓ Промпт #{prompt.id} обработан, confidence: {score:.3f}")
|
| 1069 |
+
else:
|
| 1070 |
+
logger.warning(f"✗ Промпт #{prompt.id} не дал результата")
|
| 1071 |
+
else:
|
| 1072 |
+
logger.warning(f"✗ Промпт #{prompt.id}: нет данных для сегментации")
|
| 1073 |
+
|
| 1074 |
+
except Exception as e:
|
| 1075 |
+
logger.error(f"✗ Ошибка обработки промпта #{prompt.id}: {e}")
|
| 1076 |
+
# Продолжаем обработку остальных промптов
|
| 1077 |
+
continue
|
| 1078 |
+
|
| 1079 |
+
response = {
|
| 1080 |
+
"success": True,
|
| 1081 |
+
"image_size": {
|
| 1082 |
+
"width": int(image.shape[1]),
|
| 1083 |
+
"height": int(image.shape[0])
|
| 1084 |
+
},
|
| 1085 |
+
"results": results
|
| 1086 |
+
}
|
| 1087 |
+
|
| 1088 |
+
logger.info(f"Батчинг завершен: обработано {len(results)} объектов")
|
| 1089 |
+
|
| 1090 |
+
# Сохраняем лог запроса для аудита (только метаданные, без изображений)
|
| 1091 |
+
try:
|
| 1092 |
+
request_dict = request.dict()
|
| 1093 |
+
save_batch_request_log(request_dict, response, image.shape[1], image.shape[0])
|
| 1094 |
+
except Exception as e:
|
| 1095 |
+
logger.warning(f"Не удалось сохранить лог запроса: {e}")
|
| 1096 |
+
|
| 1097 |
+
return convert_to_native_types(response)
|
| 1098 |
+
|
| 1099 |
+
except Exception as e:
|
| 1100 |
+
logger.error(f"Ошибка при батчинг сегментации: {e}")
|
| 1101 |
+
raise HTTPException(status_code=500, detail=f"Ошибка обработки: {str(e)}")
|
| 1102 |
+
|
| 1103 |
+
if __name__ == "__main__":
|
| 1104 |
+
import uvicorn
|
| 1105 |
+
import os
|
| 1106 |
+
|
| 1107 |
+
# Порт из переменной окружения (для HF Spaces) или 8000 по умолчанию
|
| 1108 |
+
port = int(os.getenv("PORT", 8000))
|
| 1109 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
download_model.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Скрипт для скачивания модели SAM2.
|
| 4 |
+
Блин, Facebook не может нормально в pip packaging, поэтому качаем руками.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import urllib.request
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
# Директория для чекпоинтов
|
| 12 |
+
CHECKPOINT_DIR = "checkpoints"
|
| 13 |
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
| 14 |
+
|
| 15 |
+
# Модели на выбор
|
| 16 |
+
MODELS = {
|
| 17 |
+
"tiny": {
|
| 18 |
+
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
|
| 19 |
+
"filename": "sam2.1_hiera_tiny.pt",
|
| 20 |
+
"size": "~39MB"
|
| 21 |
+
},
|
| 22 |
+
"small": {
|
| 23 |
+
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
|
| 24 |
+
"filename": "sam2.1_hiera_small.pt",
|
| 25 |
+
"size": "~46MB"
|
| 26 |
+
},
|
| 27 |
+
"base_plus": {
|
| 28 |
+
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
|
| 29 |
+
"filename": "sam2.1_hiera_base_plus.pt",
|
| 30 |
+
"size": "~81MB"
|
| 31 |
+
},
|
| 32 |
+
"large": {
|
| 33 |
+
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
|
| 34 |
+
"filename": "sam2.1_hiera_large.pt",
|
| 35 |
+
"size": "~224MB"
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
def download_model(model_name="tiny"):
|
| 40 |
+
"""Качает модель, показывает прогресс"""
|
| 41 |
+
if model_name not in MODELS:
|
| 42 |
+
print(f"Неизвестная модель: {model_name}")
|
| 43 |
+
print(f"Доступные: {', '.join(MODELS.keys())}")
|
| 44 |
+
sys.exit(1)
|
| 45 |
+
|
| 46 |
+
model_info = MODELS[model_name]
|
| 47 |
+
filepath = os.path.join(CHECKPOINT_DIR, model_info["filename"])
|
| 48 |
+
|
| 49 |
+
if os.path.exists(filepath):
|
| 50 |
+
print(f"Модель уже скачана: {filepath}")
|
| 51 |
+
return filepath
|
| 52 |
+
|
| 53 |
+
print(f"Качаю {model_name} модель ({model_info['size']})...")
|
| 54 |
+
print(f"URL: {model_info['url']}")
|
| 55 |
+
|
| 56 |
+
def progress_hook(block_num, block_size, total_size):
|
| 57 |
+
downloaded = block_num * block_size
|
| 58 |
+
if total_size > 0:
|
| 59 |
+
percent = min(100, downloaded * 100 / total_size)
|
| 60 |
+
sys.stdout.write(f"\rПрогресс: {percent:.1f}%")
|
| 61 |
+
sys.stdout.flush()
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
urllib.request.urlretrieve(
|
| 65 |
+
model_info["url"],
|
| 66 |
+
filepath,
|
| 67 |
+
reporthook=progress_hook
|
| 68 |
+
)
|
| 69 |
+
print(f"\n✓ Модель скачана: {filepath}")
|
| 70 |
+
return filepath
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"\n✗ Ошибка при скачивании: {e}")
|
| 73 |
+
if os.path.exists(filepath):
|
| 74 |
+
os.remove(filepath)
|
| 75 |
+
sys.exit(1)
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
model_name = sys.argv[1] if len(sys.argv) > 1 else "tiny"
|
| 79 |
+
download_model(model_name)
|
| 80 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.0
|
| 2 |
+
uvicorn[standard]==0.32.0
|
| 3 |
+
python-multipart==0.0.12
|
| 4 |
+
Pillow==11.0.0
|
| 5 |
+
numpy==2.1.0
|
| 6 |
+
torch==2.6.0
|
| 7 |
+
torchvision==0.21.0
|
| 8 |
+
opencv-python==4.10.0.84
|
| 9 |
+
pydantic==2.9.0
|
| 10 |
+
|
sam2_repo/README.md
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SAM 2: Segment Anything in Images and Videos
|
| 2 |
+
|
| 3 |
+
**[AI at Meta, FAIR](https://ai.meta.com/research/)**
|
| 4 |
+
|
| 5 |
+
[Nikhila Ravi](https://nikhilaravi.com/), [Valentin Gabeur](https://gabeur.github.io/), [Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en), [Ronghang Hu](https://ronghanghu.com/), [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en), [Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en), [Haitham Khedr](https://hkhedr.com/), [Roman Rädle](https://scholar.google.de/citations?user=Tpt57v0AAAAJ&hl=en), [Chloe Rolland](https://scholar.google.com/citations?hl=fr&user=n-SnMhoAAAAJ), [Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en), [Eric Mintun](https://ericmintun.github.io/), [Junting Pan](https://junting.github.io/), [Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en), [Nicolas Carion](https://www.nicolascarion.com/), [Chao-Yuan Wu](https://chaoyuan.org/), [Ross Girshick](https://www.rossgirshick.info/), [Piotr Dollár](https://pdollar.github.io/), [Christoph Feichtenhofer](https://feichtenhofer.github.io/)
|
| 6 |
+
|
| 7 |
+
[[`Paper`](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/)] [[`Project`](https://ai.meta.com/sam2)] [[`Demo`](https://sam2.metademolab.com/)] [[`Dataset`](https://ai.meta.com/datasets/segment-anything-video)] [[`Blog`](https://ai.meta.com/blog/segment-anything-2)] [[`BibTeX`](#citing-sam-2)]
|
| 8 |
+
|
| 9 |
+

|
| 10 |
+
|
| 11 |
+
**Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model-in-the-loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains.
|
| 12 |
+
|
| 13 |
+

|
| 14 |
+
|
| 15 |
+
## Latest updates
|
| 16 |
+
|
| 17 |
+
**12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking**
|
| 18 |
+
|
| 19 |
+
- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor`, leading to a major speedup for VOS inference.
|
| 20 |
+
- We update the implementation of `SAM2VideoPredictor` to support independent per-object inference, allowing us to relax the assumption of prompting for multi-object tracking and adding new objects after tracking starts.
|
| 21 |
+
- See [`RELEASE_NOTES.md`](RELEASE_NOTES.md) for full details.
|
| 22 |
+
|
| 23 |
+
**09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released**
|
| 24 |
+
|
| 25 |
+
- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details.
|
| 26 |
+
* To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below.
|
| 27 |
+
- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started.
|
| 28 |
+
- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details.
|
| 29 |
+
|
| 30 |
+
## Installation
|
| 31 |
+
|
| 32 |
+
SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.5.1` and `torchvision>=0.20.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using:
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
git clone https://github.com/facebookresearch/sam2.git && cd sam2
|
| 36 |
+
|
| 37 |
+
pip install -e .
|
| 38 |
+
```
|
| 39 |
+
If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
|
| 40 |
+
|
| 41 |
+
To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by:
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
pip install -e ".[notebooks]"
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
Note:
|
| 48 |
+
1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.5.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.5.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`.
|
| 49 |
+
2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version.
|
| 50 |
+
3. If you see a message like `Failed to build the SAM 2 CUDA extension` during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases).
|
| 51 |
+
|
| 52 |
+
Please see [`INSTALL.md`](./INSTALL.md) for FAQs on potential issues and solutions.
|
| 53 |
+
|
| 54 |
+
## Getting Started
|
| 55 |
+
|
| 56 |
+
### Download Checkpoints
|
| 57 |
+
|
| 58 |
+
First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running:
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
cd checkpoints && \
|
| 62 |
+
./download_ckpts.sh && \
|
| 63 |
+
cd ..
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
or individually from:
|
| 67 |
+
|
| 68 |
+
- [sam2.1_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)
|
| 69 |
+
- [sam2.1_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)
|
| 70 |
+
- [sam2.1_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)
|
| 71 |
+
- [sam2.1_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)
|
| 72 |
+
|
| 73 |
+
(note that these are the improved checkpoints denoted as SAM 2.1; see [Model Description](#model-description) for details.)
|
| 74 |
+
|
| 75 |
+
Then SAM 2 can be used in a few lines as follows for image and video prediction.
|
| 76 |
+
|
| 77 |
+
### Image prediction
|
| 78 |
+
|
| 79 |
+
SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segment-anything) on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The `SAM2ImagePredictor` class has an easy interface for image prompting.
|
| 80 |
+
|
| 81 |
+
```python
|
| 82 |
+
import torch
|
| 83 |
+
from sam2.build_sam import build_sam2
|
| 84 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 85 |
+
|
| 86 |
+
checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
| 87 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 88 |
+
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
|
| 89 |
+
|
| 90 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 91 |
+
predictor.set_image(<your_image>)
|
| 92 |
+
masks, _, _ = predictor.predict(<input_prompts>)
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/image_predictor_example.ipynb)) for static image use cases.
|
| 96 |
+
|
| 97 |
+
SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/automatic_mask_generator_example.ipynb)) for automatic mask generation in images.
|
| 98 |
+
|
| 99 |
+
### Video prediction
|
| 100 |
+
|
| 101 |
+
For promptable segmentation and tracking in videos, we provide a video predictor with APIs for example to add prompts and propagate masklets throughout a video. SAM 2 supports video inference on multiple objects and uses an inference state to keep track of the interactions in each video.
|
| 102 |
+
|
| 103 |
+
```python
|
| 104 |
+
import torch
|
| 105 |
+
from sam2.build_sam import build_sam2_video_predictor
|
| 106 |
+
|
| 107 |
+
checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
| 108 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 109 |
+
predictor = build_sam2_video_predictor(model_cfg, checkpoint)
|
| 110 |
+
|
| 111 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 112 |
+
state = predictor.init_state(<your_video>)
|
| 113 |
+
|
| 114 |
+
# add new prompts and instantly get the output on the same frame
|
| 115 |
+
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
|
| 116 |
+
|
| 117 |
+
# propagate the prompts to get masklets throughout the video
|
| 118 |
+
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
| 119 |
+
...
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/video_predictor_example.ipynb)) for details on how to add click or box prompts, make refinements, and track multiple objects in videos.
|
| 123 |
+
|
| 124 |
+
## Load from 🤗 Hugging Face
|
| 125 |
+
|
| 126 |
+
Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`).
|
| 127 |
+
|
| 128 |
+
For image prediction:
|
| 129 |
+
|
| 130 |
+
```python
|
| 131 |
+
import torch
|
| 132 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 133 |
+
|
| 134 |
+
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
|
| 135 |
+
|
| 136 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 137 |
+
predictor.set_image(<your_image>)
|
| 138 |
+
masks, _, _ = predictor.predict(<input_prompts>)
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
For video prediction:
|
| 142 |
+
|
| 143 |
+
```python
|
| 144 |
+
import torch
|
| 145 |
+
from sam2.sam2_video_predictor import SAM2VideoPredictor
|
| 146 |
+
|
| 147 |
+
predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
|
| 148 |
+
|
| 149 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 150 |
+
state = predictor.init_state(<your_video>)
|
| 151 |
+
|
| 152 |
+
# add new prompts and instantly get the output on the same frame
|
| 153 |
+
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
|
| 154 |
+
|
| 155 |
+
# propagate the prompts to get masklets throughout the video
|
| 156 |
+
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
| 157 |
+
...
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
## Model Description
|
| 161 |
+
|
| 162 |
+
### SAM 2.1 checkpoints
|
| 163 |
+
|
| 164 |
+
The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024.
|
| 165 |
+
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
|
| 166 |
+
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
|
| 167 |
+
| sam2.1_hiera_tiny <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 91.2 | 76.5 | 71.8 | 77.3 |
|
| 168 |
+
| sam2.1_hiera_small <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 84.8 | 76.6 | 73.5 | 78.3 |
|
| 169 |
+
| sam2.1_hiera_base_plus <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 64.1 | 78.2 | 73.7 | 78.2 |
|
| 170 |
+
| sam2.1_hiera_large <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 39.5 | 79.5 | 74.6 | 80.6 |
|
| 171 |
+
|
| 172 |
+
### SAM 2 checkpoints
|
| 173 |
+
|
| 174 |
+
The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows:
|
| 175 |
+
|
| 176 |
+
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
|
| 177 |
+
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
|
| 178 |
+
| sam2_hiera_tiny <br /> ([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 91.5 | 75.0 | 70.9 | 75.3 |
|
| 179 |
+
| sam2_hiera_small <br /> ([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 85.6 | 74.9 | 71.5 | 76.4 |
|
| 180 |
+
| sam2_hiera_base_plus <br /> ([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 64.8 | 74.7 | 72.8 | 75.8 |
|
| 181 |
+
| sam2_hiera_large <br /> ([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 39.7 | 76.0 | 74.6 | 79.8 |
|
| 182 |
+
|
| 183 |
+
Speed measured on an A100 with `torch 2.5.1, cuda 12.4`. See `benchmark.py` for an example on benchmarking (compiling all the model components). Compiling only the image encoder can be more flexible and also provide (a smaller) speed-up (set `compile_image_encoder: True` in the config).
|
| 184 |
+
## Segment Anything Video Dataset
|
| 185 |
+
|
| 186 |
+
See [sav_dataset/README.md](sav_dataset/README.md) for details.
|
| 187 |
+
|
| 188 |
+
## Training SAM 2
|
| 189 |
+
|
| 190 |
+
You can train or fine-tune SAM 2 on custom datasets of images, videos, or both. Please check the training [README](training/README.md) on how to get started.
|
| 191 |
+
|
| 192 |
+
## Web demo for SAM 2
|
| 193 |
+
|
| 194 |
+
We have released the frontend + backend code for the SAM 2 web demo (a locally deployable version similar to https://sam2.metademolab.com/demo). Please see the web demo [README](demo/README.md) for details.
|
| 195 |
+
|
| 196 |
+
## License
|
| 197 |
+
|
| 198 |
+
The SAM 2 model checkpoints, SAM 2 demo code (front-end and back-end), and SAM 2 training code are licensed under [Apache 2.0](./LICENSE), however the [Inter Font](https://github.com/rsms/inter?tab=OFL-1.1-1-ov-file) and [Noto Color Emoji](https://github.com/googlefonts/noto-emoji) used in the SAM 2 demo code are made available under the [SIL Open Font License, version 1.1](https://openfontlicense.org/open-font-license-official-text/).
|
| 199 |
+
|
| 200 |
+
## Contributing
|
| 201 |
+
|
| 202 |
+
See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
|
| 203 |
+
|
| 204 |
+
## Contributors
|
| 205 |
+
|
| 206 |
+
The SAM 2 project was made possible with the help of many contributors (alphabetical):
|
| 207 |
+
|
| 208 |
+
Karen Bergan, Daniel Bolya, Alex Bosenberg, Kai Brown, Vispi Cassod, Christopher Chedeau, Ida Cheng, Luc Dahlin, Shoubhik Debnath, Rene Martinez Doehner, Grant Gardner, Sahir Gomez, Rishi Godugu, Baishan Guo, Caleb Ho, Andrew Huang, Somya Jain, Bob Kamma, Amanda Kallet, Jake Kinney, Alexander Kirillov, Shiva Koduvayur, Devansh Kukreja, Robert Kuo, Aohan Lin, Parth Malani, Jitendra Malik, Mallika Malhotra, Miguel Martin, Alexander Miller, Sasha Mitts, William Ngan, George Orlin, Joelle Pineau, Kate Saenko, Rodrick Shepard, Azita Shokrpour, David Soofian, Jonathan Torres, Jenny Truong, Sagar Vaze, Meng Wang, Claudette Ward, Pengchuan Zhang.
|
| 209 |
+
|
| 210 |
+
Third-party code: we use a GPU-based connected component algorithm adapted from [`cc_torch`](https://github.com/zsef123/Connected_components_PyTorch) (with its license in [`LICENSE_cctorch`](./LICENSE_cctorch)) as an optional post-processing step for the mask predictions.
|
| 211 |
+
|
| 212 |
+
## Citing SAM 2
|
| 213 |
+
|
| 214 |
+
If you use SAM 2 or the SA-V dataset in your research, please use the following BibTeX entry.
|
| 215 |
+
|
| 216 |
+
```bibtex
|
| 217 |
+
@article{ravi2024sam2,
|
| 218 |
+
title={SAM 2: Segment Anything in Images and Videos},
|
| 219 |
+
author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
|
| 220 |
+
journal={arXiv preprint arXiv:2408.00714},
|
| 221 |
+
url={https://arxiv.org/abs/2408.00714},
|
| 222 |
+
year={2024}
|
| 223 |
+
}
|
| 224 |
+
```
|
sam2_repo/checkpoints/download_ckpts.sh
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
|
| 6 |
+
# This source code is licensed under the license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
|
| 9 |
+
# Use either wget or curl to download the checkpoints
|
| 10 |
+
if command -v wget &> /dev/null; then
|
| 11 |
+
CMD="wget"
|
| 12 |
+
elif command -v curl &> /dev/null; then
|
| 13 |
+
CMD="curl -L -O"
|
| 14 |
+
else
|
| 15 |
+
echo "Please install wget or curl to download the checkpoints."
|
| 16 |
+
exit 1
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
# Define the URLs for SAM 2 checkpoints
|
| 20 |
+
# SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824"
|
| 21 |
+
# sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt"
|
| 22 |
+
# sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt"
|
| 23 |
+
# sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt"
|
| 24 |
+
# sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt"
|
| 25 |
+
|
| 26 |
+
# Download each of the four checkpoints using wget
|
| 27 |
+
# echo "Downloading sam2_hiera_tiny.pt checkpoint..."
|
| 28 |
+
# $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }
|
| 29 |
+
|
| 30 |
+
# echo "Downloading sam2_hiera_small.pt checkpoint..."
|
| 31 |
+
# $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }
|
| 32 |
+
|
| 33 |
+
# echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
|
| 34 |
+
# $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }
|
| 35 |
+
|
| 36 |
+
# echo "Downloading sam2_hiera_large.pt checkpoint..."
|
| 37 |
+
# $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }
|
| 38 |
+
|
| 39 |
+
# Define the URLs for SAM 2.1 checkpoints
|
| 40 |
+
SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824"
|
| 41 |
+
sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt"
|
| 42 |
+
sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt"
|
| 43 |
+
sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt"
|
| 44 |
+
sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt"
|
| 45 |
+
|
| 46 |
+
# SAM 2.1 checkpoints
|
| 47 |
+
echo "Downloading sam2.1_hiera_tiny.pt checkpoint..."
|
| 48 |
+
$CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; }
|
| 49 |
+
|
| 50 |
+
echo "Downloading sam2.1_hiera_small.pt checkpoint..."
|
| 51 |
+
$CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; }
|
| 52 |
+
|
| 53 |
+
echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..."
|
| 54 |
+
$CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; }
|
| 55 |
+
|
| 56 |
+
echo "Downloading sam2.1_hiera_large.pt checkpoint..."
|
| 57 |
+
$CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; }
|
| 58 |
+
|
| 59 |
+
echo "All checkpoints are downloaded successfully."
|
sam2_repo/pyproject.toml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = [
|
| 3 |
+
"setuptools>=61.0",
|
| 4 |
+
"torch>=2.5.1",
|
| 5 |
+
]
|
| 6 |
+
build-backend = "setuptools.build_meta"
|
sam2_repo/sam2/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from hydra import initialize_config_module
|
| 8 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 9 |
+
|
| 10 |
+
if not GlobalHydra.instance().is_initialized():
|
| 11 |
+
initialize_config_module("sam2", version_base="1.2")
|
sam2_repo/sam2/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (455 Bytes). View file
|
|
|
sam2_repo/sam2/__pycache__/build_sam.cpython-313.pyc
ADDED
|
Binary file (5.39 kB). View file
|
|
|
sam2_repo/sam2/__pycache__/sam2_image_predictor.cpython-313.pyc
ADDED
|
Binary file (21.9 kB). View file
|
|
|
sam2_repo/sam2/automatic_mask_generator.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
| 13 |
+
|
| 14 |
+
from sam2.modeling.sam2_base import SAM2Base
|
| 15 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 16 |
+
from sam2.utils.amg import (
|
| 17 |
+
area_from_rle,
|
| 18 |
+
batch_iterator,
|
| 19 |
+
batched_mask_to_box,
|
| 20 |
+
box_xyxy_to_xywh,
|
| 21 |
+
build_all_layer_point_grids,
|
| 22 |
+
calculate_stability_score,
|
| 23 |
+
coco_encode_rle,
|
| 24 |
+
generate_crop_boxes,
|
| 25 |
+
is_box_near_crop_edge,
|
| 26 |
+
mask_to_rle_pytorch,
|
| 27 |
+
MaskData,
|
| 28 |
+
remove_small_regions,
|
| 29 |
+
rle_to_mask,
|
| 30 |
+
uncrop_boxes_xyxy,
|
| 31 |
+
uncrop_masks,
|
| 32 |
+
uncrop_points,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SAM2AutomaticMaskGenerator:
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
model: SAM2Base,
|
| 40 |
+
points_per_side: Optional[int] = 32,
|
| 41 |
+
points_per_batch: int = 64,
|
| 42 |
+
pred_iou_thresh: float = 0.8,
|
| 43 |
+
stability_score_thresh: float = 0.95,
|
| 44 |
+
stability_score_offset: float = 1.0,
|
| 45 |
+
mask_threshold: float = 0.0,
|
| 46 |
+
box_nms_thresh: float = 0.7,
|
| 47 |
+
crop_n_layers: int = 0,
|
| 48 |
+
crop_nms_thresh: float = 0.7,
|
| 49 |
+
crop_overlap_ratio: float = 512 / 1500,
|
| 50 |
+
crop_n_points_downscale_factor: int = 1,
|
| 51 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
| 52 |
+
min_mask_region_area: int = 0,
|
| 53 |
+
output_mode: str = "binary_mask",
|
| 54 |
+
use_m2m: bool = False,
|
| 55 |
+
multimask_output: bool = True,
|
| 56 |
+
**kwargs,
|
| 57 |
+
) -> None:
|
| 58 |
+
"""
|
| 59 |
+
Using a SAM 2 model, generates masks for the entire image.
|
| 60 |
+
Generates a grid of point prompts over the image, then filters
|
| 61 |
+
low quality and duplicate masks. The default settings are chosen
|
| 62 |
+
for SAM 2 with a HieraL backbone.
|
| 63 |
+
|
| 64 |
+
Arguments:
|
| 65 |
+
model (Sam): The SAM 2 model to use for mask prediction.
|
| 66 |
+
points_per_side (int or None): The number of points to be sampled
|
| 67 |
+
along one side of the image. The total number of points is
|
| 68 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
| 69 |
+
point sampling.
|
| 70 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
| 71 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
| 72 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
| 73 |
+
model's predicted mask quality.
|
| 74 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
| 75 |
+
the stability of the mask under changes to the cutoff used to binarize
|
| 76 |
+
the model's mask predictions.
|
| 77 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
| 78 |
+
calculated the stability score.
|
| 79 |
+
mask_threshold (float): Threshold for binarizing the mask logits
|
| 80 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
| 81 |
+
suppression to filter duplicate masks.
|
| 82 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
| 83 |
+
crops of the image. Sets the number of layers to run, where each
|
| 84 |
+
layer has 2**i_layer number of image crops.
|
| 85 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
| 86 |
+
suppression to filter duplicate masks between different crops.
|
| 87 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
| 88 |
+
In the first crop layer, crops will overlap by this fraction of
|
| 89 |
+
the image length. Later layers with more crops scale down this overlap.
|
| 90 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
| 91 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
| 92 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
| 93 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
| 94 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
| 95 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
| 96 |
+
to remove disconnected regions and holes in masks with area smaller
|
| 97 |
+
than min_mask_region_area. Requires opencv.
|
| 98 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
| 99 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
| 100 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
| 101 |
+
memory.
|
| 102 |
+
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
|
| 103 |
+
multimask_output (bool): Whether to output multimask at each point of the grid.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
assert (points_per_side is None) != (
|
| 107 |
+
point_grids is None
|
| 108 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
| 109 |
+
if points_per_side is not None:
|
| 110 |
+
self.point_grids = build_all_layer_point_grids(
|
| 111 |
+
points_per_side,
|
| 112 |
+
crop_n_layers,
|
| 113 |
+
crop_n_points_downscale_factor,
|
| 114 |
+
)
|
| 115 |
+
elif point_grids is not None:
|
| 116 |
+
self.point_grids = point_grids
|
| 117 |
+
else:
|
| 118 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
| 119 |
+
|
| 120 |
+
assert output_mode in [
|
| 121 |
+
"binary_mask",
|
| 122 |
+
"uncompressed_rle",
|
| 123 |
+
"coco_rle",
|
| 124 |
+
], f"Unknown output_mode {output_mode}."
|
| 125 |
+
if output_mode == "coco_rle":
|
| 126 |
+
try:
|
| 127 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
| 128 |
+
except ImportError as e:
|
| 129 |
+
print("Please install pycocotools")
|
| 130 |
+
raise e
|
| 131 |
+
|
| 132 |
+
self.predictor = SAM2ImagePredictor(
|
| 133 |
+
model,
|
| 134 |
+
max_hole_area=min_mask_region_area,
|
| 135 |
+
max_sprinkle_area=min_mask_region_area,
|
| 136 |
+
)
|
| 137 |
+
self.points_per_batch = points_per_batch
|
| 138 |
+
self.pred_iou_thresh = pred_iou_thresh
|
| 139 |
+
self.stability_score_thresh = stability_score_thresh
|
| 140 |
+
self.stability_score_offset = stability_score_offset
|
| 141 |
+
self.mask_threshold = mask_threshold
|
| 142 |
+
self.box_nms_thresh = box_nms_thresh
|
| 143 |
+
self.crop_n_layers = crop_n_layers
|
| 144 |
+
self.crop_nms_thresh = crop_nms_thresh
|
| 145 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
| 146 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
| 147 |
+
self.min_mask_region_area = min_mask_region_area
|
| 148 |
+
self.output_mode = output_mode
|
| 149 |
+
self.use_m2m = use_m2m
|
| 150 |
+
self.multimask_output = multimask_output
|
| 151 |
+
|
| 152 |
+
@classmethod
|
| 153 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
|
| 154 |
+
"""
|
| 155 |
+
Load a pretrained model from the Hugging Face hub.
|
| 156 |
+
|
| 157 |
+
Arguments:
|
| 158 |
+
model_id (str): The Hugging Face repository ID.
|
| 159 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
(SAM2AutomaticMaskGenerator): The loaded model.
|
| 163 |
+
"""
|
| 164 |
+
from sam2.build_sam import build_sam2_hf
|
| 165 |
+
|
| 166 |
+
sam_model = build_sam2_hf(model_id, **kwargs)
|
| 167 |
+
return cls(sam_model, **kwargs)
|
| 168 |
+
|
| 169 |
+
@torch.no_grad()
|
| 170 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
| 171 |
+
"""
|
| 172 |
+
Generates masks for the given image.
|
| 173 |
+
|
| 174 |
+
Arguments:
|
| 175 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
| 179 |
+
a dict containing the following keys:
|
| 180 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
| 181 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
| 182 |
+
is a dictionary containing the RLE.
|
| 183 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
| 184 |
+
area (int): The area in pixels of the mask.
|
| 185 |
+
predicted_iou (float): The model's own prediction of the mask's
|
| 186 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
| 187 |
+
point_coords (list(list(float))): The point coordinates input
|
| 188 |
+
to the model to generate this mask.
|
| 189 |
+
stability_score (float): A measure of the mask's quality. This
|
| 190 |
+
is filtered on using the stability_score_thresh parameter.
|
| 191 |
+
crop_box (list(float)): The crop of the image used to generate
|
| 192 |
+
the mask, given in XYWH format.
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
# Generate masks
|
| 196 |
+
mask_data = self._generate_masks(image)
|
| 197 |
+
|
| 198 |
+
# Encode masks
|
| 199 |
+
if self.output_mode == "coco_rle":
|
| 200 |
+
mask_data["segmentations"] = [
|
| 201 |
+
coco_encode_rle(rle) for rle in mask_data["rles"]
|
| 202 |
+
]
|
| 203 |
+
elif self.output_mode == "binary_mask":
|
| 204 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
| 205 |
+
else:
|
| 206 |
+
mask_data["segmentations"] = mask_data["rles"]
|
| 207 |
+
|
| 208 |
+
# Write mask records
|
| 209 |
+
curr_anns = []
|
| 210 |
+
for idx in range(len(mask_data["segmentations"])):
|
| 211 |
+
ann = {
|
| 212 |
+
"segmentation": mask_data["segmentations"][idx],
|
| 213 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
| 214 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
| 215 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
| 216 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
| 217 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
| 218 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
| 219 |
+
}
|
| 220 |
+
curr_anns.append(ann)
|
| 221 |
+
|
| 222 |
+
return curr_anns
|
| 223 |
+
|
| 224 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
| 225 |
+
orig_size = image.shape[:2]
|
| 226 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
| 227 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Iterate over image crops
|
| 231 |
+
data = MaskData()
|
| 232 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
| 233 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
| 234 |
+
data.cat(crop_data)
|
| 235 |
+
|
| 236 |
+
# Remove duplicate masks between crops
|
| 237 |
+
if len(crop_boxes) > 1:
|
| 238 |
+
# Prefer masks from smaller crops
|
| 239 |
+
scores = 1 / box_area(data["crop_boxes"])
|
| 240 |
+
scores = scores.to(data["boxes"].device)
|
| 241 |
+
keep_by_nms = batched_nms(
|
| 242 |
+
data["boxes"].float(),
|
| 243 |
+
scores,
|
| 244 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
| 245 |
+
iou_threshold=self.crop_nms_thresh,
|
| 246 |
+
)
|
| 247 |
+
data.filter(keep_by_nms)
|
| 248 |
+
data.to_numpy()
|
| 249 |
+
return data
|
| 250 |
+
|
| 251 |
+
def _process_crop(
|
| 252 |
+
self,
|
| 253 |
+
image: np.ndarray,
|
| 254 |
+
crop_box: List[int],
|
| 255 |
+
crop_layer_idx: int,
|
| 256 |
+
orig_size: Tuple[int, ...],
|
| 257 |
+
) -> MaskData:
|
| 258 |
+
# Crop the image and calculate embeddings
|
| 259 |
+
x0, y0, x1, y1 = crop_box
|
| 260 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
| 261 |
+
cropped_im_size = cropped_im.shape[:2]
|
| 262 |
+
self.predictor.set_image(cropped_im)
|
| 263 |
+
|
| 264 |
+
# Get points for this crop
|
| 265 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
| 266 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
| 267 |
+
|
| 268 |
+
# Generate masks for this crop in batches
|
| 269 |
+
data = MaskData()
|
| 270 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
| 271 |
+
batch_data = self._process_batch(
|
| 272 |
+
points, cropped_im_size, crop_box, orig_size, normalize=True
|
| 273 |
+
)
|
| 274 |
+
data.cat(batch_data)
|
| 275 |
+
del batch_data
|
| 276 |
+
self.predictor.reset_predictor()
|
| 277 |
+
|
| 278 |
+
# Remove duplicates within this crop.
|
| 279 |
+
keep_by_nms = batched_nms(
|
| 280 |
+
data["boxes"].float(),
|
| 281 |
+
data["iou_preds"],
|
| 282 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
| 283 |
+
iou_threshold=self.box_nms_thresh,
|
| 284 |
+
)
|
| 285 |
+
data.filter(keep_by_nms)
|
| 286 |
+
|
| 287 |
+
# Return to the original image frame
|
| 288 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
| 289 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
| 290 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
| 291 |
+
|
| 292 |
+
return data
|
| 293 |
+
|
| 294 |
+
def _process_batch(
|
| 295 |
+
self,
|
| 296 |
+
points: np.ndarray,
|
| 297 |
+
im_size: Tuple[int, ...],
|
| 298 |
+
crop_box: List[int],
|
| 299 |
+
orig_size: Tuple[int, ...],
|
| 300 |
+
normalize=False,
|
| 301 |
+
) -> MaskData:
|
| 302 |
+
orig_h, orig_w = orig_size
|
| 303 |
+
|
| 304 |
+
# Run model on this batch
|
| 305 |
+
points = torch.as_tensor(
|
| 306 |
+
points, dtype=torch.float32, device=self.predictor.device
|
| 307 |
+
)
|
| 308 |
+
in_points = self.predictor._transforms.transform_coords(
|
| 309 |
+
points, normalize=normalize, orig_hw=im_size
|
| 310 |
+
)
|
| 311 |
+
in_labels = torch.ones(
|
| 312 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
| 313 |
+
)
|
| 314 |
+
masks, iou_preds, low_res_masks = self.predictor._predict(
|
| 315 |
+
in_points[:, None, :],
|
| 316 |
+
in_labels[:, None],
|
| 317 |
+
multimask_output=self.multimask_output,
|
| 318 |
+
return_logits=True,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Serialize predictions and store in MaskData
|
| 322 |
+
data = MaskData(
|
| 323 |
+
masks=masks.flatten(0, 1),
|
| 324 |
+
iou_preds=iou_preds.flatten(0, 1),
|
| 325 |
+
points=points.repeat_interleave(masks.shape[1], dim=0),
|
| 326 |
+
low_res_masks=low_res_masks.flatten(0, 1),
|
| 327 |
+
)
|
| 328 |
+
del masks
|
| 329 |
+
|
| 330 |
+
if not self.use_m2m:
|
| 331 |
+
# Filter by predicted IoU
|
| 332 |
+
if self.pred_iou_thresh > 0.0:
|
| 333 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
| 334 |
+
data.filter(keep_mask)
|
| 335 |
+
|
| 336 |
+
# Calculate and filter by stability score
|
| 337 |
+
data["stability_score"] = calculate_stability_score(
|
| 338 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
| 339 |
+
)
|
| 340 |
+
if self.stability_score_thresh > 0.0:
|
| 341 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 342 |
+
data.filter(keep_mask)
|
| 343 |
+
else:
|
| 344 |
+
# One step refinement using previous mask predictions
|
| 345 |
+
in_points = self.predictor._transforms.transform_coords(
|
| 346 |
+
data["points"], normalize=normalize, orig_hw=im_size
|
| 347 |
+
)
|
| 348 |
+
labels = torch.ones(
|
| 349 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
| 350 |
+
)
|
| 351 |
+
masks, ious = self.refine_with_m2m(
|
| 352 |
+
in_points, labels, data["low_res_masks"], self.points_per_batch
|
| 353 |
+
)
|
| 354 |
+
data["masks"] = masks.squeeze(1)
|
| 355 |
+
data["iou_preds"] = ious.squeeze(1)
|
| 356 |
+
|
| 357 |
+
if self.pred_iou_thresh > 0.0:
|
| 358 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
| 359 |
+
data.filter(keep_mask)
|
| 360 |
+
|
| 361 |
+
data["stability_score"] = calculate_stability_score(
|
| 362 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
| 363 |
+
)
|
| 364 |
+
if self.stability_score_thresh > 0.0:
|
| 365 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 366 |
+
data.filter(keep_mask)
|
| 367 |
+
|
| 368 |
+
# Threshold masks and calculate boxes
|
| 369 |
+
data["masks"] = data["masks"] > self.mask_threshold
|
| 370 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
| 371 |
+
|
| 372 |
+
# Filter boxes that touch crop boundaries
|
| 373 |
+
keep_mask = ~is_box_near_crop_edge(
|
| 374 |
+
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
|
| 375 |
+
)
|
| 376 |
+
if not torch.all(keep_mask):
|
| 377 |
+
data.filter(keep_mask)
|
| 378 |
+
|
| 379 |
+
# Compress to RLE
|
| 380 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
| 381 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
| 382 |
+
del data["masks"]
|
| 383 |
+
|
| 384 |
+
return data
|
| 385 |
+
|
| 386 |
+
@staticmethod
|
| 387 |
+
def postprocess_small_regions(
|
| 388 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
| 389 |
+
) -> MaskData:
|
| 390 |
+
"""
|
| 391 |
+
Removes small disconnected regions and holes in masks, then reruns
|
| 392 |
+
box NMS to remove any new duplicates.
|
| 393 |
+
|
| 394 |
+
Edits mask_data in place.
|
| 395 |
+
|
| 396 |
+
Requires open-cv as a dependency.
|
| 397 |
+
"""
|
| 398 |
+
if len(mask_data["rles"]) == 0:
|
| 399 |
+
return mask_data
|
| 400 |
+
|
| 401 |
+
# Filter small disconnected regions and holes
|
| 402 |
+
new_masks = []
|
| 403 |
+
scores = []
|
| 404 |
+
for rle in mask_data["rles"]:
|
| 405 |
+
mask = rle_to_mask(rle)
|
| 406 |
+
|
| 407 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
| 408 |
+
unchanged = not changed
|
| 409 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
| 410 |
+
unchanged = unchanged and not changed
|
| 411 |
+
|
| 412 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
| 413 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
| 414 |
+
# so NMS will prefer ones that didn't need postprocessing
|
| 415 |
+
scores.append(float(unchanged))
|
| 416 |
+
|
| 417 |
+
# Recalculate boxes and remove any new duplicates
|
| 418 |
+
masks = torch.cat(new_masks, dim=0)
|
| 419 |
+
boxes = batched_mask_to_box(masks)
|
| 420 |
+
keep_by_nms = batched_nms(
|
| 421 |
+
boxes.float(),
|
| 422 |
+
torch.as_tensor(scores),
|
| 423 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
| 424 |
+
iou_threshold=nms_thresh,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# Only recalculate RLEs for masks that have changed
|
| 428 |
+
for i_mask in keep_by_nms:
|
| 429 |
+
if scores[i_mask] == 0.0:
|
| 430 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
| 431 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
| 432 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
| 433 |
+
mask_data.filter(keep_by_nms)
|
| 434 |
+
|
| 435 |
+
return mask_data
|
| 436 |
+
|
| 437 |
+
def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
|
| 438 |
+
new_masks = []
|
| 439 |
+
new_iou_preds = []
|
| 440 |
+
|
| 441 |
+
for cur_points, cur_point_labels, low_res_mask in batch_iterator(
|
| 442 |
+
points_per_batch, points, point_labels, low_res_masks
|
| 443 |
+
):
|
| 444 |
+
best_masks, best_iou_preds, _ = self.predictor._predict(
|
| 445 |
+
cur_points[:, None, :],
|
| 446 |
+
cur_point_labels[:, None],
|
| 447 |
+
mask_input=low_res_mask[:, None, :],
|
| 448 |
+
multimask_output=False,
|
| 449 |
+
return_logits=True,
|
| 450 |
+
)
|
| 451 |
+
new_masks.append(best_masks)
|
| 452 |
+
new_iou_preds.append(best_iou_preds)
|
| 453 |
+
masks = torch.cat(new_masks, dim=0)
|
| 454 |
+
return masks, torch.cat(new_iou_preds, dim=0)
|
sam2_repo/sam2/benchmark.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from sam2.build_sam import build_sam2_video_predictor
|
| 15 |
+
|
| 16 |
+
# Only cuda supported
|
| 17 |
+
assert torch.cuda.is_available()
|
| 18 |
+
device = torch.device("cuda")
|
| 19 |
+
|
| 20 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
| 21 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
| 22 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
| 23 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 24 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 25 |
+
|
| 26 |
+
# Config and checkpoint
|
| 27 |
+
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
|
| 28 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
|
| 29 |
+
|
| 30 |
+
# Build video predictor with vos_optimized=True setting
|
| 31 |
+
predictor = build_sam2_video_predictor(
|
| 32 |
+
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Initialize with video
|
| 37 |
+
video_dir = "notebooks/videos/bedroom"
|
| 38 |
+
# scan all the JPEG frame names in this directory
|
| 39 |
+
frame_names = [
|
| 40 |
+
p
|
| 41 |
+
for p in os.listdir(video_dir)
|
| 42 |
+
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
| 43 |
+
]
|
| 44 |
+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
| 45 |
+
inference_state = predictor.init_state(video_path=video_dir)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Number of runs, warmup etc
|
| 49 |
+
warm_up, runs = 5, 25
|
| 50 |
+
verbose = True
|
| 51 |
+
num_frames = len(frame_names)
|
| 52 |
+
total, count = 0, 0
|
| 53 |
+
torch.cuda.empty_cache()
|
| 54 |
+
|
| 55 |
+
# We will select an object with a click.
|
| 56 |
+
# See video_predictor_example.ipynb for more detailed explanation
|
| 57 |
+
ann_frame_idx, ann_obj_id = 0, 1
|
| 58 |
+
# Add a positive click at (x, y) = (210, 350)
|
| 59 |
+
# For labels, `1` means positive click
|
| 60 |
+
points = np.array([[210, 350]], dtype=np.float32)
|
| 61 |
+
labels = np.array([1], np.int32)
|
| 62 |
+
|
| 63 |
+
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
|
| 64 |
+
inference_state=inference_state,
|
| 65 |
+
frame_idx=ann_frame_idx,
|
| 66 |
+
obj_id=ann_obj_id,
|
| 67 |
+
points=points,
|
| 68 |
+
labels=labels,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Warmup and then average FPS over several runs
|
| 72 |
+
with torch.autocast("cuda", torch.bfloat16):
|
| 73 |
+
with torch.inference_mode():
|
| 74 |
+
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
|
| 75 |
+
start = time.time()
|
| 76 |
+
# Start tracking
|
| 77 |
+
for (
|
| 78 |
+
out_frame_idx,
|
| 79 |
+
out_obj_ids,
|
| 80 |
+
out_mask_logits,
|
| 81 |
+
) in predictor.propagate_in_video(inference_state):
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
end = time.time()
|
| 85 |
+
total += end - start
|
| 86 |
+
count += 1
|
| 87 |
+
if i == warm_up - 1:
|
| 88 |
+
print("Warmup FPS: ", count * num_frames / total)
|
| 89 |
+
total = 0
|
| 90 |
+
count = 0
|
| 91 |
+
|
| 92 |
+
print("FPS: ", count * num_frames / total)
|
sam2_repo/sam2/build_sam.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from hydra import compose
|
| 12 |
+
from hydra.utils import instantiate
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
|
| 15 |
+
import sam2
|
| 16 |
+
|
| 17 |
+
# Check if the user is running Python from the parent directory of the sam2 repo
|
| 18 |
+
# (i.e. the directory where this repo is cloned into) -- this is not supported since
|
| 19 |
+
# it could shadow the sam2 package and cause issues.
|
| 20 |
+
if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
|
| 21 |
+
# If the user has "sam2/sam2" in their path, they are likey importing the repo itself
|
| 22 |
+
# as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
|
| 23 |
+
# This typically happens because the user is running Python from the parent directory
|
| 24 |
+
# that contains the sam2 repo they cloned.
|
| 25 |
+
raise RuntimeError(
|
| 26 |
+
"You're likely running Python from the parent directory of the sam2 repository "
|
| 27 |
+
"(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
|
| 28 |
+
"This is not supported since the `sam2` Python package could be shadowed by the "
|
| 29 |
+
"repository name (the repository is also named `sam2` and contains the Python package "
|
| 30 |
+
"in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
|
| 31 |
+
"rather than its parent dir, or from your home directory) after installing SAM 2."
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
HF_MODEL_ID_TO_FILENAMES = {
|
| 36 |
+
"facebook/sam2-hiera-tiny": (
|
| 37 |
+
"configs/sam2/sam2_hiera_t.yaml",
|
| 38 |
+
"sam2_hiera_tiny.pt",
|
| 39 |
+
),
|
| 40 |
+
"facebook/sam2-hiera-small": (
|
| 41 |
+
"configs/sam2/sam2_hiera_s.yaml",
|
| 42 |
+
"sam2_hiera_small.pt",
|
| 43 |
+
),
|
| 44 |
+
"facebook/sam2-hiera-base-plus": (
|
| 45 |
+
"configs/sam2/sam2_hiera_b+.yaml",
|
| 46 |
+
"sam2_hiera_base_plus.pt",
|
| 47 |
+
),
|
| 48 |
+
"facebook/sam2-hiera-large": (
|
| 49 |
+
"configs/sam2/sam2_hiera_l.yaml",
|
| 50 |
+
"sam2_hiera_large.pt",
|
| 51 |
+
),
|
| 52 |
+
"facebook/sam2.1-hiera-tiny": (
|
| 53 |
+
"configs/sam2.1/sam2.1_hiera_t.yaml",
|
| 54 |
+
"sam2.1_hiera_tiny.pt",
|
| 55 |
+
),
|
| 56 |
+
"facebook/sam2.1-hiera-small": (
|
| 57 |
+
"configs/sam2.1/sam2.1_hiera_s.yaml",
|
| 58 |
+
"sam2.1_hiera_small.pt",
|
| 59 |
+
),
|
| 60 |
+
"facebook/sam2.1-hiera-base-plus": (
|
| 61 |
+
"configs/sam2.1/sam2.1_hiera_b+.yaml",
|
| 62 |
+
"sam2.1_hiera_base_plus.pt",
|
| 63 |
+
),
|
| 64 |
+
"facebook/sam2.1-hiera-large": (
|
| 65 |
+
"configs/sam2.1/sam2.1_hiera_l.yaml",
|
| 66 |
+
"sam2.1_hiera_large.pt",
|
| 67 |
+
),
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def build_sam2(
|
| 72 |
+
config_file,
|
| 73 |
+
ckpt_path=None,
|
| 74 |
+
device="cuda",
|
| 75 |
+
mode="eval",
|
| 76 |
+
hydra_overrides_extra=[],
|
| 77 |
+
apply_postprocessing=True,
|
| 78 |
+
**kwargs,
|
| 79 |
+
):
|
| 80 |
+
|
| 81 |
+
if apply_postprocessing:
|
| 82 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
| 83 |
+
hydra_overrides_extra += [
|
| 84 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
| 85 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
| 86 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
| 87 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
| 88 |
+
]
|
| 89 |
+
# Read config and init model
|
| 90 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
|
| 91 |
+
OmegaConf.resolve(cfg)
|
| 92 |
+
model = instantiate(cfg.model, _recursive_=True)
|
| 93 |
+
_load_checkpoint(model, ckpt_path)
|
| 94 |
+
model = model.to(device)
|
| 95 |
+
if mode == "eval":
|
| 96 |
+
model.eval()
|
| 97 |
+
return model
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def build_sam2_video_predictor(
|
| 101 |
+
config_file,
|
| 102 |
+
ckpt_path=None,
|
| 103 |
+
device="cuda",
|
| 104 |
+
mode="eval",
|
| 105 |
+
hydra_overrides_extra=[],
|
| 106 |
+
apply_postprocessing=True,
|
| 107 |
+
vos_optimized=False,
|
| 108 |
+
**kwargs,
|
| 109 |
+
):
|
| 110 |
+
hydra_overrides = [
|
| 111 |
+
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
| 112 |
+
]
|
| 113 |
+
if vos_optimized:
|
| 114 |
+
hydra_overrides = [
|
| 115 |
+
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
|
| 116 |
+
"++model.compile_image_encoder=True", # Let sam2_base handle this
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
if apply_postprocessing:
|
| 120 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
| 121 |
+
hydra_overrides_extra += [
|
| 122 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
| 123 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
| 124 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
| 125 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
| 126 |
+
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
|
| 127 |
+
"++model.binarize_mask_from_pts_for_mem_enc=true",
|
| 128 |
+
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
|
| 129 |
+
"++model.fill_hole_area=8",
|
| 130 |
+
]
|
| 131 |
+
hydra_overrides.extend(hydra_overrides_extra)
|
| 132 |
+
|
| 133 |
+
# Read config and init model
|
| 134 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides)
|
| 135 |
+
OmegaConf.resolve(cfg)
|
| 136 |
+
model = instantiate(cfg.model, _recursive_=True)
|
| 137 |
+
_load_checkpoint(model, ckpt_path)
|
| 138 |
+
model = model.to(device)
|
| 139 |
+
if mode == "eval":
|
| 140 |
+
model.eval()
|
| 141 |
+
return model
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _hf_download(model_id):
|
| 145 |
+
from huggingface_hub import hf_hub_download
|
| 146 |
+
|
| 147 |
+
config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
|
| 148 |
+
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
| 149 |
+
return config_name, ckpt_path
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def build_sam2_hf(model_id, **kwargs):
|
| 153 |
+
config_name, ckpt_path = _hf_download(model_id)
|
| 154 |
+
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def build_sam2_video_predictor_hf(model_id, **kwargs):
|
| 158 |
+
config_name, ckpt_path = _hf_download(model_id)
|
| 159 |
+
return build_sam2_video_predictor(
|
| 160 |
+
config_file=config_name, ckpt_path=ckpt_path, **kwargs
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _load_checkpoint(model, ckpt_path):
|
| 165 |
+
if ckpt_path is not None:
|
| 166 |
+
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
|
| 167 |
+
missing_keys, unexpected_keys = model.load_state_dict(sd)
|
| 168 |
+
if missing_keys:
|
| 169 |
+
logging.error(missing_keys)
|
| 170 |
+
raise RuntimeError()
|
| 171 |
+
if unexpected_keys:
|
| 172 |
+
logging.error(unexpected_keys)
|
| 173 |
+
raise RuntimeError()
|
| 174 |
+
logging.info("Loaded checkpoint sucessfully")
|
sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 112
|
| 12 |
+
num_heads: 2
|
| 13 |
+
neck:
|
| 14 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 15 |
+
position_encoding:
|
| 16 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 17 |
+
num_pos_feats: 256
|
| 18 |
+
normalize: true
|
| 19 |
+
scale: null
|
| 20 |
+
temperature: 10000
|
| 21 |
+
d_model: 256
|
| 22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
| 23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 24 |
+
fpn_interp_model: nearest
|
| 25 |
+
|
| 26 |
+
memory_attention:
|
| 27 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 28 |
+
d_model: 256
|
| 29 |
+
pos_enc_at_input: true
|
| 30 |
+
layer:
|
| 31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 32 |
+
activation: relu
|
| 33 |
+
dim_feedforward: 2048
|
| 34 |
+
dropout: 0.1
|
| 35 |
+
pos_enc_at_attn: false
|
| 36 |
+
self_attention:
|
| 37 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 38 |
+
rope_theta: 10000.0
|
| 39 |
+
feat_sizes: [64, 64]
|
| 40 |
+
embedding_dim: 256
|
| 41 |
+
num_heads: 1
|
| 42 |
+
downsample_rate: 1
|
| 43 |
+
dropout: 0.1
|
| 44 |
+
d_model: 256
|
| 45 |
+
pos_enc_at_cross_attn_keys: true
|
| 46 |
+
pos_enc_at_cross_attn_queries: false
|
| 47 |
+
cross_attention:
|
| 48 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 49 |
+
rope_theta: 10000.0
|
| 50 |
+
feat_sizes: [64, 64]
|
| 51 |
+
rope_k_repeat: True
|
| 52 |
+
embedding_dim: 256
|
| 53 |
+
num_heads: 1
|
| 54 |
+
downsample_rate: 1
|
| 55 |
+
dropout: 0.1
|
| 56 |
+
kv_in_dim: 64
|
| 57 |
+
num_layers: 4
|
| 58 |
+
|
| 59 |
+
memory_encoder:
|
| 60 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 61 |
+
out_dim: 64
|
| 62 |
+
position_encoding:
|
| 63 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 64 |
+
num_pos_feats: 64
|
| 65 |
+
normalize: true
|
| 66 |
+
scale: null
|
| 67 |
+
temperature: 10000
|
| 68 |
+
mask_downsampler:
|
| 69 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 70 |
+
kernel_size: 3
|
| 71 |
+
stride: 2
|
| 72 |
+
padding: 1
|
| 73 |
+
fuser:
|
| 74 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 75 |
+
layer:
|
| 76 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 77 |
+
dim: 256
|
| 78 |
+
kernel_size: 7
|
| 79 |
+
padding: 3
|
| 80 |
+
layer_scale_init_value: 1e-6
|
| 81 |
+
use_dwconv: True # depth-wise convs
|
| 82 |
+
num_layers: 2
|
| 83 |
+
|
| 84 |
+
num_maskmem: 7
|
| 85 |
+
image_size: 1024
|
| 86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 89 |
+
use_mask_input_as_output_without_sam: true
|
| 90 |
+
# Memory
|
| 91 |
+
directly_add_no_mem_embed: true
|
| 92 |
+
no_obj_embed_spatial: true
|
| 93 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 94 |
+
use_high_res_features_in_sam: true
|
| 95 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 96 |
+
multimask_output_in_sam: true
|
| 97 |
+
# SAM heads
|
| 98 |
+
iou_prediction_use_sigmoid: True
|
| 99 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 100 |
+
use_obj_ptrs_in_encoder: true
|
| 101 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 102 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 103 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 105 |
+
# object occlusion prediction
|
| 106 |
+
pred_obj_scores: true
|
| 107 |
+
pred_obj_scores_mlp: true
|
| 108 |
+
fixed_no_obj_ptr: true
|
| 109 |
+
# multimask tracking settings
|
| 110 |
+
multimask_output_for_tracking: true
|
| 111 |
+
use_multimask_token_for_obj_ptr: true
|
| 112 |
+
multimask_min_pt_num: 0
|
| 113 |
+
multimask_max_pt_num: 1
|
| 114 |
+
use_mlp_for_obj_ptr_proj: true
|
| 115 |
+
# Compilation flag
|
| 116 |
+
compile_image_encoder: False
|
sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_l.yaml
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 144
|
| 12 |
+
num_heads: 2
|
| 13 |
+
stages: [2, 6, 36, 4]
|
| 14 |
+
global_att_blocks: [23, 33, 43]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
window_spec: [8, 4, 16, 8]
|
| 17 |
+
neck:
|
| 18 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 19 |
+
position_encoding:
|
| 20 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 21 |
+
num_pos_feats: 256
|
| 22 |
+
normalize: true
|
| 23 |
+
scale: null
|
| 24 |
+
temperature: 10000
|
| 25 |
+
d_model: 256
|
| 26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
| 27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 28 |
+
fpn_interp_model: nearest
|
| 29 |
+
|
| 30 |
+
memory_attention:
|
| 31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 32 |
+
d_model: 256
|
| 33 |
+
pos_enc_at_input: true
|
| 34 |
+
layer:
|
| 35 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 36 |
+
activation: relu
|
| 37 |
+
dim_feedforward: 2048
|
| 38 |
+
dropout: 0.1
|
| 39 |
+
pos_enc_at_attn: false
|
| 40 |
+
self_attention:
|
| 41 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 42 |
+
rope_theta: 10000.0
|
| 43 |
+
feat_sizes: [64, 64]
|
| 44 |
+
embedding_dim: 256
|
| 45 |
+
num_heads: 1
|
| 46 |
+
downsample_rate: 1
|
| 47 |
+
dropout: 0.1
|
| 48 |
+
d_model: 256
|
| 49 |
+
pos_enc_at_cross_attn_keys: true
|
| 50 |
+
pos_enc_at_cross_attn_queries: false
|
| 51 |
+
cross_attention:
|
| 52 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 53 |
+
rope_theta: 10000.0
|
| 54 |
+
feat_sizes: [64, 64]
|
| 55 |
+
rope_k_repeat: True
|
| 56 |
+
embedding_dim: 256
|
| 57 |
+
num_heads: 1
|
| 58 |
+
downsample_rate: 1
|
| 59 |
+
dropout: 0.1
|
| 60 |
+
kv_in_dim: 64
|
| 61 |
+
num_layers: 4
|
| 62 |
+
|
| 63 |
+
memory_encoder:
|
| 64 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 65 |
+
out_dim: 64
|
| 66 |
+
position_encoding:
|
| 67 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 68 |
+
num_pos_feats: 64
|
| 69 |
+
normalize: true
|
| 70 |
+
scale: null
|
| 71 |
+
temperature: 10000
|
| 72 |
+
mask_downsampler:
|
| 73 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 74 |
+
kernel_size: 3
|
| 75 |
+
stride: 2
|
| 76 |
+
padding: 1
|
| 77 |
+
fuser:
|
| 78 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 79 |
+
layer:
|
| 80 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 81 |
+
dim: 256
|
| 82 |
+
kernel_size: 7
|
| 83 |
+
padding: 3
|
| 84 |
+
layer_scale_init_value: 1e-6
|
| 85 |
+
use_dwconv: True # depth-wise convs
|
| 86 |
+
num_layers: 2
|
| 87 |
+
|
| 88 |
+
num_maskmem: 7
|
| 89 |
+
image_size: 1024
|
| 90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
no_obj_embed_spatial: true
|
| 97 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 98 |
+
use_high_res_features_in_sam: true
|
| 99 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 100 |
+
multimask_output_in_sam: true
|
| 101 |
+
# SAM heads
|
| 102 |
+
iou_prediction_use_sigmoid: True
|
| 103 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 104 |
+
use_obj_ptrs_in_encoder: true
|
| 105 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 106 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 107 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 108 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 109 |
+
# object occlusion prediction
|
| 110 |
+
pred_obj_scores: true
|
| 111 |
+
pred_obj_scores_mlp: true
|
| 112 |
+
fixed_no_obj_ptr: true
|
| 113 |
+
# multimask tracking settings
|
| 114 |
+
multimask_output_for_tracking: true
|
| 115 |
+
use_multimask_token_for_obj_ptr: true
|
| 116 |
+
multimask_min_pt_num: 0
|
| 117 |
+
multimask_max_pt_num: 1
|
| 118 |
+
use_mlp_for_obj_ptr_proj: true
|
| 119 |
+
# Compilation flag
|
| 120 |
+
compile_image_encoder: False
|
sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_s.yaml
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 11, 2]
|
| 14 |
+
global_att_blocks: [7, 10, 13]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [64, 64]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [64, 64]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
|
| 62 |
+
memory_encoder:
|
| 63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
+
out_dim: 64
|
| 65 |
+
position_encoding:
|
| 66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
+
num_pos_feats: 64
|
| 68 |
+
normalize: true
|
| 69 |
+
scale: null
|
| 70 |
+
temperature: 10000
|
| 71 |
+
mask_downsampler:
|
| 72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
+
kernel_size: 3
|
| 74 |
+
stride: 2
|
| 75 |
+
padding: 1
|
| 76 |
+
fuser:
|
| 77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
+
layer:
|
| 79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
+
dim: 256
|
| 81 |
+
kernel_size: 7
|
| 82 |
+
padding: 3
|
| 83 |
+
layer_scale_init_value: 1e-6
|
| 84 |
+
use_dwconv: True # depth-wise convs
|
| 85 |
+
num_layers: 2
|
| 86 |
+
|
| 87 |
+
num_maskmem: 7
|
| 88 |
+
image_size: 1024
|
| 89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 90 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 91 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 92 |
+
use_mask_input_as_output_without_sam: true
|
| 93 |
+
# Memory
|
| 94 |
+
directly_add_no_mem_embed: true
|
| 95 |
+
no_obj_embed_spatial: true
|
| 96 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 97 |
+
use_high_res_features_in_sam: true
|
| 98 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 99 |
+
multimask_output_in_sam: true
|
| 100 |
+
# SAM heads
|
| 101 |
+
iou_prediction_use_sigmoid: True
|
| 102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 103 |
+
use_obj_ptrs_in_encoder: true
|
| 104 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 105 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 106 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 107 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 108 |
+
# object occlusion prediction
|
| 109 |
+
pred_obj_scores: true
|
| 110 |
+
pred_obj_scores_mlp: true
|
| 111 |
+
fixed_no_obj_ptr: true
|
| 112 |
+
# multimask tracking settings
|
| 113 |
+
multimask_output_for_tracking: true
|
| 114 |
+
use_multimask_token_for_obj_ptr: true
|
| 115 |
+
multimask_min_pt_num: 0
|
| 116 |
+
multimask_max_pt_num: 1
|
| 117 |
+
use_mlp_for_obj_ptr_proj: true
|
| 118 |
+
# Compilation flag
|
| 119 |
+
compile_image_encoder: False
|
sam2_repo/sam2/configs/sam2.1/sam2.1_hiera_t.yaml
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 7, 2]
|
| 14 |
+
global_att_blocks: [5, 7, 9]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [64, 64]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [64, 64]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
|
| 62 |
+
memory_encoder:
|
| 63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
+
out_dim: 64
|
| 65 |
+
position_encoding:
|
| 66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
+
num_pos_feats: 64
|
| 68 |
+
normalize: true
|
| 69 |
+
scale: null
|
| 70 |
+
temperature: 10000
|
| 71 |
+
mask_downsampler:
|
| 72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
+
kernel_size: 3
|
| 74 |
+
stride: 2
|
| 75 |
+
padding: 1
|
| 76 |
+
fuser:
|
| 77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
+
layer:
|
| 79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
+
dim: 256
|
| 81 |
+
kernel_size: 7
|
| 82 |
+
padding: 3
|
| 83 |
+
layer_scale_init_value: 1e-6
|
| 84 |
+
use_dwconv: True # depth-wise convs
|
| 85 |
+
num_layers: 2
|
| 86 |
+
|
| 87 |
+
num_maskmem: 7
|
| 88 |
+
image_size: 1024
|
| 89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 90 |
+
# SAM decoder
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
no_obj_embed_spatial: true
|
| 97 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 98 |
+
use_high_res_features_in_sam: true
|
| 99 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 100 |
+
multimask_output_in_sam: true
|
| 101 |
+
# SAM heads
|
| 102 |
+
iou_prediction_use_sigmoid: True
|
| 103 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 104 |
+
use_obj_ptrs_in_encoder: true
|
| 105 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 106 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 107 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 108 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 109 |
+
# object occlusion prediction
|
| 110 |
+
pred_obj_scores: true
|
| 111 |
+
pred_obj_scores_mlp: true
|
| 112 |
+
fixed_no_obj_ptr: true
|
| 113 |
+
# multimask tracking settings
|
| 114 |
+
multimask_output_for_tracking: true
|
| 115 |
+
use_multimask_token_for_obj_ptr: true
|
| 116 |
+
multimask_min_pt_num: 0
|
| 117 |
+
multimask_max_pt_num: 1
|
| 118 |
+
use_mlp_for_obj_ptr_proj: true
|
| 119 |
+
# Compilation flag
|
| 120 |
+
# HieraT does not currently support compilation, should always be set to False
|
| 121 |
+
compile_image_encoder: False
|
sam2_repo/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
scratch:
|
| 4 |
+
resolution: 1024
|
| 5 |
+
train_batch_size: 1
|
| 6 |
+
num_train_workers: 10
|
| 7 |
+
num_frames: 8
|
| 8 |
+
max_num_objects: 3
|
| 9 |
+
base_lr: 5.0e-6
|
| 10 |
+
vision_lr: 3.0e-06
|
| 11 |
+
phases_per_epoch: 1
|
| 12 |
+
num_epochs: 40
|
| 13 |
+
|
| 14 |
+
dataset:
|
| 15 |
+
# PATHS to Dataset
|
| 16 |
+
img_folder: null # PATH to MOSE JPEGImages folder
|
| 17 |
+
gt_folder: null # PATH to MOSE Annotations folder
|
| 18 |
+
file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
|
| 19 |
+
multiplier: 2
|
| 20 |
+
|
| 21 |
+
# Video transforms
|
| 22 |
+
vos:
|
| 23 |
+
train_transforms:
|
| 24 |
+
- _target_: training.dataset.transforms.ComposeAPI
|
| 25 |
+
transforms:
|
| 26 |
+
- _target_: training.dataset.transforms.RandomHorizontalFlip
|
| 27 |
+
consistent_transform: True
|
| 28 |
+
- _target_: training.dataset.transforms.RandomAffine
|
| 29 |
+
degrees: 25
|
| 30 |
+
shear: 20
|
| 31 |
+
image_interpolation: bilinear
|
| 32 |
+
consistent_transform: True
|
| 33 |
+
- _target_: training.dataset.transforms.RandomResizeAPI
|
| 34 |
+
sizes: ${scratch.resolution}
|
| 35 |
+
square: true
|
| 36 |
+
consistent_transform: True
|
| 37 |
+
- _target_: training.dataset.transforms.ColorJitter
|
| 38 |
+
consistent_transform: True
|
| 39 |
+
brightness: 0.1
|
| 40 |
+
contrast: 0.03
|
| 41 |
+
saturation: 0.03
|
| 42 |
+
hue: null
|
| 43 |
+
- _target_: training.dataset.transforms.RandomGrayscale
|
| 44 |
+
p: 0.05
|
| 45 |
+
consistent_transform: True
|
| 46 |
+
- _target_: training.dataset.transforms.ColorJitter
|
| 47 |
+
consistent_transform: False
|
| 48 |
+
brightness: 0.1
|
| 49 |
+
contrast: 0.05
|
| 50 |
+
saturation: 0.05
|
| 51 |
+
hue: null
|
| 52 |
+
- _target_: training.dataset.transforms.ToTensorAPI
|
| 53 |
+
- _target_: training.dataset.transforms.NormalizeAPI
|
| 54 |
+
mean: [0.485, 0.456, 0.406]
|
| 55 |
+
std: [0.229, 0.224, 0.225]
|
| 56 |
+
|
| 57 |
+
trainer:
|
| 58 |
+
_target_: training.trainer.Trainer
|
| 59 |
+
mode: train_only
|
| 60 |
+
max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
|
| 61 |
+
accelerator: cuda
|
| 62 |
+
seed_value: 123
|
| 63 |
+
|
| 64 |
+
model:
|
| 65 |
+
_target_: training.model.sam2.SAM2Train
|
| 66 |
+
image_encoder:
|
| 67 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 68 |
+
scalp: 1
|
| 69 |
+
trunk:
|
| 70 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 71 |
+
embed_dim: 112
|
| 72 |
+
num_heads: 2
|
| 73 |
+
drop_path_rate: 0.1
|
| 74 |
+
neck:
|
| 75 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 76 |
+
position_encoding:
|
| 77 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 78 |
+
num_pos_feats: 256
|
| 79 |
+
normalize: true
|
| 80 |
+
scale: null
|
| 81 |
+
temperature: 10000
|
| 82 |
+
d_model: 256
|
| 83 |
+
backbone_channel_list: [896, 448, 224, 112]
|
| 84 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 85 |
+
fpn_interp_model: nearest
|
| 86 |
+
|
| 87 |
+
memory_attention:
|
| 88 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 89 |
+
d_model: 256
|
| 90 |
+
pos_enc_at_input: true
|
| 91 |
+
layer:
|
| 92 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 93 |
+
activation: relu
|
| 94 |
+
dim_feedforward: 2048
|
| 95 |
+
dropout: 0.1
|
| 96 |
+
pos_enc_at_attn: false
|
| 97 |
+
self_attention:
|
| 98 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 99 |
+
rope_theta: 10000.0
|
| 100 |
+
feat_sizes: [64, 64]
|
| 101 |
+
embedding_dim: 256
|
| 102 |
+
num_heads: 1
|
| 103 |
+
downsample_rate: 1
|
| 104 |
+
dropout: 0.1
|
| 105 |
+
d_model: 256
|
| 106 |
+
pos_enc_at_cross_attn_keys: true
|
| 107 |
+
pos_enc_at_cross_attn_queries: false
|
| 108 |
+
cross_attention:
|
| 109 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 110 |
+
rope_theta: 10000.0
|
| 111 |
+
feat_sizes: [64, 64]
|
| 112 |
+
rope_k_repeat: True
|
| 113 |
+
embedding_dim: 256
|
| 114 |
+
num_heads: 1
|
| 115 |
+
downsample_rate: 1
|
| 116 |
+
dropout: 0.1
|
| 117 |
+
kv_in_dim: 64
|
| 118 |
+
num_layers: 4
|
| 119 |
+
|
| 120 |
+
memory_encoder:
|
| 121 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 122 |
+
out_dim: 64
|
| 123 |
+
position_encoding:
|
| 124 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 125 |
+
num_pos_feats: 64
|
| 126 |
+
normalize: true
|
| 127 |
+
scale: null
|
| 128 |
+
temperature: 10000
|
| 129 |
+
mask_downsampler:
|
| 130 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 131 |
+
kernel_size: 3
|
| 132 |
+
stride: 2
|
| 133 |
+
padding: 1
|
| 134 |
+
fuser:
|
| 135 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 136 |
+
layer:
|
| 137 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 138 |
+
dim: 256
|
| 139 |
+
kernel_size: 7
|
| 140 |
+
padding: 3
|
| 141 |
+
layer_scale_init_value: 1e-6
|
| 142 |
+
use_dwconv: True # depth-wise convs
|
| 143 |
+
num_layers: 2
|
| 144 |
+
|
| 145 |
+
num_maskmem: 7
|
| 146 |
+
image_size: ${scratch.resolution}
|
| 147 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 148 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 149 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 150 |
+
use_mask_input_as_output_without_sam: true
|
| 151 |
+
# Memory
|
| 152 |
+
directly_add_no_mem_embed: true
|
| 153 |
+
no_obj_embed_spatial: true
|
| 154 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 155 |
+
use_high_res_features_in_sam: true
|
| 156 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 157 |
+
multimask_output_in_sam: true
|
| 158 |
+
# SAM heads
|
| 159 |
+
iou_prediction_use_sigmoid: True
|
| 160 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 161 |
+
use_obj_ptrs_in_encoder: true
|
| 162 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 163 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 164 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 165 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 166 |
+
# object occlusion prediction
|
| 167 |
+
pred_obj_scores: true
|
| 168 |
+
pred_obj_scores_mlp: true
|
| 169 |
+
fixed_no_obj_ptr: true
|
| 170 |
+
# multimask tracking settings
|
| 171 |
+
multimask_output_for_tracking: true
|
| 172 |
+
use_multimask_token_for_obj_ptr: true
|
| 173 |
+
multimask_min_pt_num: 0
|
| 174 |
+
multimask_max_pt_num: 1
|
| 175 |
+
use_mlp_for_obj_ptr_proj: true
|
| 176 |
+
# Compilation flag
|
| 177 |
+
# compile_image_encoder: False
|
| 178 |
+
|
| 179 |
+
####### Training specific params #######
|
| 180 |
+
# box/point input and corrections
|
| 181 |
+
prob_to_use_pt_input_for_train: 0.5
|
| 182 |
+
prob_to_use_pt_input_for_eval: 0.0
|
| 183 |
+
prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
|
| 184 |
+
prob_to_use_box_input_for_eval: 0.0
|
| 185 |
+
prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
|
| 186 |
+
num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
|
| 187 |
+
num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
|
| 188 |
+
rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
|
| 189 |
+
add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
|
| 190 |
+
# maximum 2 initial conditioning frames
|
| 191 |
+
num_init_cond_frames_for_train: 2
|
| 192 |
+
rand_init_cond_frames_for_train: True # random 1~2
|
| 193 |
+
num_correction_pt_per_frame: 7
|
| 194 |
+
use_act_ckpt_iterative_pt_sampling: false
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
num_init_cond_frames_for_eval: 1 # only mask on the first frame
|
| 199 |
+
forward_backbone_per_frame_for_eval: True
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
data:
|
| 203 |
+
train:
|
| 204 |
+
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
|
| 205 |
+
phases_per_epoch: ${scratch.phases_per_epoch}
|
| 206 |
+
batch_sizes:
|
| 207 |
+
- ${scratch.train_batch_size}
|
| 208 |
+
|
| 209 |
+
datasets:
|
| 210 |
+
- _target_: training.dataset.utils.RepeatFactorWrapper
|
| 211 |
+
dataset:
|
| 212 |
+
_target_: training.dataset.utils.ConcatDataset
|
| 213 |
+
datasets:
|
| 214 |
+
- _target_: training.dataset.vos_dataset.VOSDataset
|
| 215 |
+
transforms: ${vos.train_transforms}
|
| 216 |
+
training: true
|
| 217 |
+
video_dataset:
|
| 218 |
+
_target_: training.dataset.vos_raw_dataset.PNGRawDataset
|
| 219 |
+
img_folder: ${dataset.img_folder}
|
| 220 |
+
gt_folder: ${dataset.gt_folder}
|
| 221 |
+
file_list_txt: ${dataset.file_list_txt}
|
| 222 |
+
sampler:
|
| 223 |
+
_target_: training.dataset.vos_sampler.RandomUniformSampler
|
| 224 |
+
num_frames: ${scratch.num_frames}
|
| 225 |
+
max_num_objects: ${scratch.max_num_objects}
|
| 226 |
+
multiplier: ${dataset.multiplier}
|
| 227 |
+
shuffle: True
|
| 228 |
+
num_workers: ${scratch.num_train_workers}
|
| 229 |
+
pin_memory: True
|
| 230 |
+
drop_last: True
|
| 231 |
+
collate_fn:
|
| 232 |
+
_target_: training.utils.data_utils.collate_fn
|
| 233 |
+
_partial_: true
|
| 234 |
+
dict_key: all
|
| 235 |
+
|
| 236 |
+
optim:
|
| 237 |
+
amp:
|
| 238 |
+
enabled: True
|
| 239 |
+
amp_dtype: bfloat16
|
| 240 |
+
|
| 241 |
+
optimizer:
|
| 242 |
+
_target_: torch.optim.AdamW
|
| 243 |
+
|
| 244 |
+
gradient_clip:
|
| 245 |
+
_target_: training.optimizer.GradientClipper
|
| 246 |
+
max_norm: 0.1
|
| 247 |
+
norm_type: 2
|
| 248 |
+
|
| 249 |
+
param_group_modifiers:
|
| 250 |
+
- _target_: training.optimizer.layer_decay_param_modifier
|
| 251 |
+
_partial_: True
|
| 252 |
+
layer_decay_value: 0.9
|
| 253 |
+
apply_to: 'image_encoder.trunk'
|
| 254 |
+
overrides:
|
| 255 |
+
- pattern: '*pos_embed*'
|
| 256 |
+
value: 1.0
|
| 257 |
+
|
| 258 |
+
options:
|
| 259 |
+
lr:
|
| 260 |
+
- scheduler:
|
| 261 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
| 262 |
+
start_value: ${scratch.base_lr}
|
| 263 |
+
end_value: ${divide:${scratch.base_lr},10}
|
| 264 |
+
- scheduler:
|
| 265 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
| 266 |
+
start_value: ${scratch.vision_lr}
|
| 267 |
+
end_value: ${divide:${scratch.vision_lr},10}
|
| 268 |
+
param_names:
|
| 269 |
+
- 'image_encoder.*'
|
| 270 |
+
weight_decay:
|
| 271 |
+
- scheduler:
|
| 272 |
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
| 273 |
+
value: 0.1
|
| 274 |
+
- scheduler:
|
| 275 |
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
| 276 |
+
value: 0.0
|
| 277 |
+
param_names:
|
| 278 |
+
- '*bias*'
|
| 279 |
+
module_cls_names: ['torch.nn.LayerNorm']
|
| 280 |
+
|
| 281 |
+
loss:
|
| 282 |
+
all:
|
| 283 |
+
_target_: training.loss_fns.MultiStepMultiMasksAndIous
|
| 284 |
+
weight_dict:
|
| 285 |
+
loss_mask: 20
|
| 286 |
+
loss_dice: 1
|
| 287 |
+
loss_iou: 1
|
| 288 |
+
loss_class: 1
|
| 289 |
+
supervise_all_iou: true
|
| 290 |
+
iou_use_l1_loss: true
|
| 291 |
+
pred_obj_scores: true
|
| 292 |
+
focal_gamma_obj_score: 0.0
|
| 293 |
+
focal_alpha_obj_score: -1.0
|
| 294 |
+
|
| 295 |
+
distributed:
|
| 296 |
+
backend: nccl
|
| 297 |
+
find_unused_parameters: True
|
| 298 |
+
|
| 299 |
+
logging:
|
| 300 |
+
tensorboard_writer:
|
| 301 |
+
_target_: training.utils.logger.make_tensorboard_logger
|
| 302 |
+
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
| 303 |
+
flush_secs: 120
|
| 304 |
+
should_log: True
|
| 305 |
+
log_dir: ${launcher.experiment_log_dir}/logs
|
| 306 |
+
log_freq: 10
|
| 307 |
+
|
| 308 |
+
# initialize from a SAM 2 checkpoint
|
| 309 |
+
checkpoint:
|
| 310 |
+
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
| 311 |
+
save_freq: 0 # 0 only last checkpoint is saved.
|
| 312 |
+
model_weight_initializer:
|
| 313 |
+
_partial_: True
|
| 314 |
+
_target_: training.utils.checkpoint_utils.load_state_dict_into_model
|
| 315 |
+
strict: True
|
| 316 |
+
ignore_unexpected_keys: null
|
| 317 |
+
ignore_missing_keys: null
|
| 318 |
+
|
| 319 |
+
state_dict:
|
| 320 |
+
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
| 321 |
+
checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
|
| 322 |
+
ckpt_state_dict_keys: ['model']
|
| 323 |
+
|
| 324 |
+
launcher:
|
| 325 |
+
num_nodes: 1
|
| 326 |
+
gpus_per_node: 8
|
| 327 |
+
experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
|
| 328 |
+
|
| 329 |
+
# SLURM args if running on a cluster
|
| 330 |
+
submitit:
|
| 331 |
+
partition: null
|
| 332 |
+
account: null
|
| 333 |
+
qos: null
|
| 334 |
+
cpus_per_task: 10
|
| 335 |
+
use_cluster: false
|
| 336 |
+
timeout_hour: 24
|
| 337 |
+
name: null
|
| 338 |
+
port_range: [10000, 65000]
|
| 339 |
+
|
sam2_repo/sam2/configs/sam2/sam2_hiera_b+.yaml
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 112
|
| 12 |
+
num_heads: 2
|
| 13 |
+
neck:
|
| 14 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 15 |
+
position_encoding:
|
| 16 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 17 |
+
num_pos_feats: 256
|
| 18 |
+
normalize: true
|
| 19 |
+
scale: null
|
| 20 |
+
temperature: 10000
|
| 21 |
+
d_model: 256
|
| 22 |
+
backbone_channel_list: [896, 448, 224, 112]
|
| 23 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 24 |
+
fpn_interp_model: nearest
|
| 25 |
+
|
| 26 |
+
memory_attention:
|
| 27 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 28 |
+
d_model: 256
|
| 29 |
+
pos_enc_at_input: true
|
| 30 |
+
layer:
|
| 31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 32 |
+
activation: relu
|
| 33 |
+
dim_feedforward: 2048
|
| 34 |
+
dropout: 0.1
|
| 35 |
+
pos_enc_at_attn: false
|
| 36 |
+
self_attention:
|
| 37 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 38 |
+
rope_theta: 10000.0
|
| 39 |
+
feat_sizes: [64, 64]
|
| 40 |
+
embedding_dim: 256
|
| 41 |
+
num_heads: 1
|
| 42 |
+
downsample_rate: 1
|
| 43 |
+
dropout: 0.1
|
| 44 |
+
d_model: 256
|
| 45 |
+
pos_enc_at_cross_attn_keys: true
|
| 46 |
+
pos_enc_at_cross_attn_queries: false
|
| 47 |
+
cross_attention:
|
| 48 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 49 |
+
rope_theta: 10000.0
|
| 50 |
+
feat_sizes: [64, 64]
|
| 51 |
+
rope_k_repeat: True
|
| 52 |
+
embedding_dim: 256
|
| 53 |
+
num_heads: 1
|
| 54 |
+
downsample_rate: 1
|
| 55 |
+
dropout: 0.1
|
| 56 |
+
kv_in_dim: 64
|
| 57 |
+
num_layers: 4
|
| 58 |
+
|
| 59 |
+
memory_encoder:
|
| 60 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 61 |
+
out_dim: 64
|
| 62 |
+
position_encoding:
|
| 63 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 64 |
+
num_pos_feats: 64
|
| 65 |
+
normalize: true
|
| 66 |
+
scale: null
|
| 67 |
+
temperature: 10000
|
| 68 |
+
mask_downsampler:
|
| 69 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 70 |
+
kernel_size: 3
|
| 71 |
+
stride: 2
|
| 72 |
+
padding: 1
|
| 73 |
+
fuser:
|
| 74 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 75 |
+
layer:
|
| 76 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 77 |
+
dim: 256
|
| 78 |
+
kernel_size: 7
|
| 79 |
+
padding: 3
|
| 80 |
+
layer_scale_init_value: 1e-6
|
| 81 |
+
use_dwconv: True # depth-wise convs
|
| 82 |
+
num_layers: 2
|
| 83 |
+
|
| 84 |
+
num_maskmem: 7
|
| 85 |
+
image_size: 1024
|
| 86 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 87 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 88 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 89 |
+
use_mask_input_as_output_without_sam: true
|
| 90 |
+
# Memory
|
| 91 |
+
directly_add_no_mem_embed: true
|
| 92 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 93 |
+
use_high_res_features_in_sam: true
|
| 94 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 95 |
+
multimask_output_in_sam: true
|
| 96 |
+
# SAM heads
|
| 97 |
+
iou_prediction_use_sigmoid: True
|
| 98 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 99 |
+
use_obj_ptrs_in_encoder: true
|
| 100 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 101 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 102 |
+
# object occlusion prediction
|
| 103 |
+
pred_obj_scores: true
|
| 104 |
+
pred_obj_scores_mlp: true
|
| 105 |
+
fixed_no_obj_ptr: true
|
| 106 |
+
# multimask tracking settings
|
| 107 |
+
multimask_output_for_tracking: true
|
| 108 |
+
use_multimask_token_for_obj_ptr: true
|
| 109 |
+
multimask_min_pt_num: 0
|
| 110 |
+
multimask_max_pt_num: 1
|
| 111 |
+
use_mlp_for_obj_ptr_proj: true
|
| 112 |
+
# Compilation flag
|
| 113 |
+
compile_image_encoder: False
|
sam2_repo/sam2/configs/sam2/sam2_hiera_l.yaml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 144
|
| 12 |
+
num_heads: 2
|
| 13 |
+
stages: [2, 6, 36, 4]
|
| 14 |
+
global_att_blocks: [23, 33, 43]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
window_spec: [8, 4, 16, 8]
|
| 17 |
+
neck:
|
| 18 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 19 |
+
position_encoding:
|
| 20 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 21 |
+
num_pos_feats: 256
|
| 22 |
+
normalize: true
|
| 23 |
+
scale: null
|
| 24 |
+
temperature: 10000
|
| 25 |
+
d_model: 256
|
| 26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
| 27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 28 |
+
fpn_interp_model: nearest
|
| 29 |
+
|
| 30 |
+
memory_attention:
|
| 31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 32 |
+
d_model: 256
|
| 33 |
+
pos_enc_at_input: true
|
| 34 |
+
layer:
|
| 35 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 36 |
+
activation: relu
|
| 37 |
+
dim_feedforward: 2048
|
| 38 |
+
dropout: 0.1
|
| 39 |
+
pos_enc_at_attn: false
|
| 40 |
+
self_attention:
|
| 41 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 42 |
+
rope_theta: 10000.0
|
| 43 |
+
feat_sizes: [64, 64]
|
| 44 |
+
embedding_dim: 256
|
| 45 |
+
num_heads: 1
|
| 46 |
+
downsample_rate: 1
|
| 47 |
+
dropout: 0.1
|
| 48 |
+
d_model: 256
|
| 49 |
+
pos_enc_at_cross_attn_keys: true
|
| 50 |
+
pos_enc_at_cross_attn_queries: false
|
| 51 |
+
cross_attention:
|
| 52 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 53 |
+
rope_theta: 10000.0
|
| 54 |
+
feat_sizes: [64, 64]
|
| 55 |
+
rope_k_repeat: True
|
| 56 |
+
embedding_dim: 256
|
| 57 |
+
num_heads: 1
|
| 58 |
+
downsample_rate: 1
|
| 59 |
+
dropout: 0.1
|
| 60 |
+
kv_in_dim: 64
|
| 61 |
+
num_layers: 4
|
| 62 |
+
|
| 63 |
+
memory_encoder:
|
| 64 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 65 |
+
out_dim: 64
|
| 66 |
+
position_encoding:
|
| 67 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 68 |
+
num_pos_feats: 64
|
| 69 |
+
normalize: true
|
| 70 |
+
scale: null
|
| 71 |
+
temperature: 10000
|
| 72 |
+
mask_downsampler:
|
| 73 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 74 |
+
kernel_size: 3
|
| 75 |
+
stride: 2
|
| 76 |
+
padding: 1
|
| 77 |
+
fuser:
|
| 78 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 79 |
+
layer:
|
| 80 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 81 |
+
dim: 256
|
| 82 |
+
kernel_size: 7
|
| 83 |
+
padding: 3
|
| 84 |
+
layer_scale_init_value: 1e-6
|
| 85 |
+
use_dwconv: True # depth-wise convs
|
| 86 |
+
num_layers: 2
|
| 87 |
+
|
| 88 |
+
num_maskmem: 7
|
| 89 |
+
image_size: 1024
|
| 90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 97 |
+
use_high_res_features_in_sam: true
|
| 98 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 99 |
+
multimask_output_in_sam: true
|
| 100 |
+
# SAM heads
|
| 101 |
+
iou_prediction_use_sigmoid: True
|
| 102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 103 |
+
use_obj_ptrs_in_encoder: true
|
| 104 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 106 |
+
# object occlusion prediction
|
| 107 |
+
pred_obj_scores: true
|
| 108 |
+
pred_obj_scores_mlp: true
|
| 109 |
+
fixed_no_obj_ptr: true
|
| 110 |
+
# multimask tracking settings
|
| 111 |
+
multimask_output_for_tracking: true
|
| 112 |
+
use_multimask_token_for_obj_ptr: true
|
| 113 |
+
multimask_min_pt_num: 0
|
| 114 |
+
multimask_max_pt_num: 1
|
| 115 |
+
use_mlp_for_obj_ptr_proj: true
|
| 116 |
+
# Compilation flag
|
| 117 |
+
compile_image_encoder: False
|
sam2_repo/sam2/configs/sam2/sam2_hiera_s.yaml
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 11, 2]
|
| 14 |
+
global_att_blocks: [7, 10, 13]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [64, 64]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [64, 64]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
|
| 62 |
+
memory_encoder:
|
| 63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
+
out_dim: 64
|
| 65 |
+
position_encoding:
|
| 66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
+
num_pos_feats: 64
|
| 68 |
+
normalize: true
|
| 69 |
+
scale: null
|
| 70 |
+
temperature: 10000
|
| 71 |
+
mask_downsampler:
|
| 72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
+
kernel_size: 3
|
| 74 |
+
stride: 2
|
| 75 |
+
padding: 1
|
| 76 |
+
fuser:
|
| 77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
+
layer:
|
| 79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
+
dim: 256
|
| 81 |
+
kernel_size: 7
|
| 82 |
+
padding: 3
|
| 83 |
+
layer_scale_init_value: 1e-6
|
| 84 |
+
use_dwconv: True # depth-wise convs
|
| 85 |
+
num_layers: 2
|
| 86 |
+
|
| 87 |
+
num_maskmem: 7
|
| 88 |
+
image_size: 1024
|
| 89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 90 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 91 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 92 |
+
use_mask_input_as_output_without_sam: true
|
| 93 |
+
# Memory
|
| 94 |
+
directly_add_no_mem_embed: true
|
| 95 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 96 |
+
use_high_res_features_in_sam: true
|
| 97 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 98 |
+
multimask_output_in_sam: true
|
| 99 |
+
# SAM heads
|
| 100 |
+
iou_prediction_use_sigmoid: True
|
| 101 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 102 |
+
use_obj_ptrs_in_encoder: true
|
| 103 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 105 |
+
# object occlusion prediction
|
| 106 |
+
pred_obj_scores: true
|
| 107 |
+
pred_obj_scores_mlp: true
|
| 108 |
+
fixed_no_obj_ptr: true
|
| 109 |
+
# multimask tracking settings
|
| 110 |
+
multimask_output_for_tracking: true
|
| 111 |
+
use_multimask_token_for_obj_ptr: true
|
| 112 |
+
multimask_min_pt_num: 0
|
| 113 |
+
multimask_max_pt_num: 1
|
| 114 |
+
use_mlp_for_obj_ptr_proj: true
|
| 115 |
+
# Compilation flag
|
| 116 |
+
compile_image_encoder: False
|
sam2_repo/sam2/configs/sam2/sam2_hiera_t.yaml
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 7, 2]
|
| 14 |
+
global_att_blocks: [5, 7, 9]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [64, 64]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [64, 64]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
|
| 62 |
+
memory_encoder:
|
| 63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 64 |
+
out_dim: 64
|
| 65 |
+
position_encoding:
|
| 66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 67 |
+
num_pos_feats: 64
|
| 68 |
+
normalize: true
|
| 69 |
+
scale: null
|
| 70 |
+
temperature: 10000
|
| 71 |
+
mask_downsampler:
|
| 72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 73 |
+
kernel_size: 3
|
| 74 |
+
stride: 2
|
| 75 |
+
padding: 1
|
| 76 |
+
fuser:
|
| 77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 78 |
+
layer:
|
| 79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 80 |
+
dim: 256
|
| 81 |
+
kernel_size: 7
|
| 82 |
+
padding: 3
|
| 83 |
+
layer_scale_init_value: 1e-6
|
| 84 |
+
use_dwconv: True # depth-wise convs
|
| 85 |
+
num_layers: 2
|
| 86 |
+
|
| 87 |
+
num_maskmem: 7
|
| 88 |
+
image_size: 1024
|
| 89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 90 |
+
# SAM decoder
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 97 |
+
use_high_res_features_in_sam: true
|
| 98 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 99 |
+
multimask_output_in_sam: true
|
| 100 |
+
# SAM heads
|
| 101 |
+
iou_prediction_use_sigmoid: True
|
| 102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 103 |
+
use_obj_ptrs_in_encoder: true
|
| 104 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 106 |
+
# object occlusion prediction
|
| 107 |
+
pred_obj_scores: true
|
| 108 |
+
pred_obj_scores_mlp: true
|
| 109 |
+
fixed_no_obj_ptr: true
|
| 110 |
+
# multimask tracking settings
|
| 111 |
+
multimask_output_for_tracking: true
|
| 112 |
+
use_multimask_token_for_obj_ptr: true
|
| 113 |
+
multimask_min_pt_num: 0
|
| 114 |
+
multimask_max_pt_num: 1
|
| 115 |
+
use_mlp_for_obj_ptr_proj: true
|
| 116 |
+
# Compilation flag
|
| 117 |
+
# HieraT does not currently support compilation, should always be set to False
|
| 118 |
+
compile_image_encoder: False
|
sam2_repo/sam2/csrc/connected_components.cu
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
// adapted from https://github.com/zsef123/Connected_components_PyTorch
|
| 8 |
+
// with license found in the LICENSE_cctorch file in the root directory.
|
| 9 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 10 |
+
#include <cuda.h>
|
| 11 |
+
#include <cuda_runtime.h>
|
| 12 |
+
#include <torch/extension.h>
|
| 13 |
+
#include <torch/script.h>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
// 2d
|
| 17 |
+
#define BLOCK_ROWS 16
|
| 18 |
+
#define BLOCK_COLS 16
|
| 19 |
+
|
| 20 |
+
namespace cc2d {
|
| 21 |
+
|
| 22 |
+
template <typename T>
|
| 23 |
+
__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
|
| 24 |
+
return (bitmap >> pos) & 1;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
__device__ int32_t find(const int32_t* s_buf, int32_t n) {
|
| 28 |
+
while (s_buf[n] != n)
|
| 29 |
+
n = s_buf[n];
|
| 30 |
+
return n;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
|
| 34 |
+
const int32_t id = n;
|
| 35 |
+
while (s_buf[n] != n) {
|
| 36 |
+
n = s_buf[n];
|
| 37 |
+
s_buf[id] = n;
|
| 38 |
+
}
|
| 39 |
+
return n;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
|
| 43 |
+
bool done;
|
| 44 |
+
do {
|
| 45 |
+
a = find(s_buf, a);
|
| 46 |
+
b = find(s_buf, b);
|
| 47 |
+
|
| 48 |
+
if (a < b) {
|
| 49 |
+
int32_t old = atomicMin(s_buf + b, a);
|
| 50 |
+
done = (old == b);
|
| 51 |
+
b = old;
|
| 52 |
+
} else if (b < a) {
|
| 53 |
+
int32_t old = atomicMin(s_buf + a, b);
|
| 54 |
+
done = (old == a);
|
| 55 |
+
a = old;
|
| 56 |
+
} else
|
| 57 |
+
done = true;
|
| 58 |
+
|
| 59 |
+
} while (!done);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
__global__ void
|
| 63 |
+
init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
|
| 64 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
|
| 65 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
|
| 66 |
+
const uint32_t idx = row * W + col;
|
| 67 |
+
|
| 68 |
+
if (row < H && col < W)
|
| 69 |
+
label[idx] = idx;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
__global__ void
|
| 73 |
+
merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
|
| 74 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
|
| 75 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
|
| 76 |
+
const uint32_t idx = row * W + col;
|
| 77 |
+
|
| 78 |
+
if (row >= H || col >= W)
|
| 79 |
+
return;
|
| 80 |
+
|
| 81 |
+
uint32_t P = 0;
|
| 82 |
+
|
| 83 |
+
if (img[idx])
|
| 84 |
+
P |= 0x777;
|
| 85 |
+
if (row + 1 < H && img[idx + W])
|
| 86 |
+
P |= 0x777 << 4;
|
| 87 |
+
if (col + 1 < W && img[idx + 1])
|
| 88 |
+
P |= 0x777 << 1;
|
| 89 |
+
|
| 90 |
+
if (col == 0)
|
| 91 |
+
P &= 0xEEEE;
|
| 92 |
+
if (col + 1 >= W)
|
| 93 |
+
P &= 0x3333;
|
| 94 |
+
else if (col + 2 >= W)
|
| 95 |
+
P &= 0x7777;
|
| 96 |
+
|
| 97 |
+
if (row == 0)
|
| 98 |
+
P &= 0xFFF0;
|
| 99 |
+
if (row + 1 >= H)
|
| 100 |
+
P &= 0xFF;
|
| 101 |
+
|
| 102 |
+
if (P > 0) {
|
| 103 |
+
// If need check about top-left pixel(if flag the first bit) and hit the
|
| 104 |
+
// top-left pixel
|
| 105 |
+
if (hasBit(P, 0) && img[idx - W - 1]) {
|
| 106 |
+
union_(label, idx, idx - 2 * W - 2); // top left block
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
|
| 110 |
+
union_(label, idx, idx - 2 * W); // top bottom block
|
| 111 |
+
|
| 112 |
+
if (hasBit(P, 3) && img[idx + 2 - W])
|
| 113 |
+
union_(label, idx, idx - 2 * W + 2); // top right block
|
| 114 |
+
|
| 115 |
+
if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
|
| 116 |
+
union_(label, idx, idx - 2); // just left block
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
__global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
|
| 121 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
|
| 122 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
|
| 123 |
+
const uint32_t idx = row * W + col;
|
| 124 |
+
|
| 125 |
+
if (row < H && col < W)
|
| 126 |
+
find_n_compress(label, idx);
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
__global__ void final_labeling(
|
| 130 |
+
const uint8_t* img,
|
| 131 |
+
int32_t* label,
|
| 132 |
+
const int32_t W,
|
| 133 |
+
const int32_t H) {
|
| 134 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
|
| 135 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
|
| 136 |
+
const uint32_t idx = row * W + col;
|
| 137 |
+
|
| 138 |
+
if (row >= H || col >= W)
|
| 139 |
+
return;
|
| 140 |
+
|
| 141 |
+
int32_t y = label[idx] + 1;
|
| 142 |
+
|
| 143 |
+
if (img[idx])
|
| 144 |
+
label[idx] = y;
|
| 145 |
+
else
|
| 146 |
+
label[idx] = 0;
|
| 147 |
+
|
| 148 |
+
if (col + 1 < W) {
|
| 149 |
+
if (img[idx + 1])
|
| 150 |
+
label[idx + 1] = y;
|
| 151 |
+
else
|
| 152 |
+
label[idx + 1] = 0;
|
| 153 |
+
|
| 154 |
+
if (row + 1 < H) {
|
| 155 |
+
if (img[idx + W + 1])
|
| 156 |
+
label[idx + W + 1] = y;
|
| 157 |
+
else
|
| 158 |
+
label[idx + W + 1] = 0;
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
if (row + 1 < H) {
|
| 163 |
+
if (img[idx + W])
|
| 164 |
+
label[idx + W] = y;
|
| 165 |
+
else
|
| 166 |
+
label[idx + W] = 0;
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
__global__ void init_counting(
|
| 171 |
+
const int32_t* label,
|
| 172 |
+
int32_t* count_init,
|
| 173 |
+
const int32_t W,
|
| 174 |
+
const int32_t H) {
|
| 175 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
|
| 176 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
|
| 177 |
+
const uint32_t idx = row * W + col;
|
| 178 |
+
|
| 179 |
+
if (row >= H || col >= W)
|
| 180 |
+
return;
|
| 181 |
+
|
| 182 |
+
int32_t y = label[idx];
|
| 183 |
+
if (y > 0) {
|
| 184 |
+
int32_t count_idx = y - 1;
|
| 185 |
+
atomicAdd(count_init + count_idx, 1);
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
__global__ void final_counting(
|
| 190 |
+
const int32_t* label,
|
| 191 |
+
const int32_t* count_init,
|
| 192 |
+
int32_t* count_final,
|
| 193 |
+
const int32_t W,
|
| 194 |
+
const int32_t H) {
|
| 195 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
|
| 196 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
|
| 197 |
+
const uint32_t idx = row * W + col;
|
| 198 |
+
|
| 199 |
+
if (row >= H || col >= W)
|
| 200 |
+
return;
|
| 201 |
+
|
| 202 |
+
int32_t y = label[idx];
|
| 203 |
+
if (y > 0) {
|
| 204 |
+
int32_t count_idx = y - 1;
|
| 205 |
+
count_final[idx] = count_init[count_idx];
|
| 206 |
+
} else {
|
| 207 |
+
count_final[idx] = 0;
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
} // namespace cc2d
|
| 212 |
+
|
| 213 |
+
std::vector<torch::Tensor> get_connected_componnets(
|
| 214 |
+
const torch::Tensor& inputs) {
|
| 215 |
+
AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
|
| 216 |
+
AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
|
| 217 |
+
AT_ASSERTM(
|
| 218 |
+
inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
|
| 219 |
+
|
| 220 |
+
const uint32_t N = inputs.size(0);
|
| 221 |
+
const uint32_t C = inputs.size(1);
|
| 222 |
+
const uint32_t H = inputs.size(2);
|
| 223 |
+
const uint32_t W = inputs.size(3);
|
| 224 |
+
|
| 225 |
+
AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
|
| 226 |
+
AT_ASSERTM((H % 2) == 0, "height must be an even number");
|
| 227 |
+
AT_ASSERTM((W % 2) == 0, "width must be an even number");
|
| 228 |
+
|
| 229 |
+
// label must be uint32_t
|
| 230 |
+
auto label_options =
|
| 231 |
+
torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
|
| 232 |
+
torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
|
| 233 |
+
torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
|
| 234 |
+
torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
|
| 235 |
+
|
| 236 |
+
dim3 grid = dim3(
|
| 237 |
+
((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
|
| 238 |
+
((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
|
| 239 |
+
dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
|
| 240 |
+
dim3 grid_count =
|
| 241 |
+
dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
|
| 242 |
+
dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
|
| 243 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 244 |
+
|
| 245 |
+
for (int n = 0; n < N; n++) {
|
| 246 |
+
uint32_t offset = n * H * W;
|
| 247 |
+
|
| 248 |
+
cc2d::init_labeling<<<grid, block, 0, stream>>>(
|
| 249 |
+
labels.data_ptr<int32_t>() + offset, W, H);
|
| 250 |
+
cc2d::merge<<<grid, block, 0, stream>>>(
|
| 251 |
+
inputs.data_ptr<uint8_t>() + offset,
|
| 252 |
+
labels.data_ptr<int32_t>() + offset,
|
| 253 |
+
W,
|
| 254 |
+
H);
|
| 255 |
+
cc2d::compression<<<grid, block, 0, stream>>>(
|
| 256 |
+
labels.data_ptr<int32_t>() + offset, W, H);
|
| 257 |
+
cc2d::final_labeling<<<grid, block, 0, stream>>>(
|
| 258 |
+
inputs.data_ptr<uint8_t>() + offset,
|
| 259 |
+
labels.data_ptr<int32_t>() + offset,
|
| 260 |
+
W,
|
| 261 |
+
H);
|
| 262 |
+
|
| 263 |
+
// get the counting of each pixel
|
| 264 |
+
cc2d::init_counting<<<grid_count, block_count, 0, stream>>>(
|
| 265 |
+
labels.data_ptr<int32_t>() + offset,
|
| 266 |
+
counts_init.data_ptr<int32_t>() + offset,
|
| 267 |
+
W,
|
| 268 |
+
H);
|
| 269 |
+
cc2d::final_counting<<<grid_count, block_count, 0, stream>>>(
|
| 270 |
+
labels.data_ptr<int32_t>() + offset,
|
| 271 |
+
counts_init.data_ptr<int32_t>() + offset,
|
| 272 |
+
counts_final.data_ptr<int32_t>() + offset,
|
| 273 |
+
W,
|
| 274 |
+
H);
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
// returned values are [labels, counts]
|
| 278 |
+
std::vector<torch::Tensor> outputs;
|
| 279 |
+
outputs.push_back(labels);
|
| 280 |
+
outputs.push_back(counts_final);
|
| 281 |
+
return outputs;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 285 |
+
m.def(
|
| 286 |
+
"get_connected_componnets",
|
| 287 |
+
&get_connected_componnets,
|
| 288 |
+
"get_connected_componnets");
|
| 289 |
+
}
|
sam2_repo/sam2/modeling/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
sam2_repo/sam2/modeling/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (165 Bytes). View file
|
|
|
sam2_repo/sam2/modeling/__pycache__/memory_attention.cpython-313.pyc
ADDED
|
Binary file (7.06 kB). View file
|
|
|
sam2_repo/sam2/modeling/__pycache__/memory_encoder.cpython-313.pyc
ADDED
|
Binary file (7.99 kB). View file
|
|
|
sam2_repo/sam2/modeling/__pycache__/position_encoding.cpython-313.pyc
ADDED
|
Binary file (15.1 kB). View file
|
|
|
sam2_repo/sam2/modeling/__pycache__/sam2_base.cpython-313.pyc
ADDED
|
Binary file (31 kB). View file
|
|
|
sam2_repo/sam2/modeling/__pycache__/sam2_utils.cpython-313.pyc
ADDED
|
Binary file (17.3 kB). View file
|
|
|
sam2_repo/sam2/modeling/backbones/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
sam2_repo/sam2/modeling/backbones/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (175 Bytes). View file
|
|
|
sam2_repo/sam2/modeling/backbones/__pycache__/hieradet.cpython-313.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
sam2_repo/sam2/modeling/backbones/__pycache__/image_encoder.cpython-313.pyc
ADDED
|
Binary file (5.53 kB). View file
|
|
|
sam2_repo/sam2/modeling/backbones/__pycache__/utils.cpython-313.pyc
ADDED
|
Binary file (4.06 kB). View file
|
|
|
sam2_repo/sam2/modeling/backbones/hieradet.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from iopath.common.file_io import g_pathmgr
|
| 15 |
+
|
| 16 |
+
from sam2.modeling.backbones.utils import (
|
| 17 |
+
PatchEmbed,
|
| 18 |
+
window_partition,
|
| 19 |
+
window_unpartition,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from sam2.modeling.sam2_utils import DropPath, MLP
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
| 26 |
+
if pool is None:
|
| 27 |
+
return x
|
| 28 |
+
# (B, H, W, C) -> (B, C, H, W)
|
| 29 |
+
x = x.permute(0, 3, 1, 2)
|
| 30 |
+
x = pool(x)
|
| 31 |
+
# (B, C, H', W') -> (B, H', W', C)
|
| 32 |
+
x = x.permute(0, 2, 3, 1)
|
| 33 |
+
if norm:
|
| 34 |
+
x = norm(x)
|
| 35 |
+
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MultiScaleAttention(nn.Module):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
dim: int,
|
| 43 |
+
dim_out: int,
|
| 44 |
+
num_heads: int,
|
| 45 |
+
q_pool: nn.Module = None,
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
self.dim = dim
|
| 50 |
+
self.dim_out = dim_out
|
| 51 |
+
self.num_heads = num_heads
|
| 52 |
+
self.q_pool = q_pool
|
| 53 |
+
self.qkv = nn.Linear(dim, dim_out * 3)
|
| 54 |
+
self.proj = nn.Linear(dim_out, dim_out)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
B, H, W, _ = x.shape
|
| 58 |
+
# qkv with shape (B, H * W, 3, nHead, C)
|
| 59 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
| 60 |
+
# q, k, v with shape (B, H * W, nheads, C)
|
| 61 |
+
q, k, v = torch.unbind(qkv, 2)
|
| 62 |
+
|
| 63 |
+
# Q pooling (for downsample at stage changes)
|
| 64 |
+
if self.q_pool:
|
| 65 |
+
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
|
| 66 |
+
H, W = q.shape[1:3] # downsampled shape
|
| 67 |
+
q = q.reshape(B, H * W, self.num_heads, -1)
|
| 68 |
+
|
| 69 |
+
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
| 70 |
+
x = F.scaled_dot_product_attention(
|
| 71 |
+
q.transpose(1, 2),
|
| 72 |
+
k.transpose(1, 2),
|
| 73 |
+
v.transpose(1, 2),
|
| 74 |
+
)
|
| 75 |
+
# Transpose back
|
| 76 |
+
x = x.transpose(1, 2)
|
| 77 |
+
x = x.reshape(B, H, W, -1)
|
| 78 |
+
|
| 79 |
+
x = self.proj(x)
|
| 80 |
+
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MultiScaleBlock(nn.Module):
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
dim: int,
|
| 88 |
+
dim_out: int,
|
| 89 |
+
num_heads: int,
|
| 90 |
+
mlp_ratio: float = 4.0,
|
| 91 |
+
drop_path: float = 0.0,
|
| 92 |
+
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
| 93 |
+
q_stride: Tuple[int, int] = None,
|
| 94 |
+
act_layer: nn.Module = nn.GELU,
|
| 95 |
+
window_size: int = 0,
|
| 96 |
+
):
|
| 97 |
+
super().__init__()
|
| 98 |
+
|
| 99 |
+
if isinstance(norm_layer, str):
|
| 100 |
+
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
|
| 101 |
+
|
| 102 |
+
self.dim = dim
|
| 103 |
+
self.dim_out = dim_out
|
| 104 |
+
self.norm1 = norm_layer(dim)
|
| 105 |
+
|
| 106 |
+
self.window_size = window_size
|
| 107 |
+
|
| 108 |
+
self.pool, self.q_stride = None, q_stride
|
| 109 |
+
if self.q_stride:
|
| 110 |
+
self.pool = nn.MaxPool2d(
|
| 111 |
+
kernel_size=q_stride, stride=q_stride, ceil_mode=False
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.attn = MultiScaleAttention(
|
| 115 |
+
dim,
|
| 116 |
+
dim_out,
|
| 117 |
+
num_heads=num_heads,
|
| 118 |
+
q_pool=self.pool,
|
| 119 |
+
)
|
| 120 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 121 |
+
|
| 122 |
+
self.norm2 = norm_layer(dim_out)
|
| 123 |
+
self.mlp = MLP(
|
| 124 |
+
dim_out,
|
| 125 |
+
int(dim_out * mlp_ratio),
|
| 126 |
+
dim_out,
|
| 127 |
+
num_layers=2,
|
| 128 |
+
activation=act_layer,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if dim != dim_out:
|
| 132 |
+
self.proj = nn.Linear(dim, dim_out)
|
| 133 |
+
|
| 134 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 135 |
+
shortcut = x # B, H, W, C
|
| 136 |
+
x = self.norm1(x)
|
| 137 |
+
|
| 138 |
+
# Skip connection
|
| 139 |
+
if self.dim != self.dim_out:
|
| 140 |
+
shortcut = do_pool(self.proj(x), self.pool)
|
| 141 |
+
|
| 142 |
+
# Window partition
|
| 143 |
+
window_size = self.window_size
|
| 144 |
+
if window_size > 0:
|
| 145 |
+
H, W = x.shape[1], x.shape[2]
|
| 146 |
+
x, pad_hw = window_partition(x, window_size)
|
| 147 |
+
|
| 148 |
+
# Window Attention + Q Pooling (if stage change)
|
| 149 |
+
x = self.attn(x)
|
| 150 |
+
if self.q_stride:
|
| 151 |
+
# Shapes have changed due to Q pooling
|
| 152 |
+
window_size = self.window_size // self.q_stride[0]
|
| 153 |
+
H, W = shortcut.shape[1:3]
|
| 154 |
+
|
| 155 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 156 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 157 |
+
pad_hw = (H + pad_h, W + pad_w)
|
| 158 |
+
|
| 159 |
+
# Reverse window partition
|
| 160 |
+
if self.window_size > 0:
|
| 161 |
+
x = window_unpartition(x, window_size, pad_hw, (H, W))
|
| 162 |
+
|
| 163 |
+
x = shortcut + self.drop_path(x)
|
| 164 |
+
# MLP
|
| 165 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class Hiera(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
Reference: https://arxiv.org/abs/2306.00989
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
embed_dim: int = 96, # initial embed dim
|
| 177 |
+
num_heads: int = 1, # initial number of heads
|
| 178 |
+
drop_path_rate: float = 0.0, # stochastic depth
|
| 179 |
+
q_pool: int = 3, # number of q_pool stages
|
| 180 |
+
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
| 181 |
+
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
| 182 |
+
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
| 183 |
+
head_mul: float = 2.0, # head_mul factor at stage shift
|
| 184 |
+
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
|
| 185 |
+
# window size per stage, when not using global att.
|
| 186 |
+
window_spec: Tuple[int, ...] = (
|
| 187 |
+
8,
|
| 188 |
+
4,
|
| 189 |
+
14,
|
| 190 |
+
7,
|
| 191 |
+
),
|
| 192 |
+
# global attn in these blocks
|
| 193 |
+
global_att_blocks: Tuple[int, ...] = (
|
| 194 |
+
12,
|
| 195 |
+
16,
|
| 196 |
+
20,
|
| 197 |
+
),
|
| 198 |
+
weights_path=None,
|
| 199 |
+
return_interm_layers=True, # return feats from every stage
|
| 200 |
+
):
|
| 201 |
+
super().__init__()
|
| 202 |
+
|
| 203 |
+
assert len(stages) == len(window_spec)
|
| 204 |
+
self.window_spec = window_spec
|
| 205 |
+
|
| 206 |
+
depth = sum(stages)
|
| 207 |
+
self.q_stride = q_stride
|
| 208 |
+
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
| 209 |
+
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
| 210 |
+
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
| 211 |
+
self.return_interm_layers = return_interm_layers
|
| 212 |
+
|
| 213 |
+
self.patch_embed = PatchEmbed(
|
| 214 |
+
embed_dim=embed_dim,
|
| 215 |
+
)
|
| 216 |
+
# Which blocks have global att?
|
| 217 |
+
self.global_att_blocks = global_att_blocks
|
| 218 |
+
|
| 219 |
+
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
| 220 |
+
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
| 221 |
+
self.pos_embed = nn.Parameter(
|
| 222 |
+
torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
|
| 223 |
+
)
|
| 224 |
+
self.pos_embed_window = nn.Parameter(
|
| 225 |
+
torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
dpr = [
|
| 229 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 230 |
+
] # stochastic depth decay rule
|
| 231 |
+
|
| 232 |
+
cur_stage = 1
|
| 233 |
+
self.blocks = nn.ModuleList()
|
| 234 |
+
|
| 235 |
+
for i in range(depth):
|
| 236 |
+
dim_out = embed_dim
|
| 237 |
+
# lags by a block, so first block of
|
| 238 |
+
# next stage uses an initial window size
|
| 239 |
+
# of previous stage and final window size of current stage
|
| 240 |
+
window_size = self.window_spec[cur_stage - 1]
|
| 241 |
+
|
| 242 |
+
if self.global_att_blocks is not None:
|
| 243 |
+
window_size = 0 if i in self.global_att_blocks else window_size
|
| 244 |
+
|
| 245 |
+
if i - 1 in self.stage_ends:
|
| 246 |
+
dim_out = int(embed_dim * dim_mul)
|
| 247 |
+
num_heads = int(num_heads * head_mul)
|
| 248 |
+
cur_stage += 1
|
| 249 |
+
|
| 250 |
+
block = MultiScaleBlock(
|
| 251 |
+
dim=embed_dim,
|
| 252 |
+
dim_out=dim_out,
|
| 253 |
+
num_heads=num_heads,
|
| 254 |
+
drop_path=dpr[i],
|
| 255 |
+
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
| 256 |
+
window_size=window_size,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
embed_dim = dim_out
|
| 260 |
+
self.blocks.append(block)
|
| 261 |
+
|
| 262 |
+
self.channel_list = (
|
| 263 |
+
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
| 264 |
+
if return_interm_layers
|
| 265 |
+
else [self.blocks[-1].dim_out]
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
if weights_path is not None:
|
| 269 |
+
with g_pathmgr.open(weights_path, "rb") as f:
|
| 270 |
+
chkpt = torch.load(f, map_location="cpu")
|
| 271 |
+
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
|
| 272 |
+
|
| 273 |
+
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
| 274 |
+
h, w = hw
|
| 275 |
+
window_embed = self.pos_embed_window
|
| 276 |
+
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
| 277 |
+
pos_embed = pos_embed + window_embed.tile(
|
| 278 |
+
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
|
| 279 |
+
)
|
| 280 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
| 281 |
+
return pos_embed
|
| 282 |
+
|
| 283 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 284 |
+
x = self.patch_embed(x)
|
| 285 |
+
# x: (B, H, W, C)
|
| 286 |
+
|
| 287 |
+
# Add pos embed
|
| 288 |
+
x = x + self._get_pos_embed(x.shape[1:3])
|
| 289 |
+
|
| 290 |
+
outputs = []
|
| 291 |
+
for i, blk in enumerate(self.blocks):
|
| 292 |
+
x = blk(x)
|
| 293 |
+
if (i == self.stage_ends[-1]) or (
|
| 294 |
+
i in self.stage_ends and self.return_interm_layers
|
| 295 |
+
):
|
| 296 |
+
feats = x.permute(0, 3, 1, 2)
|
| 297 |
+
outputs.append(feats)
|
| 298 |
+
|
| 299 |
+
return outputs
|
| 300 |
+
|
| 301 |
+
def get_layer_id(self, layer_name):
|
| 302 |
+
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
| 303 |
+
num_layers = self.get_num_layers()
|
| 304 |
+
|
| 305 |
+
if layer_name.find("rel_pos") != -1:
|
| 306 |
+
return num_layers + 1
|
| 307 |
+
elif layer_name.find("pos_embed") != -1:
|
| 308 |
+
return 0
|
| 309 |
+
elif layer_name.find("patch_embed") != -1:
|
| 310 |
+
return 0
|
| 311 |
+
elif layer_name.find("blocks") != -1:
|
| 312 |
+
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
|
| 313 |
+
else:
|
| 314 |
+
return num_layers + 1
|
| 315 |
+
|
| 316 |
+
def get_num_layers(self) -> int:
|
| 317 |
+
return len(self.blocks)
|
sam2_repo/sam2/modeling/backbones/image_encoder.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ImageEncoder(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
trunk: nn.Module,
|
| 18 |
+
neck: nn.Module,
|
| 19 |
+
scalp: int = 0,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.trunk = trunk
|
| 23 |
+
self.neck = neck
|
| 24 |
+
self.scalp = scalp
|
| 25 |
+
assert (
|
| 26 |
+
self.trunk.channel_list == self.neck.backbone_channel_list
|
| 27 |
+
), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
|
| 28 |
+
|
| 29 |
+
def forward(self, sample: torch.Tensor):
|
| 30 |
+
# Forward through backbone
|
| 31 |
+
features, pos = self.neck(self.trunk(sample))
|
| 32 |
+
if self.scalp > 0:
|
| 33 |
+
# Discard the lowest resolution features
|
| 34 |
+
features, pos = features[: -self.scalp], pos[: -self.scalp]
|
| 35 |
+
|
| 36 |
+
src = features[-1]
|
| 37 |
+
output = {
|
| 38 |
+
"vision_features": src,
|
| 39 |
+
"vision_pos_enc": pos,
|
| 40 |
+
"backbone_fpn": features,
|
| 41 |
+
}
|
| 42 |
+
return output
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class FpnNeck(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
A modified variant of Feature Pyramid Network (FPN) neck
|
| 48 |
+
(we remove output conv and also do bicubic interpolation similar to ViT
|
| 49 |
+
pos embed interpolation)
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
position_encoding: nn.Module,
|
| 55 |
+
d_model: int,
|
| 56 |
+
backbone_channel_list: List[int],
|
| 57 |
+
kernel_size: int = 1,
|
| 58 |
+
stride: int = 1,
|
| 59 |
+
padding: int = 0,
|
| 60 |
+
fpn_interp_model: str = "bilinear",
|
| 61 |
+
fuse_type: str = "sum",
|
| 62 |
+
fpn_top_down_levels: Optional[List[int]] = None,
|
| 63 |
+
):
|
| 64 |
+
"""Initialize the neck
|
| 65 |
+
:param trunk: the backbone
|
| 66 |
+
:param position_encoding: the positional encoding to use
|
| 67 |
+
:param d_model: the dimension of the model
|
| 68 |
+
:param neck_norm: the normalization to use
|
| 69 |
+
"""
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.position_encoding = position_encoding
|
| 72 |
+
self.convs = nn.ModuleList()
|
| 73 |
+
self.backbone_channel_list = backbone_channel_list
|
| 74 |
+
self.d_model = d_model
|
| 75 |
+
for dim in backbone_channel_list:
|
| 76 |
+
current = nn.Sequential()
|
| 77 |
+
current.add_module(
|
| 78 |
+
"conv",
|
| 79 |
+
nn.Conv2d(
|
| 80 |
+
in_channels=dim,
|
| 81 |
+
out_channels=d_model,
|
| 82 |
+
kernel_size=kernel_size,
|
| 83 |
+
stride=stride,
|
| 84 |
+
padding=padding,
|
| 85 |
+
),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.convs.append(current)
|
| 89 |
+
self.fpn_interp_model = fpn_interp_model
|
| 90 |
+
assert fuse_type in ["sum", "avg"]
|
| 91 |
+
self.fuse_type = fuse_type
|
| 92 |
+
|
| 93 |
+
# levels to have top-down features in its outputs
|
| 94 |
+
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
|
| 95 |
+
# have top-down propagation, while outputs of level 0 and level 1 have only
|
| 96 |
+
# lateral features from the same backbone level.
|
| 97 |
+
if fpn_top_down_levels is None:
|
| 98 |
+
# default is to have top-down features on all levels
|
| 99 |
+
fpn_top_down_levels = range(len(self.convs))
|
| 100 |
+
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
| 101 |
+
|
| 102 |
+
def forward(self, xs: List[torch.Tensor]):
|
| 103 |
+
|
| 104 |
+
out = [None] * len(self.convs)
|
| 105 |
+
pos = [None] * len(self.convs)
|
| 106 |
+
assert len(xs) == len(self.convs)
|
| 107 |
+
# fpn forward pass
|
| 108 |
+
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
|
| 109 |
+
prev_features = None
|
| 110 |
+
# forward in top-down order (from low to high resolution)
|
| 111 |
+
n = len(self.convs) - 1
|
| 112 |
+
for i in range(n, -1, -1):
|
| 113 |
+
x = xs[i]
|
| 114 |
+
lateral_features = self.convs[n - i](x)
|
| 115 |
+
if i in self.fpn_top_down_levels and prev_features is not None:
|
| 116 |
+
top_down_features = F.interpolate(
|
| 117 |
+
prev_features.to(dtype=torch.float32),
|
| 118 |
+
scale_factor=2.0,
|
| 119 |
+
mode=self.fpn_interp_model,
|
| 120 |
+
align_corners=(
|
| 121 |
+
None if self.fpn_interp_model == "nearest" else False
|
| 122 |
+
),
|
| 123 |
+
antialias=False,
|
| 124 |
+
)
|
| 125 |
+
prev_features = lateral_features + top_down_features
|
| 126 |
+
if self.fuse_type == "avg":
|
| 127 |
+
prev_features /= 2
|
| 128 |
+
else:
|
| 129 |
+
prev_features = lateral_features
|
| 130 |
+
x_out = prev_features
|
| 131 |
+
out[i] = x_out
|
| 132 |
+
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
|
| 133 |
+
|
| 134 |
+
return out, pos
|
sam2_repo/sam2/modeling/backbones/utils.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Some utilities for backbones, in particular for windowing"""
|
| 8 |
+
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def window_partition(x, window_size):
|
| 17 |
+
"""
|
| 18 |
+
Partition into non-overlapping windows with padding if needed.
|
| 19 |
+
Args:
|
| 20 |
+
x (tensor): input tokens with [B, H, W, C].
|
| 21 |
+
window_size (int): window size.
|
| 22 |
+
Returns:
|
| 23 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
| 24 |
+
(Hp, Wp): padded height and width before partition
|
| 25 |
+
"""
|
| 26 |
+
B, H, W, C = x.shape
|
| 27 |
+
|
| 28 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 29 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 30 |
+
if pad_h > 0 or pad_w > 0:
|
| 31 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
| 32 |
+
Hp, Wp = H + pad_h, W + pad_w
|
| 33 |
+
|
| 34 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
| 35 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
|
| 36 |
+
return windows, (Hp, Wp)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
| 40 |
+
"""
|
| 41 |
+
Window unpartition into original sequences and removing padding.
|
| 42 |
+
Args:
|
| 43 |
+
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
| 44 |
+
window_size (int): window size.
|
| 45 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
| 46 |
+
hw (Tuple): original height and width (H, W) before padding.
|
| 47 |
+
Returns:
|
| 48 |
+
x: unpartitioned sequences with [B, H, W, C].
|
| 49 |
+
"""
|
| 50 |
+
Hp, Wp = pad_hw
|
| 51 |
+
H, W = hw
|
| 52 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
| 53 |
+
x = windows.reshape(
|
| 54 |
+
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
| 55 |
+
)
|
| 56 |
+
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
|
| 57 |
+
|
| 58 |
+
if Hp > H or Wp > W:
|
| 59 |
+
x = x[:, :H, :W, :]
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class PatchEmbed(nn.Module):
|
| 64 |
+
"""
|
| 65 |
+
Image to Patch Embedding.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
kernel_size: Tuple[int, ...] = (7, 7),
|
| 71 |
+
stride: Tuple[int, ...] = (4, 4),
|
| 72 |
+
padding: Tuple[int, ...] = (3, 3),
|
| 73 |
+
in_chans: int = 3,
|
| 74 |
+
embed_dim: int = 768,
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Args:
|
| 78 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
| 79 |
+
stride (Tuple): stride of the projection layer.
|
| 80 |
+
padding (Tuple): padding size of the projection layer.
|
| 81 |
+
in_chans (int): Number of input image channels.
|
| 82 |
+
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
| 83 |
+
"""
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.proj = nn.Conv2d(
|
| 86 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
x = self.proj(x)
|
| 91 |
+
# B C H W -> B H W C
|
| 92 |
+
x = x.permute(0, 2, 3, 1)
|
| 93 |
+
return x
|
sam2_repo/sam2/modeling/memory_attention.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn, Tensor
|
| 11 |
+
|
| 12 |
+
from sam2.modeling.sam.transformer import RoPEAttention
|
| 13 |
+
|
| 14 |
+
from sam2.modeling.sam2_utils import get_activation_fn, get_clones
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MemoryAttentionLayer(nn.Module):
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
activation: str,
|
| 22 |
+
cross_attention: nn.Module,
|
| 23 |
+
d_model: int,
|
| 24 |
+
dim_feedforward: int,
|
| 25 |
+
dropout: float,
|
| 26 |
+
pos_enc_at_attn: bool,
|
| 27 |
+
pos_enc_at_cross_attn_keys: bool,
|
| 28 |
+
pos_enc_at_cross_attn_queries: bool,
|
| 29 |
+
self_attention: nn.Module,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.d_model = d_model
|
| 33 |
+
self.dim_feedforward = dim_feedforward
|
| 34 |
+
self.dropout_value = dropout
|
| 35 |
+
self.self_attn = self_attention
|
| 36 |
+
self.cross_attn_image = cross_attention
|
| 37 |
+
|
| 38 |
+
# Implementation of Feedforward model
|
| 39 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 40 |
+
self.dropout = nn.Dropout(dropout)
|
| 41 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 42 |
+
|
| 43 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 44 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 45 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 46 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 47 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 48 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 49 |
+
|
| 50 |
+
self.activation_str = activation
|
| 51 |
+
self.activation = get_activation_fn(activation)
|
| 52 |
+
|
| 53 |
+
# Where to add pos enc
|
| 54 |
+
self.pos_enc_at_attn = pos_enc_at_attn
|
| 55 |
+
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
|
| 56 |
+
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
|
| 57 |
+
|
| 58 |
+
def _forward_sa(self, tgt, query_pos):
|
| 59 |
+
# Self-Attention
|
| 60 |
+
tgt2 = self.norm1(tgt)
|
| 61 |
+
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
|
| 62 |
+
tgt2 = self.self_attn(q, k, v=tgt2)
|
| 63 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 64 |
+
return tgt
|
| 65 |
+
|
| 66 |
+
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
|
| 67 |
+
kwds = {}
|
| 68 |
+
if num_k_exclude_rope > 0:
|
| 69 |
+
assert isinstance(self.cross_attn_image, RoPEAttention)
|
| 70 |
+
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
|
| 71 |
+
|
| 72 |
+
# Cross-Attention
|
| 73 |
+
tgt2 = self.norm2(tgt)
|
| 74 |
+
tgt2 = self.cross_attn_image(
|
| 75 |
+
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
|
| 76 |
+
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
|
| 77 |
+
v=memory,
|
| 78 |
+
**kwds,
|
| 79 |
+
)
|
| 80 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 81 |
+
return tgt
|
| 82 |
+
|
| 83 |
+
def forward(
|
| 84 |
+
self,
|
| 85 |
+
tgt,
|
| 86 |
+
memory,
|
| 87 |
+
pos: Optional[Tensor] = None,
|
| 88 |
+
query_pos: Optional[Tensor] = None,
|
| 89 |
+
num_k_exclude_rope: int = 0,
|
| 90 |
+
) -> torch.Tensor:
|
| 91 |
+
|
| 92 |
+
# Self-Attn, Cross-Attn
|
| 93 |
+
tgt = self._forward_sa(tgt, query_pos)
|
| 94 |
+
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
|
| 95 |
+
# MLP
|
| 96 |
+
tgt2 = self.norm3(tgt)
|
| 97 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 98 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 99 |
+
return tgt
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class MemoryAttention(nn.Module):
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
d_model: int,
|
| 106 |
+
pos_enc_at_input: bool,
|
| 107 |
+
layer: nn.Module,
|
| 108 |
+
num_layers: int,
|
| 109 |
+
batch_first: bool = True, # Do layers expect batch first input?
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.d_model = d_model
|
| 113 |
+
self.layers = get_clones(layer, num_layers)
|
| 114 |
+
self.num_layers = num_layers
|
| 115 |
+
self.norm = nn.LayerNorm(d_model)
|
| 116 |
+
self.pos_enc_at_input = pos_enc_at_input
|
| 117 |
+
self.batch_first = batch_first
|
| 118 |
+
|
| 119 |
+
def forward(
|
| 120 |
+
self,
|
| 121 |
+
curr: torch.Tensor, # self-attention inputs
|
| 122 |
+
memory: torch.Tensor, # cross-attention inputs
|
| 123 |
+
curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
|
| 124 |
+
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
|
| 125 |
+
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
|
| 126 |
+
):
|
| 127 |
+
if isinstance(curr, list):
|
| 128 |
+
assert isinstance(curr_pos, list)
|
| 129 |
+
assert len(curr) == len(curr_pos) == 1
|
| 130 |
+
curr, curr_pos = (
|
| 131 |
+
curr[0],
|
| 132 |
+
curr_pos[0],
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
assert (
|
| 136 |
+
curr.shape[1] == memory.shape[1]
|
| 137 |
+
), "Batch size must be the same for curr and memory"
|
| 138 |
+
|
| 139 |
+
output = curr
|
| 140 |
+
if self.pos_enc_at_input and curr_pos is not None:
|
| 141 |
+
output = output + 0.1 * curr_pos
|
| 142 |
+
|
| 143 |
+
if self.batch_first:
|
| 144 |
+
# Convert to batch first
|
| 145 |
+
output = output.transpose(0, 1)
|
| 146 |
+
curr_pos = curr_pos.transpose(0, 1)
|
| 147 |
+
memory = memory.transpose(0, 1)
|
| 148 |
+
memory_pos = memory_pos.transpose(0, 1)
|
| 149 |
+
|
| 150 |
+
for layer in self.layers:
|
| 151 |
+
kwds = {}
|
| 152 |
+
if isinstance(layer.cross_attn_image, RoPEAttention):
|
| 153 |
+
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
|
| 154 |
+
|
| 155 |
+
output = layer(
|
| 156 |
+
tgt=output,
|
| 157 |
+
memory=memory,
|
| 158 |
+
pos=memory_pos,
|
| 159 |
+
query_pos=curr_pos,
|
| 160 |
+
**kwds,
|
| 161 |
+
)
|
| 162 |
+
normed_output = self.norm(output)
|
| 163 |
+
|
| 164 |
+
if self.batch_first:
|
| 165 |
+
# Convert back to seq first
|
| 166 |
+
normed_output = normed_output.transpose(0, 1)
|
| 167 |
+
curr_pos = curr_pos.transpose(0, 1)
|
| 168 |
+
|
| 169 |
+
return normed_output
|
sam2_repo/sam2/modeling/memory_encoder.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MaskDownSampler(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
Progressively downsample a mask by total_stride, each time by stride.
|
| 20 |
+
Note that LayerNorm is applied per *token*, like in ViT.
|
| 21 |
+
|
| 22 |
+
With each downsample (by a factor stride**2), channel capacity increases by the same factor.
|
| 23 |
+
In the end, we linearly project to embed_dim channels.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
embed_dim=256,
|
| 29 |
+
kernel_size=4,
|
| 30 |
+
stride=4,
|
| 31 |
+
padding=0,
|
| 32 |
+
total_stride=16,
|
| 33 |
+
activation=nn.GELU,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
num_layers = int(math.log2(total_stride) // math.log2(stride))
|
| 37 |
+
assert stride**num_layers == total_stride
|
| 38 |
+
self.encoder = nn.Sequential()
|
| 39 |
+
mask_in_chans, mask_out_chans = 1, 1
|
| 40 |
+
for _ in range(num_layers):
|
| 41 |
+
mask_out_chans = mask_in_chans * (stride**2)
|
| 42 |
+
self.encoder.append(
|
| 43 |
+
nn.Conv2d(
|
| 44 |
+
mask_in_chans,
|
| 45 |
+
mask_out_chans,
|
| 46 |
+
kernel_size=kernel_size,
|
| 47 |
+
stride=stride,
|
| 48 |
+
padding=padding,
|
| 49 |
+
)
|
| 50 |
+
)
|
| 51 |
+
self.encoder.append(LayerNorm2d(mask_out_chans))
|
| 52 |
+
self.encoder.append(activation())
|
| 53 |
+
mask_in_chans = mask_out_chans
|
| 54 |
+
|
| 55 |
+
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
return self.encoder(x)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
|
| 62 |
+
class CXBlock(nn.Module):
|
| 63 |
+
r"""ConvNeXt Block. There are two equivalent implementations:
|
| 64 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 65 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 66 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
dim (int): Number of input channels.
|
| 70 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 71 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
dim,
|
| 77 |
+
kernel_size=7,
|
| 78 |
+
padding=3,
|
| 79 |
+
drop_path=0.0,
|
| 80 |
+
layer_scale_init_value=1e-6,
|
| 81 |
+
use_dwconv=True,
|
| 82 |
+
):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.dwconv = nn.Conv2d(
|
| 85 |
+
dim,
|
| 86 |
+
dim,
|
| 87 |
+
kernel_size=kernel_size,
|
| 88 |
+
padding=padding,
|
| 89 |
+
groups=dim if use_dwconv else 1,
|
| 90 |
+
) # depthwise conv
|
| 91 |
+
self.norm = LayerNorm2d(dim, eps=1e-6)
|
| 92 |
+
self.pwconv1 = nn.Linear(
|
| 93 |
+
dim, 4 * dim
|
| 94 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 95 |
+
self.act = nn.GELU()
|
| 96 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 97 |
+
self.gamma = (
|
| 98 |
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 99 |
+
if layer_scale_init_value > 0
|
| 100 |
+
else None
|
| 101 |
+
)
|
| 102 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
input = x
|
| 106 |
+
x = self.dwconv(x)
|
| 107 |
+
x = self.norm(x)
|
| 108 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 109 |
+
x = self.pwconv1(x)
|
| 110 |
+
x = self.act(x)
|
| 111 |
+
x = self.pwconv2(x)
|
| 112 |
+
if self.gamma is not None:
|
| 113 |
+
x = self.gamma * x
|
| 114 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 115 |
+
|
| 116 |
+
x = input + self.drop_path(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Fuser(nn.Module):
|
| 121 |
+
def __init__(self, layer, num_layers, dim=None, input_projection=False):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.proj = nn.Identity()
|
| 124 |
+
self.layers = get_clones(layer, num_layers)
|
| 125 |
+
|
| 126 |
+
if input_projection:
|
| 127 |
+
assert dim is not None
|
| 128 |
+
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
|
| 129 |
+
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
# normally x: (N, C, H, W)
|
| 132 |
+
x = self.proj(x)
|
| 133 |
+
for layer in self.layers:
|
| 134 |
+
x = layer(x)
|
| 135 |
+
return x
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class MemoryEncoder(nn.Module):
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
out_dim,
|
| 142 |
+
mask_downsampler,
|
| 143 |
+
fuser,
|
| 144 |
+
position_encoding,
|
| 145 |
+
in_dim=256, # in_dim of pix_feats
|
| 146 |
+
):
|
| 147 |
+
super().__init__()
|
| 148 |
+
|
| 149 |
+
self.mask_downsampler = mask_downsampler
|
| 150 |
+
|
| 151 |
+
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
|
| 152 |
+
self.fuser = fuser
|
| 153 |
+
self.position_encoding = position_encoding
|
| 154 |
+
self.out_proj = nn.Identity()
|
| 155 |
+
if out_dim != in_dim:
|
| 156 |
+
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
| 157 |
+
|
| 158 |
+
def forward(
|
| 159 |
+
self,
|
| 160 |
+
pix_feat: torch.Tensor,
|
| 161 |
+
masks: torch.Tensor,
|
| 162 |
+
skip_mask_sigmoid: bool = False,
|
| 163 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 164 |
+
## Process masks
|
| 165 |
+
# sigmoid, so that less domain shift from gt masks which are bool
|
| 166 |
+
if not skip_mask_sigmoid:
|
| 167 |
+
masks = F.sigmoid(masks)
|
| 168 |
+
masks = self.mask_downsampler(masks)
|
| 169 |
+
|
| 170 |
+
## Fuse pix_feats and downsampled masks
|
| 171 |
+
# in case the visual features are on CPU, cast them to CUDA
|
| 172 |
+
pix_feat = pix_feat.to(masks.device)
|
| 173 |
+
|
| 174 |
+
x = self.pix_feat_proj(pix_feat)
|
| 175 |
+
x = x + masks
|
| 176 |
+
x = self.fuser(x)
|
| 177 |
+
x = self.out_proj(x)
|
| 178 |
+
|
| 179 |
+
pos = self.position_encoding(x).to(x.dtype)
|
| 180 |
+
|
| 181 |
+
return {"vision_features": x, "vision_pos_enc": [pos]}
|
sam2_repo/sam2/modeling/position_encoding.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import Any, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PositionEmbeddingSine(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 19 |
+
used by the Attention Is All You Need paper, generalized to work on images.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
num_pos_feats,
|
| 25 |
+
temperature: int = 10000,
|
| 26 |
+
normalize: bool = True,
|
| 27 |
+
scale: Optional[float] = None,
|
| 28 |
+
# Following settings only relevant
|
| 29 |
+
# for warmping up cache for compilation
|
| 30 |
+
warmup_cache: bool = True,
|
| 31 |
+
image_size: int = 1024,
|
| 32 |
+
strides: Tuple[int] = (4, 8, 16, 32),
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
| 36 |
+
self.num_pos_feats = num_pos_feats // 2
|
| 37 |
+
self.temperature = temperature
|
| 38 |
+
self.normalize = normalize
|
| 39 |
+
if scale is not None and normalize is False:
|
| 40 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 41 |
+
if scale is None:
|
| 42 |
+
scale = 2 * math.pi
|
| 43 |
+
self.scale = scale
|
| 44 |
+
|
| 45 |
+
self.cache = {}
|
| 46 |
+
if warmup_cache and torch.cuda.is_available():
|
| 47 |
+
# Warmup cache for cuda, to help with compilation
|
| 48 |
+
device = torch.device("cuda")
|
| 49 |
+
for stride in strides:
|
| 50 |
+
cache_key = (image_size // stride, image_size // stride)
|
| 51 |
+
self._pe(1, device, *cache_key)
|
| 52 |
+
|
| 53 |
+
def _encode_xy(self, x, y):
|
| 54 |
+
# The positions are expected to be normalized
|
| 55 |
+
assert len(x) == len(y) and x.ndim == y.ndim == 1
|
| 56 |
+
x_embed = x * self.scale
|
| 57 |
+
y_embed = y * self.scale
|
| 58 |
+
|
| 59 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 60 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 61 |
+
|
| 62 |
+
pos_x = x_embed[:, None] / dim_t
|
| 63 |
+
pos_y = y_embed[:, None] / dim_t
|
| 64 |
+
pos_x = torch.stack(
|
| 65 |
+
(pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
|
| 66 |
+
).flatten(1)
|
| 67 |
+
pos_y = torch.stack(
|
| 68 |
+
(pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
|
| 69 |
+
).flatten(1)
|
| 70 |
+
return pos_x, pos_y
|
| 71 |
+
|
| 72 |
+
@torch.no_grad()
|
| 73 |
+
def encode_boxes(self, x, y, w, h):
|
| 74 |
+
pos_x, pos_y = self._encode_xy(x, y)
|
| 75 |
+
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
|
| 76 |
+
return pos
|
| 77 |
+
|
| 78 |
+
encode = encode_boxes # Backwards compatibility
|
| 79 |
+
|
| 80 |
+
@torch.no_grad()
|
| 81 |
+
def encode_points(self, x, y, labels):
|
| 82 |
+
(bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
|
| 83 |
+
assert bx == by and nx == ny and bx == bl and nx == nl
|
| 84 |
+
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
|
| 85 |
+
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
|
| 86 |
+
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
|
| 87 |
+
return pos
|
| 88 |
+
|
| 89 |
+
@torch.no_grad()
|
| 90 |
+
def _pe(self, B, device, *cache_key):
|
| 91 |
+
H, W = cache_key
|
| 92 |
+
if cache_key in self.cache:
|
| 93 |
+
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
|
| 94 |
+
|
| 95 |
+
y_embed = (
|
| 96 |
+
torch.arange(1, H + 1, dtype=torch.float32, device=device)
|
| 97 |
+
.view(1, -1, 1)
|
| 98 |
+
.repeat(B, 1, W)
|
| 99 |
+
)
|
| 100 |
+
x_embed = (
|
| 101 |
+
torch.arange(1, W + 1, dtype=torch.float32, device=device)
|
| 102 |
+
.view(1, 1, -1)
|
| 103 |
+
.repeat(B, H, 1)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if self.normalize:
|
| 107 |
+
eps = 1e-6
|
| 108 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 109 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 110 |
+
|
| 111 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
|
| 112 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 113 |
+
|
| 114 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 115 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 116 |
+
pos_x = torch.stack(
|
| 117 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
| 118 |
+
).flatten(3)
|
| 119 |
+
pos_y = torch.stack(
|
| 120 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
| 121 |
+
).flatten(3)
|
| 122 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 123 |
+
self.cache[cache_key] = pos[0]
|
| 124 |
+
return pos
|
| 125 |
+
|
| 126 |
+
@torch.no_grad()
|
| 127 |
+
def forward(self, x: torch.Tensor):
|
| 128 |
+
B = x.shape[0]
|
| 129 |
+
cache_key = (x.shape[-2], x.shape[-1])
|
| 130 |
+
return self._pe(B, x.device, *cache_key)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class PositionEmbeddingRandom(nn.Module):
|
| 134 |
+
"""
|
| 135 |
+
Positional encoding using random spatial frequencies.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
| 139 |
+
super().__init__()
|
| 140 |
+
if scale is None or scale <= 0.0:
|
| 141 |
+
scale = 1.0
|
| 142 |
+
self.register_buffer(
|
| 143 |
+
"positional_encoding_gaussian_matrix",
|
| 144 |
+
scale * torch.randn((2, num_pos_feats)),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
| 148 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
| 149 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 150 |
+
coords = 2 * coords - 1
|
| 151 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
| 152 |
+
coords = 2 * np.pi * coords
|
| 153 |
+
# outputs d_1 x ... x d_n x C shape
|
| 154 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
| 155 |
+
|
| 156 |
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
| 157 |
+
"""Generate positional encoding for a grid of the specified size."""
|
| 158 |
+
h, w = size
|
| 159 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
| 160 |
+
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
| 161 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
| 162 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
| 163 |
+
y_embed = y_embed / h
|
| 164 |
+
x_embed = x_embed / w
|
| 165 |
+
|
| 166 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
| 167 |
+
return pe.permute(2, 0, 1) # C x H x W
|
| 168 |
+
|
| 169 |
+
def forward_with_coords(
|
| 170 |
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
| 171 |
+
) -> torch.Tensor:
|
| 172 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
| 173 |
+
coords = coords_input.clone()
|
| 174 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
| 175 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
| 176 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# Rotary Positional Encoding, adapted from:
|
| 180 |
+
# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
| 181 |
+
# 2. https://github.com/naver-ai/rope-vit
|
| 182 |
+
# 3. https://github.com/lucidrains/rotary-embedding-torch
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def init_t_xy(end_x: int, end_y: int):
|
| 186 |
+
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
| 187 |
+
t_x = (t % end_x).float()
|
| 188 |
+
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
| 189 |
+
return t_x, t_y
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
| 193 |
+
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
| 194 |
+
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
| 195 |
+
|
| 196 |
+
t_x, t_y = init_t_xy(end_x, end_y)
|
| 197 |
+
freqs_x = torch.outer(t_x, freqs_x)
|
| 198 |
+
freqs_y = torch.outer(t_y, freqs_y)
|
| 199 |
+
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
|
| 200 |
+
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
|
| 201 |
+
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| 205 |
+
ndim = x.ndim
|
| 206 |
+
assert 0 <= 1 < ndim
|
| 207 |
+
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
| 208 |
+
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
|
| 209 |
+
return freqs_cis.view(*shape)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def apply_rotary_enc(
|
| 213 |
+
xq: torch.Tensor,
|
| 214 |
+
xk: torch.Tensor,
|
| 215 |
+
freqs_cis: torch.Tensor,
|
| 216 |
+
repeat_freqs_k: bool = False,
|
| 217 |
+
):
|
| 218 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 219 |
+
xk_ = (
|
| 220 |
+
torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 221 |
+
if xk.shape[-2] != 0
|
| 222 |
+
else None
|
| 223 |
+
)
|
| 224 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| 225 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 226 |
+
if xk_ is None:
|
| 227 |
+
# no keys to rotate, due to dropout
|
| 228 |
+
return xq_out.type_as(xq).to(xq.device), xk
|
| 229 |
+
# repeat freqs along seq_len dim to match k seq_len
|
| 230 |
+
if repeat_freqs_k:
|
| 231 |
+
r = xk_.shape[-2] // xq_.shape[-2]
|
| 232 |
+
if freqs_cis.is_cuda:
|
| 233 |
+
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
| 234 |
+
else:
|
| 235 |
+
# torch.repeat on complex numbers may not be supported on non-CUDA devices
|
| 236 |
+
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
|
| 237 |
+
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
|
| 238 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 239 |
+
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
sam2_repo/sam2/modeling/sam/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
sam2_repo/sam2/modeling/sam/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (169 Bytes). View file
|
|
|
sam2_repo/sam2/modeling/sam/__pycache__/mask_decoder.cpython-313.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
sam2_repo/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-313.pyc
ADDED
|
Binary file (9.69 kB). View file
|
|
|
sam2_repo/sam2/modeling/sam/__pycache__/transformer.cpython-313.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
sam2_repo/sam2/modeling/sam/mask_decoder.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import List, Optional, Tuple, Type
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from sam2.modeling.sam2_utils import LayerNorm2d, MLP
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MaskDecoder(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
*,
|
| 19 |
+
transformer_dim: int,
|
| 20 |
+
transformer: nn.Module,
|
| 21 |
+
num_multimask_outputs: int = 3,
|
| 22 |
+
activation: Type[nn.Module] = nn.GELU,
|
| 23 |
+
iou_head_depth: int = 3,
|
| 24 |
+
iou_head_hidden_dim: int = 256,
|
| 25 |
+
use_high_res_features: bool = False,
|
| 26 |
+
iou_prediction_use_sigmoid=False,
|
| 27 |
+
dynamic_multimask_via_stability=False,
|
| 28 |
+
dynamic_multimask_stability_delta=0.05,
|
| 29 |
+
dynamic_multimask_stability_thresh=0.98,
|
| 30 |
+
pred_obj_scores: bool = False,
|
| 31 |
+
pred_obj_scores_mlp: bool = False,
|
| 32 |
+
use_multimask_token_for_obj_ptr: bool = False,
|
| 33 |
+
) -> None:
|
| 34 |
+
"""
|
| 35 |
+
Predicts masks given an image and prompt embeddings, using a
|
| 36 |
+
transformer architecture.
|
| 37 |
+
|
| 38 |
+
Arguments:
|
| 39 |
+
transformer_dim (int): the channel dimension of the transformer
|
| 40 |
+
transformer (nn.Module): the transformer used to predict masks
|
| 41 |
+
num_multimask_outputs (int): the number of masks to predict
|
| 42 |
+
when disambiguating masks
|
| 43 |
+
activation (nn.Module): the type of activation to use when
|
| 44 |
+
upscaling masks
|
| 45 |
+
iou_head_depth (int): the depth of the MLP used to predict
|
| 46 |
+
mask quality
|
| 47 |
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
| 48 |
+
used to predict mask quality
|
| 49 |
+
"""
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.transformer_dim = transformer_dim
|
| 52 |
+
self.transformer = transformer
|
| 53 |
+
|
| 54 |
+
self.num_multimask_outputs = num_multimask_outputs
|
| 55 |
+
|
| 56 |
+
self.iou_token = nn.Embedding(1, transformer_dim)
|
| 57 |
+
self.num_mask_tokens = num_multimask_outputs + 1
|
| 58 |
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
| 59 |
+
|
| 60 |
+
self.pred_obj_scores = pred_obj_scores
|
| 61 |
+
if self.pred_obj_scores:
|
| 62 |
+
self.obj_score_token = nn.Embedding(1, transformer_dim)
|
| 63 |
+
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
| 64 |
+
|
| 65 |
+
self.output_upscaling = nn.Sequential(
|
| 66 |
+
nn.ConvTranspose2d(
|
| 67 |
+
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
|
| 68 |
+
),
|
| 69 |
+
LayerNorm2d(transformer_dim // 4),
|
| 70 |
+
activation(),
|
| 71 |
+
nn.ConvTranspose2d(
|
| 72 |
+
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
|
| 73 |
+
),
|
| 74 |
+
activation(),
|
| 75 |
+
)
|
| 76 |
+
self.use_high_res_features = use_high_res_features
|
| 77 |
+
if use_high_res_features:
|
| 78 |
+
self.conv_s0 = nn.Conv2d(
|
| 79 |
+
transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
|
| 80 |
+
)
|
| 81 |
+
self.conv_s1 = nn.Conv2d(
|
| 82 |
+
transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.output_hypernetworks_mlps = nn.ModuleList(
|
| 86 |
+
[
|
| 87 |
+
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
| 88 |
+
for i in range(self.num_mask_tokens)
|
| 89 |
+
]
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.iou_prediction_head = MLP(
|
| 93 |
+
transformer_dim,
|
| 94 |
+
iou_head_hidden_dim,
|
| 95 |
+
self.num_mask_tokens,
|
| 96 |
+
iou_head_depth,
|
| 97 |
+
sigmoid_output=iou_prediction_use_sigmoid,
|
| 98 |
+
)
|
| 99 |
+
if self.pred_obj_scores:
|
| 100 |
+
self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
|
| 101 |
+
if pred_obj_scores_mlp:
|
| 102 |
+
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
|
| 103 |
+
|
| 104 |
+
# When outputting a single mask, optionally we can dynamically fall back to the best
|
| 105 |
+
# multimask output token if the single mask output token gives low stability scores.
|
| 106 |
+
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
|
| 107 |
+
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
|
| 108 |
+
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
|
| 109 |
+
|
| 110 |
+
def forward(
|
| 111 |
+
self,
|
| 112 |
+
image_embeddings: torch.Tensor,
|
| 113 |
+
image_pe: torch.Tensor,
|
| 114 |
+
sparse_prompt_embeddings: torch.Tensor,
|
| 115 |
+
dense_prompt_embeddings: torch.Tensor,
|
| 116 |
+
multimask_output: bool,
|
| 117 |
+
repeat_image: bool,
|
| 118 |
+
high_res_features: Optional[List[torch.Tensor]] = None,
|
| 119 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 120 |
+
"""
|
| 121 |
+
Predict masks given image and prompt embeddings.
|
| 122 |
+
|
| 123 |
+
Arguments:
|
| 124 |
+
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
| 125 |
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
| 126 |
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
| 127 |
+
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
| 128 |
+
multimask_output (bool): Whether to return multiple masks or a single
|
| 129 |
+
mask.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
torch.Tensor: batched predicted masks
|
| 133 |
+
torch.Tensor: batched predictions of mask quality
|
| 134 |
+
torch.Tensor: batched SAM token for mask output
|
| 135 |
+
"""
|
| 136 |
+
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
|
| 137 |
+
image_embeddings=image_embeddings,
|
| 138 |
+
image_pe=image_pe,
|
| 139 |
+
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
| 140 |
+
dense_prompt_embeddings=dense_prompt_embeddings,
|
| 141 |
+
repeat_image=repeat_image,
|
| 142 |
+
high_res_features=high_res_features,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Select the correct mask or masks for output
|
| 146 |
+
if multimask_output:
|
| 147 |
+
masks = masks[:, 1:, :, :]
|
| 148 |
+
iou_pred = iou_pred[:, 1:]
|
| 149 |
+
elif self.dynamic_multimask_via_stability and not self.training:
|
| 150 |
+
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
|
| 151 |
+
else:
|
| 152 |
+
masks = masks[:, 0:1, :, :]
|
| 153 |
+
iou_pred = iou_pred[:, 0:1]
|
| 154 |
+
|
| 155 |
+
if multimask_output and self.use_multimask_token_for_obj_ptr:
|
| 156 |
+
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
|
| 157 |
+
else:
|
| 158 |
+
# Take the mask output token. Here we *always* use the token for single mask output.
|
| 159 |
+
# At test time, even if we track after 1-click (and using multimask_output=True),
|
| 160 |
+
# we still take the single mask token here. The rationale is that we always track
|
| 161 |
+
# after multiple clicks during training, so the past tokens seen during training
|
| 162 |
+
# are always the single mask token (and we'll let it be the object-memory token).
|
| 163 |
+
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
|
| 164 |
+
|
| 165 |
+
# Prepare output
|
| 166 |
+
return masks, iou_pred, sam_tokens_out, object_score_logits
|
| 167 |
+
|
| 168 |
+
def predict_masks(
|
| 169 |
+
self,
|
| 170 |
+
image_embeddings: torch.Tensor,
|
| 171 |
+
image_pe: torch.Tensor,
|
| 172 |
+
sparse_prompt_embeddings: torch.Tensor,
|
| 173 |
+
dense_prompt_embeddings: torch.Tensor,
|
| 174 |
+
repeat_image: bool,
|
| 175 |
+
high_res_features: Optional[List[torch.Tensor]] = None,
|
| 176 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 177 |
+
"""Predicts masks. See 'forward' for more details."""
|
| 178 |
+
# Concatenate output tokens
|
| 179 |
+
s = 0
|
| 180 |
+
if self.pred_obj_scores:
|
| 181 |
+
output_tokens = torch.cat(
|
| 182 |
+
[
|
| 183 |
+
self.obj_score_token.weight,
|
| 184 |
+
self.iou_token.weight,
|
| 185 |
+
self.mask_tokens.weight,
|
| 186 |
+
],
|
| 187 |
+
dim=0,
|
| 188 |
+
)
|
| 189 |
+
s = 1
|
| 190 |
+
else:
|
| 191 |
+
output_tokens = torch.cat(
|
| 192 |
+
[self.iou_token.weight, self.mask_tokens.weight], dim=0
|
| 193 |
+
)
|
| 194 |
+
output_tokens = output_tokens.unsqueeze(0).expand(
|
| 195 |
+
sparse_prompt_embeddings.size(0), -1, -1
|
| 196 |
+
)
|
| 197 |
+
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
| 198 |
+
|
| 199 |
+
# Expand per-image data in batch direction to be per-mask
|
| 200 |
+
if repeat_image:
|
| 201 |
+
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
| 202 |
+
else:
|
| 203 |
+
assert image_embeddings.shape[0] == tokens.shape[0]
|
| 204 |
+
src = image_embeddings
|
| 205 |
+
src = src + dense_prompt_embeddings
|
| 206 |
+
assert (
|
| 207 |
+
image_pe.size(0) == 1
|
| 208 |
+
), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
| 209 |
+
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
| 210 |
+
b, c, h, w = src.shape
|
| 211 |
+
|
| 212 |
+
# Run the transformer
|
| 213 |
+
hs, src = self.transformer(src, pos_src, tokens)
|
| 214 |
+
iou_token_out = hs[:, s, :]
|
| 215 |
+
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
|
| 216 |
+
|
| 217 |
+
# Upscale mask embeddings and predict masks using the mask tokens
|
| 218 |
+
src = src.transpose(1, 2).view(b, c, h, w)
|
| 219 |
+
if not self.use_high_res_features:
|
| 220 |
+
upscaled_embedding = self.output_upscaling(src)
|
| 221 |
+
else:
|
| 222 |
+
dc1, ln1, act1, dc2, act2 = self.output_upscaling
|
| 223 |
+
feat_s0, feat_s1 = high_res_features
|
| 224 |
+
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
|
| 225 |
+
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
|
| 226 |
+
|
| 227 |
+
hyper_in_list: List[torch.Tensor] = []
|
| 228 |
+
for i in range(self.num_mask_tokens):
|
| 229 |
+
hyper_in_list.append(
|
| 230 |
+
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
| 231 |
+
)
|
| 232 |
+
hyper_in = torch.stack(hyper_in_list, dim=1)
|
| 233 |
+
b, c, h, w = upscaled_embedding.shape
|
| 234 |
+
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
| 235 |
+
|
| 236 |
+
# Generate mask quality predictions
|
| 237 |
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
| 238 |
+
if self.pred_obj_scores:
|
| 239 |
+
assert s == 1
|
| 240 |
+
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
|
| 241 |
+
else:
|
| 242 |
+
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
|
| 243 |
+
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
|
| 244 |
+
|
| 245 |
+
return masks, iou_pred, mask_tokens_out, object_score_logits
|
| 246 |
+
|
| 247 |
+
def _get_stability_scores(self, mask_logits):
|
| 248 |
+
"""
|
| 249 |
+
Compute stability scores of the mask logits based on the IoU between upper and
|
| 250 |
+
lower thresholds.
|
| 251 |
+
"""
|
| 252 |
+
mask_logits = mask_logits.flatten(-2)
|
| 253 |
+
stability_delta = self.dynamic_multimask_stability_delta
|
| 254 |
+
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
| 255 |
+
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
| 256 |
+
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
|
| 257 |
+
return stability_scores
|
| 258 |
+
|
| 259 |
+
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
| 260 |
+
"""
|
| 261 |
+
When outputting a single mask, if the stability score from the current single-mask
|
| 262 |
+
output (based on output token 0) falls below a threshold, we instead select from
|
| 263 |
+
multi-mask outputs (based on output token 1~3) the mask with the highest predicted
|
| 264 |
+
IoU score. This is intended to ensure a valid mask for both clicking and tracking.
|
| 265 |
+
"""
|
| 266 |
+
# The best mask from multimask output tokens (1~3)
|
| 267 |
+
multimask_logits = all_mask_logits[:, 1:, :, :]
|
| 268 |
+
multimask_iou_scores = all_iou_scores[:, 1:]
|
| 269 |
+
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
|
| 270 |
+
batch_inds = torch.arange(
|
| 271 |
+
multimask_iou_scores.size(0), device=all_iou_scores.device
|
| 272 |
+
)
|
| 273 |
+
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
|
| 274 |
+
best_multimask_logits = best_multimask_logits.unsqueeze(1)
|
| 275 |
+
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
|
| 276 |
+
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
|
| 277 |
+
|
| 278 |
+
# The mask from singlemask output token 0 and its stability score
|
| 279 |
+
singlemask_logits = all_mask_logits[:, 0:1, :, :]
|
| 280 |
+
singlemask_iou_scores = all_iou_scores[:, 0:1]
|
| 281 |
+
stability_scores = self._get_stability_scores(singlemask_logits)
|
| 282 |
+
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
|
| 283 |
+
|
| 284 |
+
# Dynamically fall back to best multimask output upon low stability scores.
|
| 285 |
+
mask_logits_out = torch.where(
|
| 286 |
+
is_stable[..., None, None].expand_as(singlemask_logits),
|
| 287 |
+
singlemask_logits,
|
| 288 |
+
best_multimask_logits,
|
| 289 |
+
)
|
| 290 |
+
iou_scores_out = torch.where(
|
| 291 |
+
is_stable.expand_as(singlemask_iou_scores),
|
| 292 |
+
singlemask_iou_scores,
|
| 293 |
+
best_multimask_iou_scores,
|
| 294 |
+
)
|
| 295 |
+
return mask_logits_out, iou_scores_out
|
sam2_repo/sam2/modeling/sam/prompt_encoder.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional, Tuple, Type
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from sam2.modeling.position_encoding import PositionEmbeddingRandom
|
| 13 |
+
|
| 14 |
+
from sam2.modeling.sam2_utils import LayerNorm2d
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PromptEncoder(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
embed_dim: int,
|
| 21 |
+
image_embedding_size: Tuple[int, int],
|
| 22 |
+
input_image_size: Tuple[int, int],
|
| 23 |
+
mask_in_chans: int,
|
| 24 |
+
activation: Type[nn.Module] = nn.GELU,
|
| 25 |
+
) -> None:
|
| 26 |
+
"""
|
| 27 |
+
Encodes prompts for input to SAM's mask decoder.
|
| 28 |
+
|
| 29 |
+
Arguments:
|
| 30 |
+
embed_dim (int): The prompts' embedding dimension
|
| 31 |
+
image_embedding_size (tuple(int, int)): The spatial size of the
|
| 32 |
+
image embedding, as (H, W).
|
| 33 |
+
input_image_size (int): The padded size of the image as input
|
| 34 |
+
to the image encoder, as (H, W).
|
| 35 |
+
mask_in_chans (int): The number of hidden channels used for
|
| 36 |
+
encoding input masks.
|
| 37 |
+
activation (nn.Module): The activation to use when encoding
|
| 38 |
+
input masks.
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.embed_dim = embed_dim
|
| 42 |
+
self.input_image_size = input_image_size
|
| 43 |
+
self.image_embedding_size = image_embedding_size
|
| 44 |
+
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
| 45 |
+
|
| 46 |
+
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
| 47 |
+
point_embeddings = [
|
| 48 |
+
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
|
| 49 |
+
]
|
| 50 |
+
self.point_embeddings = nn.ModuleList(point_embeddings)
|
| 51 |
+
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
| 52 |
+
|
| 53 |
+
self.mask_input_size = (
|
| 54 |
+
4 * image_embedding_size[0],
|
| 55 |
+
4 * image_embedding_size[1],
|
| 56 |
+
)
|
| 57 |
+
self.mask_downscaling = nn.Sequential(
|
| 58 |
+
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
| 59 |
+
LayerNorm2d(mask_in_chans // 4),
|
| 60 |
+
activation(),
|
| 61 |
+
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
| 62 |
+
LayerNorm2d(mask_in_chans),
|
| 63 |
+
activation(),
|
| 64 |
+
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
| 65 |
+
)
|
| 66 |
+
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
| 67 |
+
|
| 68 |
+
def get_dense_pe(self) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Returns the positional encoding used to encode point prompts,
|
| 71 |
+
applied to a dense set of points the shape of the image encoding.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
torch.Tensor: Positional encoding with shape
|
| 75 |
+
1x(embed_dim)x(embedding_h)x(embedding_w)
|
| 76 |
+
"""
|
| 77 |
+
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
| 78 |
+
|
| 79 |
+
def _embed_points(
|
| 80 |
+
self,
|
| 81 |
+
points: torch.Tensor,
|
| 82 |
+
labels: torch.Tensor,
|
| 83 |
+
pad: bool,
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
"""Embeds point prompts."""
|
| 86 |
+
points = points + 0.5 # Shift to center of pixel
|
| 87 |
+
if pad:
|
| 88 |
+
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
| 89 |
+
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
| 90 |
+
points = torch.cat([points, padding_point], dim=1)
|
| 91 |
+
labels = torch.cat([labels, padding_label], dim=1)
|
| 92 |
+
point_embedding = self.pe_layer.forward_with_coords(
|
| 93 |
+
points, self.input_image_size
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
point_embedding = torch.where(
|
| 97 |
+
(labels == -1).unsqueeze(-1),
|
| 98 |
+
torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
|
| 99 |
+
point_embedding,
|
| 100 |
+
)
|
| 101 |
+
point_embedding = torch.where(
|
| 102 |
+
(labels == 0).unsqueeze(-1),
|
| 103 |
+
point_embedding + self.point_embeddings[0].weight,
|
| 104 |
+
point_embedding,
|
| 105 |
+
)
|
| 106 |
+
point_embedding = torch.where(
|
| 107 |
+
(labels == 1).unsqueeze(-1),
|
| 108 |
+
point_embedding + self.point_embeddings[1].weight,
|
| 109 |
+
point_embedding,
|
| 110 |
+
)
|
| 111 |
+
point_embedding = torch.where(
|
| 112 |
+
(labels == 2).unsqueeze(-1),
|
| 113 |
+
point_embedding + self.point_embeddings[2].weight,
|
| 114 |
+
point_embedding,
|
| 115 |
+
)
|
| 116 |
+
point_embedding = torch.where(
|
| 117 |
+
(labels == 3).unsqueeze(-1),
|
| 118 |
+
point_embedding + self.point_embeddings[3].weight,
|
| 119 |
+
point_embedding,
|
| 120 |
+
)
|
| 121 |
+
return point_embedding
|
| 122 |
+
|
| 123 |
+
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
| 124 |
+
"""Embeds box prompts."""
|
| 125 |
+
boxes = boxes + 0.5 # Shift to center of pixel
|
| 126 |
+
coords = boxes.reshape(-1, 2, 2)
|
| 127 |
+
corner_embedding = self.pe_layer.forward_with_coords(
|
| 128 |
+
coords, self.input_image_size
|
| 129 |
+
)
|
| 130 |
+
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
| 131 |
+
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
| 132 |
+
return corner_embedding
|
| 133 |
+
|
| 134 |
+
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
| 135 |
+
"""Embeds mask inputs."""
|
| 136 |
+
mask_embedding = self.mask_downscaling(masks)
|
| 137 |
+
return mask_embedding
|
| 138 |
+
|
| 139 |
+
def _get_batch_size(
|
| 140 |
+
self,
|
| 141 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 142 |
+
boxes: Optional[torch.Tensor],
|
| 143 |
+
masks: Optional[torch.Tensor],
|
| 144 |
+
) -> int:
|
| 145 |
+
"""
|
| 146 |
+
Gets the batch size of the output given the batch size of the input prompts.
|
| 147 |
+
"""
|
| 148 |
+
if points is not None:
|
| 149 |
+
return points[0].shape[0]
|
| 150 |
+
elif boxes is not None:
|
| 151 |
+
return boxes.shape[0]
|
| 152 |
+
elif masks is not None:
|
| 153 |
+
return masks.shape[0]
|
| 154 |
+
else:
|
| 155 |
+
return 1
|
| 156 |
+
|
| 157 |
+
def _get_device(self) -> torch.device:
|
| 158 |
+
return self.point_embeddings[0].weight.device
|
| 159 |
+
|
| 160 |
+
def forward(
|
| 161 |
+
self,
|
| 162 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 163 |
+
boxes: Optional[torch.Tensor],
|
| 164 |
+
masks: Optional[torch.Tensor],
|
| 165 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 166 |
+
"""
|
| 167 |
+
Embeds different types of prompts, returning both sparse and dense
|
| 168 |
+
embeddings.
|
| 169 |
+
|
| 170 |
+
Arguments:
|
| 171 |
+
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
| 172 |
+
and labels to embed.
|
| 173 |
+
boxes (torch.Tensor or none): boxes to embed
|
| 174 |
+
masks (torch.Tensor or none): masks to embed
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
| 178 |
+
BxNx(embed_dim), where N is determined by the number of input points
|
| 179 |
+
and boxes.
|
| 180 |
+
torch.Tensor: dense embeddings for the masks, in the shape
|
| 181 |
+
Bx(embed_dim)x(embed_H)x(embed_W)
|
| 182 |
+
"""
|
| 183 |
+
bs = self._get_batch_size(points, boxes, masks)
|
| 184 |
+
sparse_embeddings = torch.empty(
|
| 185 |
+
(bs, 0, self.embed_dim), device=self._get_device()
|
| 186 |
+
)
|
| 187 |
+
if points is not None:
|
| 188 |
+
coords, labels = points
|
| 189 |
+
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
| 190 |
+
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
| 191 |
+
if boxes is not None:
|
| 192 |
+
box_embeddings = self._embed_boxes(boxes)
|
| 193 |
+
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
| 194 |
+
|
| 195 |
+
if masks is not None:
|
| 196 |
+
dense_embeddings = self._embed_masks(masks)
|
| 197 |
+
else:
|
| 198 |
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
| 199 |
+
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return sparse_embeddings, dense_embeddings
|