Spaces:
Paused
Paused
| import os, subprocess, shlex, sys, gc | |
| import time | |
| import torch | |
| import numpy as np | |
| import shutil | |
| import argparse | |
| import uuid | |
| import spaces | |
| import requests | |
| from typing import Optional | |
| import base64 | |
| import tempfile | |
| from pydantic import BaseModel | |
| # --- FIX: make gradio compatible by downgrading huggingface_hub ----------- | |
| # Gradio 5.0.1 requires huggingface_hub<1.0.0 due to HfFolder import | |
| subprocess.run( | |
| shlex.split("pip install 'huggingface_hub<1.0.0'"), | |
| check=False, | |
| ) | |
| # -------------------------------------------------------------------------- | |
| import gradio as gr # import AFTER the pip install above | |
| import trimesh | |
| from plyfile import PlyData | |
| # install custom wheels for gaussian splatting | |
| subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl")) | |
| subprocess.run(shlex.split("pip install wheel/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl")) | |
| subprocess.run(shlex.split("pip install wheel/curope-0.0.0-cp310-cp310-linux_x86_64.whl")) | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r"))) | |
| from dust3r.inference import inference | |
| from dust3r.model import AsymmetricCroCo3DStereo | |
| from dust3r.utils.device import to_numpy | |
| from dust3r.image_pairs import make_pairs | |
| from dust3r.cloud_opt import global_aligner, GlobalAlignerMode | |
| from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images | |
| from argparse import ArgumentParser | |
| from arguments import ModelParams, PipelineParams, OptimizationParams | |
| from train_joint import training | |
| from render_by_interp import render_sets | |
| GRADIO_CACHE_FOLDER = './gradio_cache_folder' | |
| ############################################################################################################################################# | |
| def upload_to_supabase_storage( | |
| file_path: str, | |
| remote_path: str, | |
| supabase_url: str, | |
| supabase_key: str, | |
| bucket_name: str = "outputs", | |
| content_type: Optional[str] = None, | |
| max_retries: int = 3 | |
| ) -> Optional[str]: | |
| """ | |
| Upload a file to Supabase Storage using HTTP requests. | |
| Args: | |
| file_path: Local path to the file to upload | |
| remote_path: Path in the bucket (e.g., "folder/file.ply") | |
| supabase_url: Supabase project URL (e.g., "https://xxxxx.supabase.co") | |
| supabase_key: Supabase service role key or anon key with appropriate permissions | |
| bucket_name: Name of the storage bucket | |
| content_type: MIME type of the file (auto-detected if None) | |
| max_retries: Number of retry attempts for failed uploads | |
| Returns: | |
| Public URL of the uploaded file, or None if upload failed | |
| """ | |
| if not os.path.exists(file_path): | |
| print(f"File not found: {file_path}") | |
| return None | |
| # Use direct storage hostname for better performance | |
| # Extract project ID from URL | |
| project_id = supabase_url.replace("https://", "").replace(".supabase.co", "") | |
| storage_url = f"https://{project_id}.storage.supabase.co" | |
| # Construct the upload endpoint | |
| upload_url = f"{storage_url}/storage/v1/object/{bucket_name}/{remote_path}" | |
| # Auto-detect content type if not provided | |
| if content_type is None: | |
| if file_path.endswith('.ply'): | |
| content_type = 'application/octet-stream' | |
| elif file_path.endswith('.glb'): | |
| content_type = 'model/gltf-binary' | |
| elif file_path.endswith('.mp4'): | |
| content_type = 'video/mp4' | |
| else: | |
| content_type = 'application/octet-stream' | |
| headers = { | |
| 'Authorization': f'Bearer {supabase_key}', | |
| 'apikey': supabase_key, | |
| 'Content-Type': content_type, | |
| 'x-upsert': 'true', # Overwrite if file exists | |
| } | |
| file_size = os.path.getsize(file_path) | |
| print(f"Uploading {file_path} ({file_size / (1024**2):.2f} MB) to Supabase Storage...") | |
| for attempt in range(max_retries): | |
| try: | |
| with open(file_path, 'rb') as f: | |
| response = requests.post( | |
| upload_url, | |
| headers=headers, | |
| data=f, | |
| timeout=600 # 10 minute timeout for large files | |
| ) | |
| if response.status_code in (200, 201): | |
| # Construct public URL | |
| public_url = f"{storage_url}/storage/v1/object/public/{bucket_name}/{remote_path}" | |
| print(f"Successfully uploaded to: {public_url}") | |
| return public_url | |
| else: | |
| print(f"Upload failed (attempt {attempt + 1}/{max_retries}): {response.status_code} - {response.text}") | |
| if attempt < max_retries - 1: | |
| time.sleep(2 ** attempt) # Exponential backoff | |
| continue | |
| except requests.exceptions.Timeout: | |
| print(f"Upload timeout (attempt {attempt + 1}/{max_retries})") | |
| if attempt < max_retries - 1: | |
| time.sleep(2 ** attempt) | |
| continue | |
| except requests.exceptions.RequestException as e: | |
| print(f"Upload error (attempt {attempt + 1}/{max_retries}): {e}") | |
| if attempt < max_retries - 1: | |
| time.sleep(2 ** attempt) | |
| continue | |
| print("Upload failed after all retries") | |
| return None | |
| def get_dust3r_args_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size") | |
| parser.add_argument("--model_path", type=str, default="submodules/dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights") | |
| parser.add_argument("--device", type=str, default='cuda', help="pytorch device") | |
| parser.add_argument("--batch_size", type=int, default=1) | |
| parser.add_argument("--schedule", type=str, default='linear') | |
| parser.add_argument("--lr", type=float, default=0.01) | |
| parser.add_argument("--niter", type=int, default=300) | |
| parser.add_argument("--focal_avg", type=bool, default=True) | |
| parser.add_argument("--n_views", type=int, default=3) | |
| parser.add_argument("--base_path", type=str, default=GRADIO_CACHE_FOLDER) | |
| return parser | |
| def convert_ply_to_glb(ply_path, glb_path): | |
| try: | |
| plydata = PlyData.read(ply_path) | |
| xyz = np.stack((np.asarray(plydata.elements[0]["x"]), | |
| np.asarray(plydata.elements[0]["y"]), | |
| np.asarray(plydata.elements[0]["z"])), axis=1) | |
| # Extract DC features (colors) | |
| # f_dc_0, f_dc_1, f_dc_2 are SH coefficients for R, G, B | |
| f_dc_0 = np.asarray(plydata.elements[0]["f_dc_0"]) | |
| f_dc_1 = np.asarray(plydata.elements[0]["f_dc_1"]) | |
| f_dc_2 = np.asarray(plydata.elements[0]["f_dc_2"]) | |
| # SH2RGB: sh * C0 + 0.5 | |
| # C0 = 0.28209479177387814 | |
| C0 = 0.28209479177387814 | |
| r = f_dc_0 * C0 + 0.5 | |
| g = f_dc_1 * C0 + 0.5 | |
| b = f_dc_2 * C0 + 0.5 | |
| colors = np.stack((r, g, b), axis=1) | |
| # Clip to [0, 1] | |
| colors = np.clip(colors, 0, 1) | |
| # Convert to uint8 for trimesh | |
| colors = (colors * 255).astype(np.uint8) | |
| # Create PointCloud | |
| pcd = trimesh.points.PointCloud(vertices=xyz, colors=colors) | |
| # Export | |
| pcd.export(glb_path) | |
| return True | |
| except Exception as e: | |
| print(f"Error converting PLY to GLB: {e}") | |
| return False | |
| def process(inputfiles, input_path=None): | |
| """ | |
| Process images and generate 3D Gaussian Splatting model. | |
| Returns: (video_path, ply_url, ply_download, ply_model, glb_model, glb_url) | |
| """ | |
| try: | |
| if input_path is not None: | |
| imgs_path = './assets/example/' + input_path | |
| imgs_names = sorted(os.listdir(imgs_path)) | |
| inputfiles = [] | |
| for imgs_name in imgs_names: | |
| file_path = os.path.join(imgs_path, imgs_name) | |
| print(file_path) | |
| inputfiles.append(file_path) | |
| print(inputfiles) | |
| # ------ (1) Coarse Geometric Initialization ------ | |
| parser = get_dust3r_args_parser() | |
| opt = parser.parse_args() | |
| tmp_user_folder = str(uuid.uuid4()).replace("-", "") | |
| opt.img_base_path = os.path.join(opt.base_path, tmp_user_folder) | |
| img_folder_path = os.path.join(opt.img_base_path, "images") | |
| model = AsymmetricCroCo3DStereo.from_pretrained(opt.model_path).to(opt.device) | |
| os.makedirs(img_folder_path, exist_ok=True) | |
| opt.n_views = len(inputfiles) | |
| if opt.n_views == 1: | |
| raise gr.Error("The number of input images should be greater than 1.") | |
| print("Multiple images: ", inputfiles) | |
| for image_path in inputfiles: | |
| if input_path is not None: | |
| shutil.copy(image_path, img_folder_path) | |
| else: | |
| shutil.move(image_path, img_folder_path) | |
| train_img_list = sorted(os.listdir(img_folder_path)) | |
| assert len(train_img_list)==opt.n_views, f"Number of images in the folder is not equal to {opt.n_views}" | |
| images, ori_size, imgs_resolution = load_images(img_folder_path, size=512) | |
| resolutions_are_equal = len(set(imgs_resolution)) == 1 | |
| if resolutions_are_equal == False: | |
| raise gr.Error("The resolution of the input image should be the same.") | |
| print("ori_size", ori_size) | |
| start_time = time.time() | |
| pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True) | |
| output = inference(pairs, model, opt.device, batch_size=opt.batch_size) | |
| output_colmap_path=img_folder_path.replace("images", "sparse/0") | |
| os.makedirs(output_colmap_path, exist_ok=True) | |
| scene = global_aligner(output, device=opt.device, mode=GlobalAlignerMode.PointCloudOptimizer) | |
| loss = compute_global_alignment(scene=scene, init="mst", niter=opt.niter, schedule=opt.schedule, lr=opt.lr, focal_avg=opt.focal_avg) | |
| scene = scene.clean_pointcloud() | |
| imgs = to_numpy(scene.imgs) | |
| focals = scene.get_focals() | |
| poses = to_numpy(scene.get_im_poses()) | |
| pts3d = to_numpy(scene.get_pts3d()) | |
| scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0))) | |
| confidence_masks = to_numpy(scene.get_masks()) | |
| intrinsics = to_numpy(scene.get_intrinsics()) | |
| end_time = time.time() | |
| print(f"Time taken for {opt.n_views} views: {end_time-start_time} seconds") | |
| save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt')) | |
| save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list) | |
| pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)]) | |
| color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)]) | |
| color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8) | |
| storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs) | |
| pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3) | |
| np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all) | |
| np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu())) | |
| ### save VRAM | |
| del scene | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| ################################################################################################################################################## | |
| # ------ (2) Fast 3D-Gaussian Optimization ------ | |
| parser = ArgumentParser(description="Training script parameters") | |
| lp = ModelParams(parser) | |
| op = OptimizationParams(parser) | |
| pp = PipelineParams(parser) | |
| parser.add_argument('--debug_from', type=int, default=-1) | |
| parser.add_argument("--test_iterations", nargs="+", type=int, default=[]) | |
| parser.add_argument("--save_iterations", nargs="+", type=int, default=[]) | |
| parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) | |
| parser.add_argument("--start_checkpoint", type=str, default=None) | |
| # FIX: scene must be string, not int | |
| parser.add_argument("--scene", type=str, default="demo") | |
| parser.add_argument("--n_views", type=int, default=3) | |
| parser.add_argument("--get_video", action="store_true") | |
| parser.add_argument("--optim_pose", type=bool, default=True) | |
| parser.add_argument("--skip_train", action="store_true") | |
| parser.add_argument("--skip_test", action="store_true") | |
| # FIX: do NOT parse system argv | |
| args, _ = parser.parse_known_args([]) | |
| args.save_iterations.append(args.iterations) | |
| args.model_path = opt.img_base_path + '/output/' | |
| args.source_path = opt.img_base_path | |
| args.iteration = 1000 | |
| os.makedirs(args.model_path, exist_ok=True) | |
| training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args) | |
| ################################################################################################################################################## | |
| # ------ (3) Render video by interpolation ------ | |
| parser = ArgumentParser(description="Testing script parameters") | |
| model = ModelParams(parser, sentinel=True) | |
| pipeline = PipelineParams(parser) | |
| args.eval = True | |
| args.get_video = True | |
| args.n_views = opt.n_views | |
| render_sets( | |
| model.extract(args), | |
| args.iteration, | |
| pipeline.extract(args), | |
| args.skip_train, | |
| args.skip_test, | |
| args, | |
| ) | |
| output_ply_path = opt.img_base_path + f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply' | |
| output_video_path = opt.img_base_path + f'/output/demo_{opt.n_views}_view.mp4' | |
| # sanity checks | |
| if not os.path.exists(output_ply_path): | |
| print("PLY not found at:", output_ply_path) | |
| raise gr.Error(f"PLY file not found at {output_ply_path}") | |
| if not os.path.exists(output_video_path): | |
| print("Video not found at:", output_video_path) | |
| raise gr.Error(f"Video file not found at {output_video_path}") | |
| # Convert PLY to GLB for visualization | |
| output_glb_path = output_ply_path.replace('.ply', '.glb') | |
| if not convert_ply_to_glb(output_ply_path, output_glb_path): | |
| output_glb_path = None | |
| # ------ (4) upload .ply and .glb to Supabase Storage ------ | |
| ply_url = None | |
| glb_url = None | |
| rel_remote_path_ply = f"{tmp_user_folder}_point_cloud.ply" | |
| rel_remote_path_glb = f"{tmp_user_folder}_point_cloud.glb" | |
| supabase_url = os.environ.get("SUPABASE_URL") | |
| supabase_key = os.environ.get("SUPABASE_KEY") | |
| supabase_bucket = os.environ.get("SUPABASE_BUCKET", "outputs") | |
| if supabase_url and supabase_key: | |
| try: | |
| # Upload PLY | |
| ply_url = upload_to_supabase_storage( | |
| file_path=output_ply_path, | |
| remote_path=rel_remote_path_ply, | |
| supabase_url=supabase_url, | |
| supabase_key=supabase_key, | |
| bucket_name=supabase_bucket, | |
| content_type='application/octet-stream' | |
| ) | |
| if ply_url is None: | |
| ply_url = "Error uploading PLY file" | |
| # Upload GLB if it exists | |
| if output_glb_path and os.path.exists(output_glb_path): | |
| glb_url = upload_to_supabase_storage( | |
| file_path=output_glb_path, | |
| remote_path=rel_remote_path_glb, | |
| supabase_url=supabase_url, | |
| supabase_key=supabase_key, | |
| bucket_name=supabase_bucket, | |
| content_type='model/gltf-binary' | |
| ) | |
| if glb_url is None: | |
| glb_url = "Error uploading GLB file" | |
| except Exception as e: | |
| print("Failed to upload files to Supabase Storage:", e) | |
| ply_url = f"Error uploading: {e}" | |
| else: | |
| print("SUPABASE_URL or SUPABASE_KEY not found, skipping upload.") | |
| ply_url = "Supabase credentials not set" | |
| # return: | |
| # 1) video path (for gr.Video) | |
| # 2) ply URL (for API + textbox) | |
| # 3) ply file path (for gr.File download) | |
| # 4) ply file path (for gr.Model3D viewer) | |
| # 5) glb file path (for gr.Model3D viewer) | |
| # 6) glb URL (for API) | |
| return output_video_path, ply_url, output_ply_path, output_ply_path, output_glb_path, glb_url | |
| except Exception as e: | |
| # Catch all errors and return them in the API response | |
| error_msg = f"Error: {str(e)}" | |
| error_traceback = f"Traceback:\n{__import__('traceback').format_exc()}" | |
| full_error = f"{error_msg}\n\n{error_traceback}" | |
| print("=" * 80) | |
| print("ERROR IN PROCESS FUNCTION:") | |
| print("=" * 80) | |
| print(full_error) | |
| print("=" * 80) | |
| # Return error messages in the same format as successful returns | |
| # This allows API clients to see the exact error | |
| error_prefix = "ERROR: " | |
| return ( | |
| None, # video path | |
| f"{error_prefix}{error_msg}", # ply_url | |
| None, # ply_download | |
| None, # ply_model | |
| None, # glb_model | |
| f"{error_prefix}{error_msg}" # glb_url | |
| ) | |
| ################################################################################################################################################## | |
| def process_api(inputfiles): | |
| """ | |
| API-friendly wrapper that returns only the GLB URL. | |
| Args: | |
| inputfiles: List of image files | |
| Returns: | |
| dict with glb_url, ply_url, and video_url | |
| """ | |
| result = process(inputfiles, input_path=None) | |
| video_path, ply_url, _, _, glb_path, glb_url = result | |
| # Detect error responses by prefix | |
| is_error = False | |
| for v in (glb_url, ply_url): | |
| if isinstance(v, str) and v.startswith("ERROR:"): | |
| is_error = True | |
| break | |
| return { | |
| "glb_url": glb_url if glb_url else "Upload failed", | |
| "ply_url": ply_url if ply_url else "Upload failed", | |
| "video_available": video_path is not None, | |
| "status": "error" if is_error else ("success" if glb_url else "error"), | |
| } | |
| def process_base64_api(images_b64): | |
| """ | |
| API entrypoint that accepts a list of base64-encoded images, | |
| decodes them to temporary files, and runs the full pipeline. | |
| """ | |
| # Create a temporary directory under the Gradio cache folder | |
| tmp_root = os.path.join(GRADIO_CACHE_FOLDER, "api_uploads") | |
| os.makedirs(tmp_root, exist_ok=True) | |
| tmp_dir = tempfile.mkdtemp(prefix="api_", dir=tmp_root) | |
| decoded_paths = [] | |
| for idx, img_str in enumerate(images_b64): | |
| if not isinstance(img_str, str): | |
| continue | |
| # Handle optional data URL prefix | |
| if img_str.startswith("data:"): | |
| try: | |
| header, b64_data = img_str.split(",", 1) | |
| except ValueError: | |
| b64_data = img_str | |
| else: | |
| b64_data = img_str | |
| try: | |
| img_bytes = base64.b64decode(b64_data) | |
| except Exception: | |
| # Skip invalid entries | |
| continue | |
| out_path = os.path.join(tmp_dir, f"img_{idx:02d}.jpg") | |
| with open(out_path, "wb") as f: | |
| f.write(img_bytes) | |
| decoded_paths.append(out_path) | |
| if len(decoded_paths) < 2: | |
| return { | |
| "glb_url": "ERROR: Need at least 2 valid base64 images", | |
| "ply_url": "ERROR: Need at least 2 valid base64 images", | |
| "video_available": False, | |
| "status": "error", | |
| } | |
| # Reuse process_api to run the pipeline and format the response | |
| return process_api(decoded_paths) | |
| _TITLE = '''InstantSplat''' | |
| _DESCRIPTION = ''' | |
| <div style="display: flex; justify-content: center; align-items: center;"> | |
| <div style="width: 100%; text-align: center; font-size: 30px;"> | |
| <strong>InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds</strong> | |
| </div> | |
| </div> | |
| <p></p> | |
| <div align="center"> | |
| <a style="display:inline-block" href="https://instantsplat.github.io/"><img src='https://img.shields.io/badge/Project_Page-1c7d45?logo=gumtree'></a> | |
| <a style="display:inline-block" href="https://www.youtube.com/watch?v=fxf_ypd7eD8"><img src='https://img.shields.io/badge/Demo_Video-E33122?logo=Youtube'></a> | |
| <a style="display:inline-block" href="https://arxiv.org/abs/2403.20309"><img src="https://img.shields.io/badge/ArXiv-2403.20309-b31b1b?logo=arxiv" alt='arxiv'></a> | |
| <a title="Social" href="https://x.com/KairunWen" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
| </a> | |
| </div> | |
| <p></p> | |
| * Official demo of: [InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds](https://instantsplat.github.io/). | |
| * Sparse-view examples for direct viewing: you can simply click the examples (in the bottom of the page), to quickly view the results on representative data. | |
| * Training speeds may slow if the resolution or number of images is large. To achieve performance comparable to what has been reported, please conduct tests on your own GPU (A100/4090). | |
| ''' | |
| block = gr.Blocks().queue() | |
| with block: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown(_DESCRIPTION) | |
| with gr.Row(variant='panel'): | |
| with gr.Tab("Input"): | |
| inputfiles = gr.File(file_count="multiple", label="images") | |
| input_path = gr.Textbox(visible=False, label="example_path") | |
| button_gen = gr.Button("RUN") | |
| with gr.Tab("API"): | |
| gr.Markdown(""" | |
| ## 🚀 API Access | |
| Submit images programmatically and get back the Supabase GLB URL. | |
| ### Quick Start (Python) | |
| ```bash | |
| pip install gradio_client | |
| ``` | |
| ```python | |
| from gradio_client import Client | |
| # Connect to this Space | |
| client = Client("your-username/InstantSplat") | |
| # Submit images | |
| result = client.predict( | |
| ["image1.jpg", "image2.jpg", "image3.jpg"], | |
| api_name="/predict" | |
| ) | |
| # Get GLB URL (it's the 6th element) | |
| glb_url = result[5] | |
| print(f"GLB URL: {glb_url}") | |
| ``` | |
| ### Response Format | |
| The API returns a tuple with 6 elements: | |
| - `[0]` - Video path | |
| - `[1]` - PLY URL (Supabase) | |
| - `[2]` - PLY download path | |
| - `[3]` - PLY model path | |
| - `[4]` - GLB model path | |
| - `[5]` - **GLB URL (Supabase)** ← This is what you want! | |
| ### CLI Tool | |
| Use the included `api_client.py`: | |
| ```bash | |
| python api_client.py img1.jpg img2.jpg img3.jpg | |
| ``` | |
| ### Full Documentation | |
| See `API_GUIDE.md` for complete documentation including: | |
| - JavaScript/TypeScript examples | |
| - Error handling | |
| - Batch processing | |
| - Complete workflows | |
| ### Requirements | |
| - **Minimum**: 2 images (3+ recommended) | |
| - **Same resolution**: All images must have matching dimensions | |
| - **Formats**: JPG, PNG | |
| """) | |
| with gr.Row(variant='panel'): | |
| with gr.Tab("Output"): | |
| with gr.Column(scale=2): | |
| with gr.Group(): | |
| output_model_glb = gr.Model3D( | |
| label="3D Model (GLB Point Cloud)", | |
| interactive=False, | |
| camera_position=[0.5, 0.5, 1], | |
| ) | |
| output_model_ply = gr.Model3D( | |
| label="Original PLY (Gaussian Splat)", | |
| interactive=False, | |
| camera_position=[0.5, 0.5, 1], | |
| ) | |
| gr.Markdown( | |
| """ | |
| <div class="model-description"> | |
| Use the left mouse button to rotate, the scroll wheel to zoom, and the right mouse button to move. | |
| </div> | |
| """ | |
| ) | |
| output_file = gr.Textbox( | |
| label="PLY download URL", | |
| interactive=False, | |
| ) | |
| output_download = gr.File( | |
| label="Download PLY", | |
| interactive=False, | |
| ) | |
| with gr.Column(scale=1): | |
| output_video = gr.Video(label="video") | |
| # Hidden output for GLB URL (for API access) | |
| output_glb_url = gr.Textbox(visible=False, label="GLB URL") | |
| button_gen.click( | |
| process, | |
| inputs=[inputfiles], | |
| outputs=[output_video, output_file, output_download, output_model_ply, output_model_glb, output_glb_url] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| "sora-santorini-3-views", | |
| ], | |
| inputs=[input_path], | |
| outputs=[output_video, output_file, output_download, output_model_ply, output_model_glb, output_glb_url], | |
| fn=lambda x: process(inputfiles=None, input_path=x), | |
| cache_examples=False, # Disabled for faster startup | |
| label='Sparse-view Examples' | |
| ) | |
| class Base64Request(BaseModel): | |
| images: list[str] | |
| fastapi_app = block.app | |
| async def api_base64_endpoint(req: Base64Request): | |
| """ | |
| FastAPI endpoint that accepts base64-encoded images and returns | |
| the same JSON structure as process_api/process_base64_api. | |
| """ | |
| return process_base64_api(req.images) | |
| block.launch(server_name="0.0.0.0", share=False, show_api=True) | |