Pixcribe / batch_processing_manager.py
DawnC's picture
Upload 5 files
f3a4ad9 verified
import time
import json
import csv
import zipfile
from io import BytesIO
from typing import List, Dict, Optional, Callable
from PIL import Image
import traceback
class BatchProcessingManager:
"""
Manages batch processing of multiple images with progress tracking,
error handling, and result export functionality.
Follows the Facade pattern by delegating actual image processing
to the PixcribePipeline instance.
"""
def __init__(self, pipeline=None):
"""
Initialize the Batch Processing Manager.
Args:
pipeline: Reference to PixcribePipeline instance for processing images
"""
self.pipeline = pipeline
self.results = {} # Store processing results indexed by image number
self.timing_data = [] # Track processing time for each image
def process_batch(
self,
images: List[Image.Image],
platform: str = 'instagram',
yolo_variant: str = 'l',
language: str = 'zh',
progress_callback: Optional[Callable] = None
) -> Dict:
"""
Process a batch of images with progress tracking.
Args:
images: List of PIL Image objects to process (max 10)
platform: Target social media platform
yolo_variant: YOLO model variant ('m', 'l', 'x')
language: Caption language ('zh', 'en')
progress_callback: Optional callback function for progress updates
Returns:
Dictionary containing batch processing summary and results
Raises:
ValueError: If images list is empty or exceeds 10 images
"""
# Validate input
if not images:
raise ValueError("Images list cannot be empty")
if len(images) > 10:
raise ValueError("Maximum 10 images allowed per batch")
# Initialize results storage
self.results = {}
self.timing_data = []
total_images = len(images)
# Record batch start time
batch_start_time = time.time()
print(f"\n{'='*60}")
print(f"Starting batch processing: {total_images} images")
print(f"Platform: {platform} | Variant: {yolo_variant} | Language: {language}")
print(f"{'='*60}\n")
# Process each image
for idx, image in enumerate(images):
image_start_time = time.time()
image_index = idx + 1
try:
print(f"[{image_index}/{total_images}] Processing image {image_index}...")
# Call pipeline's process_image method
result = self.pipeline.process_image(
image=image,
platform=platform,
yolo_variant=yolo_variant,
language=language
)
# Store successful result
self.results[image_index] = {
'status': 'success',
'result': result,
'image_index': image_index,
'error': None
}
print(f"✓ Image {image_index} processed successfully")
except Exception as e:
# Store error result
error_trace = traceback.format_exc()
self.results[image_index] = {
'status': 'failed',
'result': None,
'image_index': image_index,
'error': {
'type': type(e).__name__,
'message': str(e),
'traceback': error_trace
}
}
print(f"✗ Image {image_index} failed: {str(e)}")
# Record processing time for this image
image_elapsed = time.time() - image_start_time
self.timing_data.append(image_elapsed)
# Calculate progress information
completed = image_index
percent = (completed / total_images) * 100
# Estimate remaining time based on average processing time
avg_time = sum(self.timing_data) / len(self.timing_data)
remaining_images = total_images - completed
estimated_remaining = avg_time * remaining_images
# Call progress callback if provided
if progress_callback:
progress_info = {
'current': completed,
'total': total_images,
'percent': percent,
'estimated_remaining': estimated_remaining,
'latest_result': self.results[image_index],
'image_index': image_index
}
progress_callback(progress_info)
# Calculate batch summary
batch_elapsed = time.time() - batch_start_time
total_processed = len(self.results)
total_failed = sum(1 for r in self.results.values() if r['status'] == 'failed')
total_success = total_processed - total_failed
print(f"\n{'='*60}")
print(f"Batch processing completed!")
print(f"Total: {total_processed} | Success: {total_success} | Failed: {total_failed}")
print(f"Total time: {batch_elapsed:.2f}s | Avg per image: {batch_elapsed/total_processed:.2f}s")
print(f"{'='*60}\n")
# Return batch summary
return {
'results': self.results,
'total_processed': total_processed,
'total_success': total_success,
'total_failed': total_failed,
'total_time': batch_elapsed,
'average_time_per_image': batch_elapsed / total_processed if total_processed > 0 else 0
}
def get_result(self, image_index: int) -> Optional[Dict]:
"""
Get processing result for a specific image.
Args:
image_index: Index of the image (1-based)
Returns:
Result dictionary or None if index doesn't exist
"""
return self.results.get(image_index)
def get_all_results(self) -> Dict:
"""
Get all processing results.
Returns:
Complete results dictionary
"""
return self.results
def clear_results(self):
"""Clear all stored results to free memory."""
self.results = {}
self.timing_data = []
print("✓ Batch results cleared")
def export_to_json(self, results: Dict, output_path: str) -> str:
"""
Export batch results to JSON format.
Args:
results: Results dictionary from process_batch
output_path: Path to save JSON file
Returns:
Path to the saved JSON file
"""
# Prepare export data
export_data = {
'batch_summary': {
'total_processed': results.get('total_processed', 0),
'total_success': results.get('total_success', 0),
'total_failed': results.get('total_failed', 0),
'total_time': results.get('total_time', 0),
'average_time_per_image': results.get('average_time_per_image', 0)
},
'images': []
}
# Process each image result
for img_idx, img_result in results.get('results', {}).items():
if img_result['status'] == 'success':
result_data = img_result['result']
image_export = {
'image_index': img_idx,
'status': 'success',
'captions': result_data.get('captions', []),
'detected_objects': [
det['class_name'] for det in result_data.get('detections', [])
],
'detected_brands': [
brand[0] if isinstance(brand, tuple) else brand
for brand in result_data.get('brands', [])
],
'scene_info': result_data.get('scene', {}),
'lighting': result_data.get('lighting', {})
}
else:
image_export = {
'image_index': img_idx,
'status': 'failed',
'error': img_result.get('error', {})
}
export_data['images'].append(image_export)
# Write to JSON file
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(export_data, f, ensure_ascii=False, indent=2)
print(f"✓ Batch results exported to JSON: {output_path}")
return output_path
def export_to_csv(self, results: Dict, output_path: str) -> str:
"""
Export batch results to CSV format.
Args:
results: Results dictionary from process_batch
output_path: Path to save CSV file
Returns:
Path to the saved CSV file
"""
# Define CSV headers
headers = [
'image_index',
'status',
'caption_professional',
'caption_creative',
'caption_authentic',
'detected_objects',
'detected_brands',
'hashtags'
]
# Prepare rows
rows = []
for img_idx, img_result in results.get('results', {}).items():
if img_result['status'] == 'success':
result_data = img_result['result']
captions = result_data.get('captions', [])
# Extract captions by tone
caption_professional = ''
caption_creative = ''
caption_authentic = ''
all_hashtags = []
for cap in captions:
tone = cap.get('tone', '').lower()
caption_text = cap.get('caption', '')
hashtags = cap.get('hashtags', [])
if 'professional' in tone:
caption_professional = caption_text
elif 'creative' in tone:
caption_creative = caption_text
elif 'authentic' in tone or 'casual' in tone:
caption_authentic = caption_text
all_hashtags.extend(hashtags)
# Remove duplicates from hashtags
all_hashtags = list(set(all_hashtags))
row = {
'image_index': img_idx,
'status': 'success',
'caption_professional': caption_professional,
'caption_creative': caption_creative,
'caption_authentic': caption_authentic,
'detected_objects': ', '.join([
det['class_name'] for det in result_data.get('detections', [])
]),
'detected_brands': ', '.join([
brand[0] if isinstance(brand, tuple) else brand
for brand in result_data.get('brands', [])
]),
'hashtags': ' '.join([f'#{tag}' for tag in all_hashtags])
}
else:
row = {
'image_index': img_idx,
'status': 'failed',
'caption_professional': '',
'caption_creative': '',
'caption_authentic': '',
'detected_objects': '',
'detected_brands': '',
'hashtags': ''
}
rows.append(row)
# Write to CSV file
with open(output_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=headers)
writer.writeheader()
writer.writerows(rows)
print(f"✓ Batch results exported to CSV: {output_path}")
return output_path
def export_to_zip(self, results: Dict, images: List[Image.Image], output_path: str) -> str:
"""
Export batch results to ZIP archive with images and text files.
Args:
results: Results dictionary from process_batch
images: List of original PIL Image objects
output_path: Path to save ZIP file
Returns:
Path to the saved ZIP file
"""
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for img_idx, img_result in results.get('results', {}).items():
if img_result['status'] == 'success':
# Save original image
image_filename = f"image_{img_idx:03d}.jpg"
# Convert PIL image to bytes
img_buffer = BytesIO()
images[img_idx - 1].save(img_buffer, format='JPEG', quality=95)
img_buffer.seek(0)
zipf.writestr(image_filename, img_buffer.read())
# Save caption text file
text_filename = f"image_{img_idx:03d}.txt"
text_content = self._format_result_as_text(img_result['result'])
zipf.writestr(text_filename, text_content)
print(f"✓ Added to ZIP: {image_filename} and {text_filename}")
print(f"✓ Batch results exported to ZIP: {output_path}")
return output_path
def _format_result_as_text(self, result: Dict) -> str:
"""
Format a single image result as plain text for ZIP export.
Args:
result: Single image processing result dictionary
Returns:
Formatted text string
"""
lines = []
lines.append("=" * 60)
lines.append("PIXCRIBE - AI GENERATED SOCIAL MEDIA CONTENT")
lines.append("=" * 60)
lines.append("")
# Captions section
captions = result.get('captions', [])
for i, cap in enumerate(captions, 1):
tone = cap.get('tone', 'Unknown').upper()
caption_text = cap.get('caption', '')
hashtags = cap.get('hashtags', [])
lines.append(f"CAPTION {i} - {tone} STYLE")
lines.append("-" * 60)
lines.append(caption_text)
lines.append("")
lines.append("Hashtags:")
lines.append(' '.join([f'#{tag}' for tag in hashtags]))
lines.append("")
lines.append("")
# Detected objects section
detections = result.get('detections', [])
if detections:
lines.append("DETECTED OBJECTS")
lines.append("-" * 60)
object_names = [det['class_name'] for det in detections]
lines.append(', '.join(object_names))
lines.append("")
# Detected brands section
brands = result.get('brands', [])
if brands:
lines.append("DETECTED BRANDS")
lines.append("-" * 60)
brand_names = [
brand[0] if isinstance(brand, tuple) else brand
for brand in brands
]
lines.append(', '.join(brand_names))
lines.append("")
# Scene information
scene_info = result.get('scene', {})
if scene_info:
lines.append("SCENE ANALYSIS")
lines.append("-" * 60)
if 'lighting' in scene_info:
lighting = scene_info['lighting'].get('top', 'Unknown')
lines.append(f"Lighting: {lighting}")
if 'mood' in scene_info:
mood = scene_info['mood'].get('top', 'Unknown')
lines.append(f"Mood: {mood}")
lines.append("")
lines.append("=" * 60)
lines.append("Generated by Pixcribe V5 - AI Social Media Caption Generator")
lines.append("=" * 60)
return '\n'.join(lines)
print("✓ BatchProcessingManager defined")