Mirko Trasciatti commited on
Commit
e7cbaa4
·
1 Parent(s): 962d5a0

Clean rebuild: Updated README, app.py, and requirements with pydantic fix

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +54 -17
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🎥
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.6.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -1,5 +1,6 @@
1
  """
2
- SAM2 Video Segmentation Space - Minimal Working Version
 
3
  """
4
 
5
  import gradio as gr
@@ -7,7 +8,9 @@ import torch
7
  import numpy as np
8
  import cv2
9
  import tempfile
 
10
  import os
 
11
  from transformers import Sam2VideoModel, Sam2VideoProcessor
12
  from PIL import Image
13
  import spaces
@@ -25,7 +28,9 @@ def initialize_model():
25
 
26
  if torch.cuda.is_available():
27
  device = torch.device("cuda")
28
- dtype = torch.float32
 
 
29
  elif torch.backends.mps.is_available():
30
  device = torch.device("mps")
31
  dtype = torch.float32
@@ -33,7 +38,7 @@ def initialize_model():
33
  device = torch.device("cpu")
34
  dtype = torch.float32
35
 
36
- print(f"Loading SAM2 model on {device}...")
37
 
38
  model = Sam2VideoModel.from_pretrained(MODEL_NAME).to(device, dtype=dtype)
39
  processor = Sam2VideoProcessor.from_pretrained(MODEL_NAME)
@@ -43,7 +48,7 @@ def initialize_model():
43
 
44
 
45
  def load_video_cv2(video_path):
46
- """Load video using OpenCV."""
47
  cap = cv2.VideoCapture(video_path)
48
  frames = []
49
  fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
@@ -52,6 +57,7 @@ def load_video_cv2(video_path):
52
  ret, frame = cap.read()
53
  if not ret:
54
  break
 
55
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
56
  frames.append(Image.fromarray(frame_rgb))
57
 
@@ -68,27 +74,33 @@ def segment_video_simple(video_file, point_x, point_y, frame_idx, remove_bg):
68
  initialize_model()
69
 
70
  try:
 
71
  if video_file is None:
72
  return None, "❌ Error: No video file provided"
73
 
 
74
  video_path = str(video_file)
 
75
  if not os.path.exists(video_path):
76
- return None, f"❌ Error: Video file not found"
 
 
77
 
78
  # Convert inputs
79
  point_x = int(float(point_x))
80
  point_y = int(float(point_y))
81
  frame_idx = int(float(frame_idx))
82
 
83
- # Load video
84
  video_frames, video_info = load_video_cv2(video_path)
85
  fps = video_info.get('fps', 30.0)
86
 
87
  # Initialize inference session
 
88
  inference_session = processor.init_video_session(
89
  video=video_frames,
90
  inference_device=device,
91
- dtype=torch.float32,
92
  )
93
 
94
  # Add annotation
@@ -100,8 +112,11 @@ def segment_video_simple(video_file, point_x, point_y, frame_idx, remove_bg):
100
  input_labels=[[[1]]],
101
  )
102
 
103
- # Run inference
104
- model(inference_session=inference_session, frame_idx=frame_idx)
 
 
 
105
 
106
  # Propagate through video
107
  video_segments = {}
@@ -114,7 +129,7 @@ def segment_video_simple(video_file, point_x, point_y, frame_idx, remove_bg):
114
  video_segments[sam2_output.frame_idx] = video_res_masks
115
 
116
  # Create output video
117
- output_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
118
  first_frame = np.array(video_frames[0])
119
  height, width = first_frame.shape[:2]
120
 
@@ -147,14 +162,16 @@ def segment_video_simple(video_file, point_x, point_y, frame_idx, remove_bg):
147
 
148
  out.release()
149
 
 
150
  if os.path.exists(output_path):
151
  return output_path, f"✅ Success! Processed {len(video_segments)} frames"
152
  else:
153
- return None, "❌ Error: Output file was not created"
154
 
155
  except Exception as e:
156
  import traceback
157
- traceback.print_exc()
 
158
  return None, f"❌ Error: {str(e)}"
159
 
160
 
@@ -163,16 +180,25 @@ def create_app():
163
  initialize_model()
164
 
165
  with gr.Blocks(title="SAM2 Video Background Remover") as app:
166
- gr.Markdown("# 🎥 SAM2 Video Background Remover")
167
- gr.Markdown("Remove backgrounds from videos by tracking objects with SAM2")
 
 
 
 
 
 
 
 
168
 
169
  with gr.Row():
170
  with gr.Column():
 
171
  video_input = gr.File(label="Upload Video", file_types=["video"])
172
 
173
  with gr.Row():
174
- point_x = gr.Textbox(label="Point X", value="360")
175
- point_y = gr.Textbox(label="Point Y", value="640")
176
 
177
  frame_idx = gr.Textbox(label="Frame Index", value="0")
178
  remove_bg = gr.Checkbox(label="Remove Background", value=True)
@@ -188,10 +214,21 @@ def create_app():
188
  inputs=[video_input, point_x, point_y, frame_idx, remove_bg],
189
  outputs=[output_video, status_text]
190
  )
 
 
 
 
 
 
 
 
 
 
191
 
192
  return app
193
 
194
 
195
  if __name__ == "__main__":
196
  app = create_app()
197
- app.launch(share=True)
 
 
1
  """
2
+ SAM2 Video Segmentation Space - SIMPLIFIED VERSION
3
+ Removes background from videos by tracking specified objects.
4
  """
5
 
6
  import gradio as gr
 
8
  import numpy as np
