import os, subprocess, shlex, sys, gc import time import torch import numpy as np import shutil import argparse import uuid import spaces import requests from typing import Optional import base64 import tempfile from pydantic import BaseModel # --- FIX: make gradio compatible by downgrading huggingface_hub ----------- # Gradio 5.0.1 requires huggingface_hub<1.0.0 due to HfFolder import subprocess.run( shlex.split("pip install 'huggingface_hub<1.0.0'"), check=False, ) # -------------------------------------------------------------------------- import gradio as gr # import AFTER the pip install above import trimesh from plyfile import PlyData # install custom wheels for gaussian splatting subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl")) subprocess.run(shlex.split("pip install wheel/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl")) subprocess.run(shlex.split("pip install wheel/curope-0.0.0-cp310-cp310-linux_x86_64.whl")) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r"))) from dust3r.inference import inference from dust3r.model import AsymmetricCroCo3DStereo from dust3r.utils.device import to_numpy from dust3r.image_pairs import make_pairs from dust3r.cloud_opt import global_aligner, GlobalAlignerMode from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images from argparse import ArgumentParser from arguments import ModelParams, PipelineParams, OptimizationParams from train_joint import training from render_by_interp import render_sets GRADIO_CACHE_FOLDER = './gradio_cache_folder' ############################################################################################################################################# def upload_to_supabase_storage( file_path: str, remote_path: str, supabase_url: str, supabase_key: str, bucket_name: str = "outputs", content_type: Optional[str] = None, max_retries: int = 3 ) -> Optional[str]: """ Upload a file to Supabase Storage using HTTP requests. Args: file_path: Local path to the file to upload remote_path: Path in the bucket (e.g., "folder/file.ply") supabase_url: Supabase project URL (e.g., "https://xxxxx.supabase.co") supabase_key: Supabase service role key or anon key with appropriate permissions bucket_name: Name of the storage bucket content_type: MIME type of the file (auto-detected if None) max_retries: Number of retry attempts for failed uploads Returns: Public URL of the uploaded file, or None if upload failed """ if not os.path.exists(file_path): print(f"File not found: {file_path}") return None # Use direct storage hostname for better performance # Extract project ID from URL project_id = supabase_url.replace("https://", "").replace(".supabase.co", "") storage_url = f"https://{project_id}.storage.supabase.co" # Construct the upload endpoint upload_url = f"{storage_url}/storage/v1/object/{bucket_name}/{remote_path}" # Auto-detect content type if not provided if content_type is None: if file_path.endswith('.ply'): content_type = 'application/octet-stream' elif file_path.endswith('.glb'): content_type = 'model/gltf-binary' elif file_path.endswith('.mp4'): content_type = 'video/mp4' else: content_type = 'application/octet-stream' headers = { 'Authorization': f'Bearer {supabase_key}', 'apikey': supabase_key, 'Content-Type': content_type, 'x-upsert': 'true', # Overwrite if file exists } file_size = os.path.getsize(file_path) print(f"Uploading {file_path} ({file_size / (1024**2):.2f} MB) to Supabase Storage...") for attempt in range(max_retries): try: with open(file_path, 'rb') as f: response = requests.post( upload_url, headers=headers, data=f, timeout=600 # 10 minute timeout for large files ) if response.status_code in (200, 201): # Construct public URL public_url = f"{storage_url}/storage/v1/object/public/{bucket_name}/{remote_path}" print(f"Successfully uploaded to: {public_url}") return public_url else: print(f"Upload failed (attempt {attempt + 1}/{max_retries}): {response.status_code} - {response.text}") if attempt < max_retries - 1: time.sleep(2 ** attempt) # Exponential backoff continue except requests.exceptions.Timeout: print(f"Upload timeout (attempt {attempt + 1}/{max_retries})") if attempt < max_retries - 1: time.sleep(2 ** attempt) continue except requests.exceptions.RequestException as e: print(f"Upload error (attempt {attempt + 1}/{max_retries}): {e}") if attempt < max_retries - 1: time.sleep(2 ** attempt) continue print("Upload failed after all retries") return None def get_dust3r_args_parser(): parser = argparse.ArgumentParser() parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size") parser.add_argument("--model_path", type=str, default="submodules/dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights") parser.add_argument("--device", type=str, default='cuda', help="pytorch device") parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--schedule", type=str, default='linear') parser.add_argument("--lr", type=float, default=0.01) parser.add_argument("--niter", type=int, default=300) parser.add_argument("--focal_avg", type=bool, default=True) parser.add_argument("--n_views", type=int, default=3) parser.add_argument("--base_path", type=str, default=GRADIO_CACHE_FOLDER) return parser def convert_ply_to_glb(ply_path, glb_path): try: plydata = PlyData.read(ply_path) xyz = np.stack((np.asarray(plydata.elements[0]["x"]), np.asarray(plydata.elements[0]["y"]), np.asarray(plydata.elements[0]["z"])), axis=1) # Extract DC features (colors) # f_dc_0, f_dc_1, f_dc_2 are SH coefficients for R, G, B f_dc_0 = np.asarray(plydata.elements[0]["f_dc_0"]) f_dc_1 = np.asarray(plydata.elements[0]["f_dc_1"]) f_dc_2 = np.asarray(plydata.elements[0]["f_dc_2"]) # SH2RGB: sh * C0 + 0.5 # C0 = 0.28209479177387814 C0 = 0.28209479177387814 r = f_dc_0 * C0 + 0.5 g = f_dc_1 * C0 + 0.5 b = f_dc_2 * C0 + 0.5 colors = np.stack((r, g, b), axis=1) # Clip to [0, 1] colors = np.clip(colors, 0, 1) # Convert to uint8 for trimesh colors = (colors * 255).astype(np.uint8) # Create PointCloud pcd = trimesh.points.PointCloud(vertices=xyz, colors=colors) # Export pcd.export(glb_path) return True except Exception as e: print(f"Error converting PLY to GLB: {e}") return False @spaces.GPU(duration=150) def process(inputfiles, input_path=None): """ Process images and generate 3D Gaussian Splatting model. Returns: (video_path, ply_url, ply_download, ply_model, glb_model, glb_url) """ try: if input_path is not None: imgs_path = './assets/example/' + input_path imgs_names = sorted(os.listdir(imgs_path)) inputfiles = [] for imgs_name in imgs_names: file_path = os.path.join(imgs_path, imgs_name) print(file_path) inputfiles.append(file_path) print(inputfiles) # ------ (1) Coarse Geometric Initialization ------ parser = get_dust3r_args_parser() opt = parser.parse_args() tmp_user_folder = str(uuid.uuid4()).replace("-", "") opt.img_base_path = os.path.join(opt.base_path, tmp_user_folder) img_folder_path = os.path.join(opt.img_base_path, "images") model = AsymmetricCroCo3DStereo.from_pretrained(opt.model_path).to(opt.device) os.makedirs(img_folder_path, exist_ok=True) opt.n_views = len(inputfiles) if opt.n_views == 1: raise gr.Error("The number of input images should be greater than 1.") print("Multiple images: ", inputfiles) for image_path in inputfiles: if input_path is not None: shutil.copy(image_path, img_folder_path) else: shutil.move(image_path, img_folder_path) train_img_list = sorted(os.listdir(img_folder_path)) assert len(train_img_list)==opt.n_views, f"Number of images in the folder is not equal to {opt.n_views}" images, ori_size, imgs_resolution = load_images(img_folder_path, size=512) resolutions_are_equal = len(set(imgs_resolution)) == 1 if resolutions_are_equal == False: raise gr.Error("The resolution of the input image should be the same.") print("ori_size", ori_size) start_time = time.time() pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True) output = inference(pairs, model, opt.device, batch_size=opt.batch_size) output_colmap_path=img_folder_path.replace("images", "sparse/0") os.makedirs(output_colmap_path, exist_ok=True) scene = global_aligner(output, device=opt.device, mode=GlobalAlignerMode.PointCloudOptimizer) loss = compute_global_alignment(scene=scene, init="mst", niter=opt.niter, schedule=opt.schedule, lr=opt.lr, focal_avg=opt.focal_avg) scene = scene.clean_pointcloud() imgs = to_numpy(scene.imgs) focals = scene.get_focals() poses = to_numpy(scene.get_im_poses()) pts3d = to_numpy(scene.get_pts3d()) scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0))) confidence_masks = to_numpy(scene.get_masks()) intrinsics = to_numpy(scene.get_intrinsics()) end_time = time.time() print(f"Time taken for {opt.n_views} views: {end_time-start_time} seconds") save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt')) save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list) pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)]) color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)]) color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8) storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs) pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3) np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all) np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu())) ### save VRAM del scene torch.cuda.empty_cache() gc.collect() ################################################################################################################################################## # ------ (2) Fast 3D-Gaussian Optimization ------ parser = ArgumentParser(description="Training script parameters") lp = ModelParams(parser) op = OptimizationParams(parser) pp = PipelineParams(parser) parser.add_argument('--debug_from', type=int, default=-1) parser.add_argument("--test_iterations", nargs="+", type=int, default=[]) parser.add_argument("--save_iterations", nargs="+", type=int, default=[]) parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) parser.add_argument("--start_checkpoint", type=str, default=None) # FIX: scene must be string, not int parser.add_argument("--scene", type=str, default="demo") parser.add_argument("--n_views", type=int, default=3) parser.add_argument("--get_video", action="store_true") parser.add_argument("--optim_pose", type=bool, default=True) parser.add_argument("--skip_train", action="store_true") parser.add_argument("--skip_test", action="store_true") # FIX: do NOT parse system argv args, _ = parser.parse_known_args([]) args.save_iterations.append(args.iterations) args.model_path = opt.img_base_path + '/output/' args.source_path = opt.img_base_path args.iteration = 1000 os.makedirs(args.model_path, exist_ok=True) training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args) ################################################################################################################################################## # ------ (3) Render video by interpolation ------ parser = ArgumentParser(description="Testing script parameters") model = ModelParams(parser, sentinel=True) pipeline = PipelineParams(parser) args.eval = True args.get_video = True args.n_views = opt.n_views render_sets( model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args, ) output_ply_path = opt.img_base_path + f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply' output_video_path = opt.img_base_path + f'/output/demo_{opt.n_views}_view.mp4' # sanity checks if not os.path.exists(output_ply_path): print("PLY not found at:", output_ply_path) raise gr.Error(f"PLY file not found at {output_ply_path}") if not os.path.exists(output_video_path): print("Video not found at:", output_video_path) raise gr.Error(f"Video file not found at {output_video_path}") # Convert PLY to GLB for visualization output_glb_path = output_ply_path.replace('.ply', '.glb') if not convert_ply_to_glb(output_ply_path, output_glb_path): output_glb_path = None # ------ (4) upload .ply and .glb to Supabase Storage ------ ply_url = None glb_url = None rel_remote_path_ply = f"{tmp_user_folder}_point_cloud.ply" rel_remote_path_glb = f"{tmp_user_folder}_point_cloud.glb" supabase_url = os.environ.get("SUPABASE_URL") supabase_key = os.environ.get("SUPABASE_KEY") supabase_bucket = os.environ.get("SUPABASE_BUCKET", "outputs") if supabase_url and supabase_key: try: # Upload PLY ply_url = upload_to_supabase_storage( file_path=output_ply_path, remote_path=rel_remote_path_ply, supabase_url=supabase_url, supabase_key=supabase_key, bucket_name=supabase_bucket, content_type='application/octet-stream' ) if ply_url is None: ply_url = "Error uploading PLY file" # Upload GLB if it exists if output_glb_path and os.path.exists(output_glb_path): glb_url = upload_to_supabase_storage( file_path=output_glb_path, remote_path=rel_remote_path_glb, supabase_url=supabase_url, supabase_key=supabase_key, bucket_name=supabase_bucket, content_type='model/gltf-binary' ) if glb_url is None: glb_url = "Error uploading GLB file" except Exception as e: print("Failed to upload files to Supabase Storage:", e) ply_url = f"Error uploading: {e}" else: print("SUPABASE_URL or SUPABASE_KEY not found, skipping upload.") ply_url = "Supabase credentials not set" # return: # 1) video path (for gr.Video) # 2) ply URL (for API + textbox) # 3) ply file path (for gr.File download) # 4) ply file path (for gr.Model3D viewer) # 5) glb file path (for gr.Model3D viewer) # 6) glb URL (for API) return output_video_path, ply_url, output_ply_path, output_ply_path, output_glb_path, glb_url except Exception as e: # Catch all errors and return them in the API response error_msg = f"Error: {str(e)}" error_traceback = f"Traceback:\n{__import__('traceback').format_exc()}" full_error = f"{error_msg}\n\n{error_traceback}" print("=" * 80) print("ERROR IN PROCESS FUNCTION:") print("=" * 80) print(full_error) print("=" * 80) # Return error messages in the same format as successful returns # This allows API clients to see the exact error error_prefix = "ERROR: " return ( None, # video path f"{error_prefix}{error_msg}", # ply_url None, # ply_download None, # ply_model None, # glb_model f"{error_prefix}{error_msg}" # glb_url ) ################################################################################################################################################## def process_api(inputfiles): """ API-friendly wrapper that returns only the GLB URL. Args: inputfiles: List of image files Returns: dict with glb_url, ply_url, and video_url """ result = process(inputfiles, input_path=None) video_path, ply_url, _, _, glb_path, glb_url = result # Detect error responses by prefix is_error = False for v in (glb_url, ply_url): if isinstance(v, str) and v.startswith("ERROR:"): is_error = True break return { "glb_url": glb_url if glb_url else "Upload failed", "ply_url": ply_url if ply_url else "Upload failed", "video_available": video_path is not None, "status": "error" if is_error else ("success" if glb_url else "error"), } def process_base64_api(images_b64): """ API entrypoint that accepts a list of base64-encoded images, decodes them to temporary files, and runs the full pipeline. """ # Create a temporary directory under the Gradio cache folder tmp_root = os.path.join(GRADIO_CACHE_FOLDER, "api_uploads") os.makedirs(tmp_root, exist_ok=True) tmp_dir = tempfile.mkdtemp(prefix="api_", dir=tmp_root) decoded_paths = [] for idx, img_str in enumerate(images_b64): if not isinstance(img_str, str): continue # Handle optional data URL prefix if img_str.startswith("data:"): try: header, b64_data = img_str.split(",", 1) except ValueError: b64_data = img_str else: b64_data = img_str try: img_bytes = base64.b64decode(b64_data) except Exception: # Skip invalid entries continue out_path = os.path.join(tmp_dir, f"img_{idx:02d}.jpg") with open(out_path, "wb") as f: f.write(img_bytes) decoded_paths.append(out_path) if len(decoded_paths) < 2: return { "glb_url": "ERROR: Need at least 2 valid base64 images", "ply_url": "ERROR: Need at least 2 valid base64 images", "video_available": False, "status": "error", } # Reuse process_api to run the pipeline and format the response return process_api(decoded_paths) _TITLE = '''InstantSplat''' _DESCRIPTION = '''