Pixcribe / app.py
DawnC's picture
Upload 5 files
f3a4ad9 verified
import gradio as gr
import torch
from PIL import Image
import spaces
import os
import json
import tempfile
from typing import List, Optional
from pixcribe_pipeline import PixcribePipeline
from ui_manager import UIManager
# Initialize Pipeline and UI Manager
print("Initializing Pixcribe V5 with Batch Processing...")
print("⏳ Loading models (this may take a while)...")
pipeline = PixcribePipeline(yolo_variant='l')
ui_manager = UIManager()
print("βœ… All models loaded successfully!")
# Global variable to store latest batch results and images for export
latest_batch_results = None
latest_batch_images = None
@spaces.GPU(duration=180)
def process_images_wrapper(files, yolo_variant, caption_language, progress=gr.Progress()):
"""
Process single or multiple images with progress tracking.
This function automatically detects whether to use single-image or batch processing
based on the number of files uploaded.
Args:
files: List of uploaded file objects (or single file)
yolo_variant: YOLO model variant ('m', 'l', 'x')
caption_language: Caption language ('zh', 'en')
progress: Gradio Progress object for progress updates
Returns:
Tuple of (visualized_image, caption_html, batch_results_html, export_panel_visibility)
"""
global latest_batch_results, latest_batch_images
# Validate input
if files is None or (isinstance(files, list) and len(files) == 0):
error_msg = "<div style='color: #E74C3C; padding: 24px; text-align: center;'>Please upload at least one image</div>"
return None, error_msg, "", gr.update(visible=False)
# Convert single file to list
if not isinstance(files, list):
files = [files]
# Check maximum limit
if len(files) > 10:
error_msg = "<div style='color: #E74C3C; padding: 24px; text-align: center;'>Maximum 10 images allowed. Please select fewer images.</div>"
return None, error_msg, "", gr.update(visible=False)
# Load images from files
images = []
for file in files:
try:
if hasattr(file, 'name'):
# File object from Gradio
img = Image.open(file.name)
else:
# Direct path
img = Image.open(file)
# Convert to RGB if needed
if img.mode != 'RGB':
img = img.convert('RGB')
images.append(img)
except Exception as e:
print(f"⚠️ Warning: Failed to load image {file}: {str(e)}")
continue
if len(images) == 0:
error_msg = "<div style='color: #E74C3C; padding: 24px; text-align: center;'>No valid images found. Please upload valid image files.</div>"
return None, error_msg, "", gr.update(visible=False)
platform = 'instagram' # Fixed platform
# Single image processing mode
if len(images) == 1:
try:
results = pipeline.process_image(
image=images[0],
platform=platform,
yolo_variant=yolo_variant,
language=caption_language
)
if results is None:
error_msg = "<div style='color: #E74C3C; padding: 24px; text-align: center;'>Processing failed. Check terminal logs for details.</div>"
return None, error_msg, "", gr.update(visible=False)
# Get visualized image with brand boxes
visualized_image = results.get('visualized_image', images[0])
# Format captions with copy functionality
captions_html = ui_manager.format_captions_with_copy(results['captions'])
# Clear batch results when in single mode
latest_batch_results = None
latest_batch_images = None
return visualized_image, captions_html, "", gr.update(visible=False)
except Exception as e:
import traceback
error_msg = traceback.format_exc()
print("="*60)
print("ERROR DETAILS:")
print(error_msg)
print("="*60)
error_html = f"""
<div style='background: #FADBD8; border: 2px solid #E74C3C; border-radius: 20px; padding: 28px; margin: 16px 0;'>
<h3 style='color: #C0392B; margin-top: 0; font-size: 22px;'>❌ Processing Error</h3>
<p style='color: #E74C3C; font-weight: bold; font-size: 17px; margin-bottom: 16px;'>{str(e)}</p>
<details style='margin-top: 12px;'>
<summary style='cursor: pointer; color: #C0392B; font-weight: bold; font-size: 16px;'>View Full Error Trace</summary>
<pre style='background: white; padding: 16px; border-radius: 12px; overflow-x: auto; font-size: 13px; color: #2C3E50; margin-top: 12px;'>{error_msg}</pre>
</details>
</div>
"""
return None, error_html, "", gr.update(visible=False)
# Batch processing mode (2+ images)
else:
try:
# Define progress callback
def update_progress(progress_info):
current = progress_info['current']
total = progress_info['total']
percent = progress_info['percent']
# Update Gradio progress
progress(percent / 100, desc=f"Processing image {current}/{total}")
# Process batch
batch_results = pipeline.process_batch(
images=images,
platform=platform,
yolo_variant=yolo_variant,
language=caption_language,
progress_callback=update_progress
)
# Store results globally for export
latest_batch_results = batch_results
latest_batch_images = images
# Format batch results as HTML
batch_html = ui_manager.format_batch_results_html(batch_results)
# Return None for single image display, batch results HTML, and show export panel
return None, "", batch_html, gr.update(visible=True)
except Exception as e:
import traceback
error_msg = traceback.format_exc()
print("="*60)
print("BATCH PROCESSING ERROR:")
print(error_msg)
print("="*60)
error_html = f"""
<div style='background: #FADBD8; border: 2px solid #E74C3C; border-radius: 20px; padding: 28px; margin: 16px 0;'>
<h3 style='color: #C0392B; margin-top: 0; font-size: 22px;'>❌ Batch Processing Error</h3>
<p style='color: #E74C3C; font-weight: bold; font-size: 17px; margin-bottom: 16px;'>{str(e)}</p>
<details style='margin-top: 12px;'>
<summary style='cursor: pointer; color: #C0392B; font-weight: bold; font-size: 16px;'>View Full Error Trace</summary>
<pre style='background: white; padding: 16px; border-radius: 12px; overflow-x: auto; font-size: 13px; color: #2C3E50; margin-top: 12px;'>{error_msg}</pre>
</details>
</div>
"""
return None, error_html, "", gr.update(visible=False)
def export_json_handler():
"""Export batch results to JSON file."""
global latest_batch_results
if latest_batch_results is None:
return None
try:
# Create temporary file
temp_dir = tempfile.gettempdir()
output_path = os.path.join(temp_dir, "pixcribe_batch_results.json")
# Export to JSON
pipeline.batch_processor.export_to_json(latest_batch_results, output_path)
return output_path
except Exception as e:
print(f"Export JSON error: {str(e)}")
return None
def export_csv_handler():
"""Export batch results to CSV file."""
global latest_batch_results
if latest_batch_results is None:
return None
try:
# Create temporary file
temp_dir = tempfile.gettempdir()
output_path = os.path.join(temp_dir, "pixcribe_batch_results.csv")
# Export to CSV
pipeline.batch_processor.export_to_csv(latest_batch_results, output_path)
return output_path
except Exception as e:
print(f"Export CSV error: {str(e)}")
return None
def export_zip_handler():
"""Export batch results to ZIP archive."""
global latest_batch_results, latest_batch_images
if latest_batch_results is None or latest_batch_images is None:
return None
try:
# Create temporary file
temp_dir = tempfile.gettempdir()
output_path = os.path.join(temp_dir, "pixcribe_batch_results.zip")
# Export to ZIP
pipeline.batch_processor.export_to_zip(
latest_batch_results,
latest_batch_images,
output_path
)
return output_path
except Exception as e:
print(f"Export ZIP error: {str(e)}")
return None
# Create Gradio Interface
with gr.Blocks(css=ui_manager.custom_css, title="Pixcribe V5 - AI Social Media Captions") as app:
# Header
ui_manager.create_header()
# Info Banner - Loading Time Notice
ui_manager.create_info_banner()
# Top Row - Upload Images & Detected Objects
with gr.Row(elem_classes="main-row"):
# Left - Upload Card
with gr.Column(scale=1):
with gr.Group(elem_classes="upload-card"):
image_input = gr.File(
file_count="multiple",
file_types=["image"],
label="Upload Images (Max 10)",
elem_classes="upload-area"
)
# Right - Detected Objects (Single Image Mode)
with gr.Column(scale=1):
with gr.Group(elem_classes="results-card"):
gr.Markdown("### Detected Objects", elem_classes="section-title")
visualized_image = gr.Image(
label="",
elem_classes="image-container"
)
# Bottom - Settings Section (Full Width)
with gr.Group(elem_classes="settings-container"):
gr.Markdown("### Settings", elem_classes="section-title-left")
with gr.Row(elem_classes="settings-row"):
caption_language = gr.Radio(
choices=[
('繁體中文', 'zh'),
('English', 'en')
],
value='en',
label="Caption Language",
elem_classes="radio-group-inline"
)
yolo_variant = gr.Radio(
choices=[
('Fast (m)', 'm'),
('Balanced (l)', 'l'),
('Accurate (x)', 'x')
],
value='l',
label="Detection Mode",
elem_classes="radio-group-inline"
)
# Generate Button (Centered)
with gr.Row(elem_classes="button-row"):
analyze_btn = gr.Button(
"Generate Captions",
variant="primary",
elem_classes="generate-button"
)
# Processing Time Notice
gr.HTML("""
<div style="text-align: center; margin-top: 16px; color: #7F8C8D; font-size: 14px;">
<span style="opacity: 0.8;">⚑ Please be patient - AI processing may take some time</span>
</div>
""")
# Single Image Caption Results (Full Width)
with gr.Group(elem_classes="caption-results-container"):
gr.Markdown("### πŸ“ Generated Captions", elem_classes="section-title")
caption_output = gr.HTML(
label="",
elem_id="caption-results"
)
# Batch Results Display (Initially Hidden)
batch_results_output = gr.HTML(
label="",
visible=True
)
# Export Panel (Initially Hidden)
with gr.Group(elem_classes="export-panel", visible=False) as export_panel:
gr.Markdown("### πŸ“₯ Export Batch Results", elem_classes="section-title-left")
with gr.Row():
json_btn = gr.Button("πŸ“„ Download JSON", variant="secondary")
csv_btn = gr.Button("πŸ“Š Download CSV", variant="secondary")
zip_btn = gr.Button("πŸ“¦ Download ZIP", variant="secondary")
json_file = gr.File(label="JSON Export", visible=False)
csv_file = gr.File(label="CSV Export", visible=False)
zip_file = gr.File(label="ZIP Export", visible=False)
# Footer
ui_manager.create_footer()
gr.HTML('''
<div style="
display: flex;
align-items: center;
justify-content: center;
gap: 20px;
padding: 20px 0;
">
<p style="
font-family: 'Arial', sans-serif;
font-size: 14px;
font-weight: 500;
letter-spacing: 2px;
background: linear-gradient(90deg, #555, #007ACC);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin: 0;
text-transform: uppercase;
display: inline-block;
">EXPLORE THE CODE β†’</p>
<a href="https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/Pixcribe" style="text-decoration: none;">
<img src="https://img.shields.io/badge/GitHub-Pixcribe-007ACC?logo=github&style=for-the-badge">
</a>
</div>
''')
# Connect button to processing function
analyze_btn.click(
fn=process_images_wrapper,
inputs=[image_input, yolo_variant, caption_language],
outputs=[visualized_image, caption_output, batch_results_output, export_panel]
)
# Connect export buttons
json_btn.click(
fn=export_json_handler,
inputs=[],
outputs=[json_file]
).then(
lambda: gr.update(visible=True),
outputs=[json_file]
)
csv_btn.click(
fn=export_csv_handler,
inputs=[],
outputs=[csv_file]
).then(
lambda: gr.update(visible=True),
outputs=[csv_file]
)
zip_btn.click(
fn=export_zip_handler,
inputs=[],
outputs=[zip_file]
).then(
lambda: gr.update(visible=True),
outputs=[zip_file]
)
if __name__ == "__main__":
app.launch(share=True)