Spaces:
Running
on
Zero
Running
on
Zero
Upload 11 files
Browse filesCreating New Feature: Inpainting mode
- app.py +1 -2
- css_styles.py +230 -1
- image_blender.py +314 -11
- inpainting_module.py +1311 -0
- inpainting_templates.py +707 -0
- mask_generator.py +1 -3
- model_manager.py +263 -48
- quality_checker.py +450 -1
- scene_templates.py +5 -6
- scene_weaver_core.py +380 -11
- ui_manager.py +677 -218
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import sys
|
|
|
|
| 2 |
import warnings
|
| 3 |
warnings.filterwarnings("ignore")
|
| 4 |
|
|
@@ -26,13 +27,11 @@ def launch_final_blend_sceneweaver(share: bool = True, debug: bool = False):
|
|
| 26 |
return interface
|
| 27 |
|
| 28 |
except ImportError as import_error:
|
| 29 |
-
import traceback
|
| 30 |
print(f"❌ Import failed: {import_error}")
|
| 31 |
print(f"Traceback: {traceback.format_exc()}")
|
| 32 |
raise
|
| 33 |
|
| 34 |
except Exception as e:
|
| 35 |
-
import traceback
|
| 36 |
print(f"❌ Failed to launch: {e}")
|
| 37 |
print(f"Full traceback: {traceback.format_exc()}")
|
| 38 |
raise
|
|
|
|
| 1 |
import sys
|
| 2 |
+
import traceback
|
| 3 |
import warnings
|
| 4 |
warnings.filterwarnings("ignore")
|
| 5 |
|
|
|
|
| 27 |
return interface
|
| 28 |
|
| 29 |
except ImportError as import_error:
|
|
|
|
| 30 |
print(f"❌ Import failed: {import_error}")
|
| 31 |
print(f"Traceback: {traceback.format_exc()}")
|
| 32 |
raise
|
| 33 |
|
| 34 |
except Exception as e:
|
|
|
|
| 35 |
print(f"❌ Failed to launch: {e}")
|
| 36 |
print(f"Full traceback: {traceback.format_exc()}")
|
| 37 |
raise
|
css_styles.py
CHANGED
|
@@ -512,6 +512,235 @@ class CSSStyles:
|
|
| 512 |
font-size: 0.95rem !important;
|
| 513 |
}
|
| 514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
/* ===== DROPDOWN POSITIONING FIX FOR GRADIO 4.x/5.x ===== */
|
| 516 |
/* Fix dropdown list positioning issue in left column */
|
| 517 |
.feature-card {
|
|
@@ -542,4 +771,4 @@ class CSSStyles:
|
|
| 542 |
box-shadow: var(--shadow-lg) !important;
|
| 543 |
margin-top: 4px !important;
|
| 544 |
}
|
| 545 |
-
"""
|
|
|
|
| 512 |
font-size: 0.95rem !important;
|
| 513 |
}
|
| 514 |
|
| 515 |
+
/* ===== INPAINTING UI STYLES ===== */
|
| 516 |
+
.inpainting-header {
|
| 517 |
+
text-align: center !important;
|
| 518 |
+
padding: 16px !important;
|
| 519 |
+
margin-bottom: 16px !important;
|
| 520 |
+
background: linear-gradient(135deg, var(--primary-light) 0%, var(--accent-light) 100%) !important;
|
| 521 |
+
border-radius: var(--radius-lg) !important;
|
| 522 |
+
border: 1px solid var(--border-color) !important;
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
.inpainting-header h3 {
|
| 526 |
+
font-size: 1.4rem !important;
|
| 527 |
+
font-weight: 600 !important;
|
| 528 |
+
color: var(--primary-color) !important;
|
| 529 |
+
margin: 0 0 8px 0 !important;
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
.inpainting-header p {
|
| 533 |
+
font-size: 0.95rem !important;
|
| 534 |
+
color: var(--text-secondary) !important;
|
| 535 |
+
margin: 0 !important;
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
/* Main mode tabs styling */
|
| 539 |
+
#main-mode-tabs {
|
| 540 |
+
margin-bottom: 16px !important;
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
#main-mode-tabs > .tab-nav {
|
| 544 |
+
background: var(--bg-secondary) !important;
|
| 545 |
+
border-radius: var(--radius-lg) !important;
|
| 546 |
+
padding: 6px !important;
|
| 547 |
+
gap: 6px !important;
|
| 548 |
+
border: 1px solid var(--border-color) !important;
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
#main-mode-tabs > .tab-nav > button {
|
| 552 |
+
border-radius: var(--radius-md) !important;
|
| 553 |
+
padding: 12px 24px !important;
|
| 554 |
+
font-weight: 600 !important;
|
| 555 |
+
font-size: 1rem !important;
|
| 556 |
+
transition: all var(--transition-normal) !important;
|
| 557 |
+
border: none !important;
|
| 558 |
+
background: transparent !important;
|
| 559 |
+
color: var(--text-secondary) !important;
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
#main-mode-tabs > .tab-nav > button.selected {
|
| 563 |
+
background: linear-gradient(135deg, var(--accent-color) 0%, var(--accent-hover) 100%) !important;
|
| 564 |
+
color: white !important;
|
| 565 |
+
box-shadow: var(--shadow-md) !important;
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
#main-mode-tabs > .tab-nav > button:hover:not(.selected) {
|
| 569 |
+
background: var(--bg-primary) !important;
|
| 570 |
+
color: var(--text-primary) !important;
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
/* Mask editor styling */
|
| 574 |
+
.mask-editor-container {
|
| 575 |
+
border: 2px dashed var(--border-color) !important;
|
| 576 |
+
border-radius: var(--radius-lg) !important;
|
| 577 |
+
padding: 8px !important;
|
| 578 |
+
background: var(--bg-secondary) !important;
|
| 579 |
+
transition: border-color var(--transition-fast) !important;
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
.mask-editor-container:hover {
|
| 583 |
+
border-color: var(--accent-color) !important;
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
/* Inpainting template cards */
|
| 587 |
+
.inpainting-gallery {
|
| 588 |
+
margin: 16px 0 !important;
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
.inpainting-category {
|
| 592 |
+
margin-bottom: 20px !important;
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
.inpainting-category-title {
|
| 596 |
+
font-size: 0.9rem !important;
|
| 597 |
+
font-weight: 600 !important;
|
| 598 |
+
color: var(--text-secondary) !important;
|
| 599 |
+
margin-bottom: 12px !important;
|
| 600 |
+
padding-bottom: 8px !important;
|
| 601 |
+
border-bottom: 1px solid var(--border-color) !important;
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
.inpainting-grid {
|
| 605 |
+
display: grid !important;
|
| 606 |
+
grid-template-columns: repeat(auto-fill, minmax(120px, 1fr)) !important;
|
| 607 |
+
gap: 10px !important;
|
| 608 |
+
}
|
| 609 |
+
|
| 610 |
+
.inpainting-card {
|
| 611 |
+
display: flex !important;
|
| 612 |
+
flex-direction: column !important;
|
| 613 |
+
align-items: center !important;
|
| 614 |
+
justify-content: center !important;
|
| 615 |
+
padding: 14px 10px !important;
|
| 616 |
+
background: var(--bg-primary) !important;
|
| 617 |
+
border: 1px solid var(--border-color) !important;
|
| 618 |
+
border-radius: var(--radius-md) !important;
|
| 619 |
+
cursor: pointer !important;
|
| 620 |
+
transition: all var(--transition-normal) !important;
|
| 621 |
+
min-height: 80px !important;
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
.inpainting-card:hover {
|
| 625 |
+
background: var(--accent-light) !important;
|
| 626 |
+
border-color: var(--accent-color) !important;
|
| 627 |
+
transform: translateY(-2px) !important;
|
| 628 |
+
box-shadow: var(--shadow-md) !important;
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
.inpainting-card.selected {
|
| 632 |
+
background: var(--accent-light) !important;
|
| 633 |
+
border-color: var(--accent-color) !important;
|
| 634 |
+
box-shadow: 0 0 0 2px var(--accent-color) !important;
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
.inpainting-icon {
|
| 638 |
+
font-size: 1.6rem !important;
|
| 639 |
+
margin-bottom: 6px !important;
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
.inpainting-name {
|
| 643 |
+
font-size: 0.8rem !important;
|
| 644 |
+
font-weight: 500 !important;
|
| 645 |
+
color: var(--text-primary) !important;
|
| 646 |
+
text-align: center !important;
|
| 647 |
+
line-height: 1.2 !important;
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
.inpainting-desc {
|
| 651 |
+
font-size: 0.7rem !important;
|
| 652 |
+
color: var(--text-muted) !important;
|
| 653 |
+
text-align: center !important;
|
| 654 |
+
margin-top: 4px !important;
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
/* ControlNet mode toggle */
|
| 658 |
+
.controlnet-mode-toggle {
|
| 659 |
+
display: flex !important;
|
| 660 |
+
gap: 8px !important;
|
| 661 |
+
margin: 8px 0 !important;
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
.controlnet-mode-btn {
|
| 665 |
+
flex: 1 !important;
|
| 666 |
+
padding: 10px 16px !important;
|
| 667 |
+
border: 1px solid var(--border-color) !important;
|
| 668 |
+
border-radius: var(--radius-md) !important;
|
| 669 |
+
background: var(--bg-primary) !important;
|
| 670 |
+
color: var(--text-secondary) !important;
|
| 671 |
+
font-weight: 500 !important;
|
| 672 |
+
cursor: pointer !important;
|
| 673 |
+
transition: all var(--transition-fast) !important;
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
.controlnet-mode-btn:hover {
|
| 677 |
+
border-color: var(--accent-color) !important;
|
| 678 |
+
color: var(--accent-color) !important;
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
.controlnet-mode-btn.active {
|
| 682 |
+
background: var(--accent-color) !important;
|
| 683 |
+
border-color: var(--accent-color) !important;
|
| 684 |
+
color: white !important;
|
| 685 |
+
}
|
| 686 |
+
|
| 687 |
+
/* Preview badge */
|
| 688 |
+
.preview-badge {
|
| 689 |
+
display: inline-flex !important;
|
| 690 |
+
align-items: center !important;
|
| 691 |
+
gap: 4px !important;
|
| 692 |
+
padding: 4px 10px !important;
|
| 693 |
+
background: var(--warning-color) !important;
|
| 694 |
+
color: white !important;
|
| 695 |
+
font-size: 0.75rem !important;
|
| 696 |
+
font-weight: 600 !important;
|
| 697 |
+
border-radius: var(--radius-sm) !important;
|
| 698 |
+
text-transform: uppercase !important;
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
/* Quality score display */
|
| 702 |
+
.quality-score {
|
| 703 |
+
display: inline-flex !important;
|
| 704 |
+
align-items: center !important;
|
| 705 |
+
gap: 6px !important;
|
| 706 |
+
padding: 6px 12px !important;
|
| 707 |
+
border-radius: var(--radius-md) !important;
|
| 708 |
+
font-weight: 500 !important;
|
| 709 |
+
font-size: 0.9rem !important;
|
| 710 |
+
}
|
| 711 |
+
|
| 712 |
+
.quality-score.excellent {
|
| 713 |
+
background: rgba(16, 185, 129, 0.1) !important;
|
| 714 |
+
color: var(--success-color) !important;
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
.quality-score.good {
|
| 718 |
+
background: rgba(59, 130, 246, 0.1) !important;
|
| 719 |
+
color: var(--accent-color) !important;
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
.quality-score.warning {
|
| 723 |
+
background: rgba(245, 158, 11, 0.1) !important;
|
| 724 |
+
color: var(--warning-color) !important;
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
.quality-score.poor {
|
| 728 |
+
background: rgba(239, 68, 68, 0.1) !important;
|
| 729 |
+
color: var(--error-color) !important;
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
/* Responsive adjustments for inpainting */
|
| 733 |
+
@media (max-width: 768px) {
|
| 734 |
+
.inpainting-grid {
|
| 735 |
+
grid-template-columns: repeat(2, 1fr) !important;
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
#main-mode-tabs > .tab-nav > button {
|
| 739 |
+
padding: 10px 16px !important;
|
| 740 |
+
font-size: 0.9rem !important;
|
| 741 |
+
}
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
/* ===== DROPDOWN POSITIONING FIX FOR GRADIO 4.x/5.x ===== */
|
| 745 |
/* Fix dropdown list positioning issue in left column */
|
| 746 |
.feature-card {
|
|
|
|
| 771 |
box-shadow: var(--shadow-lg) !important;
|
| 772 |
margin-top: 4px !important;
|
| 773 |
}
|
| 774 |
+
"""
|
image_blender.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import cv2
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
import logging
|
| 5 |
from typing import Dict, Any, Optional, Tuple
|
|
@@ -7,24 +8,37 @@ from typing import Dict, Any, Optional, Tuple
|
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
logger.setLevel(logging.INFO)
|
| 9 |
|
|
|
|
| 10 |
class ImageBlender:
|
| 11 |
"""
|
| 12 |
-
Advanced image blending with aggressive spill suppression and color replacement
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
-
EDGE_EROSION_PIXELS = 1 # Pixels to erode from mask edge
|
| 17 |
-
ALPHA_BINARIZE_THRESHOLD = 0.5 # Alpha threshold for binarization
|
| 18 |
-
DARK_LUMINANCE_THRESHOLD = 60 # Luminance threshold for dark foreground
|
| 19 |
-
FOREGROUND_PROTECTION_THRESHOLD = 140 # Mask value
|
| 20 |
-
BACKGROUND_COLOR_TOLERANCE = 30 # DeltaE tolerance for background
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def __init__(self, enable_multi_scale: bool = True):
|
| 23 |
"""
|
| 24 |
Initialize ImageBlender.
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
"""
|
| 29 |
self.enable_multi_scale = enable_multi_scale
|
| 30 |
self._debug_info = {}
|
|
@@ -499,7 +513,6 @@ class ImageBlender:
|
|
| 499 |
logger.info(f"🔍 Trimap regions - FG_CORE: {fg_core.sum()}, RING: {ring_zone.sum()}, BG: {bg_zone.sum()}")
|
| 500 |
|
| 501 |
except Exception as e:
|
| 502 |
-
import traceback
|
| 503 |
logger.error(f"❌ Trimap definition failed: {e}")
|
| 504 |
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 505 |
print(f"❌ TRIMAP ERROR: {e}")
|
|
@@ -678,7 +691,7 @@ class ImageBlender:
|
|
| 678 |
orig_linear = srgb_to_linear(orig_array)
|
| 679 |
bg_linear = srgb_to_linear(bg_array)
|
| 680 |
|
| 681 |
-
#
|
| 682 |
alpha = mask_array.astype(np.float32) / 255.0
|
| 683 |
|
| 684 |
# Core foreground region - fully opaque
|
|
@@ -800,3 +813,293 @@ class ImageBlender:
|
|
| 800 |
debug_images["adaptive_strength_heatmap"] = Image.fromarray(strength_heatmap_rgb)
|
| 801 |
|
| 802 |
return debug_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import cv2
|
| 2 |
import numpy as np
|
| 3 |
+
import traceback
|
| 4 |
from PIL import Image
|
| 5 |
import logging
|
| 6 |
from typing import Dict, Any, Optional, Tuple
|
|
|
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
logger.setLevel(logging.INFO)
|
| 10 |
|
| 11 |
+
|
| 12 |
class ImageBlender:
|
| 13 |
"""
|
| 14 |
+
Advanced image blending with aggressive spill suppression and color replacement.
|
| 15 |
+
|
| 16 |
+
Supports two primary modes:
|
| 17 |
+
- Background generation: Foreground preservation with edge refinement
|
| 18 |
+
- Inpainting: Seamless blending with adaptive color correction
|
| 19 |
+
|
| 20 |
+
Attributes:
|
| 21 |
+
enable_multi_scale: Whether multi-scale edge refinement is enabled
|
| 22 |
"""
|
| 23 |
|
| 24 |
+
EDGE_EROSION_PIXELS = 1 # Pixels to erode from mask edge
|
| 25 |
+
ALPHA_BINARIZE_THRESHOLD = 0.5 # Alpha threshold for binarization
|
| 26 |
+
DARK_LUMINANCE_THRESHOLD = 60 # Luminance threshold for dark foreground
|
| 27 |
+
FOREGROUND_PROTECTION_THRESHOLD = 140 # Mask value for strong protection
|
| 28 |
+
BACKGROUND_COLOR_TOLERANCE = 30 # DeltaE tolerance for background detection
|
| 29 |
+
|
| 30 |
+
# Inpainting-specific parameters
|
| 31 |
+
INPAINT_FEATHER_SCALE = 1.2 # Scale factor for inpainting feathering
|
| 32 |
+
INPAINT_COLOR_BLEND_RADIUS = 10 # Radius for color adaptation zone
|
| 33 |
|
| 34 |
def __init__(self, enable_multi_scale: bool = True):
|
| 35 |
"""
|
| 36 |
Initialize ImageBlender.
|
| 37 |
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
enable_multi_scale : bool
|
| 41 |
+
Whether to enable multi-scale edge refinement (default True)
|
| 42 |
"""
|
| 43 |
self.enable_multi_scale = enable_multi_scale
|
| 44 |
self._debug_info = {}
|
|
|
|
| 513 |
logger.info(f"🔍 Trimap regions - FG_CORE: {fg_core.sum()}, RING: {ring_zone.sum()}, BG: {bg_zone.sum()}")
|
| 514 |
|
| 515 |
except Exception as e:
|
|
|
|
| 516 |
logger.error(f"❌ Trimap definition failed: {e}")
|
| 517 |
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 518 |
print(f"❌ TRIMAP ERROR: {e}")
|
|
|
|
| 691 |
orig_linear = srgb_to_linear(orig_array)
|
| 692 |
bg_linear = srgb_to_linear(bg_array)
|
| 693 |
|
| 694 |
+
# Cartoon-optimized Alpha calculation
|
| 695 |
alpha = mask_array.astype(np.float32) / 255.0
|
| 696 |
|
| 697 |
# Core foreground region - fully opaque
|
|
|
|
| 813 |
debug_images["adaptive_strength_heatmap"] = Image.fromarray(strength_heatmap_rgb)
|
| 814 |
|
| 815 |
return debug_images
|
| 816 |
+
|
| 817 |
+
# INPAINTING-SPECIFIC BLENDING METHODS
|
| 818 |
+
def blend_inpainting(
|
| 819 |
+
self,
|
| 820 |
+
original: Image.Image,
|
| 821 |
+
generated: Image.Image,
|
| 822 |
+
mask: Image.Image,
|
| 823 |
+
feather_radius: int = 8,
|
| 824 |
+
apply_color_correction: bool = True
|
| 825 |
+
) -> Image.Image:
|
| 826 |
+
"""
|
| 827 |
+
Blend inpainted region with original image.
|
| 828 |
+
|
| 829 |
+
Specialized blending for inpainting that focuses on seamless integration
|
| 830 |
+
rather than foreground protection. Performs blending in linear color space
|
| 831 |
+
with optional adaptive color correction at boundaries.
|
| 832 |
+
|
| 833 |
+
Parameters
|
| 834 |
+
----------
|
| 835 |
+
original : PIL.Image
|
| 836 |
+
Original image before inpainting
|
| 837 |
+
generated : PIL.Image
|
| 838 |
+
Generated/inpainted result from the model
|
| 839 |
+
mask : PIL.Image
|
| 840 |
+
Inpainting mask (white = inpainted area)
|
| 841 |
+
feather_radius : int
|
| 842 |
+
Feathering radius for smooth transitions
|
| 843 |
+
apply_color_correction : bool
|
| 844 |
+
Whether to apply adaptive color correction at boundaries
|
| 845 |
+
|
| 846 |
+
Returns
|
| 847 |
+
-------
|
| 848 |
+
PIL.Image
|
| 849 |
+
Blended result
|
| 850 |
+
"""
|
| 851 |
+
logger.info(f"Inpainting blend: feather={feather_radius}, color_correction={apply_color_correction}")
|
| 852 |
+
|
| 853 |
+
# Ensure same size
|
| 854 |
+
if generated.size != original.size:
|
| 855 |
+
generated = generated.resize(original.size, Image.LANCZOS)
|
| 856 |
+
if mask.size != original.size:
|
| 857 |
+
mask = mask.resize(original.size, Image.LANCZOS)
|
| 858 |
+
|
| 859 |
+
# Convert to arrays
|
| 860 |
+
orig_array = np.array(original.convert('RGB')).astype(np.float32)
|
| 861 |
+
gen_array = np.array(generated.convert('RGB')).astype(np.float32)
|
| 862 |
+
mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
|
| 863 |
+
|
| 864 |
+
# Apply feathering to mask
|
| 865 |
+
if feather_radius > 0:
|
| 866 |
+
scaled_radius = int(feather_radius * self.INPAINT_FEATHER_SCALE)
|
| 867 |
+
kernel_size = scaled_radius * 2 + 1
|
| 868 |
+
mask_array = cv2.GaussianBlur(
|
| 869 |
+
mask_array,
|
| 870 |
+
(kernel_size, kernel_size),
|
| 871 |
+
scaled_radius / 2
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
# Apply adaptive color correction if enabled
|
| 875 |
+
if apply_color_correction:
|
| 876 |
+
gen_array = self._apply_inpaint_color_correction(
|
| 877 |
+
orig_array, gen_array, mask_array
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
# sRGB to linear conversion for accurate blending
|
| 881 |
+
def srgb_to_linear(img):
|
| 882 |
+
img_norm = img / 255.0
|
| 883 |
+
return np.where(
|
| 884 |
+
img_norm <= 0.04045,
|
| 885 |
+
img_norm / 12.92,
|
| 886 |
+
np.power((img_norm + 0.055) / 1.055, 2.4)
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
def linear_to_srgb(img):
|
| 890 |
+
img_clipped = np.clip(img, 0, 1)
|
| 891 |
+
return np.where(
|
| 892 |
+
img_clipped <= 0.0031308,
|
| 893 |
+
12.92 * img_clipped,
|
| 894 |
+
1.055 * np.power(img_clipped, 1/2.4) - 0.055
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
# Convert to linear space
|
| 898 |
+
orig_linear = srgb_to_linear(orig_array)
|
| 899 |
+
gen_linear = srgb_to_linear(gen_array)
|
| 900 |
+
|
| 901 |
+
# Alpha blending in linear space
|
| 902 |
+
alpha = mask_array[:, :, np.newaxis]
|
| 903 |
+
result_linear = gen_linear * alpha + orig_linear * (1 - alpha)
|
| 904 |
+
|
| 905 |
+
# Convert back to sRGB
|
| 906 |
+
result_srgb = linear_to_srgb(result_linear)
|
| 907 |
+
result_array = (result_srgb * 255).astype(np.uint8)
|
| 908 |
+
|
| 909 |
+
logger.debug("Inpainting blend completed in linear color space")
|
| 910 |
+
|
| 911 |
+
return Image.fromarray(result_array)
|
| 912 |
+
|
| 913 |
+
def _apply_inpaint_color_correction(
|
| 914 |
+
self,
|
| 915 |
+
original: np.ndarray,
|
| 916 |
+
generated: np.ndarray,
|
| 917 |
+
mask: np.ndarray
|
| 918 |
+
) -> np.ndarray:
|
| 919 |
+
"""
|
| 920 |
+
Apply adaptive color correction to match generated region with surroundings.
|
| 921 |
+
|
| 922 |
+
Analyzes the boundary region and adjusts the generated content's
|
| 923 |
+
luminance and color to better match the original context.
|
| 924 |
+
|
| 925 |
+
Parameters
|
| 926 |
+
----------
|
| 927 |
+
original : np.ndarray
|
| 928 |
+
Original image (float32, 0-255)
|
| 929 |
+
generated : np.ndarray
|
| 930 |
+
Generated image (float32, 0-255)
|
| 931 |
+
mask : np.ndarray
|
| 932 |
+
Blend mask (float32, 0-1)
|
| 933 |
+
|
| 934 |
+
Returns
|
| 935 |
+
-------
|
| 936 |
+
np.ndarray
|
| 937 |
+
Color-corrected generated image
|
| 938 |
+
"""
|
| 939 |
+
# Find boundary region
|
| 940 |
+
mask_binary = (mask > 0.5).astype(np.uint8)
|
| 941 |
+
kernel = cv2.getStructuringElement(
|
| 942 |
+
cv2.MORPH_ELLIPSE,
|
| 943 |
+
(self.INPAINT_COLOR_BLEND_RADIUS * 2 + 1, self.INPAINT_COLOR_BLEND_RADIUS * 2 + 1)
|
| 944 |
+
)
|
| 945 |
+
dilated = cv2.dilate(mask_binary, kernel, iterations=1)
|
| 946 |
+
boundary_zone = (dilated > 0) & (mask < 0.3)
|
| 947 |
+
|
| 948 |
+
if not np.any(boundary_zone):
|
| 949 |
+
return generated
|
| 950 |
+
|
| 951 |
+
# Convert to Lab for perceptual color matching
|
| 952 |
+
orig_lab = cv2.cvtColor(
|
| 953 |
+
original.astype(np.uint8), cv2.COLOR_RGB2LAB
|
| 954 |
+
).astype(np.float32)
|
| 955 |
+
gen_lab = cv2.cvtColor(
|
| 956 |
+
generated.astype(np.uint8), cv2.COLOR_RGB2LAB
|
| 957 |
+
).astype(np.float32)
|
| 958 |
+
|
| 959 |
+
# Calculate statistics in boundary zone (original)
|
| 960 |
+
boundary_orig_l = orig_lab[boundary_zone, 0]
|
| 961 |
+
boundary_orig_a = orig_lab[boundary_zone, 1]
|
| 962 |
+
boundary_orig_b = orig_lab[boundary_zone, 2]
|
| 963 |
+
|
| 964 |
+
orig_mean_l = np.median(boundary_orig_l)
|
| 965 |
+
orig_mean_a = np.median(boundary_orig_a)
|
| 966 |
+
orig_mean_b = np.median(boundary_orig_b)
|
| 967 |
+
|
| 968 |
+
# Calculate statistics in generated inpaint region
|
| 969 |
+
inpaint_zone = mask > 0.5
|
| 970 |
+
if not np.any(inpaint_zone):
|
| 971 |
+
return generated
|
| 972 |
+
|
| 973 |
+
gen_inpaint_l = gen_lab[inpaint_zone, 0]
|
| 974 |
+
gen_inpaint_a = gen_lab[inpaint_zone, 1]
|
| 975 |
+
gen_inpaint_b = gen_lab[inpaint_zone, 2]
|
| 976 |
+
|
| 977 |
+
gen_mean_l = np.median(gen_inpaint_l)
|
| 978 |
+
gen_mean_a = np.median(gen_inpaint_a)
|
| 979 |
+
gen_mean_b = np.median(gen_inpaint_b)
|
| 980 |
+
|
| 981 |
+
# Calculate correction deltas
|
| 982 |
+
delta_l = orig_mean_l - gen_mean_l
|
| 983 |
+
delta_a = orig_mean_a - gen_mean_a
|
| 984 |
+
delta_b = orig_mean_b - gen_mean_b
|
| 985 |
+
|
| 986 |
+
# Limit correction to avoid over-adjustment
|
| 987 |
+
max_correction = 15
|
| 988 |
+
delta_l = np.clip(delta_l, -max_correction, max_correction)
|
| 989 |
+
delta_a = np.clip(delta_a, -max_correction * 0.5, max_correction * 0.5)
|
| 990 |
+
delta_b = np.clip(delta_b, -max_correction * 0.5, max_correction * 0.5)
|
| 991 |
+
|
| 992 |
+
logger.debug(f"Color correction deltas: L={delta_l:.1f}, a={delta_a:.1f}, b={delta_b:.1f}")
|
| 993 |
+
|
| 994 |
+
# Apply correction with spatial falloff from boundary
|
| 995 |
+
# Create distance map from boundary
|
| 996 |
+
distance = cv2.distanceTransform(
|
| 997 |
+
mask_binary, cv2.DIST_L2, 5
|
| 998 |
+
)
|
| 999 |
+
max_dist = np.max(distance)
|
| 1000 |
+
if max_dist > 0:
|
| 1001 |
+
# Correction strength falls off from boundary toward center
|
| 1002 |
+
correction_strength = 1.0 - np.clip(distance / (max_dist * 0.5), 0, 1)
|
| 1003 |
+
else:
|
| 1004 |
+
correction_strength = np.ones_like(distance)
|
| 1005 |
+
|
| 1006 |
+
# Apply correction to Lab channels
|
| 1007 |
+
corrected_lab = gen_lab.copy()
|
| 1008 |
+
corrected_lab[:, :, 0] += delta_l * correction_strength * 0.7
|
| 1009 |
+
corrected_lab[:, :, 1] += delta_a * correction_strength * 0.5
|
| 1010 |
+
corrected_lab[:, :, 2] += delta_b * correction_strength * 0.5
|
| 1011 |
+
|
| 1012 |
+
# Clip to valid Lab ranges
|
| 1013 |
+
corrected_lab[:, :, 0] = np.clip(corrected_lab[:, :, 0], 0, 255)
|
| 1014 |
+
corrected_lab[:, :, 1] = np.clip(corrected_lab[:, :, 1], 0, 255)
|
| 1015 |
+
corrected_lab[:, :, 2] = np.clip(corrected_lab[:, :, 2], 0, 255)
|
| 1016 |
+
|
| 1017 |
+
# Convert back to RGB
|
| 1018 |
+
corrected_rgb = cv2.cvtColor(
|
| 1019 |
+
corrected_lab.astype(np.uint8), cv2.COLOR_LAB2RGB
|
| 1020 |
+
).astype(np.float32)
|
| 1021 |
+
|
| 1022 |
+
return corrected_rgb
|
| 1023 |
+
|
| 1024 |
+
def blend_inpainting_with_guided_filter(
|
| 1025 |
+
self,
|
| 1026 |
+
original: Image.Image,
|
| 1027 |
+
generated: Image.Image,
|
| 1028 |
+
mask: Image.Image,
|
| 1029 |
+
feather_radius: int = 8,
|
| 1030 |
+
guide_radius: int = 8,
|
| 1031 |
+
guide_eps: float = 0.01
|
| 1032 |
+
) -> Image.Image:
|
| 1033 |
+
"""
|
| 1034 |
+
Blend inpainted region using guided filter for edge-aware transitions.
|
| 1035 |
+
|
| 1036 |
+
Combines standard alpha blending with guided filtering to preserve
|
| 1037 |
+
edges in the original image while seamlessly integrating new content.
|
| 1038 |
+
|
| 1039 |
+
Parameters
|
| 1040 |
+
----------
|
| 1041 |
+
original : PIL.Image
|
| 1042 |
+
Original image
|
| 1043 |
+
generated : PIL.Image
|
| 1044 |
+
Generated/inpainted result
|
| 1045 |
+
mask : PIL.Image
|
| 1046 |
+
Inpainting mask
|
| 1047 |
+
feather_radius : int
|
| 1048 |
+
Base feathering radius
|
| 1049 |
+
guide_radius : int
|
| 1050 |
+
Guided filter radius
|
| 1051 |
+
guide_eps : float
|
| 1052 |
+
Guided filter regularization
|
| 1053 |
+
|
| 1054 |
+
Returns
|
| 1055 |
+
-------
|
| 1056 |
+
PIL.Image
|
| 1057 |
+
Blended result with edge-aware transitions
|
| 1058 |
+
"""
|
| 1059 |
+
logger.info("Applying guided filter inpainting blend")
|
| 1060 |
+
|
| 1061 |
+
# Ensure same size
|
| 1062 |
+
if generated.size != original.size:
|
| 1063 |
+
generated = generated.resize(original.size, Image.LANCZOS)
|
| 1064 |
+
if mask.size != original.size:
|
| 1065 |
+
mask = mask.resize(original.size, Image.LANCZOS)
|
| 1066 |
+
|
| 1067 |
+
# Convert to arrays
|
| 1068 |
+
orig_array = np.array(original.convert('RGB')).astype(np.float32)
|
| 1069 |
+
gen_array = np.array(generated.convert('RGB')).astype(np.float32)
|
| 1070 |
+
mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
|
| 1071 |
+
|
| 1072 |
+
# Apply base feathering
|
| 1073 |
+
if feather_radius > 0:
|
| 1074 |
+
kernel_size = feather_radius * 2 + 1
|
| 1075 |
+
mask_feathered = cv2.GaussianBlur(
|
| 1076 |
+
mask_array,
|
| 1077 |
+
(kernel_size, kernel_size),
|
| 1078 |
+
feather_radius / 2
|
| 1079 |
+
)
|
| 1080 |
+
else:
|
| 1081 |
+
mask_feathered = mask_array
|
| 1082 |
+
|
| 1083 |
+
# Use original image as guide for the filter
|
| 1084 |
+
guide = cv2.cvtColor(orig_array.astype(np.uint8), cv2.COLOR_RGB2GRAY)
|
| 1085 |
+
guide = guide.astype(np.float32) / 255.0
|
| 1086 |
+
|
| 1087 |
+
# Apply guided filter to the mask
|
| 1088 |
+
try:
|
| 1089 |
+
mask_guided = cv2.ximgproc.guidedFilter(
|
| 1090 |
+
guide=guide,
|
| 1091 |
+
src=mask_feathered,
|
| 1092 |
+
radius=guide_radius,
|
| 1093 |
+
eps=guide_eps
|
| 1094 |
+
)
|
| 1095 |
+
logger.debug("Guided filter applied successfully")
|
| 1096 |
+
except Exception as e:
|
| 1097 |
+
logger.warning(f"Guided filter failed: {e}, using standard feathering")
|
| 1098 |
+
mask_guided = mask_feathered
|
| 1099 |
+
|
| 1100 |
+
# Alpha blending
|
| 1101 |
+
alpha = mask_guided[:, :, np.newaxis]
|
| 1102 |
+
result = gen_array * alpha + orig_array * (1 - alpha)
|
| 1103 |
+
result = np.clip(result, 0, 255).astype(np.uint8)
|
| 1104 |
+
|
| 1105 |
+
return Image.fromarray(result)
|
inpainting_module.py
ADDED
|
@@ -0,0 +1,1311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
import traceback
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image, ImageFilter
|
| 12 |
+
|
| 13 |
+
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
| 14 |
+
from diffusers import StableDiffusionXLControlNetInpaintPipeline
|
| 15 |
+
from diffusers import StableDiffusionXLInpaintPipeline
|
| 16 |
+
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
|
| 17 |
+
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
logger.setLevel(logging.INFO)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class InpaintingConfig:
|
| 25 |
+
"""Configuration for inpainting operations."""
|
| 26 |
+
|
| 27 |
+
# ControlNet settings
|
| 28 |
+
controlnet_conditioning_scale: float = 0.7
|
| 29 |
+
conditioning_type: str = "canny" # "canny" or "depth"
|
| 30 |
+
|
| 31 |
+
# Canny edge detection parameters
|
| 32 |
+
canny_low_threshold: int = 100
|
| 33 |
+
canny_high_threshold: int = 200
|
| 34 |
+
|
| 35 |
+
# Mask settings
|
| 36 |
+
feather_radius: int = 8
|
| 37 |
+
min_mask_coverage: float = 0.01
|
| 38 |
+
max_mask_coverage: float = 0.95
|
| 39 |
+
|
| 40 |
+
# Generation settings
|
| 41 |
+
num_inference_steps: int = 25
|
| 42 |
+
guidance_scale: float = 7.5
|
| 43 |
+
preview_steps: int = 15
|
| 44 |
+
preview_guidance_scale: float = 8.0
|
| 45 |
+
|
| 46 |
+
# Quality settings
|
| 47 |
+
enable_auto_optimization: bool = True
|
| 48 |
+
max_optimization_retries: int = 3
|
| 49 |
+
min_quality_score: float = 70.0
|
| 50 |
+
|
| 51 |
+
# Memory settings
|
| 52 |
+
enable_vae_tiling: bool = True
|
| 53 |
+
enable_attention_slicing: bool = True
|
| 54 |
+
max_resolution: int = 1024
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class InpaintingResult:
|
| 59 |
+
"""Result container for inpainting operations."""
|
| 60 |
+
|
| 61 |
+
success: bool
|
| 62 |
+
result_image: Optional[Image.Image] = None
|
| 63 |
+
preview_image: Optional[Image.Image] = None
|
| 64 |
+
control_image: Optional[Image.Image] = None
|
| 65 |
+
blended_image: Optional[Image.Image] = None
|
| 66 |
+
quality_score: float = 0.0
|
| 67 |
+
quality_details: Dict[str, Any] = field(default_factory=dict)
|
| 68 |
+
generation_time: float = 0.0
|
| 69 |
+
retries: int = 0
|
| 70 |
+
error_message: str = ""
|
| 71 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class InpaintingModule:
|
| 75 |
+
"""
|
| 76 |
+
ControlNet-based Inpainting Module for SceneWeaver.
|
| 77 |
+
|
| 78 |
+
Implements StableDiffusionXLControlNetInpaintPipeline with support for
|
| 79 |
+
Canny edge and depth map conditioning. Features two-stage generation
|
| 80 |
+
(preview + full quality) and automatic quality optimization.
|
| 81 |
+
|
| 82 |
+
Attributes:
|
| 83 |
+
device: Computation device (cuda/mps/cpu)
|
| 84 |
+
config: InpaintingConfig instance
|
| 85 |
+
is_initialized: Whether pipeline is loaded
|
| 86 |
+
|
| 87 |
+
Example:
|
| 88 |
+
>>> module = InpaintingModule(device="cuda")
|
| 89 |
+
>>> module.load_inpainting_pipeline(progress_callback=my_callback)
|
| 90 |
+
>>> result = module.execute_inpainting(
|
| 91 |
+
... image=my_image,
|
| 92 |
+
... mask=my_mask,
|
| 93 |
+
... prompt="a beautiful garden"
|
| 94 |
+
... )
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
# Model identifiers
|
| 98 |
+
CONTROLNET_CANNY_MODEL = "diffusers/controlnet-canny-sdxl-1.0"
|
| 99 |
+
CONTROLNET_DEPTH_MODEL = "diffusers/controlnet-depth-sdxl-1.0"
|
| 100 |
+
DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf"
|
| 101 |
+
DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas"
|
| 102 |
+
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 103 |
+
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
device: str = "auto",
|
| 107 |
+
config: Optional[InpaintingConfig] = None
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Initialize the InpaintingModule.
|
| 111 |
+
|
| 112 |
+
Parameters
|
| 113 |
+
----------
|
| 114 |
+
device : str, optional
|
| 115 |
+
Computation device. "auto" for automatic detection.
|
| 116 |
+
config : InpaintingConfig, optional
|
| 117 |
+
Configuration object. Uses defaults if not provided.
|
| 118 |
+
"""
|
| 119 |
+
self.device = self._setup_device(device)
|
| 120 |
+
self.config = config or InpaintingConfig()
|
| 121 |
+
|
| 122 |
+
# Pipeline instances (lazy loaded)
|
| 123 |
+
self._inpaint_pipeline = None
|
| 124 |
+
self._controlnet_canny = None
|
| 125 |
+
self._controlnet_depth = None
|
| 126 |
+
self._depth_estimator = None
|
| 127 |
+
self._depth_processor = None
|
| 128 |
+
|
| 129 |
+
# State tracking
|
| 130 |
+
self.is_initialized = False
|
| 131 |
+
self._current_conditioning_type = None
|
| 132 |
+
self._last_seed = None
|
| 133 |
+
self._cached_latents = None
|
| 134 |
+
self._use_controlnet = True # Track if ControlNet is available
|
| 135 |
+
|
| 136 |
+
# Reference to model manager (set by SceneWeaverCore)
|
| 137 |
+
self._model_manager = None
|
| 138 |
+
|
| 139 |
+
logger.info(f"InpaintingModule initialized on {self.device}")
|
| 140 |
+
|
| 141 |
+
def _setup_device(self, device: str) -> str:
|
| 142 |
+
"""
|
| 143 |
+
Setup computation device.
|
| 144 |
+
|
| 145 |
+
Parameters
|
| 146 |
+
----------
|
| 147 |
+
device : str
|
| 148 |
+
Device specification or "auto"
|
| 149 |
+
|
| 150 |
+
Returns
|
| 151 |
+
-------
|
| 152 |
+
str
|
| 153 |
+
Resolved device name
|
| 154 |
+
"""
|
| 155 |
+
if device == "auto":
|
| 156 |
+
if torch.cuda.is_available():
|
| 157 |
+
return "cuda"
|
| 158 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 159 |
+
return "mps"
|
| 160 |
+
return "cpu"
|
| 161 |
+
return device
|
| 162 |
+
|
| 163 |
+
def set_model_manager(self, manager: Any) -> None:
|
| 164 |
+
"""
|
| 165 |
+
Set reference to ModelManager for coordinated model lifecycle.
|
| 166 |
+
|
| 167 |
+
Parameters
|
| 168 |
+
----------
|
| 169 |
+
manager : ModelManager
|
| 170 |
+
The global model manager instance
|
| 171 |
+
"""
|
| 172 |
+
self._model_manager = manager
|
| 173 |
+
logger.info("ModelManager reference set for InpaintingModule")
|
| 174 |
+
|
| 175 |
+
def _memory_cleanup(self, aggressive: bool = False) -> None:
|
| 176 |
+
"""
|
| 177 |
+
Perform memory cleanup.
|
| 178 |
+
|
| 179 |
+
Parameters
|
| 180 |
+
----------
|
| 181 |
+
aggressive : bool
|
| 182 |
+
If True, perform multiple GC rounds and sync CUDA
|
| 183 |
+
"""
|
| 184 |
+
rounds = 5 if aggressive else 2
|
| 185 |
+
for _ in range(rounds):
|
| 186 |
+
gc.collect()
|
| 187 |
+
|
| 188 |
+
if torch.cuda.is_available():
|
| 189 |
+
torch.cuda.empty_cache()
|
| 190 |
+
if aggressive:
|
| 191 |
+
torch.cuda.ipc_collect()
|
| 192 |
+
torch.cuda.synchronize()
|
| 193 |
+
|
| 194 |
+
logger.debug(f"Memory cleanup completed (aggressive={aggressive})")
|
| 195 |
+
|
| 196 |
+
def _check_memory_status(self) -> Dict[str, float]:
|
| 197 |
+
"""
|
| 198 |
+
Check current GPU memory status.
|
| 199 |
+
|
| 200 |
+
Returns
|
| 201 |
+
-------
|
| 202 |
+
dict
|
| 203 |
+
Memory statistics including allocated, total, and usage ratio
|
| 204 |
+
"""
|
| 205 |
+
if not torch.cuda.is_available():
|
| 206 |
+
return {"available": True, "usage_ratio": 0.0}
|
| 207 |
+
|
| 208 |
+
allocated = torch.cuda.memory_allocated() / 1024**3
|
| 209 |
+
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 210 |
+
usage_ratio = allocated / total
|
| 211 |
+
|
| 212 |
+
return {
|
| 213 |
+
"allocated_gb": round(allocated, 2),
|
| 214 |
+
"total_gb": round(total, 2),
|
| 215 |
+
"free_gb": round(total - allocated, 2),
|
| 216 |
+
"usage_ratio": round(usage_ratio, 3),
|
| 217 |
+
"available": usage_ratio < 0.9
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
def load_inpainting_pipeline(
|
| 221 |
+
self,
|
| 222 |
+
conditioning_type: str = "canny",
|
| 223 |
+
progress_callback: Optional[Callable[[str, int], None]] = None
|
| 224 |
+
) -> Tuple[bool, str]:
|
| 225 |
+
"""
|
| 226 |
+
Load the ControlNet inpainting pipeline.
|
| 227 |
+
|
| 228 |
+
Implements mutual exclusion with background generation pipeline.
|
| 229 |
+
Only one pipeline can be loaded at a time.
|
| 230 |
+
|
| 231 |
+
Parameters
|
| 232 |
+
----------
|
| 233 |
+
conditioning_type : str
|
| 234 |
+
Type of ControlNet conditioning: "canny" or "depth"
|
| 235 |
+
progress_callback : callable, optional
|
| 236 |
+
Function(message, percentage) for progress updates
|
| 237 |
+
|
| 238 |
+
Returns
|
| 239 |
+
-------
|
| 240 |
+
tuple
|
| 241 |
+
(success: bool, error_message: str)
|
| 242 |
+
"""
|
| 243 |
+
if self.is_initialized and self._current_conditioning_type == conditioning_type:
|
| 244 |
+
logger.info(f"Inpainting pipeline already loaded with {conditioning_type}")
|
| 245 |
+
return True, ""
|
| 246 |
+
|
| 247 |
+
logger.info(f"Loading inpainting pipeline with {conditioning_type} conditioning...")
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
self._memory_cleanup(aggressive=True)
|
| 251 |
+
|
| 252 |
+
if progress_callback:
|
| 253 |
+
progress_callback("Preparing to load inpainting models...", 5)
|
| 254 |
+
|
| 255 |
+
# Unload existing pipeline if different conditioning type
|
| 256 |
+
if self._inpaint_pipeline is not None:
|
| 257 |
+
self._unload_pipeline()
|
| 258 |
+
|
| 259 |
+
# Use ControlNet inpainting by default
|
| 260 |
+
use_controlnet_inpaint = True
|
| 261 |
+
logger.info("Using StableDiffusionXLControlNetInpaintPipeline")
|
| 262 |
+
|
| 263 |
+
if progress_callback:
|
| 264 |
+
progress_callback("Loading ControlNet model...", 20)
|
| 265 |
+
|
| 266 |
+
# Load appropriate ControlNet
|
| 267 |
+
dtype = torch.float16 if self.device == "cuda" else torch.float32
|
| 268 |
+
controlnet = None
|
| 269 |
+
|
| 270 |
+
if use_controlnet_inpaint:
|
| 271 |
+
if conditioning_type == "canny":
|
| 272 |
+
controlnet = ControlNetModel.from_pretrained(
|
| 273 |
+
self.CONTROLNET_CANNY_MODEL,
|
| 274 |
+
torch_dtype=dtype,
|
| 275 |
+
use_safetensors=True
|
| 276 |
+
)
|
| 277 |
+
self._controlnet_canny = controlnet
|
| 278 |
+
logger.info("Loaded ControlNet Canny model")
|
| 279 |
+
|
| 280 |
+
elif conditioning_type == "depth":
|
| 281 |
+
controlnet = ControlNetModel.from_pretrained(
|
| 282 |
+
self.CONTROLNET_DEPTH_MODEL,
|
| 283 |
+
torch_dtype=dtype,
|
| 284 |
+
use_safetensors=True
|
| 285 |
+
)
|
| 286 |
+
self._controlnet_depth = controlnet
|
| 287 |
+
|
| 288 |
+
# Load depth estimator
|
| 289 |
+
if progress_callback:
|
| 290 |
+
progress_callback("Loading depth estimation model...", 35)
|
| 291 |
+
self._load_depth_estimator()
|
| 292 |
+
logger.info("Loaded ControlNet Depth model")
|
| 293 |
+
else:
|
| 294 |
+
raise ValueError(f"Unknown conditioning type: {conditioning_type}")
|
| 295 |
+
else:
|
| 296 |
+
# Skip ControlNet loading for fallback mode
|
| 297 |
+
logger.info(f"Skipping ControlNet loading (fallback mode)")
|
| 298 |
+
|
| 299 |
+
if progress_callback:
|
| 300 |
+
progress_callback("Loading SDXL Inpainting pipeline...", 50)
|
| 301 |
+
|
| 302 |
+
# Load the inpainting pipeline
|
| 303 |
+
if use_controlnet_inpaint and controlnet is not None:
|
| 304 |
+
self._inpaint_pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
|
| 305 |
+
self.BASE_MODEL,
|
| 306 |
+
controlnet=controlnet,
|
| 307 |
+
torch_dtype=dtype,
|
| 308 |
+
use_safetensors=True,
|
| 309 |
+
variant="fp16" if dtype == torch.float16 else None
|
| 310 |
+
)
|
| 311 |
+
else:
|
| 312 |
+
# Fallback: Use dedicated inpainting model without ControlNet
|
| 313 |
+
self._inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
|
| 314 |
+
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
| 315 |
+
torch_dtype=dtype,
|
| 316 |
+
use_safetensors=True,
|
| 317 |
+
variant="fp16" if dtype == torch.float16 else None
|
| 318 |
+
)
|
| 319 |
+
self._use_controlnet = False
|
| 320 |
+
|
| 321 |
+
# Track ControlNet usage
|
| 322 |
+
self._use_controlnet = use_controlnet_inpaint and controlnet is not None
|
| 323 |
+
|
| 324 |
+
if progress_callback:
|
| 325 |
+
progress_callback("Configuring scheduler...", 70)
|
| 326 |
+
|
| 327 |
+
# Configure scheduler for faster generation
|
| 328 |
+
self._inpaint_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 329 |
+
self._inpaint_pipeline.scheduler.config
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Move to device
|
| 333 |
+
self._inpaint_pipeline = self._inpaint_pipeline.to(self.device)
|
| 334 |
+
|
| 335 |
+
if progress_callback:
|
| 336 |
+
progress_callback("Applying optimizations...", 85)
|
| 337 |
+
|
| 338 |
+
# Apply memory optimizations
|
| 339 |
+
self._apply_pipeline_optimizations()
|
| 340 |
+
|
| 341 |
+
# Set eval mode
|
| 342 |
+
self._inpaint_pipeline.unet.eval()
|
| 343 |
+
if hasattr(self._inpaint_pipeline, 'vae'):
|
| 344 |
+
self._inpaint_pipeline.vae.eval()
|
| 345 |
+
|
| 346 |
+
self.is_initialized = True
|
| 347 |
+
self._current_conditioning_type = conditioning_type if self._use_controlnet else "none"
|
| 348 |
+
|
| 349 |
+
if progress_callback:
|
| 350 |
+
progress_callback("Inpainting pipeline ready!", 100)
|
| 351 |
+
|
| 352 |
+
# Log memory status
|
| 353 |
+
mem_status = self._check_memory_status()
|
| 354 |
+
logger.info(f"Pipeline loaded. GPU memory: {mem_status.get('allocated_gb', 0):.1f}GB used")
|
| 355 |
+
|
| 356 |
+
return True, ""
|
| 357 |
+
|
| 358 |
+
except Exception as e:
|
| 359 |
+
error_msg = str(e)
|
| 360 |
+
logger.error(f"Failed to load inpainting pipeline: {error_msg}")
|
| 361 |
+
traceback.print_exc()
|
| 362 |
+
self._unload_pipeline()
|
| 363 |
+
return False, error_msg
|
| 364 |
+
|
| 365 |
+
def _load_depth_estimator(self) -> None:
|
| 366 |
+
"""
|
| 367 |
+
Load depth estimation model with fallback strategy.
|
| 368 |
+
|
| 369 |
+
Tries Depth-Anything first, falls back to MiDaS if unavailable.
|
| 370 |
+
"""
|
| 371 |
+
try:
|
| 372 |
+
logger.info(f"Attempting to load depth model: {self.DEPTH_MODEL_PRIMARY}")
|
| 373 |
+
|
| 374 |
+
self._depth_processor = AutoImageProcessor.from_pretrained(
|
| 375 |
+
self.DEPTH_MODEL_PRIMARY
|
| 376 |
+
)
|
| 377 |
+
self._depth_estimator = AutoModelForDepthEstimation.from_pretrained(
|
| 378 |
+
self.DEPTH_MODEL_PRIMARY,
|
| 379 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
| 380 |
+
)
|
| 381 |
+
self._depth_estimator.to(self.device)
|
| 382 |
+
self._depth_estimator.eval()
|
| 383 |
+
|
| 384 |
+
logger.info("Successfully loaded Depth-Anything model")
|
| 385 |
+
|
| 386 |
+
except Exception as e:
|
| 387 |
+
logger.warning(f"Primary depth model failed: {e}, trying fallback...")
|
| 388 |
+
|
| 389 |
+
try:
|
| 390 |
+
self._depth_processor = DPTImageProcessor.from_pretrained(
|
| 391 |
+
self.DEPTH_MODEL_FALLBACK
|
| 392 |
+
)
|
| 393 |
+
self._depth_estimator = DPTForDepthEstimation.from_pretrained(
|
| 394 |
+
self.DEPTH_MODEL_FALLBACK,
|
| 395 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
| 396 |
+
)
|
| 397 |
+
self._depth_estimator.to(self.device)
|
| 398 |
+
self._depth_estimator.eval()
|
| 399 |
+
|
| 400 |
+
logger.info("Successfully loaded MiDaS fallback model")
|
| 401 |
+
|
| 402 |
+
except Exception as fallback_e:
|
| 403 |
+
logger.error(f"Fallback depth model also failed: {fallback_e}")
|
| 404 |
+
raise RuntimeError("Unable to load any depth estimation model")
|
| 405 |
+
|
| 406 |
+
def _apply_pipeline_optimizations(self) -> None:
|
| 407 |
+
"""Apply memory and performance optimizations to the pipeline."""
|
| 408 |
+
if self._inpaint_pipeline is None:
|
| 409 |
+
return
|
| 410 |
+
|
| 411 |
+
# Try xformers first
|
| 412 |
+
try:
|
| 413 |
+
self._inpaint_pipeline.enable_xformers_memory_efficient_attention()
|
| 414 |
+
logger.info("Enabled xformers memory efficient attention")
|
| 415 |
+
except Exception:
|
| 416 |
+
try:
|
| 417 |
+
self._inpaint_pipeline.enable_attention_slicing()
|
| 418 |
+
logger.info("Enabled attention slicing")
|
| 419 |
+
except Exception:
|
| 420 |
+
logger.warning("No attention optimization available")
|
| 421 |
+
|
| 422 |
+
# VAE optimizations
|
| 423 |
+
if self.config.enable_vae_tiling:
|
| 424 |
+
if hasattr(self._inpaint_pipeline, 'enable_vae_tiling'):
|
| 425 |
+
self._inpaint_pipeline.enable_vae_tiling()
|
| 426 |
+
logger.debug("Enabled VAE tiling")
|
| 427 |
+
|
| 428 |
+
if hasattr(self._inpaint_pipeline, 'enable_vae_slicing'):
|
| 429 |
+
self._inpaint_pipeline.enable_vae_slicing()
|
| 430 |
+
logger.debug("Enabled VAE slicing")
|
| 431 |
+
|
| 432 |
+
def _unload_pipeline(self) -> None:
|
| 433 |
+
"""Unload the inpainting pipeline and free memory."""
|
| 434 |
+
logger.info("Unloading inpainting pipeline...")
|
| 435 |
+
|
| 436 |
+
if self._inpaint_pipeline is not None:
|
| 437 |
+
del self._inpaint_pipeline
|
| 438 |
+
self._inpaint_pipeline = None
|
| 439 |
+
|
| 440 |
+
if self._controlnet_canny is not None:
|
| 441 |
+
del self._controlnet_canny
|
| 442 |
+
self._controlnet_canny = None
|
| 443 |
+
|
| 444 |
+
if self._controlnet_depth is not None:
|
| 445 |
+
del self._controlnet_depth
|
| 446 |
+
self._controlnet_depth = None
|
| 447 |
+
|
| 448 |
+
if self._depth_estimator is not None:
|
| 449 |
+
del self._depth_estimator
|
| 450 |
+
self._depth_estimator = None
|
| 451 |
+
|
| 452 |
+
if self._depth_processor is not None:
|
| 453 |
+
del self._depth_processor
|
| 454 |
+
self._depth_processor = None
|
| 455 |
+
|
| 456 |
+
self.is_initialized = False
|
| 457 |
+
self._current_conditioning_type = None
|
| 458 |
+
self._cached_latents = None
|
| 459 |
+
|
| 460 |
+
self._memory_cleanup(aggressive=True)
|
| 461 |
+
logger.info("Inpainting pipeline unloaded")
|
| 462 |
+
|
| 463 |
+
def prepare_control_image(
|
| 464 |
+
self,
|
| 465 |
+
image: Image.Image,
|
| 466 |
+
mode: str = "canny"
|
| 467 |
+
) -> Image.Image:
|
| 468 |
+
"""
|
| 469 |
+
Generate ControlNet conditioning image.
|
| 470 |
+
|
| 471 |
+
Parameters
|
| 472 |
+
----------
|
| 473 |
+
image : PIL.Image
|
| 474 |
+
Input image
|
| 475 |
+
mode : str
|
| 476 |
+
Conditioning mode: "canny" or "depth"
|
| 477 |
+
|
| 478 |
+
Returns
|
| 479 |
+
-------
|
| 480 |
+
PIL.Image
|
| 481 |
+
Generated control image (edges or depth map)
|
| 482 |
+
"""
|
| 483 |
+
logger.info(f"Preparing control image with mode: {mode}")
|
| 484 |
+
|
| 485 |
+
# Convert to RGB if needed
|
| 486 |
+
if image.mode != 'RGB':
|
| 487 |
+
image = image.convert('RGB')
|
| 488 |
+
|
| 489 |
+
img_array = np.array(image)
|
| 490 |
+
|
| 491 |
+
if mode == "canny":
|
| 492 |
+
return self._generate_canny_edges(img_array)
|
| 493 |
+
elif mode == "depth":
|
| 494 |
+
return self._generate_depth_map(image)
|
| 495 |
+
else:
|
| 496 |
+
raise ValueError(f"Unknown control mode: {mode}")
|
| 497 |
+
|
| 498 |
+
def _generate_canny_edges(self, img_array: np.ndarray) -> Image.Image:
|
| 499 |
+
"""
|
| 500 |
+
Generate Canny edge detection image.
|
| 501 |
+
|
| 502 |
+
Parameters
|
| 503 |
+
----------
|
| 504 |
+
img_array : np.ndarray
|
| 505 |
+
Input image as RGB numpy array
|
| 506 |
+
|
| 507 |
+
Returns
|
| 508 |
+
-------
|
| 509 |
+
PIL.Image
|
| 510 |
+
Edge detection result as grayscale image
|
| 511 |
+
"""
|
| 512 |
+
# Convert to grayscale
|
| 513 |
+
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 514 |
+
|
| 515 |
+
# Apply Gaussian blur to reduce noise
|
| 516 |
+
blurred = cv2.GaussianBlur(gray, (5, 5), 1.4)
|
| 517 |
+
|
| 518 |
+
# Canny edge detection
|
| 519 |
+
edges = cv2.Canny(
|
| 520 |
+
blurred,
|
| 521 |
+
self.config.canny_low_threshold,
|
| 522 |
+
self.config.canny_high_threshold
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
# Convert to 3-channel for ControlNet
|
| 526 |
+
edges_3ch = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
|
| 527 |
+
|
| 528 |
+
logger.debug(f"Generated Canny edges with thresholds "
|
| 529 |
+
f"{self.config.canny_low_threshold}/{self.config.canny_high_threshold}")
|
| 530 |
+
|
| 531 |
+
return Image.fromarray(edges_3ch)
|
| 532 |
+
|
| 533 |
+
def _generate_depth_map(self, image: Image.Image) -> Image.Image:
|
| 534 |
+
"""
|
| 535 |
+
Generate depth map using depth estimation model.
|
| 536 |
+
|
| 537 |
+
Parameters
|
| 538 |
+
----------
|
| 539 |
+
image : PIL.Image
|
| 540 |
+
Input RGB image
|
| 541 |
+
|
| 542 |
+
Returns
|
| 543 |
+
-------
|
| 544 |
+
PIL.Image
|
| 545 |
+
Depth map as grayscale image
|
| 546 |
+
"""
|
| 547 |
+
if self._depth_estimator is None or self._depth_processor is None:
|
| 548 |
+
raise RuntimeError("Depth estimator not loaded")
|
| 549 |
+
|
| 550 |
+
# Preprocess
|
| 551 |
+
inputs = self._depth_processor(images=image, return_tensors="pt")
|
| 552 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 553 |
+
|
| 554 |
+
# Inference
|
| 555 |
+
with torch.no_grad():
|
| 556 |
+
outputs = self._depth_estimator(**inputs)
|
| 557 |
+
predicted_depth = outputs.predicted_depth
|
| 558 |
+
|
| 559 |
+
# Interpolate to original size
|
| 560 |
+
prediction = torch.nn.functional.interpolate(
|
| 561 |
+
predicted_depth.unsqueeze(1),
|
| 562 |
+
size=image.size[::-1], # (H, W)
|
| 563 |
+
mode="bicubic",
|
| 564 |
+
align_corners=False
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
# Normalize to 0-255
|
| 568 |
+
depth_array = prediction.squeeze().cpu().numpy()
|
| 569 |
+
depth_min = depth_array.min()
|
| 570 |
+
depth_max = depth_array.max()
|
| 571 |
+
|
| 572 |
+
if depth_max - depth_min > 0:
|
| 573 |
+
depth_normalized = ((depth_array - depth_min) / (depth_max - depth_min) * 255)
|
| 574 |
+
else:
|
| 575 |
+
depth_normalized = np.zeros_like(depth_array)
|
| 576 |
+
|
| 577 |
+
depth_normalized = depth_normalized.astype(np.uint8)
|
| 578 |
+
|
| 579 |
+
# Convert to 3-channel for ControlNet
|
| 580 |
+
depth_3ch = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB)
|
| 581 |
+
|
| 582 |
+
logger.debug(f"Generated depth map, range: {depth_min:.2f} - {depth_max:.2f}")
|
| 583 |
+
|
| 584 |
+
return Image.fromarray(depth_3ch)
|
| 585 |
+
|
| 586 |
+
def prepare_mask(
|
| 587 |
+
self,
|
| 588 |
+
mask: Image.Image,
|
| 589 |
+
target_size: Tuple[int, int],
|
| 590 |
+
feather_radius: Optional[int] = None
|
| 591 |
+
) -> Tuple[Image.Image, Dict[str, Any]]:
|
| 592 |
+
"""
|
| 593 |
+
Prepare and validate mask for inpainting.
|
| 594 |
+
|
| 595 |
+
Parameters
|
| 596 |
+
----------
|
| 597 |
+
mask : PIL.Image
|
| 598 |
+
Input mask (white = inpaint area)
|
| 599 |
+
target_size : tuple
|
| 600 |
+
Target (width, height) to match input image
|
| 601 |
+
feather_radius : int, optional
|
| 602 |
+
Feathering radius in pixels. Uses config default if None.
|
| 603 |
+
|
| 604 |
+
Returns
|
| 605 |
+
-------
|
| 606 |
+
tuple
|
| 607 |
+
(processed_mask, validation_info)
|
| 608 |
+
|
| 609 |
+
Raises
|
| 610 |
+
------
|
| 611 |
+
ValueError
|
| 612 |
+
If mask coverage is outside acceptable range
|
| 613 |
+
"""
|
| 614 |
+
feather = feather_radius if feather_radius is not None else self.config.feather_radius
|
| 615 |
+
|
| 616 |
+
# Convert to grayscale
|
| 617 |
+
if mask.mode != 'L':
|
| 618 |
+
mask = mask.convert('L')
|
| 619 |
+
|
| 620 |
+
# Resize to match target
|
| 621 |
+
if mask.size != target_size:
|
| 622 |
+
mask = mask.resize(target_size, Image.LANCZOS)
|
| 623 |
+
|
| 624 |
+
# Convert to array for processing
|
| 625 |
+
mask_array = np.array(mask)
|
| 626 |
+
|
| 627 |
+
# Calculate coverage
|
| 628 |
+
total_pixels = mask_array.size
|
| 629 |
+
white_pixels = np.count_nonzero(mask_array > 127)
|
| 630 |
+
coverage = white_pixels / total_pixels
|
| 631 |
+
|
| 632 |
+
validation_info = {
|
| 633 |
+
"coverage": coverage,
|
| 634 |
+
"white_pixels": white_pixels,
|
| 635 |
+
"total_pixels": total_pixels,
|
| 636 |
+
"feather_radius": feather,
|
| 637 |
+
"valid": True,
|
| 638 |
+
"warning": ""
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
+
# Validate coverage
|
| 642 |
+
if coverage < self.config.min_mask_coverage:
|
| 643 |
+
validation_info["valid"] = False
|
| 644 |
+
validation_info["warning"] = (
|
| 645 |
+
f"Mask coverage too low ({coverage:.1%}). "
|
| 646 |
+
f"Please select a larger area to inpaint."
|
| 647 |
+
)
|
| 648 |
+
logger.warning(f"Mask coverage {coverage:.1%} below minimum {self.config.min_mask_coverage:.1%}")
|
| 649 |
+
|
| 650 |
+
elif coverage > self.config.max_mask_coverage:
|
| 651 |
+
validation_info["valid"] = False
|
| 652 |
+
validation_info["warning"] = (
|
| 653 |
+
f"Mask coverage too high ({coverage:.1%}). "
|
| 654 |
+
f"Consider using background generation instead."
|
| 655 |
+
)
|
| 656 |
+
logger.warning(f"Mask coverage {coverage:.1%} above maximum {self.config.max_mask_coverage:.1%}")
|
| 657 |
+
|
| 658 |
+
# Apply feathering
|
| 659 |
+
if feather > 0:
|
| 660 |
+
mask_array = cv2.GaussianBlur(
|
| 661 |
+
mask_array,
|
| 662 |
+
(feather * 2 + 1, feather * 2 + 1),
|
| 663 |
+
feather / 2
|
| 664 |
+
)
|
| 665 |
+
logger.debug(f"Applied {feather}px feathering to mask")
|
| 666 |
+
|
| 667 |
+
processed_mask = Image.fromarray(mask_array, mode='L')
|
| 668 |
+
|
| 669 |
+
return processed_mask, validation_info
|
| 670 |
+
|
| 671 |
+
def enhance_prompt_for_inpainting(
|
| 672 |
+
self,
|
| 673 |
+
prompt: str,
|
| 674 |
+
image: Image.Image,
|
| 675 |
+
mask: Image.Image
|
| 676 |
+
) -> Tuple[str, str]:
|
| 677 |
+
"""
|
| 678 |
+
Enhance prompt based on non-masked region analysis.
|
| 679 |
+
|
| 680 |
+
Analyzes the surrounding context to generate appropriate
|
| 681 |
+
lighting and color descriptors.
|
| 682 |
+
|
| 683 |
+
Parameters
|
| 684 |
+
----------
|
| 685 |
+
prompt : str
|
| 686 |
+
User-provided prompt
|
| 687 |
+
image : PIL.Image
|
| 688 |
+
Original image
|
| 689 |
+
mask : PIL.Image
|
| 690 |
+
Inpainting mask
|
| 691 |
+
|
| 692 |
+
Returns
|
| 693 |
+
-------
|
| 694 |
+
tuple
|
| 695 |
+
(enhanced_prompt, negative_prompt)
|
| 696 |
+
"""
|
| 697 |
+
logger.info("Enhancing prompt for inpainting context...")
|
| 698 |
+
|
| 699 |
+
# Convert to arrays
|
| 700 |
+
img_array = np.array(image.convert('RGB'))
|
| 701 |
+
mask_array = np.array(mask.convert('L'))
|
| 702 |
+
|
| 703 |
+
# Analyze non-masked regions
|
| 704 |
+
non_masked = mask_array < 127
|
| 705 |
+
|
| 706 |
+
if not np.any(non_masked):
|
| 707 |
+
# No context available
|
| 708 |
+
enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic"
|
| 709 |
+
negative_prompt = self._get_inpainting_negative_prompt()
|
| 710 |
+
return enhanced_prompt, negative_prompt
|
| 711 |
+
|
| 712 |
+
# Extract context pixels
|
| 713 |
+
context_pixels = img_array[non_masked]
|
| 714 |
+
|
| 715 |
+
# Convert to Lab for analysis
|
| 716 |
+
context_lab = cv2.cvtColor(
|
| 717 |
+
context_pixels.reshape(-1, 1, 3),
|
| 718 |
+
cv2.COLOR_RGB2LAB
|
| 719 |
+
).reshape(-1, 3)
|
| 720 |
+
|
| 721 |
+
# Use robust statistics (median) to avoid outlier influence
|
| 722 |
+
median_l = np.median(context_lab[:, 0])
|
| 723 |
+
median_a = np.median(context_lab[:, 1])
|
| 724 |
+
median_b = np.median(context_lab[:, 2])
|
| 725 |
+
|
| 726 |
+
# Analyze lighting conditions
|
| 727 |
+
lighting_descriptors = []
|
| 728 |
+
|
| 729 |
+
if median_l > 170:
|
| 730 |
+
lighting_descriptors.append("bright")
|
| 731 |
+
elif median_l > 130:
|
| 732 |
+
lighting_descriptors.append("well-lit")
|
| 733 |
+
elif median_l > 80:
|
| 734 |
+
lighting_descriptors.append("moderate lighting")
|
| 735 |
+
else:
|
| 736 |
+
lighting_descriptors.append("dim lighting")
|
| 737 |
+
|
| 738 |
+
# Analyze color temperature (b channel: blue(-) to yellow(+))
|
| 739 |
+
if median_b > 140:
|
| 740 |
+
lighting_descriptors.append("warm golden tones")
|
| 741 |
+
elif median_b > 120:
|
| 742 |
+
lighting_descriptors.append("warm afternoon light")
|
| 743 |
+
elif median_b < 110:
|
| 744 |
+
lighting_descriptors.append("cool neutral tones")
|
| 745 |
+
|
| 746 |
+
# Calculate saturation from context
|
| 747 |
+
hsv = cv2.cvtColor(context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV)
|
| 748 |
+
median_saturation = np.median(hsv[:, :, 1])
|
| 749 |
+
|
| 750 |
+
if median_saturation > 150:
|
| 751 |
+
lighting_descriptors.append("vibrant colors")
|
| 752 |
+
elif median_saturation < 80:
|
| 753 |
+
lighting_descriptors.append("subtle muted colors")
|
| 754 |
+
|
| 755 |
+
# Build enhanced prompt
|
| 756 |
+
lighting_desc = ", ".join(lighting_descriptors) if lighting_descriptors else ""
|
| 757 |
+
quality_suffix = "high quality, detailed, photorealistic, seamless integration"
|
| 758 |
+
|
| 759 |
+
if lighting_desc:
|
| 760 |
+
enhanced_prompt = f"{prompt}, {lighting_desc}, {quality_suffix}"
|
| 761 |
+
else:
|
| 762 |
+
enhanced_prompt = f"{prompt}, {quality_suffix}"
|
| 763 |
+
|
| 764 |
+
negative_prompt = self._get_inpainting_negative_prompt()
|
| 765 |
+
|
| 766 |
+
logger.info(f"Enhanced prompt with context: {lighting_desc}")
|
| 767 |
+
|
| 768 |
+
return enhanced_prompt, negative_prompt
|
| 769 |
+
|
| 770 |
+
def _get_inpainting_negative_prompt(self) -> str:
|
| 771 |
+
"""Get standard negative prompt for inpainting."""
|
| 772 |
+
return (
|
| 773 |
+
"inconsistent lighting, wrong perspective, mismatched colors, "
|
| 774 |
+
"visible seams, blending artifacts, color bleeding, "
|
| 775 |
+
"blurry, low quality, distorted, deformed, "
|
| 776 |
+
"harsh edges, unnatural transition"
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
def execute_inpainting(
|
| 780 |
+
self,
|
| 781 |
+
image: Image.Image,
|
| 782 |
+
mask: Image.Image,
|
| 783 |
+
prompt: str,
|
| 784 |
+
preview_only: bool = False,
|
| 785 |
+
seed: Optional[int] = None,
|
| 786 |
+
progress_callback: Optional[Callable[[str, int], None]] = None,
|
| 787 |
+
**kwargs
|
| 788 |
+
) -> InpaintingResult:
|
| 789 |
+
"""
|
| 790 |
+
Execute the inpainting operation.
|
| 791 |
+
|
| 792 |
+
Implements two-stage generation: fast preview followed by
|
| 793 |
+
full quality generation if requested.
|
| 794 |
+
|
| 795 |
+
Parameters
|
| 796 |
+
----------
|
| 797 |
+
image : PIL.Image
|
| 798 |
+
Original image to inpaint
|
| 799 |
+
mask : PIL.Image
|
| 800 |
+
Inpainting mask (white = area to regenerate)
|
| 801 |
+
prompt : str
|
| 802 |
+
Text description of desired content
|
| 803 |
+
preview_only : bool
|
| 804 |
+
If True, only generate preview (faster)
|
| 805 |
+
seed : int, optional
|
| 806 |
+
Random seed for reproducibility
|
| 807 |
+
progress_callback : callable, optional
|
| 808 |
+
Progress update function(message, percentage)
|
| 809 |
+
**kwargs
|
| 810 |
+
Additional parameters:
|
| 811 |
+
- controlnet_conditioning_scale: float
|
| 812 |
+
- feather_radius: int
|
| 813 |
+
- num_inference_steps: int
|
| 814 |
+
- guidance_scale: float
|
| 815 |
+
|
| 816 |
+
Returns
|
| 817 |
+
-------
|
| 818 |
+
InpaintingResult
|
| 819 |
+
Result container with generated images and metadata
|
| 820 |
+
"""
|
| 821 |
+
start_time = time.time()
|
| 822 |
+
|
| 823 |
+
if not self.is_initialized:
|
| 824 |
+
return InpaintingResult(
|
| 825 |
+
success=False,
|
| 826 |
+
error_message="Inpainting pipeline not initialized. Call load_inpainting_pipeline() first."
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
logger.info(f"Starting inpainting: prompt='{prompt[:50]}...', preview_only={preview_only}")
|
| 830 |
+
|
| 831 |
+
try:
|
| 832 |
+
# Update config with kwargs
|
| 833 |
+
conditioning_scale = kwargs.get(
|
| 834 |
+
'controlnet_conditioning_scale',
|
| 835 |
+
self.config.controlnet_conditioning_scale
|
| 836 |
+
)
|
| 837 |
+
feather_radius = kwargs.get('feather_radius', self.config.feather_radius)
|
| 838 |
+
|
| 839 |
+
if progress_callback:
|
| 840 |
+
progress_callback("Preparing images...", 5)
|
| 841 |
+
|
| 842 |
+
# Prepare image
|
| 843 |
+
if image.mode != 'RGB':
|
| 844 |
+
image = image.convert('RGB')
|
| 845 |
+
|
| 846 |
+
# Ensure dimensions are multiple of 8
|
| 847 |
+
width, height = image.size
|
| 848 |
+
new_width = (width // 8) * 8
|
| 849 |
+
new_height = (height // 8) * 8
|
| 850 |
+
|
| 851 |
+
if new_width != width or new_height != height:
|
| 852 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 853 |
+
|
| 854 |
+
# Check and potentially reduce resolution for memory
|
| 855 |
+
max_res = self.config.max_resolution
|
| 856 |
+
if max(new_width, new_height) > max_res:
|
| 857 |
+
scale = max_res / max(new_width, new_height)
|
| 858 |
+
new_width = int(new_width * scale) // 8 * 8
|
| 859 |
+
new_height = int(new_height * scale) // 8 * 8
|
| 860 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 861 |
+
logger.info(f"Reduced resolution to {new_width}x{new_height} for memory")
|
| 862 |
+
|
| 863 |
+
# Prepare mask
|
| 864 |
+
if progress_callback:
|
| 865 |
+
progress_callback("Processing mask...", 10)
|
| 866 |
+
|
| 867 |
+
processed_mask, mask_info = self.prepare_mask(
|
| 868 |
+
mask,
|
| 869 |
+
(new_width, new_height),
|
| 870 |
+
feather_radius
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
if not mask_info["valid"]:
|
| 874 |
+
return InpaintingResult(
|
| 875 |
+
success=False,
|
| 876 |
+
error_message=mask_info["warning"]
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
# Generate control image
|
| 880 |
+
if progress_callback:
|
| 881 |
+
progress_callback("Generating control image...", 20)
|
| 882 |
+
|
| 883 |
+
control_image = self.prepare_control_image(
|
| 884 |
+
image,
|
| 885 |
+
self._current_conditioning_type
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
# Enhance prompt
|
| 889 |
+
if progress_callback:
|
| 890 |
+
progress_callback("Enhancing prompt...", 25)
|
| 891 |
+
|
| 892 |
+
enhanced_prompt, negative_prompt = self.enhance_prompt_for_inpainting(
|
| 893 |
+
prompt, image, processed_mask
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
# Setup generator for reproducibility
|
| 897 |
+
if seed is None:
|
| 898 |
+
seed = int(time.time() * 1000) % (2**32)
|
| 899 |
+
self._last_seed = seed
|
| 900 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 901 |
+
|
| 902 |
+
# Stage 1: Preview generation
|
| 903 |
+
if progress_callback:
|
| 904 |
+
progress_callback("Generating preview...", 30)
|
| 905 |
+
|
| 906 |
+
preview_result = self._generate_inpaint(
|
| 907 |
+
image=image,
|
| 908 |
+
mask=processed_mask,
|
| 909 |
+
control_image=control_image,
|
| 910 |
+
prompt=enhanced_prompt,
|
| 911 |
+
negative_prompt=negative_prompt,
|
| 912 |
+
num_inference_steps=self.config.preview_steps,
|
| 913 |
+
guidance_scale=self.config.preview_guidance_scale,
|
| 914 |
+
controlnet_conditioning_scale=conditioning_scale,
|
| 915 |
+
generator=generator
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
if preview_only:
|
| 919 |
+
generation_time = time.time() - start_time
|
| 920 |
+
|
| 921 |
+
return InpaintingResult(
|
| 922 |
+
success=True,
|
| 923 |
+
preview_image=preview_result,
|
| 924 |
+
control_image=control_image,
|
| 925 |
+
generation_time=generation_time,
|
| 926 |
+
metadata={
|
| 927 |
+
"seed": seed,
|
| 928 |
+
"prompt": enhanced_prompt,
|
| 929 |
+
"conditioning_type": self._current_conditioning_type,
|
| 930 |
+
"conditioning_scale": conditioning_scale,
|
| 931 |
+
"preview_only": True
|
| 932 |
+
}
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
# Stage 2: Full quality generation
|
| 936 |
+
if progress_callback:
|
| 937 |
+
progress_callback("Generating full quality...", 60)
|
| 938 |
+
|
| 939 |
+
# Use same seed for reproducibility
|
| 940 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 941 |
+
|
| 942 |
+
num_steps = kwargs.get('num_inference_steps', self.config.num_inference_steps)
|
| 943 |
+
guidance = kwargs.get('guidance_scale', self.config.guidance_scale)
|
| 944 |
+
|
| 945 |
+
full_result = self._generate_inpaint(
|
| 946 |
+
image=image,
|
| 947 |
+
mask=processed_mask,
|
| 948 |
+
control_image=control_image,
|
| 949 |
+
prompt=enhanced_prompt,
|
| 950 |
+
negative_prompt=negative_prompt,
|
| 951 |
+
num_inference_steps=num_steps,
|
| 952 |
+
guidance_scale=guidance,
|
| 953 |
+
controlnet_conditioning_scale=conditioning_scale,
|
| 954 |
+
generator=generator
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
if progress_callback:
|
| 958 |
+
progress_callback("Blending result...", 90)
|
| 959 |
+
|
| 960 |
+
# Blend result
|
| 961 |
+
blended = self.blend_result(image, full_result, processed_mask)
|
| 962 |
+
|
| 963 |
+
generation_time = time.time() - start_time
|
| 964 |
+
|
| 965 |
+
if progress_callback:
|
| 966 |
+
progress_callback("Complete!", 100)
|
| 967 |
+
|
| 968 |
+
return InpaintingResult(
|
| 969 |
+
success=True,
|
| 970 |
+
result_image=full_result,
|
| 971 |
+
preview_image=preview_result,
|
| 972 |
+
control_image=control_image,
|
| 973 |
+
blended_image=blended,
|
| 974 |
+
generation_time=generation_time,
|
| 975 |
+
metadata={
|
| 976 |
+
"seed": seed,
|
| 977 |
+
"prompt": enhanced_prompt,
|
| 978 |
+
"negative_prompt": negative_prompt,
|
| 979 |
+
"conditioning_type": self._current_conditioning_type,
|
| 980 |
+
"conditioning_scale": conditioning_scale,
|
| 981 |
+
"num_inference_steps": num_steps,
|
| 982 |
+
"guidance_scale": guidance,
|
| 983 |
+
"feather_radius": feather_radius,
|
| 984 |
+
"mask_coverage": mask_info["coverage"],
|
| 985 |
+
"preview_only": False
|
| 986 |
+
}
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
except torch.cuda.OutOfMemoryError:
|
| 990 |
+
logger.error("CUDA out of memory during inpainting")
|
| 991 |
+
self._memory_cleanup(aggressive=True)
|
| 992 |
+
return InpaintingResult(
|
| 993 |
+
success=False,
|
| 994 |
+
error_message="GPU memory exhausted. Try reducing image size or closing other applications."
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
except Exception as e:
|
| 998 |
+
logger.error(f"Inpainting failed: {e}")
|
| 999 |
+
logger.error(traceback.format_exc())
|
| 1000 |
+
return InpaintingResult(
|
| 1001 |
+
success=False,
|
| 1002 |
+
error_message=f"Inpainting failed: {str(e)}"
|
| 1003 |
+
)
|
| 1004 |
+
|
| 1005 |
+
def _generate_inpaint(
|
| 1006 |
+
self,
|
| 1007 |
+
image: Image.Image,
|
| 1008 |
+
mask: Image.Image,
|
| 1009 |
+
control_image: Image.Image,
|
| 1010 |
+
prompt: str,
|
| 1011 |
+
negative_prompt: str,
|
| 1012 |
+
num_inference_steps: int,
|
| 1013 |
+
guidance_scale: float,
|
| 1014 |
+
controlnet_conditioning_scale: float,
|
| 1015 |
+
generator: torch.Generator
|
| 1016 |
+
) -> Image.Image:
|
| 1017 |
+
"""
|
| 1018 |
+
Internal method to run the inpainting pipeline.
|
| 1019 |
+
|
| 1020 |
+
Supports both ControlNet and non-ControlNet pipelines.
|
| 1021 |
+
|
| 1022 |
+
Parameters
|
| 1023 |
+
----------
|
| 1024 |
+
image : PIL.Image
|
| 1025 |
+
Original image
|
| 1026 |
+
mask : PIL.Image
|
| 1027 |
+
Processed mask
|
| 1028 |
+
control_image : PIL.Image
|
| 1029 |
+
ControlNet conditioning image (ignored if ControlNet not available)
|
| 1030 |
+
prompt : str
|
| 1031 |
+
Enhanced prompt
|
| 1032 |
+
negative_prompt : str
|
| 1033 |
+
Negative prompt
|
| 1034 |
+
num_inference_steps : int
|
| 1035 |
+
Number of denoising steps
|
| 1036 |
+
guidance_scale : float
|
| 1037 |
+
Classifier-free guidance scale
|
| 1038 |
+
controlnet_conditioning_scale : float
|
| 1039 |
+
ControlNet influence strength (ignored if ControlNet not available)
|
| 1040 |
+
generator : torch.Generator
|
| 1041 |
+
Random generator for reproducibility
|
| 1042 |
+
|
| 1043 |
+
Returns
|
| 1044 |
+
-------
|
| 1045 |
+
PIL.Image
|
| 1046 |
+
Generated image
|
| 1047 |
+
"""
|
| 1048 |
+
with torch.inference_mode():
|
| 1049 |
+
if self._use_controlnet:
|
| 1050 |
+
# Full ControlNet inpainting pipeline
|
| 1051 |
+
result = self._inpaint_pipeline(
|
| 1052 |
+
prompt=prompt,
|
| 1053 |
+
negative_prompt=negative_prompt,
|
| 1054 |
+
image=image,
|
| 1055 |
+
mask_image=mask,
|
| 1056 |
+
control_image=control_image,
|
| 1057 |
+
num_inference_steps=num_inference_steps,
|
| 1058 |
+
guidance_scale=guidance_scale,
|
| 1059 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
| 1060 |
+
generator=generator
|
| 1061 |
+
)
|
| 1062 |
+
else:
|
| 1063 |
+
# Fallback: Standard SDXL inpainting without ControlNet
|
| 1064 |
+
result = self._inpaint_pipeline(
|
| 1065 |
+
prompt=prompt,
|
| 1066 |
+
negative_prompt=negative_prompt,
|
| 1067 |
+
image=image,
|
| 1068 |
+
mask_image=mask,
|
| 1069 |
+
num_inference_steps=num_inference_steps,
|
| 1070 |
+
guidance_scale=guidance_scale,
|
| 1071 |
+
generator=generator
|
| 1072 |
+
)
|
| 1073 |
+
|
| 1074 |
+
return result.images[0]
|
| 1075 |
+
|
| 1076 |
+
def blend_result(
|
| 1077 |
+
self,
|
| 1078 |
+
original: Image.Image,
|
| 1079 |
+
generated: Image.Image,
|
| 1080 |
+
mask: Image.Image
|
| 1081 |
+
) -> Image.Image:
|
| 1082 |
+
"""
|
| 1083 |
+
Blend generated content with original image.
|
| 1084 |
+
|
| 1085 |
+
Uses linear color space blending for accurate results.
|
| 1086 |
+
|
| 1087 |
+
Parameters
|
| 1088 |
+
----------
|
| 1089 |
+
original : PIL.Image
|
| 1090 |
+
Original image
|
| 1091 |
+
generated : PIL.Image
|
| 1092 |
+
Generated inpainted image
|
| 1093 |
+
mask : PIL.Image
|
| 1094 |
+
Blending mask (white = use generated)
|
| 1095 |
+
|
| 1096 |
+
Returns
|
| 1097 |
+
-------
|
| 1098 |
+
PIL.Image
|
| 1099 |
+
Blended result
|
| 1100 |
+
"""
|
| 1101 |
+
logger.info("Blending inpainting result...")
|
| 1102 |
+
|
| 1103 |
+
# Ensure same size
|
| 1104 |
+
if generated.size != original.size:
|
| 1105 |
+
generated = generated.resize(original.size, Image.LANCZOS)
|
| 1106 |
+
if mask.size != original.size:
|
| 1107 |
+
mask = mask.resize(original.size, Image.LANCZOS)
|
| 1108 |
+
|
| 1109 |
+
# Convert to arrays
|
| 1110 |
+
orig_array = np.array(original.convert('RGB')).astype(np.float32)
|
| 1111 |
+
gen_array = np.array(generated.convert('RGB')).astype(np.float32)
|
| 1112 |
+
mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
|
| 1113 |
+
|
| 1114 |
+
# sRGB to linear conversion
|
| 1115 |
+
def srgb_to_linear(img):
|
| 1116 |
+
img_norm = img / 255.0
|
| 1117 |
+
return np.where(
|
| 1118 |
+
img_norm <= 0.04045,
|
| 1119 |
+
img_norm / 12.92,
|
| 1120 |
+
np.power((img_norm + 0.055) / 1.055, 2.4)
|
| 1121 |
+
)
|
| 1122 |
+
|
| 1123 |
+
def linear_to_srgb(img):
|
| 1124 |
+
img_clipped = np.clip(img, 0, 1)
|
| 1125 |
+
return np.where(
|
| 1126 |
+
img_clipped <= 0.0031308,
|
| 1127 |
+
12.92 * img_clipped,
|
| 1128 |
+
1.055 * np.power(img_clipped, 1/2.4) - 0.055
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
# Convert to linear space
|
| 1132 |
+
orig_linear = srgb_to_linear(orig_array)
|
| 1133 |
+
gen_linear = srgb_to_linear(gen_array)
|
| 1134 |
+
|
| 1135 |
+
# Alpha blending in linear space
|
| 1136 |
+
alpha = mask_array[:, :, np.newaxis]
|
| 1137 |
+
result_linear = gen_linear * alpha + orig_linear * (1 - alpha)
|
| 1138 |
+
|
| 1139 |
+
# Convert back to sRGB
|
| 1140 |
+
result_srgb = linear_to_srgb(result_linear)
|
| 1141 |
+
result_array = (result_srgb * 255).astype(np.uint8)
|
| 1142 |
+
|
| 1143 |
+
logger.debug("Blending completed in linear color space")
|
| 1144 |
+
|
| 1145 |
+
return Image.fromarray(result_array)
|
| 1146 |
+
|
| 1147 |
+
def execute_with_auto_optimization(
|
| 1148 |
+
self,
|
| 1149 |
+
image: Image.Image,
|
| 1150 |
+
mask: Image.Image,
|
| 1151 |
+
prompt: str,
|
| 1152 |
+
quality_checker: Any,
|
| 1153 |
+
progress_callback: Optional[Callable[[str, int], None]] = None,
|
| 1154 |
+
**kwargs
|
| 1155 |
+
) -> InpaintingResult:
|
| 1156 |
+
"""
|
| 1157 |
+
Execute inpainting with automatic quality-based optimization.
|
| 1158 |
+
|
| 1159 |
+
Retries with adjusted parameters if quality score is below threshold.
|
| 1160 |
+
|
| 1161 |
+
Parameters
|
| 1162 |
+
----------
|
| 1163 |
+
image : PIL.Image
|
| 1164 |
+
Original image
|
| 1165 |
+
mask : PIL.Image
|
| 1166 |
+
Inpainting mask
|
| 1167 |
+
prompt : str
|
| 1168 |
+
Text prompt
|
| 1169 |
+
quality_checker : QualityChecker
|
| 1170 |
+
Quality assessment instance
|
| 1171 |
+
progress_callback : callable, optional
|
| 1172 |
+
Progress update function
|
| 1173 |
+
**kwargs
|
| 1174 |
+
Additional inpainting parameters
|
| 1175 |
+
|
| 1176 |
+
Returns
|
| 1177 |
+
-------
|
| 1178 |
+
InpaintingResult
|
| 1179 |
+
Best result achieved (may include retry information)
|
| 1180 |
+
"""
|
| 1181 |
+
if not self.config.enable_auto_optimization:
|
| 1182 |
+
return self.execute_inpainting(
|
| 1183 |
+
image, mask, prompt,
|
| 1184 |
+
progress_callback=progress_callback,
|
| 1185 |
+
**kwargs
|
| 1186 |
+
)
|
| 1187 |
+
|
| 1188 |
+
best_result = None
|
| 1189 |
+
best_score = 0.0
|
| 1190 |
+
retry_count = 0
|
| 1191 |
+
prev_score = 0.0
|
| 1192 |
+
|
| 1193 |
+
# Mutable parameters for optimization
|
| 1194 |
+
current_feather = kwargs.get('feather_radius', self.config.feather_radius)
|
| 1195 |
+
current_scale = kwargs.get(
|
| 1196 |
+
'controlnet_conditioning_scale',
|
| 1197 |
+
self.config.controlnet_conditioning_scale
|
| 1198 |
+
)
|
| 1199 |
+
current_guidance = kwargs.get('guidance_scale', self.config.guidance_scale)
|
| 1200 |
+
current_prompt = prompt
|
| 1201 |
+
|
| 1202 |
+
while retry_count <= self.config.max_optimization_retries:
|
| 1203 |
+
if progress_callback and retry_count > 0:
|
| 1204 |
+
progress_callback(f"Optimizing (attempt {retry_count + 1})...", 5)
|
| 1205 |
+
|
| 1206 |
+
# Execute inpainting
|
| 1207 |
+
result = self.execute_inpainting(
|
| 1208 |
+
image, mask, current_prompt,
|
| 1209 |
+
preview_only=False,
|
| 1210 |
+
feather_radius=current_feather,
|
| 1211 |
+
controlnet_conditioning_scale=current_scale,
|
| 1212 |
+
guidance_scale=current_guidance,
|
| 1213 |
+
progress_callback=progress_callback if retry_count == 0 else None,
|
| 1214 |
+
**{k: v for k, v in kwargs.items()
|
| 1215 |
+
if k not in ['feather_radius', 'controlnet_conditioning_scale',
|
| 1216 |
+
'guidance_scale']}
|
| 1217 |
+
)
|
| 1218 |
+
|
| 1219 |
+
if not result.success:
|
| 1220 |
+
return result
|
| 1221 |
+
|
| 1222 |
+
# Evaluate quality
|
| 1223 |
+
if result.blended_image is not None:
|
| 1224 |
+
quality_results = quality_checker.run_all_checks(
|
| 1225 |
+
foreground=image,
|
| 1226 |
+
background=result.result_image,
|
| 1227 |
+
mask=mask,
|
| 1228 |
+
combined=result.blended_image
|
| 1229 |
+
)
|
| 1230 |
+
quality_score = quality_results.get("overall_score", 0)
|
| 1231 |
+
else:
|
| 1232 |
+
quality_score = 50.0 # Default if no blended image
|
| 1233 |
+
|
| 1234 |
+
result.quality_score = quality_score
|
| 1235 |
+
result.quality_details = quality_results if result.blended_image else {}
|
| 1236 |
+
result.retries = retry_count
|
| 1237 |
+
|
| 1238 |
+
logger.info(f"Quality score: {quality_score:.1f} (attempt {retry_count + 1})")
|
| 1239 |
+
|
| 1240 |
+
# Track best result
|
| 1241 |
+
if quality_score > best_score:
|
| 1242 |
+
best_score = quality_score
|
| 1243 |
+
best_result = result
|
| 1244 |
+
|
| 1245 |
+
# Check if quality is acceptable
|
| 1246 |
+
if quality_score >= self.config.min_quality_score:
|
| 1247 |
+
logger.info(f"Quality threshold met: {quality_score:.1f}")
|
| 1248 |
+
return best_result
|
| 1249 |
+
|
| 1250 |
+
# Check for minimal improvement (early termination)
|
| 1251 |
+
if retry_count > 0 and abs(quality_score - prev_score) < 5.0:
|
| 1252 |
+
logger.info("Minimal improvement, stopping optimization")
|
| 1253 |
+
return best_result
|
| 1254 |
+
|
| 1255 |
+
prev_score = quality_score
|
| 1256 |
+
retry_count += 1
|
| 1257 |
+
|
| 1258 |
+
if retry_count > self.config.max_optimization_retries:
|
| 1259 |
+
break
|
| 1260 |
+
|
| 1261 |
+
# Adjust parameters based on quality issues
|
| 1262 |
+
checks = quality_results.get("checks", {})
|
| 1263 |
+
|
| 1264 |
+
edge_score = checks.get("edge_continuity", {}).get("score", 100)
|
| 1265 |
+
harmony_score = checks.get("color_harmony", {}).get("score", 100)
|
| 1266 |
+
|
| 1267 |
+
if edge_score < 60:
|
| 1268 |
+
# Edge issues: increase feathering, decrease control strength
|
| 1269 |
+
current_feather = min(20, current_feather + 3)
|
| 1270 |
+
current_scale = max(0.5, current_scale - 0.1)
|
| 1271 |
+
logger.debug(f"Adjusting for edges: feather={current_feather}, scale={current_scale}")
|
| 1272 |
+
|
| 1273 |
+
if harmony_score < 60:
|
| 1274 |
+
# Color harmony issues: emphasize consistency in prompt
|
| 1275 |
+
if "color consistent" not in current_prompt.lower():
|
| 1276 |
+
current_prompt = f"{current_prompt}, color consistent with surroundings, matching lighting"
|
| 1277 |
+
current_guidance = min(12.0, current_guidance + 1.0)
|
| 1278 |
+
logger.debug(f"Adjusting for harmony: guidance={current_guidance}")
|
| 1279 |
+
|
| 1280 |
+
if edge_score < 60 and harmony_score < 60:
|
| 1281 |
+
# Both issues: stronger guidance
|
| 1282 |
+
current_guidance = min(12.0, current_guidance + 1.5)
|
| 1283 |
+
|
| 1284 |
+
logger.info(f"Optimization complete. Best score: {best_score:.1f}")
|
| 1285 |
+
return best_result
|
| 1286 |
+
|
| 1287 |
+
def get_status(self) -> Dict[str, Any]:
|
| 1288 |
+
"""
|
| 1289 |
+
Get current module status.
|
| 1290 |
+
|
| 1291 |
+
Returns
|
| 1292 |
+
-------
|
| 1293 |
+
dict
|
| 1294 |
+
Status information including initialization state and memory usage
|
| 1295 |
+
"""
|
| 1296 |
+
status = {
|
| 1297 |
+
"initialized": self.is_initialized,
|
| 1298 |
+
"device": self.device,
|
| 1299 |
+
"conditioning_type": self._current_conditioning_type,
|
| 1300 |
+
"last_seed": self._last_seed,
|
| 1301 |
+
"config": {
|
| 1302 |
+
"controlnet_conditioning_scale": self.config.controlnet_conditioning_scale,
|
| 1303 |
+
"feather_radius": self.config.feather_radius,
|
| 1304 |
+
"num_inference_steps": self.config.num_inference_steps,
|
| 1305 |
+
"guidance_scale": self.config.guidance_scale
|
| 1306 |
+
}
|
| 1307 |
+
}
|
| 1308 |
+
|
| 1309 |
+
status["memory"] = self._check_memory_status()
|
| 1310 |
+
|
| 1311 |
+
return status
|
inpainting_templates.py
ADDED
|
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class InpaintingTemplate:
|
| 10 |
+
"""Data class representing an inpainting template."""
|
| 11 |
+
|
| 12 |
+
key: str
|
| 13 |
+
name: str
|
| 14 |
+
category: str
|
| 15 |
+
icon: str
|
| 16 |
+
description: str
|
| 17 |
+
|
| 18 |
+
# Prompt templates
|
| 19 |
+
prompt_template: str
|
| 20 |
+
negative_prompt: str
|
| 21 |
+
|
| 22 |
+
# Recommended parameters
|
| 23 |
+
controlnet_conditioning_scale: float = 0.7
|
| 24 |
+
feather_radius: int = 8
|
| 25 |
+
guidance_scale: float = 7.5
|
| 26 |
+
num_inference_steps: int = 25
|
| 27 |
+
|
| 28 |
+
# Conditioning type preference
|
| 29 |
+
preferred_conditioning: str = "canny" # "canny" or "depth"
|
| 30 |
+
|
| 31 |
+
# Tips for users
|
| 32 |
+
usage_tips: List[str] = field(default_factory=list)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class InpaintingTemplateManager:
|
| 36 |
+
"""
|
| 37 |
+
Manages inpainting templates for various use cases.
|
| 38 |
+
|
| 39 |
+
Provides categorized presets optimized for different inpainting scenarios
|
| 40 |
+
including object replacement, removal, style transfer, and enhancement.
|
| 41 |
+
|
| 42 |
+
Attributes:
|
| 43 |
+
TEMPLATES: Dictionary of all available templates
|
| 44 |
+
CATEGORIES: List of category names in display order
|
| 45 |
+
|
| 46 |
+
Example:
|
| 47 |
+
>>> manager = InpaintingTemplateManager()
|
| 48 |
+
>>> template = manager.get_template("object_replacement")
|
| 49 |
+
>>> print(template.prompt_template)
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
TEMPLATES: Dict[str, InpaintingTemplate] = {
|
| 53 |
+
# Object Replacement Category
|
| 54 |
+
"object_replacement": InpaintingTemplate(
|
| 55 |
+
key="object_replacement",
|
| 56 |
+
name="Object Replacement",
|
| 57 |
+
category="Replacement",
|
| 58 |
+
icon="🔄",
|
| 59 |
+
description="Replace selected objects with new content while preserving context",
|
| 60 |
+
prompt_template="{content}, seamlessly integrated into scene, matching lighting and perspective, realistic placement",
|
| 61 |
+
negative_prompt=(
|
| 62 |
+
"inconsistent lighting, wrong perspective, mismatched colors, "
|
| 63 |
+
"visible seams, floating objects, unrealistic placement, original object, "
|
| 64 |
+
"poorly integrated, disconnected from scene, keeping original"
|
| 65 |
+
),
|
| 66 |
+
controlnet_conditioning_scale=0.48,
|
| 67 |
+
feather_radius=14,
|
| 68 |
+
guidance_scale=10.0,
|
| 69 |
+
num_inference_steps=38,
|
| 70 |
+
preferred_conditioning="canny",
|
| 71 |
+
usage_tips=[
|
| 72 |
+
"Draw mask PRECISELY around the object to replace (include small margin)",
|
| 73 |
+
"Be specific: Instead of 'plant', use 'large green potted fern with detailed leaves'",
|
| 74 |
+
"📝 Example prompts:",
|
| 75 |
+
" • 'elegant white ceramic teacup with delicate gold rim and floral pattern'",
|
| 76 |
+
" • 'modern silver laptop computer with sleek metallic finish'",
|
| 77 |
+
" • 'vintage wooden desk lamp with warm brass details'"
|
| 78 |
+
]
|
| 79 |
+
),
|
| 80 |
+
|
| 81 |
+
"face_swap": InpaintingTemplate(
|
| 82 |
+
key="face_swap",
|
| 83 |
+
name="Face Enhancement",
|
| 84 |
+
category="Replacement",
|
| 85 |
+
icon="👤",
|
| 86 |
+
description="Enhance or modify facial features while maintaining identity cues",
|
| 87 |
+
prompt_template="{content}, natural skin texture, proper facial proportions, realistic lighting, detailed facial features",
|
| 88 |
+
negative_prompt=(
|
| 89 |
+
"deformed face, asymmetric features, unnatural skin, "
|
| 90 |
+
"plastic appearance, wrong eye direction, blurry features, "
|
| 91 |
+
"artificial smoothing, uncanny valley, distorted proportions"
|
| 92 |
+
),
|
| 93 |
+
controlnet_conditioning_scale=0.88,
|
| 94 |
+
feather_radius=6,
|
| 95 |
+
guidance_scale=8.5,
|
| 96 |
+
num_inference_steps=35,
|
| 97 |
+
preferred_conditioning="canny",
|
| 98 |
+
usage_tips=[
|
| 99 |
+
"Draw mask CAREFULLY around face outline (avoid hair and neck)",
|
| 100 |
+
"High conditioning preserves facial structure - good for subtle enhancements",
|
| 101 |
+
"📝 Example prompts:",
|
| 102 |
+
" • 'warm friendly smile with bright eyes and natural expression'",
|
| 103 |
+
" • 'professional headshot with confident neutral expression'",
|
| 104 |
+
" • 'gentle smile with soft natural lighting on face'"
|
| 105 |
+
]
|
| 106 |
+
),
|
| 107 |
+
|
| 108 |
+
"clothing_change": InpaintingTemplate(
|
| 109 |
+
key="clothing_change",
|
| 110 |
+
name="Clothing Change",
|
| 111 |
+
category="Replacement",
|
| 112 |
+
icon="👕",
|
| 113 |
+
description="Change clothing color, pattern, style, and material (moderate changes)",
|
| 114 |
+
prompt_template="transform to {content}, vivid fabric texture, natural folds and wrinkles, correct fit, proper color saturation",
|
| 115 |
+
negative_prompt=(
|
| 116 |
+
"wrong body proportions, floating fabric, unrealistic wrinkles, "
|
| 117 |
+
"mismatched lighting, visible edges, original clothing style, "
|
| 118 |
+
"keeping same color, faded colors, unchanged appearance, partial change"
|
| 119 |
+
),
|
| 120 |
+
controlnet_conditioning_scale=0.42,
|
| 121 |
+
feather_radius=14,
|
| 122 |
+
guidance_scale=10.5,
|
| 123 |
+
num_inference_steps=38,
|
| 124 |
+
preferred_conditioning="depth",
|
| 125 |
+
usage_tips=[
|
| 126 |
+
"Best for: Similar color changes (black→navy, white→cream) and pattern changes",
|
| 127 |
+
"For extreme changes (black↔white), use 'Dramatic Color Change' template instead",
|
| 128 |
+
"📝 Example prompts:",
|
| 129 |
+
" • 'deep navy blue formal blazer with fine texture and professional fit'",
|
| 130 |
+
" • 'bright red polo shirt with clean collar and soft fabric'",
|
| 131 |
+
" • 'charcoal gray sweater with ribbed knit texture'"
|
| 132 |
+
]
|
| 133 |
+
),
|
| 134 |
+
|
| 135 |
+
"dramatic_color_change": InpaintingTemplate(
|
| 136 |
+
key="dramatic_color_change",
|
| 137 |
+
name="Dramatic Color Change",
|
| 138 |
+
category="Replacement",
|
| 139 |
+
icon="🎨",
|
| 140 |
+
description="Extreme color transformations (dark↔light, black↔white)",
|
| 141 |
+
prompt_template="completely transform to {content}, vivid color saturation, high contrast, sharp color definition, clear distinct color",
|
| 142 |
+
negative_prompt=(
|
| 143 |
+
"original color, keeping same shade, partial change, color bleeding, "
|
| 144 |
+
"faded colors, mixed tones, subtle change, gradual transition, "
|
| 145 |
+
"original appearance, unchanged, dark remnants, light patches"
|
| 146 |
+
),
|
| 147 |
+
controlnet_conditioning_scale=0.38,
|
| 148 |
+
feather_radius=16,
|
| 149 |
+
guidance_scale=12.0,
|
| 150 |
+
num_inference_steps=42,
|
| 151 |
+
preferred_conditioning="depth",
|
| 152 |
+
usage_tips=[
|
| 153 |
+
"Optimized for: black→white, white→black, dark→light extreme transformations",
|
| 154 |
+
"Use VERY vivid descriptors: 'PURE WHITE', 'JET BLACK', 'BRIGHT crimson red'",
|
| 155 |
+
"📝 Example prompts:",
|
| 156 |
+
" • 'pure white dress shirt with crisp texture and sharp collar'",
|
| 157 |
+
" • 'jet black leather jacket with smooth matte finish'",
|
| 158 |
+
" • 'bright crimson red blazer with vivid saturated color'"
|
| 159 |
+
]
|
| 160 |
+
),
|
| 161 |
+
|
| 162 |
+
"clothing_addition": InpaintingTemplate(
|
| 163 |
+
key="clothing_addition",
|
| 164 |
+
name="Add Accessories",
|
| 165 |
+
category="Replacement",
|
| 166 |
+
icon="👔",
|
| 167 |
+
description="Add ties, pockets, buttons, or accessories to clothing (Advanced)",
|
| 168 |
+
prompt_template="{content}, clearly visible, highly detailed accessory, seamlessly integrated into clothing, proper placement and perspective",
|
| 169 |
+
negative_prompt=(
|
| 170 |
+
"missing details, incomplete, floating objects, disconnected, "
|
| 171 |
+
"unrealistic placement, wrong perspective, blurry, poorly integrated, "
|
| 172 |
+
"invisible, faint, unclear, hidden, absent, not visible"
|
| 173 |
+
),
|
| 174 |
+
controlnet_conditioning_scale=0.25,
|
| 175 |
+
feather_radius=12,
|
| 176 |
+
guidance_scale=14.0,
|
| 177 |
+
num_inference_steps=50,
|
| 178 |
+
preferred_conditioning="depth",
|
| 179 |
+
usage_tips=[
|
| 180 |
+
"⚡ Advanced feature: For adding neckties, pocket squares, badges, or buttons",
|
| 181 |
+
"Draw mask from NECK to CHEST (vertical strip) for ties, not just collar area",
|
| 182 |
+
"📝 Example prompts:",
|
| 183 |
+
" • 'burgundy silk necktie with diagonal stripes and Windsor knot, hanging down from collar to chest'",
|
| 184 |
+
" • 'white pocket square with neat fold, visible in breast pocket'",
|
| 185 |
+
" • 'silver lapel pin with detailed engraving on left collar'",
|
| 186 |
+
"⚠️ TIP: For ties, mask should cover the entire length where tie should appear"
|
| 187 |
+
]
|
| 188 |
+
),
|
| 189 |
+
|
| 190 |
+
# Object Removal Category
|
| 191 |
+
"object_removal": InpaintingTemplate(
|
| 192 |
+
key="object_removal",
|
| 193 |
+
name="Object Removal",
|
| 194 |
+
category="Removal",
|
| 195 |
+
icon="🗑️",
|
| 196 |
+
description="Remove unwanted objects and fill with matching background",
|
| 197 |
+
prompt_template="clean background, seamless continuation, {content}",
|
| 198 |
+
negative_prompt=(
|
| 199 |
+
"visible patches, color mismatch, texture inconsistency, "
|
| 200 |
+
"ghost artifacts, blur spots, repeated patterns, visible seams"
|
| 201 |
+
),
|
| 202 |
+
controlnet_conditioning_scale=0.48,
|
| 203 |
+
feather_radius=14,
|
| 204 |
+
guidance_scale=8.0,
|
| 205 |
+
num_inference_steps=30,
|
| 206 |
+
preferred_conditioning="canny",
|
| 207 |
+
usage_tips=[
|
| 208 |
+
"Draw mask slightly BEYOND object edges for better blending",
|
| 209 |
+
"Describe what background SHOULD look like (e.g., 'grass lawn', 'wooden floor')",
|
| 210 |
+
"📝 Example prompts:",
|
| 211 |
+
" • 'clean grass lawn with natural green color'",
|
| 212 |
+
" • 'smooth wooden floor with consistent grain pattern'",
|
| 213 |
+
" • 'plain white wall with even texture'"
|
| 214 |
+
]
|
| 215 |
+
),
|
| 216 |
+
|
| 217 |
+
"watermark_removal": InpaintingTemplate(
|
| 218 |
+
key="watermark_removal",
|
| 219 |
+
name="Watermark Removal",
|
| 220 |
+
category="Removal",
|
| 221 |
+
icon="💧",
|
| 222 |
+
description="Remove watermarks and text overlays",
|
| 223 |
+
prompt_template="clean image, no text, seamless background, {content}",
|
| 224 |
+
negative_prompt=(
|
| 225 |
+
"text, watermark, logo, signature, letters, numbers, visible artifacts, "
|
| 226 |
+
"color inconsistency, blur, remnants, ghost text"
|
| 227 |
+
),
|
| 228 |
+
controlnet_conditioning_scale=0.45,
|
| 229 |
+
feather_radius=12,
|
| 230 |
+
guidance_scale=8.5,
|
| 231 |
+
num_inference_steps=30,
|
| 232 |
+
preferred_conditioning="canny",
|
| 233 |
+
usage_tips=[
|
| 234 |
+
"Draw mask covering ALL watermark/text areas precisely",
|
| 235 |
+
"Describe what SHOULD be there instead (e.g., 'sky', 'fabric texture')",
|
| 236 |
+
"📝 Example prompts:",
|
| 237 |
+
" • 'clean blue sky with smooth gradient'",
|
| 238 |
+
" • 'natural skin texture without marks'",
|
| 239 |
+
" • 'smooth fabric surface with consistent color'"
|
| 240 |
+
]
|
| 241 |
+
),
|
| 242 |
+
|
| 243 |
+
"blemish_removal": InpaintingTemplate(
|
| 244 |
+
key="blemish_removal",
|
| 245 |
+
name="Blemish Removal",
|
| 246 |
+
category="Removal",
|
| 247 |
+
icon="✨",
|
| 248 |
+
description="Remove skin blemishes, scratches, or small imperfections",
|
| 249 |
+
prompt_template="clean smooth surface, natural texture, {content}",
|
| 250 |
+
negative_prompt=(
|
| 251 |
+
"artificial smoothing, plastic texture, visible editing, "
|
| 252 |
+
"color patches, unnatural appearance, over-processed"
|
| 253 |
+
),
|
| 254 |
+
controlnet_conditioning_scale=0.6,
|
| 255 |
+
feather_radius=6,
|
| 256 |
+
guidance_scale=6.5,
|
| 257 |
+
num_inference_steps=20,
|
| 258 |
+
preferred_conditioning="canny",
|
| 259 |
+
usage_tips=[
|
| 260 |
+
"Draw small precise masks for EACH blemish/imperfection",
|
| 261 |
+
"Lower guidance (6.5) preserves natural skin texture",
|
| 262 |
+
"📝 Example prompts:",
|
| 263 |
+
" • 'natural clean skin with smooth texture'",
|
| 264 |
+
" • 'smooth surface without scratches or marks'",
|
| 265 |
+
" • 'clear skin with natural pores and texture'"
|
| 266 |
+
]
|
| 267 |
+
),
|
| 268 |
+
|
| 269 |
+
# Style Transfer Category
|
| 270 |
+
"style_artistic": InpaintingTemplate(
|
| 271 |
+
key="style_artistic",
|
| 272 |
+
name="Artistic Style",
|
| 273 |
+
category="Style",
|
| 274 |
+
icon="🎨",
|
| 275 |
+
description="Apply artistic style to selected region",
|
| 276 |
+
prompt_template="{content}, distinctive artistic style, strong painterly effect, creative interpretation, visible brushstrokes",
|
| 277 |
+
negative_prompt=(
|
| 278 |
+
"photorealistic, plain, boring, low contrast, unchanged, "
|
| 279 |
+
"inconsistent style, harsh transitions, original appearance, realistic photo"
|
| 280 |
+
),
|
| 281 |
+
controlnet_conditioning_scale=0.52,
|
| 282 |
+
feather_radius=12,
|
| 283 |
+
guidance_scale=11.5,
|
| 284 |
+
num_inference_steps=38,
|
| 285 |
+
preferred_conditioning="canny",
|
| 286 |
+
usage_tips=[
|
| 287 |
+
"Works best on larger areas (faces, clothing, backgrounds) for visible transformation",
|
| 288 |
+
"Be VERY specific about art style for best results",
|
| 289 |
+
"📝 Example prompts:",
|
| 290 |
+
" • 'impressionist oil painting with visible thick brushstrokes and vibrant colors'",
|
| 291 |
+
" • 'watercolor painting with soft edges and delicate color washes'",
|
| 292 |
+
" • 'Van Gogh style with swirling brushstrokes and bold color contrasts'"
|
| 293 |
+
]
|
| 294 |
+
),
|
| 295 |
+
|
| 296 |
+
"style_vintage": InpaintingTemplate(
|
| 297 |
+
key="style_vintage",
|
| 298 |
+
name="Vintage Look",
|
| 299 |
+
category="Style",
|
| 300 |
+
icon="📻",
|
| 301 |
+
description="Apply vintage or retro aesthetic to selected area",
|
| 302 |
+
prompt_template="{content}, strong vintage aesthetic, warm sepia tones, film grain texture, nostalgic atmosphere",
|
| 303 |
+
negative_prompt=(
|
| 304 |
+
"modern, digital, cold colors, harsh contrast, "
|
| 305 |
+
"oversaturated, neon colors, contemporary look, clean digital, crisp"
|
| 306 |
+
),
|
| 307 |
+
controlnet_conditioning_scale=0.55,
|
| 308 |
+
feather_radius=14,
|
| 309 |
+
guidance_scale=10.5,
|
| 310 |
+
num_inference_steps=35,
|
| 311 |
+
preferred_conditioning="canny",
|
| 312 |
+
usage_tips=[
|
| 313 |
+
"Works best on medium to large regions for visible aesthetic change",
|
| 314 |
+
"Specify era and style for best results",
|
| 315 |
+
"📝 Example prompts:",
|
| 316 |
+
" • '1920s sepia photograph with faded brown tones and soft grain'",
|
| 317 |
+
" • '1970s vintage photo with warm orange tones and slight film grain'",
|
| 318 |
+
" • '1950s Kodachrome with saturated warm colors and nostalgic feel'"
|
| 319 |
+
]
|
| 320 |
+
),
|
| 321 |
+
|
| 322 |
+
"style_anime": InpaintingTemplate(
|
| 323 |
+
key="style_anime",
|
| 324 |
+
name="Anime Style",
|
| 325 |
+
category="Style",
|
| 326 |
+
icon="🎌",
|
| 327 |
+
description="Transform selected region to anime/illustration style",
|
| 328 |
+
prompt_template="{content}, anime illustration style, clean sharp lines, vibrant saturated colors, cel-shaded with flat colors",
|
| 329 |
+
negative_prompt=(
|
| 330 |
+
"photorealistic, blurry lines, muddy colors, realistic photo, "
|
| 331 |
+
"3D render, uncanny valley, western cartoon, gradient shading, photographic"
|
| 332 |
+
),
|
| 333 |
+
controlnet_conditioning_scale=0.48,
|
| 334 |
+
feather_radius=10,
|
| 335 |
+
guidance_scale=12.5,
|
| 336 |
+
num_inference_steps=40,
|
| 337 |
+
preferred_conditioning="canny",
|
| 338 |
+
usage_tips=[
|
| 339 |
+
"⚡ DRAMATIC transformation - best for portraits and characters",
|
| 340 |
+
"Expect significant stylistic changes from realistic to anime",
|
| 341 |
+
"📝 Example prompts:",
|
| 342 |
+
" • 'modern anime style with large expressive eyes and vibrant colors'",
|
| 343 |
+
" • 'Studio Ghibli style with soft features and warm color palette'",
|
| 344 |
+
" • 'manga style with clean black lines and cel-shaded coloring'"
|
| 345 |
+
]
|
| 346 |
+
),
|
| 347 |
+
|
| 348 |
+
# Detail Enhancement Category
|
| 349 |
+
"detail_enhance": InpaintingTemplate(
|
| 350 |
+
key="detail_enhance",
|
| 351 |
+
name="Detail Enhancement",
|
| 352 |
+
category="Enhancement",
|
| 353 |
+
icon="🔍",
|
| 354 |
+
description="Add fine details and textures to selected area",
|
| 355 |
+
prompt_template="{content}, highly detailed, intricate textures, fine details, sharp focus",
|
| 356 |
+
negative_prompt=(
|
| 357 |
+
"blurry, smooth, low detail, soft focus, "
|
| 358 |
+
"oversimplified, lacking texture"
|
| 359 |
+
),
|
| 360 |
+
controlnet_conditioning_scale=0.85,
|
| 361 |
+
feather_radius=4,
|
| 362 |
+
guidance_scale=8.0,
|
| 363 |
+
num_inference_steps=30,
|
| 364 |
+
preferred_conditioning="depth",
|
| 365 |
+
usage_tips=[
|
| 366 |
+
"High conditioning (0.85) preserves overall structure while adding detail",
|
| 367 |
+
"Best for adding fine details to existing objects",
|
| 368 |
+
"📝 Example prompts:",
|
| 369 |
+
" • 'highly detailed fabric with visible weave and fine threads'",
|
| 370 |
+
" • 'intricate wood grain with natural knots and detailed texture'",
|
| 371 |
+
" • 'sharp facial features with fine skin pores and detail'"
|
| 372 |
+
]
|
| 373 |
+
),
|
| 374 |
+
|
| 375 |
+
"texture_add": InpaintingTemplate(
|
| 376 |
+
key="texture_add",
|
| 377 |
+
name="Texture Addition",
|
| 378 |
+
category="Enhancement",
|
| 379 |
+
icon="🧱",
|
| 380 |
+
description="Add or enhance surface textures",
|
| 381 |
+
prompt_template="{content} texture, realistic surface detail, natural material appearance",
|
| 382 |
+
negative_prompt=(
|
| 383 |
+
"flat, smooth, unrealistic, plastic, "
|
| 384 |
+
"wrong material, inconsistent texture"
|
| 385 |
+
),
|
| 386 |
+
controlnet_conditioning_scale=0.8,
|
| 387 |
+
feather_radius=5,
|
| 388 |
+
guidance_scale=7.5,
|
| 389 |
+
num_inference_steps=25,
|
| 390 |
+
preferred_conditioning="depth",
|
| 391 |
+
usage_tips=[
|
| 392 |
+
"Specify material type clearly for best results",
|
| 393 |
+
"Depth conditioning preserves 3D form while changing texture",
|
| 394 |
+
"📝 Example prompts:",
|
| 395 |
+
" • 'rough wood texture with natural grain and knots'",
|
| 396 |
+
" • 'soft cotton fabric with gentle weave pattern'",
|
| 397 |
+
" • 'smooth marble surface with subtle veining'"
|
| 398 |
+
]
|
| 399 |
+
),
|
| 400 |
+
|
| 401 |
+
"lighting_fix": InpaintingTemplate(
|
| 402 |
+
key="lighting_fix",
|
| 403 |
+
name="Lighting Correction",
|
| 404 |
+
category="Enhancement",
|
| 405 |
+
icon="💡",
|
| 406 |
+
description="Correct or enhance lighting in selected area",
|
| 407 |
+
prompt_template="{content}, proper lighting, natural shadows, balanced exposure",
|
| 408 |
+
negative_prompt=(
|
| 409 |
+
"harsh shadows, overexposed, underexposed, "
|
| 410 |
+
"flat lighting, unnatural highlights"
|
| 411 |
+
),
|
| 412 |
+
controlnet_conditioning_scale=0.65,
|
| 413 |
+
feather_radius=15,
|
| 414 |
+
guidance_scale=7.0,
|
| 415 |
+
num_inference_steps=25,
|
| 416 |
+
preferred_conditioning="depth",
|
| 417 |
+
usage_tips=[
|
| 418 |
+
"Use large feather (15px) for smooth lighting transitions",
|
| 419 |
+
"Best for fixing uneven lighting or adding natural light",
|
| 420 |
+
"📝 Example prompts:",
|
| 421 |
+
" • 'soft natural lighting from window, gentle shadows'",
|
| 422 |
+
" • 'balanced exposure with warm golden hour light'",
|
| 423 |
+
" • 'even studio lighting with soft diffused shadows'"
|
| 424 |
+
]
|
| 425 |
+
),
|
| 426 |
+
|
| 427 |
+
# Background Category
|
| 428 |
+
"background_extend": InpaintingTemplate(
|
| 429 |
+
key="background_extend",
|
| 430 |
+
name="Background Extension",
|
| 431 |
+
category="Background",
|
| 432 |
+
icon="📐",
|
| 433 |
+
description="Extend image background seamlessly",
|
| 434 |
+
prompt_template="seamless background extension, {content}, consistent style and lighting",
|
| 435 |
+
negative_prompt=(
|
| 436 |
+
"visible seams, style mismatch, lighting inconsistency, "
|
| 437 |
+
"repeated elements, unnatural continuation, abrupt changes"
|
| 438 |
+
),
|
| 439 |
+
controlnet_conditioning_scale=0.55,
|
| 440 |
+
feather_radius=20,
|
| 441 |
+
guidance_scale=8.0,
|
| 442 |
+
num_inference_steps=32,
|
| 443 |
+
preferred_conditioning="canny",
|
| 444 |
+
usage_tips=[
|
| 445 |
+
"Draw mask on area to extend (edges of image)",
|
| 446 |
+
"Large feather (20px) ensures smooth blending with existing background",
|
| 447 |
+
"📝 Example prompts:",
|
| 448 |
+
" • 'continue the wooden floor with same grain pattern'",
|
| 449 |
+
" • 'extend blue sky with matching clouds and lighting'",
|
| 450 |
+
" • 'seamless continuation of brick wall texture'"
|
| 451 |
+
]
|
| 452 |
+
),
|
| 453 |
+
|
| 454 |
+
"background_replace": InpaintingTemplate(
|
| 455 |
+
key="background_replace",
|
| 456 |
+
name="Background Replacement",
|
| 457 |
+
category="Background",
|
| 458 |
+
icon="🖼️",
|
| 459 |
+
description="Replace background while keeping subject intact",
|
| 460 |
+
prompt_template="{content}, professional background scene, seamless integration with subject, matching lighting and atmosphere",
|
| 461 |
+
negative_prompt=(
|
| 462 |
+
"floating subject, inconsistent lighting, disconnected, "
|
| 463 |
+
"wrong perspective, visible edges, color mismatch, original background, "
|
| 464 |
+
"poor integration, obvious composite"
|
| 465 |
+
),
|
| 466 |
+
controlnet_conditioning_scale=0.60,
|
| 467 |
+
feather_radius=12,
|
| 468 |
+
guidance_scale=9.5,
|
| 469 |
+
num_inference_steps=35,
|
| 470 |
+
preferred_conditioning="depth",
|
| 471 |
+
usage_tips=[
|
| 472 |
+
"Draw mask around ENTIRE background (leave subject unmasked with small margin)",
|
| 473 |
+
"Include lighting description to match subject for natural results",
|
| 474 |
+
"📝 Example prompts:",
|
| 475 |
+
" • 'professional photography studio with white backdrop and soft lighting'",
|
| 476 |
+
" • 'modern minimalist office with white walls and bright natural lighting'",
|
| 477 |
+
" • 'sunny beach with blue ocean and golden hour lighting'"
|
| 478 |
+
]
|
| 479 |
+
),
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
# Category display order
|
| 483 |
+
CATEGORIES = ["Replacement", "Removal", "Style", "Enhancement", "Background"]
|
| 484 |
+
|
| 485 |
+
def __init__(self):
|
| 486 |
+
"""Initialize the InpaintingTemplateManager."""
|
| 487 |
+
logger.info(f"InpaintingTemplateManager initialized with {len(self.TEMPLATES)} templates")
|
| 488 |
+
|
| 489 |
+
def get_all_templates(self) -> Dict[str, InpaintingTemplate]:
|
| 490 |
+
"""
|
| 491 |
+
Get all available templates.
|
| 492 |
+
|
| 493 |
+
Returns
|
| 494 |
+
-------
|
| 495 |
+
dict
|
| 496 |
+
Dictionary of all templates keyed by template key
|
| 497 |
+
"""
|
| 498 |
+
return self.TEMPLATES
|
| 499 |
+
|
| 500 |
+
def get_template(self, key: str) -> Optional[InpaintingTemplate]:
|
| 501 |
+
"""
|
| 502 |
+
Get a specific template by key.
|
| 503 |
+
|
| 504 |
+
Parameters
|
| 505 |
+
----------
|
| 506 |
+
key : str
|
| 507 |
+
Template identifier
|
| 508 |
+
|
| 509 |
+
Returns
|
| 510 |
+
-------
|
| 511 |
+
InpaintingTemplate or None
|
| 512 |
+
Template if found, None otherwise
|
| 513 |
+
"""
|
| 514 |
+
return self.TEMPLATES.get(key)
|
| 515 |
+
|
| 516 |
+
def get_templates_by_category(self, category: str) -> List[InpaintingTemplate]:
|
| 517 |
+
"""
|
| 518 |
+
Get all templates in a specific category.
|
| 519 |
+
|
| 520 |
+
Parameters
|
| 521 |
+
----------
|
| 522 |
+
category : str
|
| 523 |
+
Category name
|
| 524 |
+
|
| 525 |
+
Returns
|
| 526 |
+
-------
|
| 527 |
+
list
|
| 528 |
+
List of templates in the category
|
| 529 |
+
"""
|
| 530 |
+
return [t for t in self.TEMPLATES.values() if t.category == category]
|
| 531 |
+
|
| 532 |
+
def get_categories(self) -> List[str]:
|
| 533 |
+
"""
|
| 534 |
+
Get list of all categories in display order.
|
| 535 |
+
|
| 536 |
+
Returns
|
| 537 |
+
-------
|
| 538 |
+
list
|
| 539 |
+
Category names
|
| 540 |
+
"""
|
| 541 |
+
return self.CATEGORIES
|
| 542 |
+
|
| 543 |
+
def get_template_choices_sorted(self) -> List[str]:
|
| 544 |
+
"""
|
| 545 |
+
Get template choices formatted for Gradio dropdown.
|
| 546 |
+
|
| 547 |
+
Returns list of display strings sorted by category then A-Z.
|
| 548 |
+
Format: "icon Name"
|
| 549 |
+
|
| 550 |
+
Returns
|
| 551 |
+
-------
|
| 552 |
+
list
|
| 553 |
+
Formatted display strings for dropdown
|
| 554 |
+
"""
|
| 555 |
+
display_list = []
|
| 556 |
+
|
| 557 |
+
for category in self.CATEGORIES:
|
| 558 |
+
templates = self.get_templates_by_category(category)
|
| 559 |
+
for template in sorted(templates, key=lambda t: t.name):
|
| 560 |
+
display_name = f"{template.icon} {template.name}"
|
| 561 |
+
display_list.append(display_name)
|
| 562 |
+
|
| 563 |
+
return display_list
|
| 564 |
+
|
| 565 |
+
def get_template_key_from_display(self, display_name: str) -> Optional[str]:
|
| 566 |
+
"""
|
| 567 |
+
Get template key from display name.
|
| 568 |
+
|
| 569 |
+
Parameters
|
| 570 |
+
----------
|
| 571 |
+
display_name : str
|
| 572 |
+
Display string like "🔄 Object Replacement"
|
| 573 |
+
|
| 574 |
+
Returns
|
| 575 |
+
-------
|
| 576 |
+
str or None
|
| 577 |
+
Template key if found
|
| 578 |
+
"""
|
| 579 |
+
if not display_name:
|
| 580 |
+
return None
|
| 581 |
+
|
| 582 |
+
for key, template in self.TEMPLATES.items():
|
| 583 |
+
if f"{template.icon} {template.name}" == display_name:
|
| 584 |
+
return key
|
| 585 |
+
return None
|
| 586 |
+
|
| 587 |
+
def get_parameters_for_template(self, key: str) -> Dict[str, any]:
|
| 588 |
+
"""
|
| 589 |
+
Get recommended parameters for a template.
|
| 590 |
+
|
| 591 |
+
Parameters
|
| 592 |
+
----------
|
| 593 |
+
key : str
|
| 594 |
+
Template key
|
| 595 |
+
|
| 596 |
+
Returns
|
| 597 |
+
-------
|
| 598 |
+
dict
|
| 599 |
+
Dictionary of parameter names and values
|
| 600 |
+
"""
|
| 601 |
+
template = self.get_template(key)
|
| 602 |
+
if not template:
|
| 603 |
+
return {}
|
| 604 |
+
|
| 605 |
+
return {
|
| 606 |
+
"controlnet_conditioning_scale": template.controlnet_conditioning_scale,
|
| 607 |
+
"feather_radius": template.feather_radius,
|
| 608 |
+
"guidance_scale": template.guidance_scale,
|
| 609 |
+
"num_inference_steps": template.num_inference_steps,
|
| 610 |
+
"preferred_conditioning": template.preferred_conditioning
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
def build_prompt(self, key: str, content: str) -> str:
|
| 614 |
+
"""
|
| 615 |
+
Build complete prompt from template and user content.
|
| 616 |
+
|
| 617 |
+
Parameters
|
| 618 |
+
----------
|
| 619 |
+
key : str
|
| 620 |
+
Template key
|
| 621 |
+
content : str
|
| 622 |
+
User-provided content description
|
| 623 |
+
|
| 624 |
+
Returns
|
| 625 |
+
-------
|
| 626 |
+
str
|
| 627 |
+
Formatted prompt with content inserted
|
| 628 |
+
"""
|
| 629 |
+
template = self.get_template(key)
|
| 630 |
+
if not template:
|
| 631 |
+
return content
|
| 632 |
+
|
| 633 |
+
return template.prompt_template.format(content=content)
|
| 634 |
+
|
| 635 |
+
def get_negative_prompt(self, key: str) -> str:
|
| 636 |
+
"""
|
| 637 |
+
Get negative prompt for a template.
|
| 638 |
+
|
| 639 |
+
Parameters
|
| 640 |
+
----------
|
| 641 |
+
key : str
|
| 642 |
+
Template key
|
| 643 |
+
|
| 644 |
+
Returns
|
| 645 |
+
-------
|
| 646 |
+
str
|
| 647 |
+
Negative prompt string
|
| 648 |
+
"""
|
| 649 |
+
template = self.get_template(key)
|
| 650 |
+
if not template:
|
| 651 |
+
return ""
|
| 652 |
+
return template.negative_prompt
|
| 653 |
+
|
| 654 |
+
def get_usage_tips(self, key: str) -> List[str]:
|
| 655 |
+
"""
|
| 656 |
+
Get usage tips for a template.
|
| 657 |
+
|
| 658 |
+
Parameters
|
| 659 |
+
----------
|
| 660 |
+
key : str
|
| 661 |
+
Template key
|
| 662 |
+
|
| 663 |
+
Returns
|
| 664 |
+
-------
|
| 665 |
+
list
|
| 666 |
+
List of tip strings
|
| 667 |
+
"""
|
| 668 |
+
template = self.get_template(key)
|
| 669 |
+
if not template:
|
| 670 |
+
return []
|
| 671 |
+
return template.usage_tips
|
| 672 |
+
|
| 673 |
+
def build_gallery_html(self) -> str:
|
| 674 |
+
"""
|
| 675 |
+
Build HTML for template gallery display.
|
| 676 |
+
|
| 677 |
+
Returns
|
| 678 |
+
-------
|
| 679 |
+
str
|
| 680 |
+
HTML string for Gradio display
|
| 681 |
+
"""
|
| 682 |
+
html_parts = ['<div class="inpainting-gallery">']
|
| 683 |
+
|
| 684 |
+
for category in self.CATEGORIES:
|
| 685 |
+
templates = self.get_templates_by_category(category)
|
| 686 |
+
if not templates:
|
| 687 |
+
continue
|
| 688 |
+
|
| 689 |
+
html_parts.append(f'''
|
| 690 |
+
<div class="inpainting-category">
|
| 691 |
+
<h4 class="inpainting-category-title">{category}</h4>
|
| 692 |
+
<div class="inpainting-grid">
|
| 693 |
+
''')
|
| 694 |
+
|
| 695 |
+
for template in sorted(templates, key=lambda t: t.name):
|
| 696 |
+
html_parts.append(f'''
|
| 697 |
+
<div class="inpainting-card" data-template="{template.key}">
|
| 698 |
+
<span class="inpainting-icon">{template.icon}</span>
|
| 699 |
+
<span class="inpainting-name">{template.name}</span>
|
| 700 |
+
<span class="inpainting-desc">{template.description[:50]}...</span>
|
| 701 |
+
</div>
|
| 702 |
+
''')
|
| 703 |
+
|
| 704 |
+
html_parts.append('</div></div>')
|
| 705 |
+
|
| 706 |
+
html_parts.append('</div>')
|
| 707 |
+
return ''.join(html_parts)
|
mask_generator.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import cv2
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
from PIL import Image, ImageFilter, ImageDraw
|
| 4 |
import logging
|
| 5 |
from typing import Optional, Tuple
|
|
@@ -222,7 +223,6 @@ class MaskGenerator:
|
|
| 222 |
|
| 223 |
except Exception as e:
|
| 224 |
logger.error(f"❌ BiRefNet mask generation failed: {e}")
|
| 225 |
-
import traceback
|
| 226 |
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 227 |
return None
|
| 228 |
|
|
@@ -380,7 +380,6 @@ class MaskGenerator:
|
|
| 380 |
return is_cartoon
|
| 381 |
|
| 382 |
except Exception as e:
|
| 383 |
-
import traceback
|
| 384 |
logger.error(f"❌ Cartoon detection failed: {e}")
|
| 385 |
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 386 |
print(f"❌ CARTOON DETECTION ERROR: {e}")
|
|
@@ -437,7 +436,6 @@ class MaskGenerator:
|
|
| 437 |
return enhanced_alpha
|
| 438 |
|
| 439 |
except Exception as e:
|
| 440 |
-
import traceback
|
| 441 |
logger.error(f"❌ Cartoon mask enhancement failed: {e}")
|
| 442 |
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 443 |
print(f"❌ CARTOON MASK ENHANCEMENT ERROR: {e}")
|
|
|
|
| 1 |
import cv2
|
| 2 |
import numpy as np
|
| 3 |
+
import traceback
|
| 4 |
from PIL import Image, ImageFilter, ImageDraw
|
| 5 |
import logging
|
| 6 |
from typing import Optional, Tuple
|
|
|
|
| 223 |
|
| 224 |
except Exception as e:
|
| 225 |
logger.error(f"❌ BiRefNet mask generation failed: {e}")
|
|
|
|
| 226 |
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 227 |
return None
|
| 228 |
|
|
|
|
| 380 |
return is_cartoon
|
| 381 |
|
| 382 |
except Exception as e:
|
|
|
|
| 383 |
logger.error(f"❌ Cartoon detection failed: {e}")
|
| 384 |
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 385 |
print(f"❌ CARTOON DETECTION ERROR: {e}")
|
|
|
|
| 436 |
return enhanced_alpha
|
| 437 |
|
| 438 |
except Exception as e:
|
|
|
|
| 439 |
logger.error(f"❌ Cartoon mask enhancement failed: {e}")
|
| 440 |
logger.error(f"📍 Traceback: {traceback.format_exc()}")
|
| 441 |
print(f"❌ CARTOON MASK ENHANCEMENT ERROR: {e}")
|
model_manager.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
import logging
|
| 2 |
import gc
|
| 3 |
import time
|
| 4 |
-
from
|
|
|
|
| 5 |
from dataclasses import dataclass, field
|
| 6 |
from threading import Lock
|
| 7 |
import torch
|
|
@@ -10,13 +11,41 @@ logger = logging.getLogger(__name__)
|
|
| 10 |
logger.setLevel(logging.INFO)
|
| 11 |
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
@dataclass
|
| 14 |
class ModelInfo:
|
| 15 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
name: str
|
| 17 |
loader: Callable[[], Any]
|
| 18 |
-
is_critical: bool = False
|
|
|
|
| 19 |
estimated_memory_gb: float = 0.0
|
|
|
|
| 20 |
is_loaded: bool = False
|
| 21 |
last_used: float = 0.0
|
| 22 |
model_instance: Any = None
|
|
@@ -25,12 +54,34 @@ class ModelInfo:
|
|
| 25 |
class ModelManager:
|
| 26 |
"""
|
| 27 |
Singleton model manager for unified model lifecycle management.
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
"""
|
| 30 |
|
| 31 |
_instance = None
|
| 32 |
_lock = Lock()
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
def __new__(cls):
|
| 35 |
if cls._instance is None:
|
| 36 |
with cls._lock:
|
|
@@ -45,9 +96,11 @@ class ModelManager:
|
|
| 45 |
|
| 46 |
self._models: Dict[str, ModelInfo] = {}
|
| 47 |
self._memory_threshold = 0.80 # Trigger cleanup at 80% GPU memory usage
|
|
|
|
| 48 |
self._device = self._detect_device()
|
|
|
|
| 49 |
|
| 50 |
-
logger.info(f"
|
| 51 |
self._initialized = True
|
| 52 |
|
| 53 |
def _detect_device(self) -> str:
|
|
@@ -63,44 +116,73 @@ class ModelManager:
|
|
| 63 |
name: str,
|
| 64 |
loader: Callable[[], Any],
|
| 65 |
is_critical: bool = False,
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
):
|
| 68 |
"""
|
| 69 |
Register a model for managed loading.
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
"""
|
| 77 |
if name in self._models:
|
| 78 |
-
logger.warning(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
self._models[name] = ModelInfo(
|
| 81 |
name=name,
|
| 82 |
loader=loader,
|
| 83 |
is_critical=is_critical,
|
|
|
|
| 84 |
estimated_memory_gb=estimated_memory_gb,
|
|
|
|
| 85 |
is_loaded=False,
|
| 86 |
last_used=0.0,
|
| 87 |
model_instance=None
|
| 88 |
)
|
| 89 |
-
logger.info(f"
|
| 90 |
|
| 91 |
-
def load_model(self, name: str) -> Any:
|
| 92 |
"""
|
| 93 |
Load a model by name. Returns cached instance if already loaded.
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
Loaded model instance
|
| 100 |
|
| 101 |
-
Raises
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
| 104 |
"""
|
| 105 |
if name not in self._models:
|
| 106 |
raise KeyError(f"Model '{name}' not registered")
|
|
@@ -110,15 +192,21 @@ class ModelManager:
|
|
| 110 |
# Return cached instance
|
| 111 |
if model_info.is_loaded and model_info.model_instance is not None:
|
| 112 |
model_info.last_used = time.time()
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
return model_info.model_instance
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
# Check memory pressure before loading
|
| 117 |
self.check_memory_pressure()
|
| 118 |
|
| 119 |
# Load the model
|
| 120 |
try:
|
| 121 |
-
logger.info(f"
|
| 122 |
start_time = time.time()
|
| 123 |
|
| 124 |
model_instance = model_info.loader()
|
|
@@ -127,32 +215,64 @@ class ModelManager:
|
|
| 127 |
model_info.is_loaded = True
|
| 128 |
model_info.last_used = time.time()
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
load_time = time.time() - start_time
|
| 131 |
-
logger.info(f"
|
| 132 |
|
| 133 |
return model_instance
|
| 134 |
|
| 135 |
except Exception as e:
|
| 136 |
-
logger.error(f"
|
| 137 |
raise RuntimeError(f"Model loading failed: {e}")
|
| 138 |
|
| 139 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
"""
|
| 141 |
Unload a specific model to free memory.
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
"""
|
| 146 |
if name not in self._models:
|
| 147 |
-
return
|
| 148 |
|
| 149 |
model_info = self._models[name]
|
| 150 |
|
| 151 |
if not model_info.is_loaded:
|
| 152 |
-
return
|
| 153 |
|
| 154 |
try:
|
| 155 |
-
logger.info(f"
|
| 156 |
|
| 157 |
# Delete model instance
|
| 158 |
if model_info.model_instance is not None:
|
|
@@ -161,21 +281,33 @@ class ModelManager:
|
|
| 161 |
model_info.model_instance = None
|
| 162 |
model_info.is_loaded = False
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
# Cleanup
|
| 165 |
gc.collect()
|
| 166 |
if torch.cuda.is_available():
|
| 167 |
torch.cuda.empty_cache()
|
|
|
|
| 168 |
|
| 169 |
-
logger.info(f"
|
|
|
|
| 170 |
|
| 171 |
except Exception as e:
|
| 172 |
-
logger.error(f"
|
|
|
|
| 173 |
|
| 174 |
def check_memory_pressure(self) -> bool:
|
| 175 |
"""
|
| 176 |
-
Check GPU memory usage and unload
|
| 177 |
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
True if cleanup was performed
|
| 180 |
"""
|
| 181 |
if not torch.cuda.is_available():
|
|
@@ -188,18 +320,19 @@ class ModelManager:
|
|
| 188 |
if usage_ratio < self._memory_threshold:
|
| 189 |
return False
|
| 190 |
|
| 191 |
-
logger.warning(f"
|
| 192 |
|
| 193 |
-
# Find
|
| 194 |
-
|
|
|
|
| 195 |
(name, info) for name, info in self._models.items()
|
| 196 |
-
if info.is_loaded and
|
| 197 |
]
|
| 198 |
-
|
| 199 |
|
| 200 |
-
# Unload
|
| 201 |
cleaned = False
|
| 202 |
-
for name, info in
|
| 203 |
self.unload_model(name)
|
| 204 |
cleaned = True
|
| 205 |
|
|
@@ -210,13 +343,21 @@ class ModelManager:
|
|
| 210 |
|
| 211 |
return cleaned
|
| 212 |
|
| 213 |
-
def force_cleanup(self):
|
| 214 |
-
"""
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
-
# Unload
|
| 218 |
-
|
| 219 |
-
|
|
|
|
| 220 |
self.unload_model(name)
|
| 221 |
|
| 222 |
# Aggressive garbage collection
|
|
@@ -228,7 +369,81 @@ class ModelManager:
|
|
| 228 |
torch.cuda.ipc_collect()
|
| 229 |
torch.cuda.synchronize()
|
| 230 |
|
| 231 |
-
logger.info("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
def get_memory_status(self) -> Dict[str, Any]:
|
| 234 |
"""
|
|
|
|
| 1 |
import logging
|
| 2 |
import gc
|
| 3 |
import time
|
| 4 |
+
from enum import IntEnum
|
| 5 |
+
from typing import Dict, Any, Optional, Callable, List
|
| 6 |
from dataclasses import dataclass, field
|
| 7 |
from threading import Lock
|
| 8 |
import torch
|
|
|
|
| 11 |
logger.setLevel(logging.INFO)
|
| 12 |
|
| 13 |
|
| 14 |
+
class ModelPriority(IntEnum):
|
| 15 |
+
"""
|
| 16 |
+
Model priority levels for memory management.
|
| 17 |
+
|
| 18 |
+
Higher priority models are kept loaded longer under memory pressure.
|
| 19 |
+
"""
|
| 20 |
+
CRITICAL = 100 # Never unload (e.g., OpenCLIP for analysis)
|
| 21 |
+
HIGH = 80 # Currently active pipeline
|
| 22 |
+
MEDIUM = 50 # Recently used models
|
| 23 |
+
LOW = 20 # Inactive pipelines, can be evicted
|
| 24 |
+
DISPOSABLE = 0 # Temporary models, evict first
|
| 25 |
+
|
| 26 |
+
|
| 27 |
@dataclass
|
| 28 |
class ModelInfo:
|
| 29 |
+
"""
|
| 30 |
+
Information about a registered model.
|
| 31 |
+
|
| 32 |
+
Attributes:
|
| 33 |
+
name: Unique model identifier
|
| 34 |
+
loader: Callable that returns the loaded model
|
| 35 |
+
is_critical: If True, model won't be unloaded under memory pressure
|
| 36 |
+
priority: ModelPriority level for eviction decisions
|
| 37 |
+
estimated_memory_gb: Estimated GPU memory usage
|
| 38 |
+
model_group: Group name for mutual exclusion (e.g., "pipeline")
|
| 39 |
+
is_loaded: Whether model is currently loaded
|
| 40 |
+
last_used: Timestamp of last use
|
| 41 |
+
model_instance: The actual model object
|
| 42 |
+
"""
|
| 43 |
name: str
|
| 44 |
loader: Callable[[], Any]
|
| 45 |
+
is_critical: bool = False
|
| 46 |
+
priority: int = ModelPriority.MEDIUM
|
| 47 |
estimated_memory_gb: float = 0.0
|
| 48 |
+
model_group: str = "" # For mutual exclusion (e.g., "pipeline")
|
| 49 |
is_loaded: bool = False
|
| 50 |
last_used: float = 0.0
|
| 51 |
model_instance: Any = None
|
|
|
|
| 54 |
class ModelManager:
|
| 55 |
"""
|
| 56 |
Singleton model manager for unified model lifecycle management.
|
| 57 |
+
|
| 58 |
+
Handles lazy loading, caching, priority-based eviction, and mutual
|
| 59 |
+
exclusion for pipeline models. Designed for memory-constrained
|
| 60 |
+
environments like Google Colab and HuggingFace Spaces.
|
| 61 |
+
|
| 62 |
+
Features:
|
| 63 |
+
- Priority-based model eviction under memory pressure
|
| 64 |
+
- Mutual exclusion for pipeline models (only one active at a time)
|
| 65 |
+
- Automatic memory monitoring and cleanup
|
| 66 |
+
- Support for model groups and dependencies
|
| 67 |
+
|
| 68 |
+
Example:
|
| 69 |
+
>>> manager = get_model_manager()
|
| 70 |
+
>>> manager.register_model(
|
| 71 |
+
... name="sdxl_pipeline",
|
| 72 |
+
... loader=load_sdxl,
|
| 73 |
+
... priority=ModelPriority.HIGH,
|
| 74 |
+
... model_group="pipeline"
|
| 75 |
+
... )
|
| 76 |
+
>>> pipeline = manager.load_model("sdxl_pipeline")
|
| 77 |
"""
|
| 78 |
|
| 79 |
_instance = None
|
| 80 |
_lock = Lock()
|
| 81 |
|
| 82 |
+
# Known model groups for mutual exclusion
|
| 83 |
+
PIPELINE_GROUP = "pipeline" # Only one pipeline can be loaded at a time
|
| 84 |
+
|
| 85 |
def __new__(cls):
|
| 86 |
if cls._instance is None:
|
| 87 |
with cls._lock:
|
|
|
|
| 96 |
|
| 97 |
self._models: Dict[str, ModelInfo] = {}
|
| 98 |
self._memory_threshold = 0.80 # Trigger cleanup at 80% GPU memory usage
|
| 99 |
+
self._high_memory_threshold = 0.90 # Critical threshold for aggressive cleanup
|
| 100 |
self._device = self._detect_device()
|
| 101 |
+
self._active_pipeline: Optional[str] = None # Track currently active pipeline
|
| 102 |
|
| 103 |
+
logger.info(f"ModelManager initialized on {self._device}")
|
| 104 |
self._initialized = True
|
| 105 |
|
| 106 |
def _detect_device(self) -> str:
|
|
|
|
| 116 |
name: str,
|
| 117 |
loader: Callable[[], Any],
|
| 118 |
is_critical: bool = False,
|
| 119 |
+
priority: int = ModelPriority.MEDIUM,
|
| 120 |
+
estimated_memory_gb: float = 0.0,
|
| 121 |
+
model_group: str = ""
|
| 122 |
):
|
| 123 |
"""
|
| 124 |
Register a model for managed loading.
|
| 125 |
|
| 126 |
+
Parameters
|
| 127 |
+
----------
|
| 128 |
+
name : str
|
| 129 |
+
Unique model identifier
|
| 130 |
+
loader : callable
|
| 131 |
+
Function that returns the loaded model
|
| 132 |
+
is_critical : bool
|
| 133 |
+
If True, model won't be unloaded under memory pressure
|
| 134 |
+
priority : int
|
| 135 |
+
ModelPriority level for eviction decisions
|
| 136 |
+
estimated_memory_gb : float
|
| 137 |
+
Estimated GPU memory usage in GB
|
| 138 |
+
model_group : str
|
| 139 |
+
Group name for mutual exclusion (e.g., "pipeline")
|
| 140 |
"""
|
| 141 |
if name in self._models:
|
| 142 |
+
logger.warning(f"Model '{name}' already registered, updating")
|
| 143 |
+
|
| 144 |
+
# Critical models always have highest priority
|
| 145 |
+
if is_critical:
|
| 146 |
+
priority = ModelPriority.CRITICAL
|
| 147 |
|
| 148 |
self._models[name] = ModelInfo(
|
| 149 |
name=name,
|
| 150 |
loader=loader,
|
| 151 |
is_critical=is_critical,
|
| 152 |
+
priority=priority,
|
| 153 |
estimated_memory_gb=estimated_memory_gb,
|
| 154 |
+
model_group=model_group,
|
| 155 |
is_loaded=False,
|
| 156 |
last_used=0.0,
|
| 157 |
model_instance=None
|
| 158 |
)
|
| 159 |
+
logger.info(f"Registered model: {name} (priority={priority}, group={model_group}, ~{estimated_memory_gb:.1f}GB)")
|
| 160 |
|
| 161 |
+
def load_model(self, name: str, update_priority: Optional[int] = None) -> Any:
|
| 162 |
"""
|
| 163 |
Load a model by name. Returns cached instance if already loaded.
|
| 164 |
|
| 165 |
+
Implements mutual exclusion for pipeline models - loading a new
|
| 166 |
+
pipeline will unload any existing pipeline first.
|
| 167 |
|
| 168 |
+
Parameters
|
| 169 |
+
----------
|
| 170 |
+
name : str
|
| 171 |
+
Model identifier
|
| 172 |
+
update_priority : int, optional
|
| 173 |
+
If provided, update the model's priority after loading
|
| 174 |
+
|
| 175 |
+
Returns
|
| 176 |
+
-------
|
| 177 |
+
Any
|
| 178 |
Loaded model instance
|
| 179 |
|
| 180 |
+
Raises
|
| 181 |
+
------
|
| 182 |
+
KeyError
|
| 183 |
+
If model not registered
|
| 184 |
+
RuntimeError
|
| 185 |
+
If loading fails
|
| 186 |
"""
|
| 187 |
if name not in self._models:
|
| 188 |
raise KeyError(f"Model '{name}' not registered")
|
|
|
|
| 192 |
# Return cached instance
|
| 193 |
if model_info.is_loaded and model_info.model_instance is not None:
|
| 194 |
model_info.last_used = time.time()
|
| 195 |
+
if update_priority is not None:
|
| 196 |
+
model_info.priority = update_priority
|
| 197 |
+
logger.debug(f"Using cached model: {name}")
|
| 198 |
return model_info.model_instance
|
| 199 |
|
| 200 |
+
# Handle mutual exclusion for pipeline group
|
| 201 |
+
if model_info.model_group == self.PIPELINE_GROUP:
|
| 202 |
+
self._ensure_pipeline_exclusion(name)
|
| 203 |
+
|
| 204 |
# Check memory pressure before loading
|
| 205 |
self.check_memory_pressure()
|
| 206 |
|
| 207 |
# Load the model
|
| 208 |
try:
|
| 209 |
+
logger.info(f"Loading model: {name}")
|
| 210 |
start_time = time.time()
|
| 211 |
|
| 212 |
model_instance = model_info.loader()
|
|
|
|
| 215 |
model_info.is_loaded = True
|
| 216 |
model_info.last_used = time.time()
|
| 217 |
|
| 218 |
+
if update_priority is not None:
|
| 219 |
+
model_info.priority = update_priority
|
| 220 |
+
|
| 221 |
+
# Track active pipeline
|
| 222 |
+
if model_info.model_group == self.PIPELINE_GROUP:
|
| 223 |
+
self._active_pipeline = name
|
| 224 |
+
|
| 225 |
load_time = time.time() - start_time
|
| 226 |
+
logger.info(f"Model '{name}' loaded in {load_time:.1f}s")
|
| 227 |
|
| 228 |
return model_instance
|
| 229 |
|
| 230 |
except Exception as e:
|
| 231 |
+
logger.error(f"Failed to load model '{name}': {e}")
|
| 232 |
raise RuntimeError(f"Model loading failed: {e}")
|
| 233 |
|
| 234 |
+
def _ensure_pipeline_exclusion(self, new_pipeline: str) -> None:
|
| 235 |
+
"""
|
| 236 |
+
Ensure only one pipeline is loaded at a time.
|
| 237 |
+
|
| 238 |
+
Unloads any existing pipeline before loading a new one.
|
| 239 |
+
|
| 240 |
+
Parameters
|
| 241 |
+
----------
|
| 242 |
+
new_pipeline : str
|
| 243 |
+
Name of the pipeline about to be loaded
|
| 244 |
+
"""
|
| 245 |
+
for name, info in self._models.items():
|
| 246 |
+
if (info.model_group == self.PIPELINE_GROUP and
|
| 247 |
+
info.is_loaded and
|
| 248 |
+
name != new_pipeline):
|
| 249 |
+
logger.info(f"Unloading {name} to make room for {new_pipeline}")
|
| 250 |
+
self.unload_model(name)
|
| 251 |
+
|
| 252 |
+
def unload_model(self, name: str) -> bool:
|
| 253 |
"""
|
| 254 |
Unload a specific model to free memory.
|
| 255 |
|
| 256 |
+
Parameters
|
| 257 |
+
----------
|
| 258 |
+
name : str
|
| 259 |
+
Model identifier
|
| 260 |
+
|
| 261 |
+
Returns
|
| 262 |
+
-------
|
| 263 |
+
bool
|
| 264 |
+
True if model was unloaded successfully
|
| 265 |
"""
|
| 266 |
if name not in self._models:
|
| 267 |
+
return False
|
| 268 |
|
| 269 |
model_info = self._models[name]
|
| 270 |
|
| 271 |
if not model_info.is_loaded:
|
| 272 |
+
return True
|
| 273 |
|
| 274 |
try:
|
| 275 |
+
logger.info(f"Unloading model: {name}")
|
| 276 |
|
| 277 |
# Delete model instance
|
| 278 |
if model_info.model_instance is not None:
|
|
|
|
| 281 |
model_info.model_instance = None
|
| 282 |
model_info.is_loaded = False
|
| 283 |
|
| 284 |
+
# Update active pipeline tracking
|
| 285 |
+
if self._active_pipeline == name:
|
| 286 |
+
self._active_pipeline = None
|
| 287 |
+
|
| 288 |
# Cleanup
|
| 289 |
gc.collect()
|
| 290 |
if torch.cuda.is_available():
|
| 291 |
torch.cuda.empty_cache()
|
| 292 |
+
torch.cuda.ipc_collect()
|
| 293 |
|
| 294 |
+
logger.info(f"Model '{name}' unloaded")
|
| 295 |
+
return True
|
| 296 |
|
| 297 |
except Exception as e:
|
| 298 |
+
logger.error(f"Error unloading model '{name}': {e}")
|
| 299 |
+
return False
|
| 300 |
|
| 301 |
def check_memory_pressure(self) -> bool:
|
| 302 |
"""
|
| 303 |
+
Check GPU memory usage and unload low-priority models if needed.
|
| 304 |
|
| 305 |
+
Uses priority-based eviction: lower priority models are unloaded first,
|
| 306 |
+
then falls back to least-recently-used within same priority tier.
|
| 307 |
+
|
| 308 |
+
Returns
|
| 309 |
+
-------
|
| 310 |
+
bool
|
| 311 |
True if cleanup was performed
|
| 312 |
"""
|
| 313 |
if not torch.cuda.is_available():
|
|
|
|
| 320 |
if usage_ratio < self._memory_threshold:
|
| 321 |
return False
|
| 322 |
|
| 323 |
+
logger.warning(f"Memory pressure detected: {usage_ratio:.1%} used")
|
| 324 |
|
| 325 |
+
# Find evictable models (not critical, loaded)
|
| 326 |
+
# Sort by priority (ascending) then by last_used (ascending)
|
| 327 |
+
evictable = [
|
| 328 |
(name, info) for name, info in self._models.items()
|
| 329 |
+
if info.is_loaded and info.priority < ModelPriority.CRITICAL
|
| 330 |
]
|
| 331 |
+
evictable.sort(key=lambda x: (x[1].priority, x[1].last_used))
|
| 332 |
|
| 333 |
+
# Unload models starting from lowest priority
|
| 334 |
cleaned = False
|
| 335 |
+
for name, info in evictable:
|
| 336 |
self.unload_model(name)
|
| 337 |
cleaned = True
|
| 338 |
|
|
|
|
| 343 |
|
| 344 |
return cleaned
|
| 345 |
|
| 346 |
+
def force_cleanup(self, keep_critical_only: bool = True):
|
| 347 |
+
"""
|
| 348 |
+
Force cleanup models and clear caches.
|
| 349 |
+
|
| 350 |
+
Parameters
|
| 351 |
+
----------
|
| 352 |
+
keep_critical_only : bool
|
| 353 |
+
If True, only keep CRITICAL priority models loaded
|
| 354 |
+
"""
|
| 355 |
+
logger.info("Force cleanup initiated")
|
| 356 |
|
| 357 |
+
# Unload models based on priority
|
| 358 |
+
threshold = ModelPriority.CRITICAL if keep_critical_only else ModelPriority.HIGH
|
| 359 |
+
for name, info in list(self._models.items()):
|
| 360 |
+
if info.is_loaded and info.priority < threshold:
|
| 361 |
self.unload_model(name)
|
| 362 |
|
| 363 |
# Aggressive garbage collection
|
|
|
|
| 369 |
torch.cuda.ipc_collect()
|
| 370 |
torch.cuda.synchronize()
|
| 371 |
|
| 372 |
+
logger.info("Force cleanup completed")
|
| 373 |
+
|
| 374 |
+
def update_priority(self, name: str, priority: int) -> bool:
|
| 375 |
+
"""
|
| 376 |
+
Update a model's priority level.
|
| 377 |
+
|
| 378 |
+
Parameters
|
| 379 |
+
----------
|
| 380 |
+
name : str
|
| 381 |
+
Model identifier
|
| 382 |
+
priority : int
|
| 383 |
+
New priority level
|
| 384 |
+
|
| 385 |
+
Returns
|
| 386 |
+
-------
|
| 387 |
+
bool
|
| 388 |
+
True if priority was updated
|
| 389 |
+
"""
|
| 390 |
+
if name not in self._models:
|
| 391 |
+
return False
|
| 392 |
+
|
| 393 |
+
self._models[name].priority = priority
|
| 394 |
+
logger.debug(f"Updated priority for {name} to {priority}")
|
| 395 |
+
return True
|
| 396 |
+
|
| 397 |
+
def get_active_pipeline(self) -> Optional[str]:
|
| 398 |
+
"""
|
| 399 |
+
Get the name of currently active pipeline.
|
| 400 |
+
|
| 401 |
+
Returns
|
| 402 |
+
-------
|
| 403 |
+
str or None
|
| 404 |
+
Name of active pipeline, or None if no pipeline is loaded
|
| 405 |
+
"""
|
| 406 |
+
return self._active_pipeline
|
| 407 |
+
|
| 408 |
+
def switch_to_pipeline(
|
| 409 |
+
self,
|
| 410 |
+
name: str,
|
| 411 |
+
loader: Optional[Callable[[], Any]] = None
|
| 412 |
+
) -> Any:
|
| 413 |
+
"""
|
| 414 |
+
Switch to a different pipeline, unloading current one.
|
| 415 |
+
|
| 416 |
+
This is a convenience method for pipeline switching that handles
|
| 417 |
+
mutual exclusion automatically.
|
| 418 |
+
|
| 419 |
+
Parameters
|
| 420 |
+
----------
|
| 421 |
+
name : str
|
| 422 |
+
Pipeline name to switch to
|
| 423 |
+
loader : callable, optional
|
| 424 |
+
Loader function if pipeline not already registered
|
| 425 |
+
|
| 426 |
+
Returns
|
| 427 |
+
-------
|
| 428 |
+
Any
|
| 429 |
+
The loaded pipeline instance
|
| 430 |
+
|
| 431 |
+
Raises
|
| 432 |
+
------
|
| 433 |
+
KeyError
|
| 434 |
+
If pipeline not registered and no loader provided
|
| 435 |
+
"""
|
| 436 |
+
# Register if needed
|
| 437 |
+
if name not in self._models and loader is not None:
|
| 438 |
+
self.register_model(
|
| 439 |
+
name=name,
|
| 440 |
+
loader=loader,
|
| 441 |
+
priority=ModelPriority.HIGH,
|
| 442 |
+
model_group=self.PIPELINE_GROUP
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# Load will handle unloading of current pipeline
|
| 446 |
+
return self.load_model(name, update_priority=ModelPriority.HIGH)
|
| 447 |
|
| 448 |
def get_memory_status(self) -> Dict[str, Any]:
|
| 449 |
"""
|
quality_checker.py
CHANGED
|
@@ -2,7 +2,7 @@ import logging
|
|
| 2 |
import numpy as np
|
| 3 |
import cv2
|
| 4 |
from PIL import Image
|
| 5 |
-
from typing import Dict, Any, Tuple, Optional
|
| 6 |
from dataclasses import dataclass
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
|
@@ -407,3 +407,452 @@ class QualityChecker:
|
|
| 407 |
summary += f"\nNotes: {'; '.join(results['warnings'])}"
|
| 408 |
|
| 409 |
return summary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import cv2
|
| 4 |
from PIL import Image
|
| 5 |
+
from typing import Dict, Any, Tuple, Optional, List
|
| 6 |
from dataclasses import dataclass
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
|
|
|
| 407 |
summary += f"\nNotes: {'; '.join(results['warnings'])}"
|
| 408 |
|
| 409 |
return summary
|
| 410 |
+
|
| 411 |
+
# =========================================================================
|
| 412 |
+
# INPAINTING-SPECIFIC QUALITY CHECKS
|
| 413 |
+
# =========================================================================
|
| 414 |
+
|
| 415 |
+
def check_inpainting_edge_continuity(
|
| 416 |
+
self,
|
| 417 |
+
original: Image.Image,
|
| 418 |
+
inpainted: Image.Image,
|
| 419 |
+
mask: Image.Image,
|
| 420 |
+
ring_width: int = 5
|
| 421 |
+
) -> QualityResult:
|
| 422 |
+
"""
|
| 423 |
+
Check edge continuity at inpainting boundary.
|
| 424 |
+
|
| 425 |
+
Calculates color distribution similarity between the ring zones
|
| 426 |
+
on each side of the mask boundary in Lab color space.
|
| 427 |
+
|
| 428 |
+
Parameters
|
| 429 |
+
----------
|
| 430 |
+
original : PIL.Image
|
| 431 |
+
Original image before inpainting
|
| 432 |
+
inpainted : PIL.Image
|
| 433 |
+
Result after inpainting
|
| 434 |
+
mask : PIL.Image
|
| 435 |
+
Inpainting mask (white = inpainted area)
|
| 436 |
+
ring_width : int
|
| 437 |
+
Width in pixels for the ring zones on each side
|
| 438 |
+
|
| 439 |
+
Returns
|
| 440 |
+
-------
|
| 441 |
+
QualityResult
|
| 442 |
+
Edge continuity assessment
|
| 443 |
+
"""
|
| 444 |
+
try:
|
| 445 |
+
# Convert to arrays
|
| 446 |
+
orig_array = np.array(original.convert('RGB'))
|
| 447 |
+
inpaint_array = np.array(inpainted.convert('RGB'))
|
| 448 |
+
mask_array = np.array(mask.convert('L'))
|
| 449 |
+
|
| 450 |
+
# Find boundary using morphological gradient
|
| 451 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 452 |
+
dilated = cv2.dilate(mask_array, kernel, iterations=ring_width)
|
| 453 |
+
eroded = cv2.erode(mask_array, kernel, iterations=ring_width)
|
| 454 |
+
|
| 455 |
+
# Inner ring (inside inpainted region, near boundary)
|
| 456 |
+
inner_ring = (mask_array > 127) & (eroded <= 127)
|
| 457 |
+
|
| 458 |
+
# Outer ring (outside inpainted region, near boundary)
|
| 459 |
+
outer_ring = (mask_array <= 127) & (dilated > 127)
|
| 460 |
+
|
| 461 |
+
if not np.any(inner_ring) or not np.any(outer_ring):
|
| 462 |
+
return QualityResult(
|
| 463 |
+
score=50,
|
| 464 |
+
passed=True,
|
| 465 |
+
issue="Unable to detect boundary rings",
|
| 466 |
+
details={"ring_width": ring_width}
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
# Convert to Lab for perceptual comparison
|
| 470 |
+
inpaint_lab = cv2.cvtColor(inpaint_array, cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 471 |
+
|
| 472 |
+
# Get Lab values for each ring from the inpainted image
|
| 473 |
+
inner_lab = inpaint_lab[inner_ring]
|
| 474 |
+
outer_lab = inpaint_lab[outer_ring]
|
| 475 |
+
|
| 476 |
+
# Calculate statistics for each channel
|
| 477 |
+
inner_mean = np.mean(inner_lab, axis=0)
|
| 478 |
+
outer_mean = np.mean(outer_lab, axis=0)
|
| 479 |
+
inner_std = np.std(inner_lab, axis=0)
|
| 480 |
+
outer_std = np.std(outer_lab, axis=0)
|
| 481 |
+
|
| 482 |
+
# Calculate differences
|
| 483 |
+
mean_diff = np.abs(inner_mean - outer_mean)
|
| 484 |
+
std_diff = np.abs(inner_std - outer_std)
|
| 485 |
+
|
| 486 |
+
# Calculate Delta E (simplified)
|
| 487 |
+
delta_e = np.sqrt(np.sum(mean_diff ** 2))
|
| 488 |
+
|
| 489 |
+
# Score calculation
|
| 490 |
+
# Low Delta E = good continuity
|
| 491 |
+
# Target: Delta E < 10 is excellent, < 20 is good
|
| 492 |
+
if delta_e < 5:
|
| 493 |
+
continuity_score = 100
|
| 494 |
+
elif delta_e < 10:
|
| 495 |
+
continuity_score = 90
|
| 496 |
+
elif delta_e < 20:
|
| 497 |
+
continuity_score = 75
|
| 498 |
+
elif delta_e < 30:
|
| 499 |
+
continuity_score = 60
|
| 500 |
+
elif delta_e < 50:
|
| 501 |
+
continuity_score = 40
|
| 502 |
+
else:
|
| 503 |
+
continuity_score = max(20, 100 - delta_e)
|
| 504 |
+
|
| 505 |
+
# Penalize for large std differences (inconsistent textures)
|
| 506 |
+
std_penalty = min(20, np.mean(std_diff) * 0.5)
|
| 507 |
+
final_score = max(0, continuity_score - std_penalty)
|
| 508 |
+
|
| 509 |
+
passed = final_score >= 60
|
| 510 |
+
issue = ""
|
| 511 |
+
|
| 512 |
+
if final_score < 60:
|
| 513 |
+
if delta_e > 30:
|
| 514 |
+
issue = f"Visible color discontinuity at boundary (Delta E: {delta_e:.1f})"
|
| 515 |
+
elif np.mean(std_diff) > 20:
|
| 516 |
+
issue = "Texture mismatch at boundary"
|
| 517 |
+
else:
|
| 518 |
+
issue = "Poor edge blending"
|
| 519 |
+
|
| 520 |
+
return QualityResult(
|
| 521 |
+
score=final_score,
|
| 522 |
+
passed=passed,
|
| 523 |
+
issue=issue,
|
| 524 |
+
details={
|
| 525 |
+
"delta_e": delta_e,
|
| 526 |
+
"mean_diff_l": mean_diff[0],
|
| 527 |
+
"mean_diff_a": mean_diff[1],
|
| 528 |
+
"mean_diff_b": mean_diff[2],
|
| 529 |
+
"std_diff_avg": np.mean(std_diff),
|
| 530 |
+
"inner_pixels": np.count_nonzero(inner_ring),
|
| 531 |
+
"outer_pixels": np.count_nonzero(outer_ring)
|
| 532 |
+
}
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
except Exception as e:
|
| 536 |
+
logger.error(f"Inpainting edge continuity check failed: {e}")
|
| 537 |
+
return QualityResult(score=0, passed=False, issue=str(e), details={})
|
| 538 |
+
|
| 539 |
+
def check_inpainting_color_harmony(
|
| 540 |
+
self,
|
| 541 |
+
original: Image.Image,
|
| 542 |
+
inpainted: Image.Image,
|
| 543 |
+
mask: Image.Image
|
| 544 |
+
) -> QualityResult:
|
| 545 |
+
"""
|
| 546 |
+
Check color harmony between inpainted region and surrounding area.
|
| 547 |
+
|
| 548 |
+
Compares color statistics of the inpainted region with adjacent
|
| 549 |
+
non-inpainted regions to assess visual coherence.
|
| 550 |
+
|
| 551 |
+
Parameters
|
| 552 |
+
----------
|
| 553 |
+
original : PIL.Image
|
| 554 |
+
Original image
|
| 555 |
+
inpainted : PIL.Image
|
| 556 |
+
Inpainted result
|
| 557 |
+
mask : PIL.Image
|
| 558 |
+
Inpainting mask
|
| 559 |
+
|
| 560 |
+
Returns
|
| 561 |
+
-------
|
| 562 |
+
QualityResult
|
| 563 |
+
Color harmony assessment
|
| 564 |
+
"""
|
| 565 |
+
try:
|
| 566 |
+
inpaint_array = np.array(inpainted.convert('RGB'))
|
| 567 |
+
mask_array = np.array(mask.convert('L'))
|
| 568 |
+
|
| 569 |
+
# Define regions
|
| 570 |
+
inpaint_region = mask_array > 127
|
| 571 |
+
|
| 572 |
+
# Get adjacent region (dilated mask minus original mask)
|
| 573 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
|
| 574 |
+
dilated = cv2.dilate(mask_array, kernel, iterations=2)
|
| 575 |
+
adjacent_region = (dilated > 127) & (mask_array <= 127)
|
| 576 |
+
|
| 577 |
+
if not np.any(inpaint_region) or not np.any(adjacent_region):
|
| 578 |
+
return QualityResult(
|
| 579 |
+
score=50,
|
| 580 |
+
passed=True,
|
| 581 |
+
issue="Insufficient regions for comparison",
|
| 582 |
+
details={}
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# Convert to Lab
|
| 586 |
+
inpaint_lab = cv2.cvtColor(inpaint_array, cv2.COLOR_RGB2LAB).astype(np.float32)
|
| 587 |
+
|
| 588 |
+
# Extract region colors
|
| 589 |
+
inpaint_colors = inpaint_lab[inpaint_region]
|
| 590 |
+
adjacent_colors = inpaint_lab[adjacent_region]
|
| 591 |
+
|
| 592 |
+
# Calculate color statistics
|
| 593 |
+
inpaint_mean = np.mean(inpaint_colors, axis=0)
|
| 594 |
+
adjacent_mean = np.mean(adjacent_colors, axis=0)
|
| 595 |
+
|
| 596 |
+
inpaint_std = np.std(inpaint_colors, axis=0)
|
| 597 |
+
adjacent_std = np.std(adjacent_colors, axis=0)
|
| 598 |
+
|
| 599 |
+
# Color histogram comparison
|
| 600 |
+
hist_scores = []
|
| 601 |
+
for i in range(3): # L, a, b channels
|
| 602 |
+
hist_inpaint, _ = np.histogram(
|
| 603 |
+
inpaint_colors[:, i], bins=32, range=(0, 255)
|
| 604 |
+
)
|
| 605 |
+
hist_adjacent, _ = np.histogram(
|
| 606 |
+
adjacent_colors[:, i], bins=32, range=(0, 255)
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
# Normalize
|
| 610 |
+
hist_inpaint = hist_inpaint.astype(np.float32) / (np.sum(hist_inpaint) + 1e-6)
|
| 611 |
+
hist_adjacent = hist_adjacent.astype(np.float32) / (np.sum(hist_adjacent) + 1e-6)
|
| 612 |
+
|
| 613 |
+
# Bhattacharyya coefficient (1 = identical, 0 = completely different)
|
| 614 |
+
bc = np.sum(np.sqrt(hist_inpaint * hist_adjacent))
|
| 615 |
+
hist_scores.append(bc)
|
| 616 |
+
|
| 617 |
+
avg_hist_score = np.mean(hist_scores)
|
| 618 |
+
|
| 619 |
+
# Calculate harmony score
|
| 620 |
+
mean_diff = np.linalg.norm(inpaint_mean - adjacent_mean)
|
| 621 |
+
|
| 622 |
+
if mean_diff < 10 and avg_hist_score > 0.8:
|
| 623 |
+
harmony_score = 100
|
| 624 |
+
elif mean_diff < 20 and avg_hist_score > 0.7:
|
| 625 |
+
harmony_score = 85
|
| 626 |
+
elif mean_diff < 30 and avg_hist_score > 0.6:
|
| 627 |
+
harmony_score = 70
|
| 628 |
+
elif mean_diff < 50:
|
| 629 |
+
harmony_score = 55
|
| 630 |
+
else:
|
| 631 |
+
harmony_score = max(30, 100 - mean_diff)
|
| 632 |
+
|
| 633 |
+
# Boost score if histogram similarity is high
|
| 634 |
+
histogram_bonus = (avg_hist_score - 0.5) * 20 # -10 to +10
|
| 635 |
+
final_score = max(0, min(100, harmony_score + histogram_bonus))
|
| 636 |
+
|
| 637 |
+
passed = final_score >= 60
|
| 638 |
+
issue = ""
|
| 639 |
+
|
| 640 |
+
if final_score < 60:
|
| 641 |
+
if mean_diff > 40:
|
| 642 |
+
issue = "Significant color mismatch with surrounding area"
|
| 643 |
+
elif avg_hist_score < 0.5:
|
| 644 |
+
issue = "Color distribution differs from context"
|
| 645 |
+
else:
|
| 646 |
+
issue = "Poor color integration"
|
| 647 |
+
|
| 648 |
+
return QualityResult(
|
| 649 |
+
score=final_score,
|
| 650 |
+
passed=passed,
|
| 651 |
+
issue=issue,
|
| 652 |
+
details={
|
| 653 |
+
"mean_color_diff": mean_diff,
|
| 654 |
+
"histogram_similarity": avg_hist_score,
|
| 655 |
+
"inpaint_luminance": inpaint_mean[0],
|
| 656 |
+
"adjacent_luminance": adjacent_mean[0]
|
| 657 |
+
}
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
except Exception as e:
|
| 661 |
+
logger.error(f"Inpainting color harmony check failed: {e}")
|
| 662 |
+
return QualityResult(score=0, passed=False, issue=str(e), details={})
|
| 663 |
+
|
| 664 |
+
def check_inpainting_artifact_detection(
|
| 665 |
+
self,
|
| 666 |
+
inpainted: Image.Image,
|
| 667 |
+
mask: Image.Image
|
| 668 |
+
) -> QualityResult:
|
| 669 |
+
"""
|
| 670 |
+
Detect common inpainting artifacts like blurriness or color bleeding.
|
| 671 |
+
|
| 672 |
+
Parameters
|
| 673 |
+
----------
|
| 674 |
+
inpainted : PIL.Image
|
| 675 |
+
Inpainted result
|
| 676 |
+
mask : PIL.Image
|
| 677 |
+
Inpainting mask
|
| 678 |
+
|
| 679 |
+
Returns
|
| 680 |
+
-------
|
| 681 |
+
QualityResult
|
| 682 |
+
Artifact detection results
|
| 683 |
+
"""
|
| 684 |
+
try:
|
| 685 |
+
inpaint_array = np.array(inpainted.convert('RGB'))
|
| 686 |
+
mask_array = np.array(mask.convert('L'))
|
| 687 |
+
|
| 688 |
+
inpaint_region = mask_array > 127
|
| 689 |
+
|
| 690 |
+
if not np.any(inpaint_region):
|
| 691 |
+
return QualityResult(
|
| 692 |
+
score=50,
|
| 693 |
+
passed=True,
|
| 694 |
+
issue="No inpainted region detected",
|
| 695 |
+
details={}
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
# Extract inpainted region pixels
|
| 699 |
+
gray = cv2.cvtColor(inpaint_array, cv2.COLOR_RGB2GRAY)
|
| 700 |
+
|
| 701 |
+
# Calculate sharpness (Laplacian variance)
|
| 702 |
+
laplacian = cv2.Laplacian(gray, cv2.CV_64F)
|
| 703 |
+
inpaint_laplacian = laplacian[inpaint_region]
|
| 704 |
+
sharpness = np.var(inpaint_laplacian)
|
| 705 |
+
|
| 706 |
+
# Get surrounding region for comparison
|
| 707 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
|
| 708 |
+
dilated = cv2.dilate(mask_array, kernel, iterations=1)
|
| 709 |
+
surrounding = (dilated > 127) & (mask_array <= 127)
|
| 710 |
+
|
| 711 |
+
if np.any(surrounding):
|
| 712 |
+
surrounding_laplacian = laplacian[surrounding]
|
| 713 |
+
surrounding_sharpness = np.var(surrounding_laplacian)
|
| 714 |
+
sharpness_ratio = sharpness / (surrounding_sharpness + 1e-6)
|
| 715 |
+
else:
|
| 716 |
+
sharpness_ratio = 1.0
|
| 717 |
+
|
| 718 |
+
# Check for color bleeding (abnormal saturation at edges)
|
| 719 |
+
hsv = cv2.cvtColor(inpaint_array, cv2.COLOR_RGB2HSV)
|
| 720 |
+
saturation = hsv[:, :, 1]
|
| 721 |
+
|
| 722 |
+
# Find boundary pixels
|
| 723 |
+
boundary_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 724 |
+
boundary = cv2.morphologyEx(mask_array, cv2.MORPH_GRADIENT, boundary_kernel) > 0
|
| 725 |
+
|
| 726 |
+
if np.any(boundary):
|
| 727 |
+
boundary_saturation = saturation[boundary]
|
| 728 |
+
saturation_std = np.std(boundary_saturation)
|
| 729 |
+
else:
|
| 730 |
+
saturation_std = 0
|
| 731 |
+
|
| 732 |
+
# Calculate score
|
| 733 |
+
sharpness_score = 100
|
| 734 |
+
if sharpness_ratio < 0.3:
|
| 735 |
+
sharpness_score = 40 # Much blurrier than surroundings
|
| 736 |
+
elif sharpness_ratio < 0.6:
|
| 737 |
+
sharpness_score = 60
|
| 738 |
+
elif sharpness_ratio < 0.8:
|
| 739 |
+
sharpness_score = 80
|
| 740 |
+
|
| 741 |
+
bleeding_penalty = min(20, saturation_std * 0.5)
|
| 742 |
+
|
| 743 |
+
final_score = max(0, sharpness_score - bleeding_penalty)
|
| 744 |
+
passed = final_score >= 60
|
| 745 |
+
|
| 746 |
+
issue = ""
|
| 747 |
+
if sharpness_ratio < 0.5:
|
| 748 |
+
issue = "Inpainted region appears blurry"
|
| 749 |
+
elif saturation_std > 40:
|
| 750 |
+
issue = "Possible color bleeding at edges"
|
| 751 |
+
elif final_score < 60:
|
| 752 |
+
issue = "Detected visual artifacts"
|
| 753 |
+
|
| 754 |
+
return QualityResult(
|
| 755 |
+
score=final_score,
|
| 756 |
+
passed=passed,
|
| 757 |
+
issue=issue,
|
| 758 |
+
details={
|
| 759 |
+
"sharpness": sharpness,
|
| 760 |
+
"sharpness_ratio": sharpness_ratio,
|
| 761 |
+
"boundary_saturation_std": saturation_std
|
| 762 |
+
}
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
except Exception as e:
|
| 766 |
+
logger.error(f"Inpainting artifact detection failed: {e}")
|
| 767 |
+
return QualityResult(score=0, passed=False, issue=str(e), details={})
|
| 768 |
+
|
| 769 |
+
def run_inpainting_checks(
|
| 770 |
+
self,
|
| 771 |
+
original: Image.Image,
|
| 772 |
+
inpainted: Image.Image,
|
| 773 |
+
mask: Image.Image
|
| 774 |
+
) -> Dict[str, Any]:
|
| 775 |
+
"""
|
| 776 |
+
Run all inpainting-specific quality checks.
|
| 777 |
+
|
| 778 |
+
Parameters
|
| 779 |
+
----------
|
| 780 |
+
original : PIL.Image
|
| 781 |
+
Original image before inpainting
|
| 782 |
+
inpainted : PIL.Image
|
| 783 |
+
Result after inpainting
|
| 784 |
+
mask : PIL.Image
|
| 785 |
+
Inpainting mask
|
| 786 |
+
|
| 787 |
+
Returns
|
| 788 |
+
-------
|
| 789 |
+
dict
|
| 790 |
+
Comprehensive quality assessment for inpainting
|
| 791 |
+
"""
|
| 792 |
+
logger.info("Running inpainting quality checks...")
|
| 793 |
+
|
| 794 |
+
results = {
|
| 795 |
+
"checks": {},
|
| 796 |
+
"overall_score": 0,
|
| 797 |
+
"passed": True,
|
| 798 |
+
"warnings": [],
|
| 799 |
+
"errors": []
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
# Run inpainting-specific checks
|
| 803 |
+
edge_result = self.check_inpainting_edge_continuity(original, inpainted, mask)
|
| 804 |
+
results["checks"]["edge_continuity"] = {
|
| 805 |
+
"score": edge_result.score,
|
| 806 |
+
"passed": edge_result.passed,
|
| 807 |
+
"issue": edge_result.issue,
|
| 808 |
+
"details": edge_result.details
|
| 809 |
+
}
|
| 810 |
+
|
| 811 |
+
harmony_result = self.check_inpainting_color_harmony(original, inpainted, mask)
|
| 812 |
+
results["checks"]["color_harmony"] = {
|
| 813 |
+
"score": harmony_result.score,
|
| 814 |
+
"passed": harmony_result.passed,
|
| 815 |
+
"issue": harmony_result.issue,
|
| 816 |
+
"details": harmony_result.details
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
artifact_result = self.check_inpainting_artifact_detection(inpainted, mask)
|
| 820 |
+
results["checks"]["artifact_detection"] = {
|
| 821 |
+
"score": artifact_result.score,
|
| 822 |
+
"passed": artifact_result.passed,
|
| 823 |
+
"issue": artifact_result.issue,
|
| 824 |
+
"details": artifact_result.details
|
| 825 |
+
}
|
| 826 |
+
|
| 827 |
+
# Calculate overall score (weighted)
|
| 828 |
+
weights = {
|
| 829 |
+
"edge_continuity": 0.4,
|
| 830 |
+
"color_harmony": 0.35,
|
| 831 |
+
"artifact_detection": 0.25
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
total_score = (
|
| 835 |
+
edge_result.score * weights["edge_continuity"] +
|
| 836 |
+
harmony_result.score * weights["color_harmony"] +
|
| 837 |
+
artifact_result.score * weights["artifact_detection"]
|
| 838 |
+
)
|
| 839 |
+
results["overall_score"] = round(total_score, 1)
|
| 840 |
+
|
| 841 |
+
# Determine overall pass/fail
|
| 842 |
+
results["passed"] = all([
|
| 843 |
+
edge_result.passed,
|
| 844 |
+
harmony_result.passed,
|
| 845 |
+
artifact_result.passed
|
| 846 |
+
])
|
| 847 |
+
|
| 848 |
+
# Collect issues
|
| 849 |
+
for check_name, check_data in results["checks"].items():
|
| 850 |
+
if check_data["issue"]:
|
| 851 |
+
if check_data["passed"]:
|
| 852 |
+
results["warnings"].append(f"{check_name}: {check_data['issue']}")
|
| 853 |
+
else:
|
| 854 |
+
results["errors"].append(f"{check_name}: {check_data['issue']}")
|
| 855 |
+
|
| 856 |
+
logger.info(f"Inpainting quality: {results['overall_score']:.1f}, Passed: {results['passed']}")
|
| 857 |
+
|
| 858 |
+
return results
|
scene_templates.py
CHANGED
|
@@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
| 7 |
-
|
| 8 |
@dataclass
|
| 9 |
class SceneTemplate:
|
| 10 |
"""Data class representing a scene template"""
|
|
@@ -25,7 +24,7 @@ class SceneTemplateManager:
|
|
| 25 |
|
| 26 |
# Scene template definitions
|
| 27 |
TEMPLATES: Dict[str, SceneTemplate] = {
|
| 28 |
-
#
|
| 29 |
"office_modern": SceneTemplate(
|
| 30 |
key="office_modern",
|
| 31 |
name="Modern Office",
|
|
@@ -72,7 +71,7 @@ class SceneTemplateManager:
|
|
| 72 |
guidance_scale=7.5
|
| 73 |
),
|
| 74 |
|
| 75 |
-
#
|
| 76 |
"beach_sunset": SceneTemplate(
|
| 77 |
key="beach_sunset",
|
| 78 |
name="Sunset Beach",
|
|
@@ -128,7 +127,7 @@ class SceneTemplateManager:
|
|
| 128 |
guidance_scale=7.0
|
| 129 |
),
|
| 130 |
|
| 131 |
-
#
|
| 132 |
"city_skyline": SceneTemplate(
|
| 133 |
key="city_skyline",
|
| 134 |
name="City Skyline",
|
|
@@ -175,7 +174,7 @@ class SceneTemplateManager:
|
|
| 175 |
guidance_scale=7.5
|
| 176 |
),
|
| 177 |
|
| 178 |
-
#
|
| 179 |
"gradient_soft": SceneTemplate(
|
| 180 |
key="gradient_soft",
|
| 181 |
name="Soft Gradient",
|
|
@@ -213,7 +212,7 @@ class SceneTemplateManager:
|
|
| 213 |
guidance_scale=6.5
|
| 214 |
),
|
| 215 |
|
| 216 |
-
#
|
| 217 |
"autumn_foliage": SceneTemplate(
|
| 218 |
key="autumn_foliage",
|
| 219 |
name="Autumn Foliage",
|
|
|
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
|
|
|
| 7 |
@dataclass
|
| 8 |
class SceneTemplate:
|
| 9 |
"""Data class representing a scene template"""
|
|
|
|
| 24 |
|
| 25 |
# Scene template definitions
|
| 26 |
TEMPLATES: Dict[str, SceneTemplate] = {
|
| 27 |
+
# Professional Category
|
| 28 |
"office_modern": SceneTemplate(
|
| 29 |
key="office_modern",
|
| 30 |
name="Modern Office",
|
|
|
|
| 71 |
guidance_scale=7.5
|
| 72 |
),
|
| 73 |
|
| 74 |
+
# Nature Category
|
| 75 |
"beach_sunset": SceneTemplate(
|
| 76 |
key="beach_sunset",
|
| 77 |
name="Sunset Beach",
|
|
|
|
| 127 |
guidance_scale=7.0
|
| 128 |
),
|
| 129 |
|
| 130 |
+
# Urban Category
|
| 131 |
"city_skyline": SceneTemplate(
|
| 132 |
key="city_skyline",
|
| 133 |
name="City Skyline",
|
|
|
|
| 174 |
guidance_scale=7.5
|
| 175 |
),
|
| 176 |
|
| 177 |
+
# Artistic Category
|
| 178 |
"gradient_soft": SceneTemplate(
|
| 179 |
key="gradient_soft",
|
| 180 |
name="Soft Gradient",
|
|
|
|
| 212 |
guidance_scale=6.5
|
| 213 |
),
|
| 214 |
|
| 215 |
+
# Seasonal Category
|
| 216 |
"autumn_foliage": SceneTemplate(
|
| 217 |
key="autumn_foliage",
|
| 218 |
name="Autumn Foliage",
|
scene_weaver_core.py
CHANGED
|
@@ -5,25 +5,47 @@ from PIL import Image
|
|
| 5 |
import logging
|
| 6 |
import gc
|
| 7 |
import time
|
| 8 |
-
from typing import Optional, Dict, Any, Tuple, List
|
| 9 |
from pathlib import Path
|
| 10 |
import warnings
|
| 11 |
warnings.filterwarnings("ignore")
|
| 12 |
|
| 13 |
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
| 14 |
import open_clip
|
|
|
|
| 15 |
from mask_generator import MaskGenerator
|
| 16 |
from image_blender import ImageBlender
|
| 17 |
from quality_checker import QualityChecker
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
logger.setLevel(logging.INFO)
|
| 21 |
|
| 22 |
class SceneWeaverCore:
|
| 23 |
"""
|
| 24 |
-
SceneWeaver
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
"""
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# Style presets for diversity generation mode
|
| 28 |
STYLE_PRESETS = {
|
| 29 |
"professional": {
|
|
@@ -82,7 +104,17 @@ class SceneWeaverCore:
|
|
| 82 |
self.image_blender = ImageBlender()
|
| 83 |
self.quality_checker = QualityChecker()
|
| 84 |
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
def _setup_device(self, device: str) -> str:
|
| 88 |
"""Setup computation device"""
|
|
@@ -277,7 +309,7 @@ class SceneWeaverCore:
|
|
| 277 |
# Analyze image characteristics
|
| 278 |
img_array = np.array(foreground_image.convert('RGB'))
|
| 279 |
|
| 280 |
-
#
|
| 281 |
# Convert to LAB to analyze color temperature
|
| 282 |
lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
|
| 283 |
avg_a = np.mean(lab[:, :, 1]) # a channel: green(-) to red(+)
|
|
@@ -286,12 +318,12 @@ class SceneWeaverCore:
|
|
| 286 |
# Determine warm/cool tone
|
| 287 |
is_warm = avg_b > 128 # b > 128 means more yellow/warm
|
| 288 |
|
| 289 |
-
#
|
| 290 |
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 291 |
avg_brightness = np.mean(gray)
|
| 292 |
is_bright = avg_brightness > 127
|
| 293 |
|
| 294 |
-
#
|
| 295 |
clip_analysis = self.analyze_image_with_clip(foreground_image)
|
| 296 |
subject_type = "unknown"
|
| 297 |
|
|
@@ -306,7 +338,7 @@ class SceneWeaverCore:
|
|
| 306 |
elif "nature" in clip_analysis.lower() or "landscape" in clip_analysis.lower():
|
| 307 |
subject_type = "nature"
|
| 308 |
|
| 309 |
-
#
|
| 310 |
lighting_options = {
|
| 311 |
"warm_bright": "warm golden hour lighting, soft natural light",
|
| 312 |
"warm_dark": "warm ambient lighting, cozy atmosphere",
|
|
@@ -325,7 +357,7 @@ class SceneWeaverCore:
|
|
| 325 |
|
| 326 |
quality_modifiers = "high quality, detailed, sharp focus, photorealistic"
|
| 327 |
|
| 328 |
-
#
|
| 329 |
# Lighting based on color temperature and brightness
|
| 330 |
if is_warm and is_bright:
|
| 331 |
lighting = lighting_options["warm_bright"]
|
|
@@ -339,7 +371,7 @@ class SceneWeaverCore:
|
|
| 339 |
# Atmosphere based on subject type
|
| 340 |
atmosphere = atmosphere_options.get(subject_type, atmosphere_options["unknown"])
|
| 341 |
|
| 342 |
-
#
|
| 343 |
user_prompt_lower = user_prompt.lower()
|
| 344 |
|
| 345 |
# Avoid adding conflicting descriptions
|
|
@@ -348,7 +380,7 @@ class SceneWeaverCore:
|
|
| 348 |
if "dark" in user_prompt_lower or "night" in user_prompt_lower:
|
| 349 |
lighting = lighting.replace("bright", "").replace("daylight", "")
|
| 350 |
|
| 351 |
-
#
|
| 352 |
fragments = [user_prompt]
|
| 353 |
|
| 354 |
if lighting:
|
|
@@ -592,7 +624,6 @@ class SceneWeaverCore:
|
|
| 592 |
}
|
| 593 |
|
| 594 |
except Exception as e:
|
| 595 |
-
import traceback
|
| 596 |
error_traceback = traceback.format_exc()
|
| 597 |
logger.error(f"❌ Generation and combination failed: {str(e)}")
|
| 598 |
logger.error(f"📍 Full traceback:\n{error_traceback}")
|
|
@@ -806,3 +837,341 @@ class SceneWeaverCore:
|
|
| 806 |
})
|
| 807 |
|
| 808 |
return status
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import logging
|
| 6 |
import gc
|
| 7 |
import time
|
| 8 |
+
from typing import Optional, Dict, Any, Tuple, List, Callable
|
| 9 |
from pathlib import Path
|
| 10 |
import warnings
|
| 11 |
warnings.filterwarnings("ignore")
|
| 12 |
|
| 13 |
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
| 14 |
import open_clip
|
| 15 |
+
import traceback
|
| 16 |
from mask_generator import MaskGenerator
|
| 17 |
from image_blender import ImageBlender
|
| 18 |
from quality_checker import QualityChecker
|
| 19 |
+
from model_manager import get_model_manager, ModelPriority
|
| 20 |
+
from inpainting_module import InpaintingModule
|
| 21 |
+
from inpainting_templates import InpaintingTemplateManager
|
| 22 |
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
logger.setLevel(logging.INFO)
|
| 25 |
|
| 26 |
class SceneWeaverCore:
|
| 27 |
"""
|
| 28 |
+
SceneWeaver Core Engine - Facade for all AI generation subsystems.
|
| 29 |
+
|
| 30 |
+
Integrates SDXL pipeline, OpenCLIP analysis, mask generation, image blending,
|
| 31 |
+
and inpainting functionality into a unified interface.
|
| 32 |
+
|
| 33 |
+
Attributes:
|
| 34 |
+
device: Computation device (cuda/mps/cpu)
|
| 35 |
+
is_initialized: Whether models are loaded
|
| 36 |
+
inpainting_module: Optional InpaintingModule instance
|
| 37 |
+
|
| 38 |
+
Example:
|
| 39 |
+
>>> core = SceneWeaverCore()
|
| 40 |
+
>>> core.load_models()
|
| 41 |
+
>>> result = core.generate_and_combine(image, prompt="sunset beach")
|
| 42 |
"""
|
| 43 |
|
| 44 |
+
# Model registry names
|
| 45 |
+
MODEL_SDXL_PIPELINE = "sdxl_background_pipeline"
|
| 46 |
+
MODEL_OPENCLIP = "openclip_analyzer"
|
| 47 |
+
MODEL_INPAINTING_PIPELINE = "inpainting_pipeline"
|
| 48 |
+
|
| 49 |
# Style presets for diversity generation mode
|
| 50 |
STYLE_PRESETS = {
|
| 51 |
"professional": {
|
|
|
|
| 104 |
self.image_blender = ImageBlender()
|
| 105 |
self.quality_checker = QualityChecker()
|
| 106 |
|
| 107 |
+
# Model manager reference
|
| 108 |
+
self._model_manager = get_model_manager()
|
| 109 |
+
|
| 110 |
+
# Inpainting module (lazy loaded)
|
| 111 |
+
self._inpainting_module = None
|
| 112 |
+
self._inpainting_initialized = False
|
| 113 |
+
|
| 114 |
+
# Current mode tracking
|
| 115 |
+
self._current_mode = "background" # "background" or "inpainting"
|
| 116 |
+
|
| 117 |
+
logger.info(f"SceneWeaverCore initialized on {self.device}")
|
| 118 |
|
| 119 |
def _setup_device(self, device: str) -> str:
|
| 120 |
"""Setup computation device"""
|
|
|
|
| 309 |
# Analyze image characteristics
|
| 310 |
img_array = np.array(foreground_image.convert('RGB'))
|
| 311 |
|
| 312 |
+
# Analyze color temperature
|
| 313 |
# Convert to LAB to analyze color temperature
|
| 314 |
lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
|
| 315 |
avg_a = np.mean(lab[:, :, 1]) # a channel: green(-) to red(+)
|
|
|
|
| 318 |
# Determine warm/cool tone
|
| 319 |
is_warm = avg_b > 128 # b > 128 means more yellow/warm
|
| 320 |
|
| 321 |
+
# Analyze brightness
|
| 322 |
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
|
| 323 |
avg_brightness = np.mean(gray)
|
| 324 |
is_bright = avg_brightness > 127
|
| 325 |
|
| 326 |
+
# Get subject type from CLIP
|
| 327 |
clip_analysis = self.analyze_image_with_clip(foreground_image)
|
| 328 |
subject_type = "unknown"
|
| 329 |
|
|
|
|
| 338 |
elif "nature" in clip_analysis.lower() or "landscape" in clip_analysis.lower():
|
| 339 |
subject_type = "nature"
|
| 340 |
|
| 341 |
+
# Build prompt fragments library
|
| 342 |
lighting_options = {
|
| 343 |
"warm_bright": "warm golden hour lighting, soft natural light",
|
| 344 |
"warm_dark": "warm ambient lighting, cozy atmosphere",
|
|
|
|
| 357 |
|
| 358 |
quality_modifiers = "high quality, detailed, sharp focus, photorealistic"
|
| 359 |
|
| 360 |
+
# Select appropriate fragments
|
| 361 |
# Lighting based on color temperature and brightness
|
| 362 |
if is_warm and is_bright:
|
| 363 |
lighting = lighting_options["warm_bright"]
|
|
|
|
| 371 |
# Atmosphere based on subject type
|
| 372 |
atmosphere = atmosphere_options.get(subject_type, atmosphere_options["unknown"])
|
| 373 |
|
| 374 |
+
# Check for conflicts in user prompt
|
| 375 |
user_prompt_lower = user_prompt.lower()
|
| 376 |
|
| 377 |
# Avoid adding conflicting descriptions
|
|
|
|
| 380 |
if "dark" in user_prompt_lower or "night" in user_prompt_lower:
|
| 381 |
lighting = lighting.replace("bright", "").replace("daylight", "")
|
| 382 |
|
| 383 |
+
# Combine enhanced prompt
|
| 384 |
fragments = [user_prompt]
|
| 385 |
|
| 386 |
if lighting:
|
|
|
|
| 624 |
}
|
| 625 |
|
| 626 |
except Exception as e:
|
|
|
|
| 627 |
error_traceback = traceback.format_exc()
|
| 628 |
logger.error(f"❌ Generation and combination failed: {str(e)}")
|
| 629 |
logger.error(f"📍 Full traceback:\n{error_traceback}")
|
|
|
|
| 837 |
})
|
| 838 |
|
| 839 |
return status
|
| 840 |
+
|
| 841 |
+
# INPAINTING FACADE METHODS
|
| 842 |
+
def get_inpainting_module(self):
|
| 843 |
+
"""
|
| 844 |
+
Get or create the InpaintingModule instance.
|
| 845 |
+
|
| 846 |
+
Implements lazy loading - module is only created when first accessed.
|
| 847 |
+
|
| 848 |
+
Returns
|
| 849 |
+
-------
|
| 850 |
+
InpaintingModule
|
| 851 |
+
The inpainting module instance
|
| 852 |
+
"""
|
| 853 |
+
if self._inpainting_module is None:
|
| 854 |
+
self._inpainting_module = InpaintingModule(device=self.device)
|
| 855 |
+
self._inpainting_module.set_model_manager(self._model_manager)
|
| 856 |
+
logger.info("InpaintingModule created (lazy load)")
|
| 857 |
+
|
| 858 |
+
return self._inpainting_module
|
| 859 |
+
|
| 860 |
+
def switch_to_inpainting_mode(
|
| 861 |
+
self,
|
| 862 |
+
conditioning_type: str = "canny",
|
| 863 |
+
progress_callback: Optional[Callable[[str, int], None]] = None
|
| 864 |
+
) -> bool:
|
| 865 |
+
"""
|
| 866 |
+
Switch to inpainting mode, unloading background pipeline.
|
| 867 |
+
|
| 868 |
+
Implements mutual exclusion between pipelines to conserve memory.
|
| 869 |
+
|
| 870 |
+
Parameters
|
| 871 |
+
----------
|
| 872 |
+
conditioning_type : str
|
| 873 |
+
ControlNet conditioning type: "canny" or "depth"
|
| 874 |
+
progress_callback : callable, optional
|
| 875 |
+
Progress update function(message, percentage)
|
| 876 |
+
|
| 877 |
+
Returns
|
| 878 |
+
-------
|
| 879 |
+
bool
|
| 880 |
+
True if switch was successful
|
| 881 |
+
"""
|
| 882 |
+
logger.info(f"Switching to inpainting mode (conditioning: {conditioning_type})")
|
| 883 |
+
|
| 884 |
+
try:
|
| 885 |
+
# Unload background pipeline first
|
| 886 |
+
if self.pipeline is not None:
|
| 887 |
+
if progress_callback:
|
| 888 |
+
progress_callback("Unloading background pipeline...", 10)
|
| 889 |
+
|
| 890 |
+
del self.pipeline
|
| 891 |
+
self.pipeline = None
|
| 892 |
+
self._ultra_memory_cleanup()
|
| 893 |
+
logger.info("Background pipeline unloaded")
|
| 894 |
+
|
| 895 |
+
# Load inpainting pipeline
|
| 896 |
+
if progress_callback:
|
| 897 |
+
progress_callback("Loading inpainting pipeline...", 20)
|
| 898 |
+
|
| 899 |
+
inpaint_module = self.get_inpainting_module()
|
| 900 |
+
|
| 901 |
+
def inpaint_progress(msg, pct):
|
| 902 |
+
if progress_callback:
|
| 903 |
+
# Map inpainting progress (0-100) to (20-90)
|
| 904 |
+
mapped_pct = 20 + int(pct * 0.7)
|
| 905 |
+
progress_callback(msg, mapped_pct)
|
| 906 |
+
|
| 907 |
+
success, error_msg = inpaint_module.load_inpainting_pipeline(
|
| 908 |
+
conditioning_type=conditioning_type,
|
| 909 |
+
progress_callback=inpaint_progress
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
if success:
|
| 913 |
+
self._current_mode = "inpainting"
|
| 914 |
+
self._inpainting_initialized = True
|
| 915 |
+
|
| 916 |
+
if progress_callback:
|
| 917 |
+
progress_callback("Inpainting mode ready!", 100)
|
| 918 |
+
|
| 919 |
+
logger.info("Successfully switched to inpainting mode")
|
| 920 |
+
else:
|
| 921 |
+
self._last_inpainting_error = error_msg
|
| 922 |
+
logger.error(f"Failed to load inpainting pipeline: {error_msg}")
|
| 923 |
+
|
| 924 |
+
return success
|
| 925 |
+
|
| 926 |
+
except Exception as e:
|
| 927 |
+
traceback.print_exc()
|
| 928 |
+
self._last_inpainting_error = str(e)
|
| 929 |
+
logger.error(f"Failed to switch to inpainting mode: {e}")
|
| 930 |
+
if progress_callback:
|
| 931 |
+
progress_callback(f"Error: {str(e)}", 0)
|
| 932 |
+
return False
|
| 933 |
+
|
| 934 |
+
def switch_to_background_mode(
|
| 935 |
+
self,
|
| 936 |
+
progress_callback: Optional[Callable[[str, int], None]] = None
|
| 937 |
+
) -> bool:
|
| 938 |
+
"""
|
| 939 |
+
Switch back to background generation mode.
|
| 940 |
+
|
| 941 |
+
Parameters
|
| 942 |
+
----------
|
| 943 |
+
progress_callback : callable, optional
|
| 944 |
+
Progress update function
|
| 945 |
+
|
| 946 |
+
Returns
|
| 947 |
+
-------
|
| 948 |
+
bool
|
| 949 |
+
True if switch was successful
|
| 950 |
+
"""
|
| 951 |
+
logger.info("Switching to background generation mode")
|
| 952 |
+
|
| 953 |
+
try:
|
| 954 |
+
# Unload inpainting pipeline
|
| 955 |
+
if self._inpainting_module is not None and self._inpainting_module.is_initialized:
|
| 956 |
+
if progress_callback:
|
| 957 |
+
progress_callback("Unloading inpainting pipeline...", 10)
|
| 958 |
+
|
| 959 |
+
self._inpainting_module._unload_pipeline()
|
| 960 |
+
self._ultra_memory_cleanup()
|
| 961 |
+
|
| 962 |
+
# Reload background pipeline
|
| 963 |
+
if progress_callback:
|
| 964 |
+
progress_callback("Loading background pipeline...", 30)
|
| 965 |
+
|
| 966 |
+
# Reset initialization flag to force reload
|
| 967 |
+
self.is_initialized = False
|
| 968 |
+
self.load_models(progress_callback=progress_callback)
|
| 969 |
+
|
| 970 |
+
self._current_mode = "background"
|
| 971 |
+
|
| 972 |
+
if progress_callback:
|
| 973 |
+
progress_callback("Background mode ready!", 100)
|
| 974 |
+
|
| 975 |
+
return True
|
| 976 |
+
|
| 977 |
+
except Exception as e:
|
| 978 |
+
logger.error(f"Failed to switch to background mode: {e}")
|
| 979 |
+
return False
|
| 980 |
+
|
| 981 |
+
def execute_inpainting(
|
| 982 |
+
self,
|
| 983 |
+
image: Image.Image,
|
| 984 |
+
mask: Image.Image,
|
| 985 |
+
prompt: str,
|
| 986 |
+
preview_only: bool = False,
|
| 987 |
+
template_key: Optional[str] = None,
|
| 988 |
+
progress_callback: Optional[Callable[[str, int], None]] = None,
|
| 989 |
+
**kwargs
|
| 990 |
+
) -> Dict[str, Any]:
|
| 991 |
+
"""
|
| 992 |
+
Execute inpainting operation through the Facade.
|
| 993 |
+
|
| 994 |
+
This is the main entry point for inpainting functionality.
|
| 995 |
+
|
| 996 |
+
Parameters
|
| 997 |
+
----------
|
| 998 |
+
image : PIL.Image
|
| 999 |
+
Original image to inpaint
|
| 1000 |
+
mask : PIL.Image
|
| 1001 |
+
Inpainting mask (white = area to regenerate)
|
| 1002 |
+
prompt : str
|
| 1003 |
+
Text description of desired content
|
| 1004 |
+
preview_only : bool
|
| 1005 |
+
If True, generate quick preview only
|
| 1006 |
+
template_key : str, optional
|
| 1007 |
+
Inpainting template key to use
|
| 1008 |
+
progress_callback : callable, optional
|
| 1009 |
+
Progress update function
|
| 1010 |
+
**kwargs
|
| 1011 |
+
Additional inpainting parameters
|
| 1012 |
+
|
| 1013 |
+
Returns
|
| 1014 |
+
-------
|
| 1015 |
+
dict
|
| 1016 |
+
Result dictionary with images and metadata
|
| 1017 |
+
"""
|
| 1018 |
+
# Ensure inpainting mode is active
|
| 1019 |
+
if self._current_mode != "inpainting" or not self._inpainting_initialized:
|
| 1020 |
+
conditioning = kwargs.get('conditioning_type', 'canny')
|
| 1021 |
+
if not self.switch_to_inpainting_mode(conditioning, progress_callback):
|
| 1022 |
+
error_detail = getattr(self, '_last_inpainting_error', 'Unknown error')
|
| 1023 |
+
return {
|
| 1024 |
+
"success": False,
|
| 1025 |
+
"error": f"Failed to initialize inpainting mode: {error_detail}"
|
| 1026 |
+
}
|
| 1027 |
+
|
| 1028 |
+
inpaint_module = self.get_inpainting_module()
|
| 1029 |
+
|
| 1030 |
+
# Apply template if specified
|
| 1031 |
+
if template_key:
|
| 1032 |
+
template_mgr = InpaintingTemplateManager()
|
| 1033 |
+
template = template_mgr.get_template(template_key)
|
| 1034 |
+
|
| 1035 |
+
if template:
|
| 1036 |
+
# Build prompt from template
|
| 1037 |
+
prompt = template_mgr.build_prompt(template_key, prompt)
|
| 1038 |
+
# Apply template parameters as defaults
|
| 1039 |
+
params = template_mgr.get_parameters_for_template(template_key)
|
| 1040 |
+
for key, value in params.items():
|
| 1041 |
+
if key not in kwargs:
|
| 1042 |
+
kwargs[key] = value
|
| 1043 |
+
|
| 1044 |
+
# Execute inpainting
|
| 1045 |
+
result = inpaint_module.execute_inpainting(
|
| 1046 |
+
image=image,
|
| 1047 |
+
mask=mask,
|
| 1048 |
+
prompt=prompt,
|
| 1049 |
+
preview_only=preview_only,
|
| 1050 |
+
progress_callback=progress_callback,
|
| 1051 |
+
**kwargs
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
# Convert InpaintingResult to dictionary format
|
| 1055 |
+
return {
|
| 1056 |
+
"success": result.success,
|
| 1057 |
+
"combined_image": result.blended_image or result.result_image,
|
| 1058 |
+
"generated_image": result.result_image,
|
| 1059 |
+
"preview_image": result.preview_image,
|
| 1060 |
+
"control_image": result.control_image,
|
| 1061 |
+
"original_image": image,
|
| 1062 |
+
"mask": mask,
|
| 1063 |
+
"quality_score": result.quality_score,
|
| 1064 |
+
"generation_time": result.generation_time,
|
| 1065 |
+
"metadata": result.metadata,
|
| 1066 |
+
"error": result.error_message if not result.success else None
|
| 1067 |
+
}
|
| 1068 |
+
|
| 1069 |
+
def execute_inpainting_with_optimization(
|
| 1070 |
+
self,
|
| 1071 |
+
image: Image.Image,
|
| 1072 |
+
mask: Image.Image,
|
| 1073 |
+
prompt: str,
|
| 1074 |
+
progress_callback: Optional[Callable[[str, int], None]] = None,
|
| 1075 |
+
**kwargs
|
| 1076 |
+
) -> Dict[str, Any]:
|
| 1077 |
+
"""
|
| 1078 |
+
Execute inpainting with automatic quality optimization.
|
| 1079 |
+
|
| 1080 |
+
Retries with adjusted parameters if quality is below threshold.
|
| 1081 |
+
|
| 1082 |
+
Parameters
|
| 1083 |
+
----------
|
| 1084 |
+
image : PIL.Image
|
| 1085 |
+
Original image
|
| 1086 |
+
mask : PIL.Image
|
| 1087 |
+
Inpainting mask
|
| 1088 |
+
prompt : str
|
| 1089 |
+
Text prompt
|
| 1090 |
+
progress_callback : callable, optional
|
| 1091 |
+
Progress callback
|
| 1092 |
+
**kwargs
|
| 1093 |
+
Additional parameters
|
| 1094 |
+
|
| 1095 |
+
Returns
|
| 1096 |
+
-------
|
| 1097 |
+
dict
|
| 1098 |
+
Optimized result dictionary
|
| 1099 |
+
"""
|
| 1100 |
+
# Ensure inpainting mode
|
| 1101 |
+
if self._current_mode != "inpainting" or not self._inpainting_initialized:
|
| 1102 |
+
conditioning = kwargs.get('conditioning_type', 'canny')
|
| 1103 |
+
if not self.switch_to_inpainting_mode(conditioning, progress_callback):
|
| 1104 |
+
error_detail = getattr(self, '_last_inpainting_error', 'Unknown error')
|
| 1105 |
+
return {
|
| 1106 |
+
"success": False,
|
| 1107 |
+
"error": f"Failed to initialize inpainting mode: {error_detail}"
|
| 1108 |
+
}
|
| 1109 |
+
|
| 1110 |
+
inpaint_module = self.get_inpainting_module()
|
| 1111 |
+
|
| 1112 |
+
result = inpaint_module.execute_with_auto_optimization(
|
| 1113 |
+
image=image,
|
| 1114 |
+
mask=mask,
|
| 1115 |
+
prompt=prompt,
|
| 1116 |
+
quality_checker=self.quality_checker,
|
| 1117 |
+
progress_callback=progress_callback,
|
| 1118 |
+
**kwargs
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
+
return {
|
| 1122 |
+
"success": result.success,
|
| 1123 |
+
"combined_image": result.blended_image or result.result_image,
|
| 1124 |
+
"generated_image": result.result_image,
|
| 1125 |
+
"preview_image": result.preview_image,
|
| 1126 |
+
"control_image": result.control_image,
|
| 1127 |
+
"quality_score": result.quality_score,
|
| 1128 |
+
"quality_details": result.quality_details,
|
| 1129 |
+
"retries": result.retries,
|
| 1130 |
+
"generation_time": result.generation_time,
|
| 1131 |
+
"metadata": result.metadata,
|
| 1132 |
+
"error": result.error_message if not result.success else None
|
| 1133 |
+
}
|
| 1134 |
+
|
| 1135 |
+
def get_current_mode(self) -> str:
|
| 1136 |
+
"""
|
| 1137 |
+
Get current operation mode.
|
| 1138 |
+
|
| 1139 |
+
Returns
|
| 1140 |
+
-------
|
| 1141 |
+
str
|
| 1142 |
+
"background" or "inpainting"
|
| 1143 |
+
"""
|
| 1144 |
+
return self._current_mode
|
| 1145 |
+
|
| 1146 |
+
def is_inpainting_ready(self) -> bool:
|
| 1147 |
+
"""
|
| 1148 |
+
Check if inpainting is ready to use.
|
| 1149 |
+
|
| 1150 |
+
Returns
|
| 1151 |
+
-------
|
| 1152 |
+
bool
|
| 1153 |
+
True if inpainting module is loaded and ready
|
| 1154 |
+
"""
|
| 1155 |
+
return (
|
| 1156 |
+
self._inpainting_module is not None and
|
| 1157 |
+
self._inpainting_module.is_initialized
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
def get_inpainting_status(self) -> Dict[str, Any]:
|
| 1161 |
+
"""
|
| 1162 |
+
Get inpainting module status.
|
| 1163 |
+
|
| 1164 |
+
Returns
|
| 1165 |
+
-------
|
| 1166 |
+
dict
|
| 1167 |
+
Status information
|
| 1168 |
+
"""
|
| 1169 |
+
if self._inpainting_module is None:
|
| 1170 |
+
return {
|
| 1171 |
+
"initialized": False,
|
| 1172 |
+
"mode": self._current_mode
|
| 1173 |
+
}
|
| 1174 |
+
|
| 1175 |
+
status = self._inpainting_module.get_status()
|
| 1176 |
+
status["mode"] = self._current_mode
|
| 1177 |
+
return status
|
ui_manager.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
import logging
|
| 2 |
import time
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
-
from typing import Optional, Tuple
|
| 5 |
from PIL import Image
|
| 6 |
import numpy as np
|
| 7 |
import cv2
|
|
@@ -11,6 +12,7 @@ import spaces
|
|
| 11 |
from scene_weaver_core import SceneWeaverCore
|
| 12 |
from css_styles import CSSStyles
|
| 13 |
from scene_templates import SceneTemplateManager
|
|
|
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
logger.setLevel(logging.INFO)
|
|
@@ -23,13 +25,26 @@ logging.basicConfig(
|
|
| 23 |
|
| 24 |
|
| 25 |
class UIManager:
|
| 26 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def __init__(self):
|
| 29 |
self.sceneweaver = SceneWeaverCore()
|
| 30 |
self.template_manager = SceneTemplateManager()
|
|
|
|
| 31 |
self.generation_history = []
|
|
|
|
| 32 |
self._preview_sensitivity = 0.5
|
|
|
|
| 33 |
|
| 34 |
def apply_template(self, display_name: str, current_negative: str) -> Tuple[str, str, float]:
|
| 35 |
"""
|
|
@@ -234,7 +249,6 @@ class UIManager:
|
|
| 234 |
return None, None, None, f"Error: {error_msg}", gr.update(visible=False)
|
| 235 |
|
| 236 |
except Exception as e:
|
| 237 |
-
import traceback
|
| 238 |
error_traceback = traceback.format_exc()
|
| 239 |
logger.error(f"Generation handler error: {str(e)}")
|
| 240 |
logger.error(f"Traceback:\n{error_traceback}")
|
|
@@ -249,18 +263,12 @@ class UIManager:
|
|
| 249 |
self._gradio_version = gr.__version__
|
| 250 |
self._gradio_major = int(self._gradio_version.split('.')[0])
|
| 251 |
|
| 252 |
-
# Gradio 5.x
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
else:
|
| 259 |
-
blocks_kwargs = {
|
| 260 |
-
"css": self._css,
|
| 261 |
-
"title": "SceneWeaver - AI Background Generator",
|
| 262 |
-
"theme": gr.themes.Soft()
|
| 263 |
-
}
|
| 264 |
|
| 265 |
with gr.Blocks(**blocks_kwargs) as interface:
|
| 266 |
|
|
@@ -271,177 +279,232 @@ class UIManager:
|
|
| 271 |
<span class="title-emoji">🎨</span>
|
| 272 |
SceneWeaver
|
| 273 |
</h1>
|
| 274 |
-
<p class="main-subtitle">AI-powered background generation with professional edge processing</p>
|
| 275 |
</div>
|
| 276 |
""")
|
| 277 |
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
with gr.Column(scale=1, min_width=350, elem_classes=["feature-card"]):
|
| 281 |
-
gr.HTML("""
|
| 282 |
-
<div class="card-content">
|
| 283 |
-
<h3 class="card-title">
|
| 284 |
-
<span class="section-emoji">📸</span>
|
| 285 |
-
Upload & Generate
|
| 286 |
-
</h3>
|
| 287 |
-
</div>
|
| 288 |
-
""")
|
| 289 |
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
type="pil",
|
| 293 |
-
height=280,
|
| 294 |
-
elem_classes=["input-field"]
|
| 295 |
-
)
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
| 312 |
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
)
|
| 320 |
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
)
|
| 361 |
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
</div>
|
| 382 |
-
""")
|
| 383 |
-
|
| 384 |
-
# Quick start guide
|
| 385 |
-
gr.HTML("""
|
| 386 |
-
<details class="user-guidance-panel">
|
| 387 |
-
<summary class="guidance-summary">
|
| 388 |
-
<span class="emoji-enhanced">💡</span>
|
| 389 |
-
Quick Start Guide
|
| 390 |
-
</summary>
|
| 391 |
-
<div class="guidance-content">
|
| 392 |
-
<p><strong>Step 1:</strong> Upload any image with a clear subject</p>
|
| 393 |
-
<p><strong>Step 2:</strong> Describe or Choose your desired background scene</p>
|
| 394 |
-
<p><strong>Step 3:</strong> Choose composition mode (center works best)</p>
|
| 395 |
-
<p><strong>Step 4:</strong> Click Generate and wait for the magic!</p>
|
| 396 |
-
<p><strong>Tip:</strong> For dark clothing, ensure good lighting in original photo.</p>
|
| 397 |
-
</div>
|
| 398 |
-
</details>
|
| 399 |
-
""")
|
| 400 |
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
elem_classes=["result-gallery"],
|
| 406 |
-
show_label=False
|
| 407 |
-
)
|
| 408 |
-
with gr.TabItem("Background"):
|
| 409 |
-
generated_output = gr.Image(
|
| 410 |
-
label="Generated Background",
|
| 411 |
-
elem_classes=["result-gallery"],
|
| 412 |
-
show_label=False
|
| 413 |
-
)
|
| 414 |
-
with gr.TabItem("Original"):
|
| 415 |
-
original_output = gr.Image(
|
| 416 |
-
label="Processed Original",
|
| 417 |
-
elem_classes=["result-gallery"],
|
| 418 |
-
show_label=False
|
| 419 |
-
)
|
| 420 |
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
interactive=False,
|
| 425 |
-
elem_classes=["status-panel", "status-ready"]
|
| 426 |
)
|
| 427 |
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
)
|
| 439 |
-
memory_btn = gr.Button(
|
| 440 |
-
"Clean Memory",
|
| 441 |
-
elem_classes=["secondary-button"]
|
| 442 |
-
)
|
| 443 |
|
| 444 |
-
# Footer with tech credits
|
| 445 |
gr.HTML("""
|
| 446 |
<div class="app-footer">
|
| 447 |
<div class="footer-powered">
|
|
@@ -464,57 +527,13 @@ class UIManager:
|
|
| 464 |
</div>
|
| 465 |
""")
|
| 466 |
|
| 467 |
-
# Event handlers
|
| 468 |
-
# Template selection handler
|
| 469 |
-
template_dropdown.change(
|
| 470 |
-
fn=self.apply_template,
|
| 471 |
-
inputs=[template_dropdown, negative_prompt],
|
| 472 |
-
outputs=[prompt_input, negative_prompt, guidance_slider]
|
| 473 |
-
)
|
| 474 |
-
|
| 475 |
-
generate_btn.click(
|
| 476 |
-
fn=self.generate_handler,
|
| 477 |
-
inputs=[
|
| 478 |
-
uploaded_image,
|
| 479 |
-
prompt_input,
|
| 480 |
-
combination_mode,
|
| 481 |
-
focus_mode,
|
| 482 |
-
negative_prompt,
|
| 483 |
-
steps_slider,
|
| 484 |
-
guidance_slider
|
| 485 |
-
],
|
| 486 |
-
outputs=[
|
| 487 |
-
combined_output,
|
| 488 |
-
generated_output,
|
| 489 |
-
original_output,
|
| 490 |
-
status_output,
|
| 491 |
-
download_btn
|
| 492 |
-
]
|
| 493 |
-
)
|
| 494 |
-
|
| 495 |
-
clear_btn.click(
|
| 496 |
-
fn=lambda: (None, None, None, "Ready to create!", gr.update(visible=False)),
|
| 497 |
-
outputs=[combined_output, generated_output, original_output, status_output, download_btn]
|
| 498 |
-
)
|
| 499 |
-
|
| 500 |
-
memory_btn.click(
|
| 501 |
-
fn=lambda: self.sceneweaver._ultra_memory_cleanup() or "Memory cleaned!",
|
| 502 |
-
outputs=[status_output]
|
| 503 |
-
)
|
| 504 |
-
|
| 505 |
-
combined_output.change(
|
| 506 |
-
fn=lambda img: gr.update(value="outputs/latest_combined.png", visible=True) if (img is not None) else gr.update(visible=False),
|
| 507 |
-
inputs=[combined_output],
|
| 508 |
-
outputs=[download_btn]
|
| 509 |
-
)
|
| 510 |
-
|
| 511 |
return interface
|
| 512 |
|
| 513 |
def launch(self, share: bool = True, debug: bool = False):
|
| 514 |
"""Launch the UI interface"""
|
| 515 |
interface = self.create_interface()
|
| 516 |
|
| 517 |
-
#
|
| 518 |
launch_kwargs = {
|
| 519 |
"share": share,
|
| 520 |
"debug": debug,
|
|
@@ -522,10 +541,450 @@ class UIManager:
|
|
| 522 |
"quiet": False
|
| 523 |
}
|
| 524 |
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import time
|
| 3 |
+
import traceback
|
| 4 |
from pathlib import Path
|
| 5 |
+
from typing import Optional, Tuple, Dict, Any, List
|
| 6 |
from PIL import Image
|
| 7 |
import numpy as np
|
| 8 |
import cv2
|
|
|
|
| 12 |
from scene_weaver_core import SceneWeaverCore
|
| 13 |
from css_styles import CSSStyles
|
| 14 |
from scene_templates import SceneTemplateManager
|
| 15 |
+
from inpainting_templates import InpaintingTemplateManager
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
logger.setLevel(logging.INFO)
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
class UIManager:
|
| 28 |
+
"""
|
| 29 |
+
Gradio UI Manager with support for background generation and inpainting.
|
| 30 |
+
|
| 31 |
+
Provides a professional interface with mode switching, template selection,
|
| 32 |
+
and advanced parameter controls.
|
| 33 |
+
|
| 34 |
+
Attributes:
|
| 35 |
+
sceneweaver: SceneWeaverCore instance
|
| 36 |
+
template_manager: Scene template manager
|
| 37 |
+
inpainting_template_manager: Inpainting template manager
|
| 38 |
+
"""
|
| 39 |
|
| 40 |
def __init__(self):
|
| 41 |
self.sceneweaver = SceneWeaverCore()
|
| 42 |
self.template_manager = SceneTemplateManager()
|
| 43 |
+
self.inpainting_template_manager = InpaintingTemplateManager()
|
| 44 |
self.generation_history = []
|
| 45 |
+
self.inpainting_history = []
|
| 46 |
self._preview_sensitivity = 0.5
|
| 47 |
+
self._current_mode = "background" # "background" or "inpainting"
|
| 48 |
|
| 49 |
def apply_template(self, display_name: str, current_negative: str) -> Tuple[str, str, float]:
|
| 50 |
"""
|
|
|
|
| 249 |
return None, None, None, f"Error: {error_msg}", gr.update(visible=False)
|
| 250 |
|
| 251 |
except Exception as e:
|
|
|
|
| 252 |
error_traceback = traceback.format_exc()
|
| 253 |
logger.error(f"Generation handler error: {str(e)}")
|
| 254 |
logger.error(f"Traceback:\n{error_traceback}")
|
|
|
|
| 263 |
self._gradio_version = gr.__version__
|
| 264 |
self._gradio_major = int(self._gradio_version.split('.')[0])
|
| 265 |
|
| 266 |
+
# Both Gradio 4.x and 5.x use css/theme in Blocks() constructor
|
| 267 |
+
blocks_kwargs = {
|
| 268 |
+
"css": self._css,
|
| 269 |
+
"title": "SceneWeaver - AI Background Generator",
|
| 270 |
+
"theme": gr.themes.Soft()
|
| 271 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
with gr.Blocks(**blocks_kwargs) as interface:
|
| 274 |
|
|
|
|
| 279 |
<span class="title-emoji">🎨</span>
|
| 280 |
SceneWeaver
|
| 281 |
</h1>
|
| 282 |
+
<p class="main-subtitle">AI-powered background generation and inpainting with professional edge processing</p>
|
| 283 |
</div>
|
| 284 |
""")
|
| 285 |
|
| 286 |
+
# Main Tabs for Mode Selection
|
| 287 |
+
with gr.Tabs(elem_id="main-mode-tabs") as main_tabs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
+
# Background Generation Tab
|
| 290 |
+
with gr.Tab("Background Generation", elem_id="bg-gen-tab"):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
+
with gr.Row():
|
| 293 |
+
# Left Column - Input controls
|
| 294 |
+
with gr.Column(scale=1, min_width=350, elem_classes=["feature-card"]):
|
| 295 |
+
gr.HTML("""
|
| 296 |
+
<div class="card-content">
|
| 297 |
+
<h3 class="card-title">
|
| 298 |
+
<span class="section-emoji">📸</span>
|
| 299 |
+
Upload & Generate
|
| 300 |
+
</h3>
|
| 301 |
+
</div>
|
| 302 |
+
""")
|
| 303 |
+
|
| 304 |
+
uploaded_image = gr.Image(
|
| 305 |
+
label="Upload Your Image",
|
| 306 |
+
type="pil",
|
| 307 |
+
height=280,
|
| 308 |
+
elem_classes=["input-field"]
|
| 309 |
+
)
|
| 310 |
|
| 311 |
+
# Scene Template Selector (without Accordion to fix dropdown positioning in Gradio 5.x)
|
| 312 |
+
template_dropdown = gr.Dropdown(
|
| 313 |
+
label="Scene Templates",
|
| 314 |
+
choices=[""] + self.template_manager.get_template_choices_sorted(),
|
| 315 |
+
value="",
|
| 316 |
+
info="24 curated scenes sorted A-Z (optional)",
|
| 317 |
+
elem_classes=["template-dropdown"]
|
| 318 |
+
)
|
| 319 |
|
| 320 |
+
prompt_input = gr.Textbox(
|
| 321 |
+
label="Background Scene Description",
|
| 322 |
+
placeholder="Select a template above or describe your own scene...",
|
| 323 |
+
lines=3,
|
| 324 |
+
elem_classes=["input-field"]
|
| 325 |
+
)
|
|
|
|
| 326 |
|
| 327 |
+
combination_mode = gr.Dropdown(
|
| 328 |
+
label="Composition Mode",
|
| 329 |
+
choices=["center", "left_half", "right_half", "full"],
|
| 330 |
+
value="center",
|
| 331 |
+
info="center=Smart Center | left_half=Left Half | right_half=Right Half | full=Full Image",
|
| 332 |
+
elem_classes=["input-field"]
|
| 333 |
+
)
|
| 334 |
|
| 335 |
+
focus_mode = gr.Dropdown(
|
| 336 |
+
label="Focus Mode",
|
| 337 |
+
choices=["person", "scene"],
|
| 338 |
+
value="person",
|
| 339 |
+
info="person=Tight Crop | scene=Include Surrounding Objects",
|
| 340 |
+
elem_classes=["input-field"]
|
| 341 |
+
)
|
| 342 |
|
| 343 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 344 |
+
negative_prompt = gr.Textbox(
|
| 345 |
+
label="Negative Prompt",
|
| 346 |
+
value="blurry, low quality, distorted, people, characters",
|
| 347 |
+
lines=2,
|
| 348 |
+
elem_classes=["input-field"]
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
steps_slider = gr.Slider(
|
| 352 |
+
label="Quality Steps",
|
| 353 |
+
minimum=15,
|
| 354 |
+
maximum=50,
|
| 355 |
+
value=25,
|
| 356 |
+
step=5,
|
| 357 |
+
elem_classes=["input-field"]
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
guidance_slider = gr.Slider(
|
| 361 |
+
label="Guidance Scale",
|
| 362 |
+
minimum=5.0,
|
| 363 |
+
maximum=15.0,
|
| 364 |
+
value=7.5,
|
| 365 |
+
step=0.5,
|
| 366 |
+
elem_classes=["input-field"]
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
generate_btn = gr.Button(
|
| 370 |
+
"Generate Background",
|
| 371 |
+
variant="primary",
|
| 372 |
+
size="lg",
|
| 373 |
+
elem_classes=["primary-button"]
|
| 374 |
+
)
|
| 375 |
|
| 376 |
+
# Right Column - Results display
|
| 377 |
+
with gr.Column(scale=2, elem_classes=["feature-card"], elem_id="results-gallery-centered"):
|
| 378 |
+
gr.HTML("""
|
| 379 |
+
<div class="card-content">
|
| 380 |
+
<h3 class="card-title">
|
| 381 |
+
<span class="section-emoji">🎭</span>
|
| 382 |
+
Results Gallery
|
| 383 |
+
</h3>
|
| 384 |
+
</div>
|
| 385 |
+
""")
|
| 386 |
+
|
| 387 |
+
# Loading notice
|
| 388 |
+
gr.HTML("""
|
| 389 |
+
<div class="loading-notice">
|
| 390 |
+
<span class="loading-notice-icon">⏱️</span>
|
| 391 |
+
<span class="loading-notice-text">
|
| 392 |
+
<strong>First-time users:</strong> Initial model loading takes 1-2 minutes.
|
| 393 |
+
Subsequent generations are much faster (~30s).
|
| 394 |
+
</span>
|
| 395 |
+
</div>
|
| 396 |
+
""")
|
| 397 |
+
|
| 398 |
+
# Quick start guide
|
| 399 |
+
gr.HTML("""
|
| 400 |
+
<details class="user-guidance-panel">
|
| 401 |
+
<summary class="guidance-summary">
|
| 402 |
+
<span class="emoji-enhanced">💡</span>
|
| 403 |
+
Quick Start Guide
|
| 404 |
+
</summary>
|
| 405 |
+
<div class="guidance-content">
|
| 406 |
+
<p><strong>Step 1:</strong> Upload any image with a clear subject</p>
|
| 407 |
+
<p><strong>Step 2:</strong> Describe or Choose your desired background scene</p>
|
| 408 |
+
<p><strong>Step 3:</strong> Choose composition mode (center works best)</p>
|
| 409 |
+
<p><strong>Step 4:</strong> Click Generate and wait for the magic!</p>
|
| 410 |
+
<p><strong>Tip:</strong> For dark clothing, ensure good lighting in original photo.</p>
|
| 411 |
+
</div>
|
| 412 |
+
</details>
|
| 413 |
+
""")
|
| 414 |
+
|
| 415 |
+
with gr.Tabs():
|
| 416 |
+
with gr.TabItem("Final Result"):
|
| 417 |
+
combined_output = gr.Image(
|
| 418 |
+
label="Your Generated Image",
|
| 419 |
+
elem_classes=["result-gallery"],
|
| 420 |
+
show_label=False
|
| 421 |
+
)
|
| 422 |
+
with gr.TabItem("Background"):
|
| 423 |
+
generated_output = gr.Image(
|
| 424 |
+
label="Generated Background",
|
| 425 |
+
elem_classes=["result-gallery"],
|
| 426 |
+
show_label=False
|
| 427 |
+
)
|
| 428 |
+
with gr.TabItem("Original"):
|
| 429 |
+
original_output = gr.Image(
|
| 430 |
+
label="Processed Original",
|
| 431 |
+
elem_classes=["result-gallery"],
|
| 432 |
+
show_label=False
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
status_output = gr.Textbox(
|
| 436 |
+
label="Status",
|
| 437 |
+
value="Ready to create! Upload an image and describe your vision.",
|
| 438 |
+
interactive=False,
|
| 439 |
+
elem_classes=["status-panel", "status-ready"]
|
| 440 |
+
)
|
| 441 |
|
| 442 |
+
with gr.Row():
|
| 443 |
+
download_btn = gr.DownloadButton(
|
| 444 |
+
"Download Result",
|
| 445 |
+
value=None,
|
| 446 |
+
visible=False,
|
| 447 |
+
elem_classes=["secondary-button"]
|
| 448 |
+
)
|
| 449 |
+
clear_btn = gr.Button(
|
| 450 |
+
"Clear All",
|
| 451 |
+
elem_classes=["secondary-button"]
|
| 452 |
+
)
|
| 453 |
+
memory_btn = gr.Button(
|
| 454 |
+
"Clean Memory",
|
| 455 |
+
elem_classes=["secondary-button"]
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# Event handlers for Background Generation Tab
|
| 459 |
+
# Template selection handler
|
| 460 |
+
template_dropdown.change(
|
| 461 |
+
fn=self.apply_template,
|
| 462 |
+
inputs=[template_dropdown, negative_prompt],
|
| 463 |
+
outputs=[prompt_input, negative_prompt, guidance_slider]
|
| 464 |
)
|
| 465 |
|
| 466 |
+
generate_btn.click(
|
| 467 |
+
fn=self.generate_handler,
|
| 468 |
+
inputs=[
|
| 469 |
+
uploaded_image,
|
| 470 |
+
prompt_input,
|
| 471 |
+
combination_mode,
|
| 472 |
+
focus_mode,
|
| 473 |
+
negative_prompt,
|
| 474 |
+
steps_slider,
|
| 475 |
+
guidance_slider
|
| 476 |
+
],
|
| 477 |
+
outputs=[
|
| 478 |
+
combined_output,
|
| 479 |
+
generated_output,
|
| 480 |
+
original_output,
|
| 481 |
+
status_output,
|
| 482 |
+
download_btn
|
| 483 |
+
]
|
| 484 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
+
clear_btn.click(
|
| 487 |
+
fn=lambda: (None, None, None, "Ready to create!", gr.update(visible=False)),
|
| 488 |
+
outputs=[combined_output, generated_output, original_output, status_output, download_btn]
|
| 489 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
+
memory_btn.click(
|
| 492 |
+
fn=lambda: self.sceneweaver._ultra_memory_cleanup() or "Memory cleaned!",
|
| 493 |
+
outputs=[status_output]
|
|
|
|
|
|
|
| 494 |
)
|
| 495 |
|
| 496 |
+
combined_output.change(
|
| 497 |
+
fn=lambda img: gr.update(value="outputs/latest_combined.png", visible=True) if (img is not None) else gr.update(visible=False),
|
| 498 |
+
inputs=[combined_output],
|
| 499 |
+
outputs=[download_btn]
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
# End of Background Generation Tab
|
| 503 |
+
|
| 504 |
+
# Inpainting Tab
|
| 505 |
+
self.create_inpainting_tab()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
+
# Footer with tech credits (outside tabs)
|
| 508 |
gr.HTML("""
|
| 509 |
<div class="app-footer">
|
| 510 |
<div class="footer-powered">
|
|
|
|
| 527 |
</div>
|
| 528 |
""")
|
| 529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
return interface
|
| 531 |
|
| 532 |
def launch(self, share: bool = True, debug: bool = False):
|
| 533 |
"""Launch the UI interface"""
|
| 534 |
interface = self.create_interface()
|
| 535 |
|
| 536 |
+
# Launch kwargs compatible with both Gradio 4.x and 5.x
|
| 537 |
launch_kwargs = {
|
| 538 |
"share": share,
|
| 539 |
"debug": debug,
|
|
|
|
| 541 |
"quiet": False
|
| 542 |
}
|
| 543 |
|
| 544 |
+
return interface.launch(**launch_kwargs)
|
| 545 |
+
|
| 546 |
+
# INPAINTING UI METHODS
|
| 547 |
+
def apply_inpainting_template(
|
| 548 |
+
self,
|
| 549 |
+
display_name: str,
|
| 550 |
+
current_prompt: str
|
| 551 |
+
) -> Tuple[str, float, int, str]:
|
| 552 |
+
"""
|
| 553 |
+
Apply an inpainting template to the UI fields.
|
| 554 |
+
|
| 555 |
+
Parameters
|
| 556 |
+
----------
|
| 557 |
+
display_name : str
|
| 558 |
+
Template display name from dropdown
|
| 559 |
+
current_prompt : str
|
| 560 |
+
Current prompt content
|
| 561 |
+
|
| 562 |
+
Returns
|
| 563 |
+
-------
|
| 564 |
+
tuple
|
| 565 |
+
(prompt, conditioning_scale, feather_radius, conditioning_type)
|
| 566 |
+
"""
|
| 567 |
+
if not display_name:
|
| 568 |
+
return current_prompt, 0.7, 8, "canny"
|
| 569 |
+
|
| 570 |
+
template_key = self.inpainting_template_manager.get_template_key_from_display(display_name)
|
| 571 |
+
if not template_key:
|
| 572 |
+
return current_prompt, 0.7, 8, "canny"
|
| 573 |
+
|
| 574 |
+
template = self.inpainting_template_manager.get_template(template_key)
|
| 575 |
+
if template:
|
| 576 |
+
params = self.inpainting_template_manager.get_parameters_for_template(template_key)
|
| 577 |
+
return (
|
| 578 |
+
current_prompt,
|
| 579 |
+
params.get('controlnet_conditioning_scale', 0.7),
|
| 580 |
+
params.get('feather_radius', 8),
|
| 581 |
+
params.get('preferred_conditioning', 'canny')
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
return current_prompt, 0.7, 8, "canny"
|
| 585 |
+
|
| 586 |
+
def extract_mask_from_editor(self, editor_output: Dict[str, Any]) -> Optional[Image.Image]:
|
| 587 |
+
"""
|
| 588 |
+
Extract mask from Gradio ImageEditor output.
|
| 589 |
+
|
| 590 |
+
Handles different Gradio versions' output formats.
|
| 591 |
+
|
| 592 |
+
Parameters
|
| 593 |
+
----------
|
| 594 |
+
editor_output : dict
|
| 595 |
+
Output from gr.ImageEditor component
|
| 596 |
+
|
| 597 |
+
Returns
|
| 598 |
+
-------
|
| 599 |
+
PIL.Image or None
|
| 600 |
+
Extracted mask as grayscale image
|
| 601 |
+
"""
|
| 602 |
+
if editor_output is None:
|
| 603 |
+
return None
|
| 604 |
+
|
| 605 |
+
try:
|
| 606 |
+
# Gradio 5.x format
|
| 607 |
+
if isinstance(editor_output, dict):
|
| 608 |
+
# Check for 'layers' key (Gradio 5.x ImageEditor)
|
| 609 |
+
if 'layers' in editor_output and editor_output['layers']:
|
| 610 |
+
# Get the first layer as mask
|
| 611 |
+
layer = editor_output['layers'][0]
|
| 612 |
+
if isinstance(layer, np.ndarray):
|
| 613 |
+
mask_array = layer
|
| 614 |
+
elif isinstance(layer, Image.Image):
|
| 615 |
+
mask_array = np.array(layer)
|
| 616 |
+
else:
|
| 617 |
+
return None
|
| 618 |
+
|
| 619 |
+
# Check for 'composite' key
|
| 620 |
+
elif 'composite' in editor_output:
|
| 621 |
+
composite = editor_output['composite']
|
| 622 |
+
if isinstance(composite, np.ndarray):
|
| 623 |
+
mask_array = composite
|
| 624 |
+
elif isinstance(composite, Image.Image):
|
| 625 |
+
mask_array = np.array(composite)
|
| 626 |
+
else:
|
| 627 |
+
return None
|
| 628 |
+
else:
|
| 629 |
+
return None
|
| 630 |
+
|
| 631 |
+
elif isinstance(editor_output, np.ndarray):
|
| 632 |
+
mask_array = editor_output
|
| 633 |
+
elif isinstance(editor_output, Image.Image):
|
| 634 |
+
mask_array = np.array(editor_output)
|
| 635 |
+
else:
|
| 636 |
+
logger.warning(f"Unexpected editor output type: {type(editor_output)}")
|
| 637 |
+
return None
|
| 638 |
+
|
| 639 |
+
# Convert to grayscale if needed
|
| 640 |
+
if len(mask_array.shape) == 3:
|
| 641 |
+
if mask_array.shape[2] == 4:
|
| 642 |
+
# RGBA - use alpha channel
|
| 643 |
+
mask_gray = mask_array[:, :, 3]
|
| 644 |
+
else:
|
| 645 |
+
# RGB - convert to grayscale
|
| 646 |
+
mask_gray = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
|
| 647 |
+
else:
|
| 648 |
+
mask_gray = mask_array
|
| 649 |
+
|
| 650 |
+
return Image.fromarray(mask_gray.astype(np.uint8), mode='L')
|
| 651 |
+
|
| 652 |
+
except Exception as e:
|
| 653 |
+
logger.error(f"Failed to extract mask from editor: {e}")
|
| 654 |
+
return None
|
| 655 |
+
|
| 656 |
+
def inpainting_handler(
|
| 657 |
+
self,
|
| 658 |
+
image: Optional[Image.Image],
|
| 659 |
+
mask_editor: Dict[str, Any],
|
| 660 |
+
prompt: str,
|
| 661 |
+
template_dropdown: str,
|
| 662 |
+
conditioning_type: str,
|
| 663 |
+
conditioning_scale: float,
|
| 664 |
+
feather_radius: int,
|
| 665 |
+
guidance_scale: float,
|
| 666 |
+
num_steps: int,
|
| 667 |
+
progress: gr.Progress = gr.Progress()
|
| 668 |
+
) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], str]:
|
| 669 |
+
"""
|
| 670 |
+
Handle inpainting generation request.
|
| 671 |
+
|
| 672 |
+
Parameters
|
| 673 |
+
----------
|
| 674 |
+
image : PIL.Image
|
| 675 |
+
Original image to inpaint
|
| 676 |
+
mask_editor : dict
|
| 677 |
+
Mask editor output
|
| 678 |
+
prompt : str
|
| 679 |
+
Text description of desired content
|
| 680 |
+
template_dropdown : str
|
| 681 |
+
Selected template (optional)
|
| 682 |
+
conditioning_type : str
|
| 683 |
+
ControlNet conditioning type
|
| 684 |
+
conditioning_scale : float
|
| 685 |
+
ControlNet influence strength
|
| 686 |
+
feather_radius : int
|
| 687 |
+
Mask feathering radius
|
| 688 |
+
guidance_scale : float
|
| 689 |
+
Guidance scale for generation
|
| 690 |
+
num_steps : int
|
| 691 |
+
Number of inference steps
|
| 692 |
+
progress : gr.Progress
|
| 693 |
+
Progress callback
|
| 694 |
+
|
| 695 |
+
Returns
|
| 696 |
+
-------
|
| 697 |
+
tuple
|
| 698 |
+
(result_image, preview_image, control_image, status_message)
|
| 699 |
+
"""
|
| 700 |
+
if image is None:
|
| 701 |
+
return None, None, None, "Please upload an image first"
|
| 702 |
+
|
| 703 |
+
# Extract mask
|
| 704 |
+
mask = self.extract_mask_from_editor(mask_editor)
|
| 705 |
+
if mask is None:
|
| 706 |
+
return None, None, None, "Please draw a mask on the image"
|
| 707 |
+
|
| 708 |
+
# Validate mask
|
| 709 |
+
mask_array = np.array(mask)
|
| 710 |
+
coverage = np.count_nonzero(mask_array > 127) / mask_array.size
|
| 711 |
+
if coverage < 0.01:
|
| 712 |
+
return None, None, None, "Mask too small - please select a larger area"
|
| 713 |
+
if coverage > 0.95:
|
| 714 |
+
return None, None, None, "Mask too large - consider using background generation instead"
|
| 715 |
+
|
| 716 |
+
def progress_callback(msg: str, pct: int):
|
| 717 |
+
progress(pct / 100, desc=msg)
|
| 718 |
+
|
| 719 |
+
try:
|
| 720 |
+
start_time = time.time()
|
| 721 |
+
|
| 722 |
+
# Get template key if selected
|
| 723 |
+
template_key = None
|
| 724 |
+
if template_dropdown:
|
| 725 |
+
template_key = self.inpainting_template_manager.get_template_key_from_display(
|
| 726 |
+
template_dropdown
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
# Execute inpainting through SceneWeaverCore facade
|
| 730 |
+
result = self.sceneweaver.execute_inpainting(
|
| 731 |
+
image=image,
|
| 732 |
+
mask=mask,
|
| 733 |
+
prompt=prompt,
|
| 734 |
+
preview_only=False,
|
| 735 |
+
template_key=template_key,
|
| 736 |
+
conditioning_type=conditioning_type,
|
| 737 |
+
controlnet_conditioning_scale=conditioning_scale,
|
| 738 |
+
feather_radius=feather_radius,
|
| 739 |
+
guidance_scale=guidance_scale,
|
| 740 |
+
num_inference_steps=num_steps,
|
| 741 |
+
progress_callback=progress_callback
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
elapsed = time.time() - start_time
|
| 745 |
+
|
| 746 |
+
if result.get('success'):
|
| 747 |
+
# Store in history
|
| 748 |
+
self.inpainting_history.append({
|
| 749 |
+
'result': result.get('combined_image'),
|
| 750 |
+
'prompt': prompt,
|
| 751 |
+
'time': elapsed
|
| 752 |
+
})
|
| 753 |
+
if len(self.inpainting_history) > 3:
|
| 754 |
+
self.inpainting_history.pop(0)
|
| 755 |
+
|
| 756 |
+
quality_score = result.get('quality_score', 0)
|
| 757 |
+
status = f"Inpainting complete in {elapsed:.1f}s"
|
| 758 |
+
if quality_score > 0:
|
| 759 |
+
status += f" | Quality: {quality_score:.0f}/100"
|
| 760 |
+
|
| 761 |
+
return (
|
| 762 |
+
result.get('combined_image'),
|
| 763 |
+
result.get('preview_image'),
|
| 764 |
+
result.get('control_image'),
|
| 765 |
+
status
|
| 766 |
+
)
|
| 767 |
+
else:
|
| 768 |
+
error_msg = result.get('error', 'Unknown error')
|
| 769 |
+
return None, None, None, f"Inpainting failed: {error_msg}"
|
| 770 |
+
|
| 771 |
+
except Exception as e:
|
| 772 |
+
logger.error(f"Inpainting handler error: {e}")
|
| 773 |
+
logger.error(traceback.format_exc())
|
| 774 |
+
return None, None, None, f"Error: {str(e)}"
|
| 775 |
+
|
| 776 |
+
def create_inpainting_tab(self) -> gr.Tab:
|
| 777 |
+
"""
|
| 778 |
+
Create the inpainting tab UI.
|
| 779 |
+
|
| 780 |
+
Returns
|
| 781 |
+
-------
|
| 782 |
+
gr.Tab
|
| 783 |
+
Configured inpainting tab component
|
| 784 |
+
"""
|
| 785 |
+
with gr.Tab("Inpainting", elem_id="inpainting-tab") as tab:
|
| 786 |
+
gr.HTML("""
|
| 787 |
+
<div class="inpainting-header">
|
| 788 |
+
<h3 style="display: flex; align-items: center; gap: 10px; margin-bottom: 8px;">
|
| 789 |
+
ControlNet Inpainting
|
| 790 |
+
<span style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 791 |
+
color: white;
|
| 792 |
+
padding: 3px 10px;
|
| 793 |
+
border-radius: 12px;
|
| 794 |
+
font-size: 0.65em;
|
| 795 |
+
font-weight: 700;
|
| 796 |
+
letter-spacing: 0.5px;
|
| 797 |
+
box-shadow: 0 2px 4px rgba(102, 126, 234, 0.3);">
|
| 798 |
+
BETA
|
| 799 |
+
</span>
|
| 800 |
+
</h3>
|
| 801 |
+
<p style="color: #666; margin-bottom: 12px;">Draw a mask to select the area you want to regenerate</p>
|
| 802 |
+
<div style="background: linear-gradient(to right, #FFF4E6, #FFE8CC);
|
| 803 |
+
border-left: 4px solid #FF9500;
|
| 804 |
+
padding: 12px 15px;
|
| 805 |
+
border-radius: 6px;
|
| 806 |
+
margin-top: 10px;
|
| 807 |
+
box-shadow: 0 2px 4px rgba(255, 149, 0, 0.1);">
|
| 808 |
+
<p style="color: #8B4513; font-size: 0.9em; margin: 0; line-height: 1.5;">
|
| 809 |
+
<strong>⚠️ Beta Feature - Continuously Optimizing</strong><br>
|
| 810 |
+
Results may vary depending on complexity. Use templates and detailed prompts for best results.
|
| 811 |
+
Advanced features (like Add Accessories) may require multiple attempts.
|
| 812 |
+
</p>
|
| 813 |
+
</div>
|
| 814 |
+
</div>
|
| 815 |
+
""")
|
| 816 |
+
|
| 817 |
+
with gr.Row():
|
| 818 |
+
# Left column - Input
|
| 819 |
+
with gr.Column(scale=1):
|
| 820 |
+
# Image upload
|
| 821 |
+
inpaint_image = gr.Image(
|
| 822 |
+
label="Upload Image",
|
| 823 |
+
type="pil",
|
| 824 |
+
height=300
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
# Mask editor
|
| 828 |
+
mask_editor = gr.ImageEditor(
|
| 829 |
+
label="Draw Mask (white = area to inpaint)",
|
| 830 |
+
type="pil",
|
| 831 |
+
height=300,
|
| 832 |
+
brush=gr.Brush(colors=["#FFFFFF"], default_size=20),
|
| 833 |
+
eraser=gr.Eraser(default_size=20),
|
| 834 |
+
layers=False
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
# Template selection
|
| 838 |
+
with gr.Accordion("Inpainting Templates", open=False):
|
| 839 |
+
inpaint_template = gr.Dropdown(
|
| 840 |
+
choices=[""] + self.inpainting_template_manager.get_template_choices_sorted(),
|
| 841 |
+
value="",
|
| 842 |
+
label="Select Template",
|
| 843 |
+
elem_classes=["template-dropdown"]
|
| 844 |
+
)
|
| 845 |
+
template_tips = gr.Markdown("")
|
| 846 |
+
|
| 847 |
+
# Prompt
|
| 848 |
+
inpaint_prompt = gr.Textbox(
|
| 849 |
+
label="Prompt",
|
| 850 |
+
placeholder="Describe what you want to generate in the masked area...",
|
| 851 |
+
lines=2
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
# Right column - Settings and Output
|
| 855 |
+
with gr.Column(scale=1):
|
| 856 |
+
# Settings
|
| 857 |
+
with gr.Accordion("Generation Settings", open=True):
|
| 858 |
+
conditioning_type = gr.Radio(
|
| 859 |
+
choices=["canny", "depth"],
|
| 860 |
+
value="canny",
|
| 861 |
+
label="ControlNet Mode"
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
conditioning_scale = gr.Slider(
|
| 865 |
+
minimum=0.5,
|
| 866 |
+
maximum=1.0,
|
| 867 |
+
value=0.7,
|
| 868 |
+
step=0.05,
|
| 869 |
+
label="ControlNet Strength"
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
feather_radius = gr.Slider(
|
| 873 |
+
minimum=0,
|
| 874 |
+
maximum=20,
|
| 875 |
+
value=8,
|
| 876 |
+
step=1,
|
| 877 |
+
label="Feather Radius (px)"
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 881 |
+
inpaint_guidance = gr.Slider(
|
| 882 |
+
minimum=5.0,
|
| 883 |
+
maximum=15.0,
|
| 884 |
+
value=7.5,
|
| 885 |
+
step=0.5,
|
| 886 |
+
label="Guidance Scale"
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
inpaint_steps = gr.Slider(
|
| 890 |
+
minimum=15,
|
| 891 |
+
maximum=50,
|
| 892 |
+
value=25,
|
| 893 |
+
step=5,
|
| 894 |
+
label="Inference Steps"
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
# Generate button
|
| 898 |
+
inpaint_btn = gr.Button(
|
| 899 |
+
"Generate Inpainting",
|
| 900 |
+
variant="primary",
|
| 901 |
+
elem_classes=["primary-button"]
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
# Status
|
| 905 |
+
inpaint_status = gr.Textbox(
|
| 906 |
+
label="Status",
|
| 907 |
+
value="Ready for inpainting",
|
| 908 |
+
interactive=False
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
# Output row
|
| 912 |
+
with gr.Row():
|
| 913 |
+
with gr.Column(scale=1):
|
| 914 |
+
inpaint_result = gr.Image(
|
| 915 |
+
label="Result",
|
| 916 |
+
type="pil",
|
| 917 |
+
height=400
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
with gr.Column(scale=1):
|
| 921 |
+
with gr.Tabs():
|
| 922 |
+
with gr.Tab("Preview"):
|
| 923 |
+
inpaint_preview = gr.Image(
|
| 924 |
+
label="Preview",
|
| 925 |
+
type="pil",
|
| 926 |
+
height=350
|
| 927 |
+
)
|
| 928 |
+
with gr.Tab("Control Image"):
|
| 929 |
+
inpaint_control = gr.Image(
|
| 930 |
+
label="Control Image",
|
| 931 |
+
type="pil",
|
| 932 |
+
height=350
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
# Event handlers
|
| 936 |
+
inpaint_template.change(
|
| 937 |
+
fn=self.apply_inpainting_template,
|
| 938 |
+
inputs=[inpaint_template, inpaint_prompt],
|
| 939 |
+
outputs=[inpaint_prompt, conditioning_scale, feather_radius, conditioning_type]
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
inpaint_template.change(
|
| 943 |
+
fn=lambda x: self._get_template_tips(x),
|
| 944 |
+
inputs=[inpaint_template],
|
| 945 |
+
outputs=[template_tips]
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
# Copy uploaded image to mask editor
|
| 949 |
+
inpaint_image.change(
|
| 950 |
+
fn=lambda x: x,
|
| 951 |
+
inputs=[inpaint_image],
|
| 952 |
+
outputs=[mask_editor]
|
| 953 |
+
)
|
| 954 |
+
|
| 955 |
+
inpaint_btn.click(
|
| 956 |
+
fn=self.inpainting_handler,
|
| 957 |
+
inputs=[
|
| 958 |
+
inpaint_image,
|
| 959 |
+
mask_editor,
|
| 960 |
+
inpaint_prompt,
|
| 961 |
+
inpaint_template,
|
| 962 |
+
conditioning_type,
|
| 963 |
+
conditioning_scale,
|
| 964 |
+
feather_radius,
|
| 965 |
+
inpaint_guidance,
|
| 966 |
+
inpaint_steps
|
| 967 |
+
],
|
| 968 |
+
outputs=[
|
| 969 |
+
inpaint_result,
|
| 970 |
+
inpaint_preview,
|
| 971 |
+
inpaint_control,
|
| 972 |
+
inpaint_status
|
| 973 |
+
]
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
return tab
|
| 977 |
+
|
| 978 |
+
def _get_template_tips(self, display_name: str) -> str:
|
| 979 |
+
"""Get usage tips for selected template."""
|
| 980 |
+
if not display_name:
|
| 981 |
+
return ""
|
| 982 |
+
|
| 983 |
+
template_key = self.inpainting_template_manager.get_template_key_from_display(display_name)
|
| 984 |
+
if not template_key:
|
| 985 |
+
return ""
|
| 986 |
|
| 987 |
+
tips = self.inpainting_template_manager.get_usage_tips(template_key)
|
| 988 |
+
if tips:
|
| 989 |
+
return "**Tips:**\n" + "\n".join(f"- {tip}" for tip in tips)
|
| 990 |
+
return ""
|