|
|
import gradio as gr |
|
|
import torch |
|
|
from PIL import Image |
|
|
import spaces |
|
|
import os |
|
|
import json |
|
|
import tempfile |
|
|
from typing import List, Optional |
|
|
|
|
|
from pixcribe_pipeline import PixcribePipeline |
|
|
from ui_manager import UIManager |
|
|
|
|
|
|
|
|
print("Initializing Pixcribe V5 with Batch Processing...") |
|
|
print("β³ Loading models (this may take a while)...") |
|
|
pipeline = PixcribePipeline(yolo_variant='l') |
|
|
ui_manager = UIManager() |
|
|
print("β
All models loaded successfully!") |
|
|
|
|
|
|
|
|
latest_batch_results = None |
|
|
latest_batch_images = None |
|
|
|
|
|
@spaces.GPU(duration=180) |
|
|
def process_images_wrapper(files, yolo_variant, caption_language, progress=gr.Progress()): |
|
|
""" |
|
|
Process single or multiple images with progress tracking. |
|
|
|
|
|
This function automatically detects whether to use single-image or batch processing |
|
|
based on the number of files uploaded. |
|
|
|
|
|
Args: |
|
|
files: List of uploaded file objects (or single file) |
|
|
yolo_variant: YOLO model variant ('m', 'l', 'x') |
|
|
caption_language: Caption language ('zh', 'en') |
|
|
progress: Gradio Progress object for progress updates |
|
|
|
|
|
Returns: |
|
|
Tuple of (visualized_image, caption_html, batch_results_html, export_panel_visibility) |
|
|
""" |
|
|
global latest_batch_results, latest_batch_images |
|
|
|
|
|
|
|
|
if files is None or (isinstance(files, list) and len(files) == 0): |
|
|
error_msg = "<div style='color: #E74C3C; padding: 24px; text-align: center;'>Please upload at least one image</div>" |
|
|
return None, error_msg, "", gr.update(visible=False) |
|
|
|
|
|
|
|
|
if not isinstance(files, list): |
|
|
files = [files] |
|
|
|
|
|
|
|
|
if len(files) > 10: |
|
|
error_msg = "<div style='color: #E74C3C; padding: 24px; text-align: center;'>Maximum 10 images allowed. Please select fewer images.</div>" |
|
|
return None, error_msg, "", gr.update(visible=False) |
|
|
|
|
|
|
|
|
images = [] |
|
|
for file in files: |
|
|
try: |
|
|
if hasattr(file, 'name'): |
|
|
|
|
|
img = Image.open(file.name) |
|
|
else: |
|
|
|
|
|
img = Image.open(file) |
|
|
|
|
|
|
|
|
if img.mode != 'RGB': |
|
|
img = img.convert('RGB') |
|
|
|
|
|
images.append(img) |
|
|
except Exception as e: |
|
|
print(f"β οΈ Warning: Failed to load image {file}: {str(e)}") |
|
|
continue |
|
|
|
|
|
if len(images) == 0: |
|
|
error_msg = "<div style='color: #E74C3C; padding: 24px; text-align: center;'>No valid images found. Please upload valid image files.</div>" |
|
|
return None, error_msg, "", gr.update(visible=False) |
|
|
|
|
|
platform = 'instagram' |
|
|
|
|
|
|
|
|
if len(images) == 1: |
|
|
try: |
|
|
results = pipeline.process_image( |
|
|
image=images[0], |
|
|
platform=platform, |
|
|
yolo_variant=yolo_variant, |
|
|
language=caption_language |
|
|
) |
|
|
|
|
|
if results is None: |
|
|
error_msg = "<div style='color: #E74C3C; padding: 24px; text-align: center;'>Processing failed. Check terminal logs for details.</div>" |
|
|
return None, error_msg, "", gr.update(visible=False) |
|
|
|
|
|
|
|
|
visualized_image = results.get('visualized_image', images[0]) |
|
|
|
|
|
|
|
|
captions_html = ui_manager.format_captions_with_copy(results['captions']) |
|
|
|
|
|
|
|
|
latest_batch_results = None |
|
|
latest_batch_images = None |
|
|
|
|
|
return visualized_image, captions_html, "", gr.update(visible=False) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = traceback.format_exc() |
|
|
print("="*60) |
|
|
print("ERROR DETAILS:") |
|
|
print(error_msg) |
|
|
print("="*60) |
|
|
|
|
|
error_html = f""" |
|
|
<div style='background: #FADBD8; border: 2px solid #E74C3C; border-radius: 20px; padding: 28px; margin: 16px 0;'> |
|
|
<h3 style='color: #C0392B; margin-top: 0; font-size: 22px;'>β Processing Error</h3> |
|
|
<p style='color: #E74C3C; font-weight: bold; font-size: 17px; margin-bottom: 16px;'>{str(e)}</p> |
|
|
<details style='margin-top: 12px;'> |
|
|
<summary style='cursor: pointer; color: #C0392B; font-weight: bold; font-size: 16px;'>View Full Error Trace</summary> |
|
|
<pre style='background: white; padding: 16px; border-radius: 12px; overflow-x: auto; font-size: 13px; color: #2C3E50; margin-top: 12px;'>{error_msg}</pre> |
|
|
</details> |
|
|
</div> |
|
|
""" |
|
|
return None, error_html, "", gr.update(visible=False) |
|
|
|
|
|
|
|
|
else: |
|
|
try: |
|
|
|
|
|
def update_progress(progress_info): |
|
|
current = progress_info['current'] |
|
|
total = progress_info['total'] |
|
|
percent = progress_info['percent'] |
|
|
|
|
|
|
|
|
progress(percent / 100, desc=f"Processing image {current}/{total}") |
|
|
|
|
|
|
|
|
batch_results = pipeline.process_batch( |
|
|
images=images, |
|
|
platform=platform, |
|
|
yolo_variant=yolo_variant, |
|
|
language=caption_language, |
|
|
progress_callback=update_progress |
|
|
) |
|
|
|
|
|
|
|
|
latest_batch_results = batch_results |
|
|
latest_batch_images = images |
|
|
|
|
|
|
|
|
batch_html = ui_manager.format_batch_results_html(batch_results) |
|
|
|
|
|
|
|
|
return None, "", batch_html, gr.update(visible=True) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = traceback.format_exc() |
|
|
print("="*60) |
|
|
print("BATCH PROCESSING ERROR:") |
|
|
print(error_msg) |
|
|
print("="*60) |
|
|
|
|
|
error_html = f""" |
|
|
<div style='background: #FADBD8; border: 2px solid #E74C3C; border-radius: 20px; padding: 28px; margin: 16px 0;'> |
|
|
<h3 style='color: #C0392B; margin-top: 0; font-size: 22px;'>β Batch Processing Error</h3> |
|
|
<p style='color: #E74C3C; font-weight: bold; font-size: 17px; margin-bottom: 16px;'>{str(e)}</p> |
|
|
<details style='margin-top: 12px;'> |
|
|
<summary style='cursor: pointer; color: #C0392B; font-weight: bold; font-size: 16px;'>View Full Error Trace</summary> |
|
|
<pre style='background: white; padding: 16px; border-radius: 12px; overflow-x: auto; font-size: 13px; color: #2C3E50; margin-top: 12px;'>{error_msg}</pre> |
|
|
</details> |
|
|
</div> |
|
|
""" |
|
|
return None, error_html, "", gr.update(visible=False) |
|
|
|
|
|
|
|
|
def export_json_handler(): |
|
|
"""Export batch results to JSON file.""" |
|
|
global latest_batch_results |
|
|
|
|
|
if latest_batch_results is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
|
output_path = os.path.join(temp_dir, "pixcribe_batch_results.json") |
|
|
|
|
|
|
|
|
pipeline.batch_processor.export_to_json(latest_batch_results, output_path) |
|
|
|
|
|
return output_path |
|
|
except Exception as e: |
|
|
print(f"Export JSON error: {str(e)}") |
|
|
return None |
|
|
|
|
|
|
|
|
def export_csv_handler(): |
|
|
"""Export batch results to CSV file.""" |
|
|
global latest_batch_results |
|
|
|
|
|
if latest_batch_results is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
|
output_path = os.path.join(temp_dir, "pixcribe_batch_results.csv") |
|
|
|
|
|
|
|
|
pipeline.batch_processor.export_to_csv(latest_batch_results, output_path) |
|
|
|
|
|
return output_path |
|
|
except Exception as e: |
|
|
print(f"Export CSV error: {str(e)}") |
|
|
return None |
|
|
|
|
|
|
|
|
def export_zip_handler(): |
|
|
"""Export batch results to ZIP archive.""" |
|
|
global latest_batch_results, latest_batch_images |
|
|
|
|
|
if latest_batch_results is None or latest_batch_images is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
|
output_path = os.path.join(temp_dir, "pixcribe_batch_results.zip") |
|
|
|
|
|
|
|
|
pipeline.batch_processor.export_to_zip( |
|
|
latest_batch_results, |
|
|
latest_batch_images, |
|
|
output_path |
|
|
) |
|
|
|
|
|
return output_path |
|
|
except Exception as e: |
|
|
print(f"Export ZIP error: {str(e)}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=ui_manager.custom_css, title="Pixcribe V5 - AI Social Media Captions") as app: |
|
|
|
|
|
|
|
|
ui_manager.create_header() |
|
|
|
|
|
|
|
|
ui_manager.create_info_banner() |
|
|
|
|
|
|
|
|
with gr.Row(elem_classes="main-row"): |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(elem_classes="upload-card"): |
|
|
image_input = gr.File( |
|
|
file_count="multiple", |
|
|
file_types=["image"], |
|
|
label="Upload Images (Max 10)", |
|
|
elem_classes="upload-area" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(elem_classes="results-card"): |
|
|
gr.Markdown("### Detected Objects", elem_classes="section-title") |
|
|
visualized_image = gr.Image( |
|
|
label="", |
|
|
elem_classes="image-container" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(elem_classes="settings-container"): |
|
|
gr.Markdown("### Settings", elem_classes="section-title-left") |
|
|
|
|
|
with gr.Row(elem_classes="settings-row"): |
|
|
caption_language = gr.Radio( |
|
|
choices=[ |
|
|
('ηΉι«δΈζ', 'zh'), |
|
|
('English', 'en') |
|
|
], |
|
|
value='en', |
|
|
label="Caption Language", |
|
|
elem_classes="radio-group-inline" |
|
|
) |
|
|
|
|
|
yolo_variant = gr.Radio( |
|
|
choices=[ |
|
|
('Fast (m)', 'm'), |
|
|
('Balanced (l)', 'l'), |
|
|
('Accurate (x)', 'x') |
|
|
], |
|
|
value='l', |
|
|
label="Detection Mode", |
|
|
elem_classes="radio-group-inline" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(elem_classes="button-row"): |
|
|
analyze_btn = gr.Button( |
|
|
"Generate Captions", |
|
|
variant="primary", |
|
|
elem_classes="generate-button" |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; margin-top: 16px; color: #7F8C8D; font-size: 14px;"> |
|
|
<span style="opacity: 0.8;">β‘ Please be patient - AI processing may take some time</span> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Group(elem_classes="caption-results-container"): |
|
|
gr.Markdown("### π Generated Captions", elem_classes="section-title") |
|
|
caption_output = gr.HTML( |
|
|
label="", |
|
|
elem_id="caption-results" |
|
|
) |
|
|
|
|
|
|
|
|
batch_results_output = gr.HTML( |
|
|
label="", |
|
|
visible=True |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(elem_classes="export-panel", visible=False) as export_panel: |
|
|
gr.Markdown("### π₯ Export Batch Results", elem_classes="section-title-left") |
|
|
|
|
|
with gr.Row(): |
|
|
json_btn = gr.Button("π Download JSON", variant="secondary") |
|
|
csv_btn = gr.Button("π Download CSV", variant="secondary") |
|
|
zip_btn = gr.Button("π¦ Download ZIP", variant="secondary") |
|
|
|
|
|
json_file = gr.File(label="JSON Export", visible=False) |
|
|
csv_file = gr.File(label="CSV Export", visible=False) |
|
|
zip_file = gr.File(label="ZIP Export", visible=False) |
|
|
|
|
|
|
|
|
ui_manager.create_footer() |
|
|
|
|
|
gr.HTML(''' |
|
|
<div style=" |
|
|
display: flex; |
|
|
align-items: center; |
|
|
justify-content: center; |
|
|
gap: 20px; |
|
|
padding: 20px 0; |
|
|
"> |
|
|
<p style=" |
|
|
font-family: 'Arial', sans-serif; |
|
|
font-size: 14px; |
|
|
font-weight: 500; |
|
|
letter-spacing: 2px; |
|
|
background: linear-gradient(90deg, #555, #007ACC); |
|
|
-webkit-background-clip: text; |
|
|
-webkit-text-fill-color: transparent; |
|
|
margin: 0; |
|
|
text-transform: uppercase; |
|
|
display: inline-block; |
|
|
">EXPLORE THE CODE β</p> |
|
|
<a href="https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/Pixcribe" style="text-decoration: none;"> |
|
|
<img src="https://img.shields.io/badge/GitHub-Pixcribe-007ACC?logo=github&style=for-the-badge"> |
|
|
</a> |
|
|
</div> |
|
|
''') |
|
|
|
|
|
|
|
|
analyze_btn.click( |
|
|
fn=process_images_wrapper, |
|
|
inputs=[image_input, yolo_variant, caption_language], |
|
|
outputs=[visualized_image, caption_output, batch_results_output, export_panel] |
|
|
) |
|
|
|
|
|
|
|
|
json_btn.click( |
|
|
fn=export_json_handler, |
|
|
inputs=[], |
|
|
outputs=[json_file] |
|
|
).then( |
|
|
lambda: gr.update(visible=True), |
|
|
outputs=[json_file] |
|
|
) |
|
|
|
|
|
csv_btn.click( |
|
|
fn=export_csv_handler, |
|
|
inputs=[], |
|
|
outputs=[csv_file] |
|
|
).then( |
|
|
lambda: gr.update(visible=True), |
|
|
outputs=[csv_file] |
|
|
) |
|
|
|
|
|
zip_btn.click( |
|
|
fn=export_zip_handler, |
|
|
inputs=[], |
|
|
outputs=[zip_file] |
|
|
).then( |
|
|
lambda: gr.update(visible=True), |
|
|
outputs=[zip_file] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch(share=True) |
|
|
|