InstantSplat / app.py
Long Hoang
attempt api fix
d1eac9d
import os, subprocess, shlex, sys, gc
import time
import torch
import numpy as np
import shutil
import argparse
import uuid
import spaces
import requests
from typing import Optional
import base64
import tempfile
from pydantic import BaseModel
# --- FIX: make gradio compatible by downgrading huggingface_hub -----------
# Gradio 5.0.1 requires huggingface_hub<1.0.0 due to HfFolder import
subprocess.run(
shlex.split("pip install 'huggingface_hub<1.0.0'"),
check=False,
)
# --------------------------------------------------------------------------
import gradio as gr # import AFTER the pip install above
import trimesh
from plyfile import PlyData
# install custom wheels for gaussian splatting
subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install wheel/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install wheel/curope-0.0.0-cp310-cp310-linux_x86_64.whl"))
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r")))
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.device import to_numpy
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, OptimizationParams
from train_joint import training
from render_by_interp import render_sets
GRADIO_CACHE_FOLDER = './gradio_cache_folder'
#############################################################################################################################################
def upload_to_supabase_storage(
file_path: str,
remote_path: str,
supabase_url: str,
supabase_key: str,
bucket_name: str = "outputs",
content_type: Optional[str] = None,
max_retries: int = 3
) -> Optional[str]:
"""
Upload a file to Supabase Storage using HTTP requests.
Args:
file_path: Local path to the file to upload
remote_path: Path in the bucket (e.g., "folder/file.ply")
supabase_url: Supabase project URL (e.g., "https://xxxxx.supabase.co")
supabase_key: Supabase service role key or anon key with appropriate permissions
bucket_name: Name of the storage bucket
content_type: MIME type of the file (auto-detected if None)
max_retries: Number of retry attempts for failed uploads
Returns:
Public URL of the uploaded file, or None if upload failed
"""
if not os.path.exists(file_path):
print(f"File not found: {file_path}")
return None
# Use direct storage hostname for better performance
# Extract project ID from URL
project_id = supabase_url.replace("https://", "").replace(".supabase.co", "")
storage_url = f"https://{project_id}.storage.supabase.co"
# Construct the upload endpoint
upload_url = f"{storage_url}/storage/v1/object/{bucket_name}/{remote_path}"
# Auto-detect content type if not provided
if content_type is None:
if file_path.endswith('.ply'):
content_type = 'application/octet-stream'
elif file_path.endswith('.glb'):
content_type = 'model/gltf-binary'
elif file_path.endswith('.mp4'):
content_type = 'video/mp4'
else:
content_type = 'application/octet-stream'
headers = {
'Authorization': f'Bearer {supabase_key}',
'apikey': supabase_key,
'Content-Type': content_type,
'x-upsert': 'true', # Overwrite if file exists
}
file_size = os.path.getsize(file_path)
print(f"Uploading {file_path} ({file_size / (1024**2):.2f} MB) to Supabase Storage...")
for attempt in range(max_retries):
try:
with open(file_path, 'rb') as f:
response = requests.post(
upload_url,
headers=headers,
data=f,
timeout=600 # 10 minute timeout for large files
)
if response.status_code in (200, 201):
# Construct public URL
public_url = f"{storage_url}/storage/v1/object/public/{bucket_name}/{remote_path}"
print(f"Successfully uploaded to: {public_url}")
return public_url
else:
print(f"Upload failed (attempt {attempt + 1}/{max_retries}): {response.status_code} - {response.text}")
if attempt < max_retries - 1:
time.sleep(2 ** attempt) # Exponential backoff
continue
except requests.exceptions.Timeout:
print(f"Upload timeout (attempt {attempt + 1}/{max_retries})")
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
except requests.exceptions.RequestException as e:
print(f"Upload error (attempt {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
print("Upload failed after all retries")
return None
def get_dust3r_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
parser.add_argument("--model_path", type=str, default="submodules/dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights")
parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--schedule", type=str, default='linear')
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--niter", type=int, default=300)
parser.add_argument("--focal_avg", type=bool, default=True)
parser.add_argument("--n_views", type=int, default=3)
parser.add_argument("--base_path", type=str, default=GRADIO_CACHE_FOLDER)
return parser
def convert_ply_to_glb(ply_path, glb_path):
try:
plydata = PlyData.read(ply_path)
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
np.asarray(plydata.elements[0]["y"]),
np.asarray(plydata.elements[0]["z"])), axis=1)
# Extract DC features (colors)
# f_dc_0, f_dc_1, f_dc_2 are SH coefficients for R, G, B
f_dc_0 = np.asarray(plydata.elements[0]["f_dc_0"])
f_dc_1 = np.asarray(plydata.elements[0]["f_dc_1"])
f_dc_2 = np.asarray(plydata.elements[0]["f_dc_2"])
# SH2RGB: sh * C0 + 0.5
# C0 = 0.28209479177387814
C0 = 0.28209479177387814
r = f_dc_0 * C0 + 0.5
g = f_dc_1 * C0 + 0.5
b = f_dc_2 * C0 + 0.5
colors = np.stack((r, g, b), axis=1)
# Clip to [0, 1]
colors = np.clip(colors, 0, 1)
# Convert to uint8 for trimesh
colors = (colors * 255).astype(np.uint8)
# Create PointCloud
pcd = trimesh.points.PointCloud(vertices=xyz, colors=colors)
# Export
pcd.export(glb_path)
return True
except Exception as e:
print(f"Error converting PLY to GLB: {e}")
return False
@spaces.GPU(duration=150)
def process(inputfiles, input_path=None):
"""
Process images and generate 3D Gaussian Splatting model.
Returns: (video_path, ply_url, ply_download, ply_model, glb_model, glb_url)
"""
try:
if input_path is not None:
imgs_path = './assets/example/' + input_path
imgs_names = sorted(os.listdir(imgs_path))
inputfiles = []
for imgs_name in imgs_names:
file_path = os.path.join(imgs_path, imgs_name)
print(file_path)
inputfiles.append(file_path)
print(inputfiles)
# ------ (1) Coarse Geometric Initialization ------
parser = get_dust3r_args_parser()
opt = parser.parse_args()
tmp_user_folder = str(uuid.uuid4()).replace("-", "")
opt.img_base_path = os.path.join(opt.base_path, tmp_user_folder)
img_folder_path = os.path.join(opt.img_base_path, "images")
model = AsymmetricCroCo3DStereo.from_pretrained(opt.model_path).to(opt.device)
os.makedirs(img_folder_path, exist_ok=True)
opt.n_views = len(inputfiles)
if opt.n_views == 1:
raise gr.Error("The number of input images should be greater than 1.")
print("Multiple images: ", inputfiles)
for image_path in inputfiles:
if input_path is not None:
shutil.copy(image_path, img_folder_path)
else:
shutil.move(image_path, img_folder_path)
train_img_list = sorted(os.listdir(img_folder_path))
assert len(train_img_list)==opt.n_views, f"Number of images in the folder is not equal to {opt.n_views}"
images, ori_size, imgs_resolution = load_images(img_folder_path, size=512)
resolutions_are_equal = len(set(imgs_resolution)) == 1
if resolutions_are_equal == False:
raise gr.Error("The resolution of the input image should be the same.")
print("ori_size", ori_size)
start_time = time.time()
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
output = inference(pairs, model, opt.device, batch_size=opt.batch_size)
output_colmap_path=img_folder_path.replace("images", "sparse/0")
os.makedirs(output_colmap_path, exist_ok=True)
scene = global_aligner(output, device=opt.device, mode=GlobalAlignerMode.PointCloudOptimizer)
loss = compute_global_alignment(scene=scene, init="mst", niter=opt.niter, schedule=opt.schedule, lr=opt.lr, focal_avg=opt.focal_avg)
scene = scene.clean_pointcloud()
imgs = to_numpy(scene.imgs)
focals = scene.get_focals()
poses = to_numpy(scene.get_im_poses())
pts3d = to_numpy(scene.get_pts3d())
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0)))
confidence_masks = to_numpy(scene.get_masks())
intrinsics = to_numpy(scene.get_intrinsics())
end_time = time.time()
print(f"Time taken for {opt.n_views} views: {end_time-start_time} seconds")
save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt'))
save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list)
pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)])
color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)])
color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8)
storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs)
pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3)
np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all)
np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu()))
### save VRAM
del scene
torch.cuda.empty_cache()
gc.collect()
##################################################################################################################################################
# ------ (2) Fast 3D-Gaussian Optimization ------
parser = ArgumentParser(description="Training script parameters")
lp = ModelParams(parser)
op = OptimizationParams(parser)
pp = PipelineParams(parser)
parser.add_argument('--debug_from', type=int, default=-1)
parser.add_argument("--test_iterations", nargs="+", type=int, default=[])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[])
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
parser.add_argument("--start_checkpoint", type=str, default=None)
# FIX: scene must be string, not int
parser.add_argument("--scene", type=str, default="demo")
parser.add_argument("--n_views", type=int, default=3)
parser.add_argument("--get_video", action="store_true")
parser.add_argument("--optim_pose", type=bool, default=True)
parser.add_argument("--skip_train", action="store_true")
parser.add_argument("--skip_test", action="store_true")
# FIX: do NOT parse system argv
args, _ = parser.parse_known_args([])
args.save_iterations.append(args.iterations)
args.model_path = opt.img_base_path + '/output/'
args.source_path = opt.img_base_path
args.iteration = 1000
os.makedirs(args.model_path, exist_ok=True)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args)
##################################################################################################################################################
# ------ (3) Render video by interpolation ------
parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
pipeline = PipelineParams(parser)
args.eval = True
args.get_video = True
args.n_views = opt.n_views
render_sets(
model.extract(args),
args.iteration,
pipeline.extract(args),
args.skip_train,
args.skip_test,
args,
)
output_ply_path = opt.img_base_path + f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
output_video_path = opt.img_base_path + f'/output/demo_{opt.n_views}_view.mp4'
# sanity checks
if not os.path.exists(output_ply_path):
print("PLY not found at:", output_ply_path)
raise gr.Error(f"PLY file not found at {output_ply_path}")
if not os.path.exists(output_video_path):
print("Video not found at:", output_video_path)
raise gr.Error(f"Video file not found at {output_video_path}")
# Convert PLY to GLB for visualization
output_glb_path = output_ply_path.replace('.ply', '.glb')
if not convert_ply_to_glb(output_ply_path, output_glb_path):
output_glb_path = None
# ------ (4) upload .ply and .glb to Supabase Storage ------
ply_url = None
glb_url = None
rel_remote_path_ply = f"{tmp_user_folder}_point_cloud.ply"
rel_remote_path_glb = f"{tmp_user_folder}_point_cloud.glb"
supabase_url = os.environ.get("SUPABASE_URL")
supabase_key = os.environ.get("SUPABASE_KEY")
supabase_bucket = os.environ.get("SUPABASE_BUCKET", "outputs")
if supabase_url and supabase_key:
try:
# Upload PLY
ply_url = upload_to_supabase_storage(
file_path=output_ply_path,
remote_path=rel_remote_path_ply,
supabase_url=supabase_url,
supabase_key=supabase_key,
bucket_name=supabase_bucket,
content_type='application/octet-stream'
)
if ply_url is None:
ply_url = "Error uploading PLY file"
# Upload GLB if it exists
if output_glb_path and os.path.exists(output_glb_path):
glb_url = upload_to_supabase_storage(
file_path=output_glb_path,
remote_path=rel_remote_path_glb,
supabase_url=supabase_url,
supabase_key=supabase_key,
bucket_name=supabase_bucket,
content_type='model/gltf-binary'
)
if glb_url is None:
glb_url = "Error uploading GLB file"
except Exception as e:
print("Failed to upload files to Supabase Storage:", e)
ply_url = f"Error uploading: {e}"
else:
print("SUPABASE_URL or SUPABASE_KEY not found, skipping upload.")
ply_url = "Supabase credentials not set"
# return:
# 1) video path (for gr.Video)
# 2) ply URL (for API + textbox)
# 3) ply file path (for gr.File download)
# 4) ply file path (for gr.Model3D viewer)
# 5) glb file path (for gr.Model3D viewer)
# 6) glb URL (for API)
return output_video_path, ply_url, output_ply_path, output_ply_path, output_glb_path, glb_url
except Exception as e:
# Catch all errors and return them in the API response
error_msg = f"Error: {str(e)}"
error_traceback = f"Traceback:\n{__import__('traceback').format_exc()}"
full_error = f"{error_msg}\n\n{error_traceback}"
print("=" * 80)
print("ERROR IN PROCESS FUNCTION:")
print("=" * 80)
print(full_error)
print("=" * 80)
# Return error messages in the same format as successful returns
# This allows API clients to see the exact error
error_prefix = "ERROR: "
return (
None, # video path
f"{error_prefix}{error_msg}", # ply_url
None, # ply_download
None, # ply_model
None, # glb_model
f"{error_prefix}{error_msg}" # glb_url
)
##################################################################################################################################################
def process_api(inputfiles):
"""
API-friendly wrapper that returns only the GLB URL.
Args:
inputfiles: List of image files
Returns:
dict with glb_url, ply_url, and video_url
"""
result = process(inputfiles, input_path=None)
video_path, ply_url, _, _, glb_path, glb_url = result
# Detect error responses by prefix
is_error = False
for v in (glb_url, ply_url):
if isinstance(v, str) and v.startswith("ERROR:"):
is_error = True
break
return {
"glb_url": glb_url if glb_url else "Upload failed",
"ply_url": ply_url if ply_url else "Upload failed",
"video_available": video_path is not None,
"status": "error" if is_error else ("success" if glb_url else "error"),
}
def process_base64_api(images_b64):
"""
API entrypoint that accepts a list of base64-encoded images,
decodes them to temporary files, and runs the full pipeline.
"""
# Create a temporary directory under the Gradio cache folder
tmp_root = os.path.join(GRADIO_CACHE_FOLDER, "api_uploads")
os.makedirs(tmp_root, exist_ok=True)
tmp_dir = tempfile.mkdtemp(prefix="api_", dir=tmp_root)
decoded_paths = []
for idx, img_str in enumerate(images_b64):
if not isinstance(img_str, str):
continue
# Handle optional data URL prefix
if img_str.startswith("data:"):
try:
header, b64_data = img_str.split(",", 1)
except ValueError:
b64_data = img_str
else:
b64_data = img_str
try:
img_bytes = base64.b64decode(b64_data)
except Exception:
# Skip invalid entries
continue
out_path = os.path.join(tmp_dir, f"img_{idx:02d}.jpg")
with open(out_path, "wb") as f:
f.write(img_bytes)
decoded_paths.append(out_path)
if len(decoded_paths) < 2:
return {
"glb_url": "ERROR: Need at least 2 valid base64 images",
"ply_url": "ERROR: Need at least 2 valid base64 images",
"video_available": False,
"status": "error",
}
# Reuse process_api to run the pipeline and format the response
return process_api(decoded_paths)
_TITLE = '''InstantSplat'''
_DESCRIPTION = '''
<div style="display: flex; justify-content: center; align-items: center;">
<div style="width: 100%; text-align: center; font-size: 30px;">
<strong>InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds</strong>
</div>
</div>
<p></p>
<div align="center">
<a style="display:inline-block" href="https://instantsplat.github.io/"><img src='https://img.shields.io/badge/Project_Page-1c7d45?logo=gumtree'></a>&nbsp;
<a style="display:inline-block" href="https://www.youtube.com/watch?v=fxf_ypd7eD8"><img src='https://img.shields.io/badge/Demo_Video-E33122?logo=Youtube'></a>&nbsp;
<a style="display:inline-block" href="https://arxiv.org/abs/2403.20309"><img src="https://img.shields.io/badge/ArXiv-2403.20309-b31b1b?logo=arxiv" alt='arxiv'></a>
<a title="Social" href="https://x.com/KairunWen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a>
</div>
<p></p>
* Official demo of: [InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds](https://instantsplat.github.io/).
* Sparse-view examples for direct viewing: you can simply click the examples (in the bottom of the page), to quickly view the results on representative data.
* Training speeds may slow if the resolution or number of images is large. To achieve performance comparable to what has been reported, please conduct tests on your own GPU (A100/4090).
'''
block = gr.Blocks().queue()
with block:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(_DESCRIPTION)
with gr.Row(variant='panel'):
with gr.Tab("Input"):
inputfiles = gr.File(file_count="multiple", label="images")
input_path = gr.Textbox(visible=False, label="example_path")
button_gen = gr.Button("RUN")
with gr.Tab("API"):
gr.Markdown("""
## 🚀 API Access
Submit images programmatically and get back the Supabase GLB URL.
### Quick Start (Python)
```bash
pip install gradio_client
```
```python
from gradio_client import Client
# Connect to this Space
client = Client("your-username/InstantSplat")
# Submit images
result = client.predict(
["image1.jpg", "image2.jpg", "image3.jpg"],
api_name="/predict"
)
# Get GLB URL (it's the 6th element)
glb_url = result[5]
print(f"GLB URL: {glb_url}")
```
### Response Format
The API returns a tuple with 6 elements:
- `[0]` - Video path
- `[1]` - PLY URL (Supabase)
- `[2]` - PLY download path
- `[3]` - PLY model path
- `[4]` - GLB model path
- `[5]` - **GLB URL (Supabase)** ← This is what you want!
### CLI Tool
Use the included `api_client.py`:
```bash
python api_client.py img1.jpg img2.jpg img3.jpg
```
### Full Documentation
See `API_GUIDE.md` for complete documentation including:
- JavaScript/TypeScript examples
- Error handling
- Batch processing
- Complete workflows
### Requirements
- **Minimum**: 2 images (3+ recommended)
- **Same resolution**: All images must have matching dimensions
- **Formats**: JPG, PNG
""")
with gr.Row(variant='panel'):
with gr.Tab("Output"):
with gr.Column(scale=2):
with gr.Group():
output_model_glb = gr.Model3D(
label="3D Model (GLB Point Cloud)",
interactive=False,
camera_position=[0.5, 0.5, 1],
)
output_model_ply = gr.Model3D(
label="Original PLY (Gaussian Splat)",
interactive=False,
camera_position=[0.5, 0.5, 1],
)
gr.Markdown(
"""
<div class="model-description">
&nbsp;&nbsp;Use the left mouse button to rotate, the scroll wheel to zoom, and the right mouse button to move.
</div>
"""
)
output_file = gr.Textbox(
label="PLY download URL",
interactive=False,
)
output_download = gr.File(
label="Download PLY",
interactive=False,
)
with gr.Column(scale=1):
output_video = gr.Video(label="video")
# Hidden output for GLB URL (for API access)
output_glb_url = gr.Textbox(visible=False, label="GLB URL")
button_gen.click(
process,
inputs=[inputfiles],
outputs=[output_video, output_file, output_download, output_model_ply, output_model_glb, output_glb_url]
)
gr.Examples(
examples=[
"sora-santorini-3-views",
],
inputs=[input_path],
outputs=[output_video, output_file, output_download, output_model_ply, output_model_glb, output_glb_url],
fn=lambda x: process(inputfiles=None, input_path=x),
cache_examples=False, # Disabled for faster startup
label='Sparse-view Examples'
)
class Base64Request(BaseModel):
images: list[str]
fastapi_app = block.app
@fastapi_app.post("/base64")
async def api_base64_endpoint(req: Base64Request):
"""
FastAPI endpoint that accepts base64-encoded images and returns
the same JSON structure as process_api/process_base64_api.
"""
return process_base64_api(req.images)
block.launch(server_name="0.0.0.0", share=False, show_api=True)