9
  import cv2
10
  import tempfile
11
+ import json
12
  import os
13
+ from typing import List, Tuple, Optional, Dict, Any
14
  from transformers import Sam2VideoModel, Sam2VideoProcessor
15
  from PIL import Image
16
  import spaces
 
28
 
29
  if torch.cuda.is_available():
30
  device = torch.device("cuda")
31
+ dtype = torch.float32 # Use float32 for universal GPU compatibility
32
+ print(f"CUDA available: {torch.cuda.is_available()}")
33
+ print(f"CUDA device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
34
  elif torch.backends.mps.is_available():
35
  device = torch.device("mps")
36
  dtype = torch.float32
 
38
  device = torch.device("cpu")
39
  dtype = torch.float32
40
 
41
+ print(f"Loading SAM2 model on {device} with dtype {dtype}...")
42
 
43
  model = Sam2VideoModel.from_pretrained(MODEL_NAME).to(device, dtype=dtype)
44
  processor = Sam2VideoProcessor.from_pretrained(MODEL_NAME)
 
48
 
49
 
50
  def load_video_cv2(video_path):
51
+ """Load video using OpenCV to preserve orientation."""
52
  cap = cv2.VideoCapture(video_path)
53
  frames = []
54
  fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
 
57
  ret, frame = cap.read()
58
  if not ret:
59
  break
60
+ # Convert BGR to RGB
61
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
62
  frames.append(Image.fromarray(frame_rgb))
63
 
 
74
  initialize_model()
75
 
76
  try:
77
+ # Handle video_file - gr.File passes it as a string path directly
78
  if video_file is None:
79
  return None, "❌ Error: No video file provided"
80
 
81
+ # gr.File returns the file path as a string
82
  video_path = str(video_file)
83
+
84
  if not os.path.exists(video_path):
85
+ return None, f"❌ Error: Video file not found: {video_path}"
86
+
87
+ print(f"Processing video from: {video_path}")
88
 
89
  # Convert inputs
90
  point_x = int(float(point_x))
91
  point_y = int(float(point_y))
92
  frame_idx = int(float(frame_idx))
93
 
94
+ # Load video using OpenCV to preserve orientation
95
  video_frames, video_info = load_video_cv2(video_path)
96
  fps = video_info.get('fps', 30.0)
97
 
98
  # Initialize inference session
99
+ dtype = torch.float32 # Use float32 for universal compatibility
100
  inference_session = processor.init_video_session(
101
  video=video_frames,
102
  inference_device=device,
103
+ dtype=dtype,
104
  )
105
 
106
  # Add annotation
 
112
  input_labels=[[[1]]],
113
  )
114
 
115
+ # Run inference on first frame
116
+ outputs = model(
117
+ inference_session=inference_session,
118
+ frame_idx=frame_idx,
119
+ )
120
 
121
  # Propagate through video
122
  video_segments = {}
 
129
  video_segments[sam2_output.frame_idx] = video_res_masks
130
 
131
  # Create output video
132
+ output_path = tempfile.mktemp(suffix=".mp4")
133
  first_frame = np.array(video_frames[0])
134
  height, width = first_frame.shape[:2]
135
 
 
162
 
163
  out.release()
164
 
165
+ # Return the video file path (Gradio will handle it)
166
  if os.path.exists(output_path):
167
  return output_path, f"✅ Success! Processed {len(video_segments)} frames"
168
  else:
169
+ return None, f"❌ Error: Output file was not created"
170
 
171
  except Exception as e:
172
  import traceback
173
+ error_details = traceback.format_exc()
174
+ print(f"Error in segment_video_simple: {error_details}")
175
  return None, f"❌ Error: {str(e)}"
176
 
177
 
 
180
  initialize_model()
181
 
182
  with gr.Blocks(title="SAM2 Video Background Remover") as app:
183
+ gr.Markdown("""
184
+ # 🎥 SAM2 Video Background Remover
185
+
186
+ Remove backgrounds from videos by tracking objects with Meta's SAM2.
187
+
188
+ **How to use:**
189
+ 1. Upload a video
190
+ 2. Enter X, Y coordinates of the object to track (from first frame)
191
+ 3. Click "Process Video"
192
+ """)
193
 
194
  with gr.Row():
195
  with gr.Column():
196
+ # Using gr.File instead of gr.Video for better API compatibility
197
  video_input = gr.File(label="Upload Video", file_types=["video"])
198
 
199
  with gr.Row():
200
+ point_x = gr.Textbox(label="Point X", value="320")
201
+ point_y = gr.Textbox(label="Point Y", value="240")
202
 
203
  frame_idx = gr.Textbox(label="Frame Index", value="0")
204
  remove_bg = gr.Checkbox(label="Remove Background", value=True)
 
214
  inputs=[video_input, point_x, point_y, frame_idx, remove_bg],
215
  outputs=[output_video, status_text]
216
  )
217
+
218
+ gr.Markdown("""
219
+ ### Tips:
220
+ - Point X, Y: Coordinates of the object in the video
221
+ - For a 720x1280 portrait video, center is typically X=360, Y=640
222
+ - For a 1920x1080 landscape video, center is typically X=960, Y=540
223
+ - Frame Index: Usually 0 (first frame)
224
+ - Processing time depends on video length (CPU processing is slow)
225
+ - Portrait and landscape videos are both supported!
226
+ """)
227
 
228
  return app
229
 
230
 
231
  if __name__ == "__main__":
232
  app = create_app()
233
+ app.launch(share=True, show_error=True)
234
+