DawnC commited on
Commit
7d583e3
·
verified ·
1 Parent(s): 2e3c747

Upload 11 files

Browse files

Creating New Feature: Inpainting mode

app.py CHANGED
@@ -1,4 +1,5 @@
1
  import sys
 
2
  import warnings
3
  warnings.filterwarnings("ignore")
4
 
@@ -26,13 +27,11 @@ def launch_final_blend_sceneweaver(share: bool = True, debug: bool = False):
26
  return interface
27
 
28
  except ImportError as import_error:
29
- import traceback
30
  print(f"❌ Import failed: {import_error}")
31
  print(f"Traceback: {traceback.format_exc()}")
32
  raise
33
 
34
  except Exception as e:
35
- import traceback
36
  print(f"❌ Failed to launch: {e}")
37
  print(f"Full traceback: {traceback.format_exc()}")
38
  raise
 
1
  import sys
2
+ import traceback
3
  import warnings
4
  warnings.filterwarnings("ignore")
5
 
 
27
  return interface
28
 
29
  except ImportError as import_error:
 
30
  print(f"❌ Import failed: {import_error}")
31
  print(f"Traceback: {traceback.format_exc()}")
32
  raise
33
 
34
  except Exception as e:
 
35
  print(f"❌ Failed to launch: {e}")
36
  print(f"Full traceback: {traceback.format_exc()}")
37
  raise
css_styles.py CHANGED
@@ -512,6 +512,235 @@ class CSSStyles:
512
  font-size: 0.95rem !important;
513
  }
514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  /* ===== DROPDOWN POSITIONING FIX FOR GRADIO 4.x/5.x ===== */
516
  /* Fix dropdown list positioning issue in left column */
517
  .feature-card {
@@ -542,4 +771,4 @@ class CSSStyles:
542
  box-shadow: var(--shadow-lg) !important;
543
  margin-top: 4px !important;
544
  }
545
- """
 
512
  font-size: 0.95rem !important;
513
  }
514
 
515
+ /* ===== INPAINTING UI STYLES ===== */
516
+ .inpainting-header {
517
+ text-align: center !important;
518
+ padding: 16px !important;
519
+ margin-bottom: 16px !important;
520
+ background: linear-gradient(135deg, var(--primary-light) 0%, var(--accent-light) 100%) !important;
521
+ border-radius: var(--radius-lg) !important;
522
+ border: 1px solid var(--border-color) !important;
523
+ }
524
+
525
+ .inpainting-header h3 {
526
+ font-size: 1.4rem !important;
527
+ font-weight: 600 !important;
528
+ color: var(--primary-color) !important;
529
+ margin: 0 0 8px 0 !important;
530
+ }
531
+
532
+ .inpainting-header p {
533
+ font-size: 0.95rem !important;
534
+ color: var(--text-secondary) !important;
535
+ margin: 0 !important;
536
+ }
537
+
538
+ /* Main mode tabs styling */
539
+ #main-mode-tabs {
540
+ margin-bottom: 16px !important;
541
+ }
542
+
543
+ #main-mode-tabs > .tab-nav {
544
+ background: var(--bg-secondary) !important;
545
+ border-radius: var(--radius-lg) !important;
546
+ padding: 6px !important;
547
+ gap: 6px !important;
548
+ border: 1px solid var(--border-color) !important;
549
+ }
550
+
551
+ #main-mode-tabs > .tab-nav > button {
552
+ border-radius: var(--radius-md) !important;
553
+ padding: 12px 24px !important;
554
+ font-weight: 600 !important;
555
+ font-size: 1rem !important;
556
+ transition: all var(--transition-normal) !important;
557
+ border: none !important;
558
+ background: transparent !important;
559
+ color: var(--text-secondary) !important;
560
+ }
561
+
562
+ #main-mode-tabs > .tab-nav > button.selected {
563
+ background: linear-gradient(135deg, var(--accent-color) 0%, var(--accent-hover) 100%) !important;
564
+ color: white !important;
565
+ box-shadow: var(--shadow-md) !important;
566
+ }
567
+
568
+ #main-mode-tabs > .tab-nav > button:hover:not(.selected) {
569
+ background: var(--bg-primary) !important;
570
+ color: var(--text-primary) !important;
571
+ }
572
+
573
+ /* Mask editor styling */
574
+ .mask-editor-container {
575
+ border: 2px dashed var(--border-color) !important;
576
+ border-radius: var(--radius-lg) !important;
577
+ padding: 8px !important;
578
+ background: var(--bg-secondary) !important;
579
+ transition: border-color var(--transition-fast) !important;
580
+ }
581
+
582
+ .mask-editor-container:hover {
583
+ border-color: var(--accent-color) !important;
584
+ }
585
+
586
+ /* Inpainting template cards */
587
+ .inpainting-gallery {
588
+ margin: 16px 0 !important;
589
+ }
590
+
591
+ .inpainting-category {
592
+ margin-bottom: 20px !important;
593
+ }
594
+
595
+ .inpainting-category-title {
596
+ font-size: 0.9rem !important;
597
+ font-weight: 600 !important;
598
+ color: var(--text-secondary) !important;
599
+ margin-bottom: 12px !important;
600
+ padding-bottom: 8px !important;
601
+ border-bottom: 1px solid var(--border-color) !important;
602
+ }
603
+
604
+ .inpainting-grid {
605
+ display: grid !important;
606
+ grid-template-columns: repeat(auto-fill, minmax(120px, 1fr)) !important;
607
+ gap: 10px !important;
608
+ }
609
+
610
+ .inpainting-card {
611
+ display: flex !important;
612
+ flex-direction: column !important;
613
+ align-items: center !important;
614
+ justify-content: center !important;
615
+ padding: 14px 10px !important;
616
+ background: var(--bg-primary) !important;
617
+ border: 1px solid var(--border-color) !important;
618
+ border-radius: var(--radius-md) !important;
619
+ cursor: pointer !important;
620
+ transition: all var(--transition-normal) !important;
621
+ min-height: 80px !important;
622
+ }
623
+
624
+ .inpainting-card:hover {
625
+ background: var(--accent-light) !important;
626
+ border-color: var(--accent-color) !important;
627
+ transform: translateY(-2px) !important;
628
+ box-shadow: var(--shadow-md) !important;
629
+ }
630
+
631
+ .inpainting-card.selected {
632
+ background: var(--accent-light) !important;
633
+ border-color: var(--accent-color) !important;
634
+ box-shadow: 0 0 0 2px var(--accent-color) !important;
635
+ }
636
+
637
+ .inpainting-icon {
638
+ font-size: 1.6rem !important;
639
+ margin-bottom: 6px !important;
640
+ }
641
+
642
+ .inpainting-name {
643
+ font-size: 0.8rem !important;
644
+ font-weight: 500 !important;
645
+ color: var(--text-primary) !important;
646
+ text-align: center !important;
647
+ line-height: 1.2 !important;
648
+ }
649
+
650
+ .inpainting-desc {
651
+ font-size: 0.7rem !important;
652
+ color: var(--text-muted) !important;
653
+ text-align: center !important;
654
+ margin-top: 4px !important;
655
+ }
656
+
657
+ /* ControlNet mode toggle */
658
+ .controlnet-mode-toggle {
659
+ display: flex !important;
660
+ gap: 8px !important;
661
+ margin: 8px 0 !important;
662
+ }
663
+
664
+ .controlnet-mode-btn {
665
+ flex: 1 !important;
666
+ padding: 10px 16px !important;
667
+ border: 1px solid var(--border-color) !important;
668
+ border-radius: var(--radius-md) !important;
669
+ background: var(--bg-primary) !important;
670
+ color: var(--text-secondary) !important;
671
+ font-weight: 500 !important;
672
+ cursor: pointer !important;
673
+ transition: all var(--transition-fast) !important;
674
+ }
675
+
676
+ .controlnet-mode-btn:hover {
677
+ border-color: var(--accent-color) !important;
678
+ color: var(--accent-color) !important;
679
+ }
680
+
681
+ .controlnet-mode-btn.active {
682
+ background: var(--accent-color) !important;
683
+ border-color: var(--accent-color) !important;
684
+ color: white !important;
685
+ }
686
+
687
+ /* Preview badge */
688
+ .preview-badge {
689
+ display: inline-flex !important;
690
+ align-items: center !important;
691
+ gap: 4px !important;
692
+ padding: 4px 10px !important;
693
+ background: var(--warning-color) !important;
694
+ color: white !important;
695
+ font-size: 0.75rem !important;
696
+ font-weight: 600 !important;
697
+ border-radius: var(--radius-sm) !important;
698
+ text-transform: uppercase !important;
699
+ }
700
+
701
+ /* Quality score display */
702
+ .quality-score {
703
+ display: inline-flex !important;
704
+ align-items: center !important;
705
+ gap: 6px !important;
706
+ padding: 6px 12px !important;
707
+ border-radius: var(--radius-md) !important;
708
+ font-weight: 500 !important;
709
+ font-size: 0.9rem !important;
710
+ }
711
+
712
+ .quality-score.excellent {
713
+ background: rgba(16, 185, 129, 0.1) !important;
714
+ color: var(--success-color) !important;
715
+ }
716
+
717
+ .quality-score.good {
718
+ background: rgba(59, 130, 246, 0.1) !important;
719
+ color: var(--accent-color) !important;
720
+ }
721
+
722
+ .quality-score.warning {
723
+ background: rgba(245, 158, 11, 0.1) !important;
724
+ color: var(--warning-color) !important;
725
+ }
726
+
727
+ .quality-score.poor {
728
+ background: rgba(239, 68, 68, 0.1) !important;
729
+ color: var(--error-color) !important;
730
+ }
731
+
732
+ /* Responsive adjustments for inpainting */
733
+ @media (max-width: 768px) {
734
+ .inpainting-grid {
735
+ grid-template-columns: repeat(2, 1fr) !important;
736
+ }
737
+
738
+ #main-mode-tabs > .tab-nav > button {
739
+ padding: 10px 16px !important;
740
+ font-size: 0.9rem !important;
741
+ }
742
+ }
743
+
744
  /* ===== DROPDOWN POSITIONING FIX FOR GRADIO 4.x/5.x ===== */
745
  /* Fix dropdown list positioning issue in left column */
746
  .feature-card {
 
771
  box-shadow: var(--shadow-lg) !important;
772
  margin-top: 4px !important;
773
  }
774
+ """
image_blender.py CHANGED
@@ -1,5 +1,6 @@
1
  import cv2
2
  import numpy as np
 
3
  from PIL import Image
4
  import logging
5
  from typing import Dict, Any, Optional, Tuple
@@ -7,24 +8,37 @@ from typing import Dict, Any, Optional, Tuple
7
  logger = logging.getLogger(__name__)
8
  logger.setLevel(logging.INFO)
9
 
 
10
  class ImageBlender:
11
  """
12
- Advanced image blending with aggressive spill suppression and color replacement
13
- Completely eliminates yellow edge residue while maintaining sharp edges
 
 
 
 
 
 
14
  """
15
 
16
- EDGE_EROSION_PIXELS = 1 # Pixels to erode from mask edge (reduced to protect more foreground)
17
- ALPHA_BINARIZE_THRESHOLD = 0.5 # Alpha threshold for binarization (increased to keep more foreground)
18
- DARK_LUMINANCE_THRESHOLD = 60 # Luminance threshold for dark foreground detection
19
- FOREGROUND_PROTECTION_THRESHOLD = 140 # Mask value above which pixels are strongly protected
20
- BACKGROUND_COLOR_TOLERANCE = 30 # DeltaE tolerance for background color detection
 
 
 
 
21
 
22
  def __init__(self, enable_multi_scale: bool = True):
23
  """
24
  Initialize ImageBlender.
25
 
26
- Args:
27
- enable_multi_scale: Whether to enable multi-scale edge refinement (default True)
 
 
28
  """
29
  self.enable_multi_scale = enable_multi_scale
30
  self._debug_info = {}
@@ -499,7 +513,6 @@ class ImageBlender:
499
  logger.info(f"🔍 Trimap regions - FG_CORE: {fg_core.sum()}, RING: {ring_zone.sum()}, BG: {bg_zone.sum()}")
500
 
501
  except Exception as e:
502
- import traceback
503
  logger.error(f"❌ Trimap definition failed: {e}")
504
  logger.error(f"📍 Traceback: {traceback.format_exc()}")
505
  print(f"❌ TRIMAP ERROR: {e}")
@@ -678,7 +691,7 @@ class ImageBlender:
678
  orig_linear = srgb_to_linear(orig_array)
679
  bg_linear = srgb_to_linear(bg_array)
680
 
681
- # === Cartoon-optimized Alpha calculation ===
682
  alpha = mask_array.astype(np.float32) / 255.0
683
 
684
  # Core foreground region - fully opaque
@@ -800,3 +813,293 @@ class ImageBlender:
800
  debug_images["adaptive_strength_heatmap"] = Image.fromarray(strength_heatmap_rgb)
801
 
802
  return debug_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
  import numpy as np
3
+ import traceback
4
  from PIL import Image
5
  import logging
6
  from typing import Dict, Any, Optional, Tuple
 
8
  logger = logging.getLogger(__name__)
9
  logger.setLevel(logging.INFO)
10
 
11
+
12
  class ImageBlender:
13
  """
14
+ Advanced image blending with aggressive spill suppression and color replacement.
15
+
16
+ Supports two primary modes:
17
+ - Background generation: Foreground preservation with edge refinement
18
+ - Inpainting: Seamless blending with adaptive color correction
19
+
20
+ Attributes:
21
+ enable_multi_scale: Whether multi-scale edge refinement is enabled
22
  """
23
 
24
+ EDGE_EROSION_PIXELS = 1 # Pixels to erode from mask edge
25
+ ALPHA_BINARIZE_THRESHOLD = 0.5 # Alpha threshold for binarization
26
+ DARK_LUMINANCE_THRESHOLD = 60 # Luminance threshold for dark foreground
27
+ FOREGROUND_PROTECTION_THRESHOLD = 140 # Mask value for strong protection
28
+ BACKGROUND_COLOR_TOLERANCE = 30 # DeltaE tolerance for background detection
29
+
30
+ # Inpainting-specific parameters
31
+ INPAINT_FEATHER_SCALE = 1.2 # Scale factor for inpainting feathering
32
+ INPAINT_COLOR_BLEND_RADIUS = 10 # Radius for color adaptation zone
33
 
34
  def __init__(self, enable_multi_scale: bool = True):
35
  """
36
  Initialize ImageBlender.
37
 
38
+ Parameters
39
+ ----------
40
+ enable_multi_scale : bool
41
+ Whether to enable multi-scale edge refinement (default True)
42
  """
43
  self.enable_multi_scale = enable_multi_scale
44
  self._debug_info = {}
 
513
  logger.info(f"🔍 Trimap regions - FG_CORE: {fg_core.sum()}, RING: {ring_zone.sum()}, BG: {bg_zone.sum()}")
514
 
515
  except Exception as e:
 
516
  logger.error(f"❌ Trimap definition failed: {e}")
517
  logger.error(f"📍 Traceback: {traceback.format_exc()}")
518
  print(f"❌ TRIMAP ERROR: {e}")
 
691
  orig_linear = srgb_to_linear(orig_array)
692
  bg_linear = srgb_to_linear(bg_array)
693
 
694
+ # Cartoon-optimized Alpha calculation
695
  alpha = mask_array.astype(np.float32) / 255.0
696
 
697
  # Core foreground region - fully opaque
 
813
  debug_images["adaptive_strength_heatmap"] = Image.fromarray(strength_heatmap_rgb)
814
 
815
  return debug_images
816
+
817
+ # INPAINTING-SPECIFIC BLENDING METHODS
818
+ def blend_inpainting(
819
+ self,
820
+ original: Image.Image,
821
+ generated: Image.Image,
822
+ mask: Image.Image,
823
+ feather_radius: int = 8,
824
+ apply_color_correction: bool = True
825
+ ) -> Image.Image:
826
+ """
827
+ Blend inpainted region with original image.
828
+
829
+ Specialized blending for inpainting that focuses on seamless integration
830
+ rather than foreground protection. Performs blending in linear color space
831
+ with optional adaptive color correction at boundaries.
832
+
833
+ Parameters
834
+ ----------
835
+ original : PIL.Image
836
+ Original image before inpainting
837
+ generated : PIL.Image
838
+ Generated/inpainted result from the model
839
+ mask : PIL.Image
840
+ Inpainting mask (white = inpainted area)
841
+ feather_radius : int
842
+ Feathering radius for smooth transitions
843
+ apply_color_correction : bool
844
+ Whether to apply adaptive color correction at boundaries
845
+
846
+ Returns
847
+ -------
848
+ PIL.Image
849
+ Blended result
850
+ """
851
+ logger.info(f"Inpainting blend: feather={feather_radius}, color_correction={apply_color_correction}")
852
+
853
+ # Ensure same size
854
+ if generated.size != original.size:
855
+ generated = generated.resize(original.size, Image.LANCZOS)
856
+ if mask.size != original.size:
857
+ mask = mask.resize(original.size, Image.LANCZOS)
858
+
859
+ # Convert to arrays
860
+ orig_array = np.array(original.convert('RGB')).astype(np.float32)
861
+ gen_array = np.array(generated.convert('RGB')).astype(np.float32)
862
+ mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
863
+
864
+ # Apply feathering to mask
865
+ if feather_radius > 0:
866
+ scaled_radius = int(feather_radius * self.INPAINT_FEATHER_SCALE)
867
+ kernel_size = scaled_radius * 2 + 1
868
+ mask_array = cv2.GaussianBlur(
869
+ mask_array,
870
+ (kernel_size, kernel_size),
871
+ scaled_radius / 2
872
+ )
873
+
874
+ # Apply adaptive color correction if enabled
875
+ if apply_color_correction:
876
+ gen_array = self._apply_inpaint_color_correction(
877
+ orig_array, gen_array, mask_array
878
+ )
879
+
880
+ # sRGB to linear conversion for accurate blending
881
+ def srgb_to_linear(img):
882
+ img_norm = img / 255.0
883
+ return np.where(
884
+ img_norm <= 0.04045,
885
+ img_norm / 12.92,
886
+ np.power((img_norm + 0.055) / 1.055, 2.4)
887
+ )
888
+
889
+ def linear_to_srgb(img):
890
+ img_clipped = np.clip(img, 0, 1)
891
+ return np.where(
892
+ img_clipped <= 0.0031308,
893
+ 12.92 * img_clipped,
894
+ 1.055 * np.power(img_clipped, 1/2.4) - 0.055
895
+ )
896
+
897
+ # Convert to linear space
898
+ orig_linear = srgb_to_linear(orig_array)
899
+ gen_linear = srgb_to_linear(gen_array)
900
+
901
+ # Alpha blending in linear space
902
+ alpha = mask_array[:, :, np.newaxis]
903
+ result_linear = gen_linear * alpha + orig_linear * (1 - alpha)
904
+
905
+ # Convert back to sRGB
906
+ result_srgb = linear_to_srgb(result_linear)
907
+ result_array = (result_srgb * 255).astype(np.uint8)
908
+
909
+ logger.debug("Inpainting blend completed in linear color space")
910
+
911
+ return Image.fromarray(result_array)
912
+
913
+ def _apply_inpaint_color_correction(
914
+ self,
915
+ original: np.ndarray,
916
+ generated: np.ndarray,
917
+ mask: np.ndarray
918
+ ) -> np.ndarray:
919
+ """
920
+ Apply adaptive color correction to match generated region with surroundings.
921
+
922
+ Analyzes the boundary region and adjusts the generated content's
923
+ luminance and color to better match the original context.
924
+
925
+ Parameters
926
+ ----------
927
+ original : np.ndarray
928
+ Original image (float32, 0-255)
929
+ generated : np.ndarray
930
+ Generated image (float32, 0-255)
931
+ mask : np.ndarray
932
+ Blend mask (float32, 0-1)
933
+
934
+ Returns
935
+ -------
936
+ np.ndarray
937
+ Color-corrected generated image
938
+ """
939
+ # Find boundary region
940
+ mask_binary = (mask > 0.5).astype(np.uint8)
941
+ kernel = cv2.getStructuringElement(
942
+ cv2.MORPH_ELLIPSE,
943
+ (self.INPAINT_COLOR_BLEND_RADIUS * 2 + 1, self.INPAINT_COLOR_BLEND_RADIUS * 2 + 1)
944
+ )
945
+ dilated = cv2.dilate(mask_binary, kernel, iterations=1)
946
+ boundary_zone = (dilated > 0) & (mask < 0.3)
947
+
948
+ if not np.any(boundary_zone):
949
+ return generated
950
+
951
+ # Convert to Lab for perceptual color matching
952
+ orig_lab = cv2.cvtColor(
953
+ original.astype(np.uint8), cv2.COLOR_RGB2LAB
954
+ ).astype(np.float32)
955
+ gen_lab = cv2.cvtColor(
956
+ generated.astype(np.uint8), cv2.COLOR_RGB2LAB
957
+ ).astype(np.float32)
958
+
959
+ # Calculate statistics in boundary zone (original)
960
+ boundary_orig_l = orig_lab[boundary_zone, 0]
961
+ boundary_orig_a = orig_lab[boundary_zone, 1]
962
+ boundary_orig_b = orig_lab[boundary_zone, 2]
963
+
964
+ orig_mean_l = np.median(boundary_orig_l)
965
+ orig_mean_a = np.median(boundary_orig_a)
966
+ orig_mean_b = np.median(boundary_orig_b)
967
+
968
+ # Calculate statistics in generated inpaint region
969
+ inpaint_zone = mask > 0.5
970
+ if not np.any(inpaint_zone):
971
+ return generated
972
+
973
+ gen_inpaint_l = gen_lab[inpaint_zone, 0]
974
+ gen_inpaint_a = gen_lab[inpaint_zone, 1]
975
+ gen_inpaint_b = gen_lab[inpaint_zone, 2]
976
+
977
+ gen_mean_l = np.median(gen_inpaint_l)
978
+ gen_mean_a = np.median(gen_inpaint_a)
979
+ gen_mean_b = np.median(gen_inpaint_b)
980
+
981
+ # Calculate correction deltas
982
+ delta_l = orig_mean_l - gen_mean_l
983
+ delta_a = orig_mean_a - gen_mean_a
984
+ delta_b = orig_mean_b - gen_mean_b
985
+
986
+ # Limit correction to avoid over-adjustment
987
+ max_correction = 15
988
+ delta_l = np.clip(delta_l, -max_correction, max_correction)
989
+ delta_a = np.clip(delta_a, -max_correction * 0.5, max_correction * 0.5)
990
+ delta_b = np.clip(delta_b, -max_correction * 0.5, max_correction * 0.5)
991
+
992
+ logger.debug(f"Color correction deltas: L={delta_l:.1f}, a={delta_a:.1f}, b={delta_b:.1f}")
993
+
994
+ # Apply correction with spatial falloff from boundary
995
+ # Create distance map from boundary
996
+ distance = cv2.distanceTransform(
997
+ mask_binary, cv2.DIST_L2, 5
998
+ )
999
+ max_dist = np.max(distance)
1000
+ if max_dist > 0:
1001
+ # Correction strength falls off from boundary toward center
1002
+ correction_strength = 1.0 - np.clip(distance / (max_dist * 0.5), 0, 1)
1003
+ else:
1004
+ correction_strength = np.ones_like(distance)
1005
+
1006
+ # Apply correction to Lab channels
1007
+ corrected_lab = gen_lab.copy()
1008
+ corrected_lab[:, :, 0] += delta_l * correction_strength * 0.7
1009
+ corrected_lab[:, :, 1] += delta_a * correction_strength * 0.5
1010
+ corrected_lab[:, :, 2] += delta_b * correction_strength * 0.5
1011
+
1012
+ # Clip to valid Lab ranges
1013
+ corrected_lab[:, :, 0] = np.clip(corrected_lab[:, :, 0], 0, 255)
1014
+ corrected_lab[:, :, 1] = np.clip(corrected_lab[:, :, 1], 0, 255)
1015
+ corrected_lab[:, :, 2] = np.clip(corrected_lab[:, :, 2], 0, 255)
1016
+
1017
+ # Convert back to RGB
1018
+ corrected_rgb = cv2.cvtColor(
1019
+ corrected_lab.astype(np.uint8), cv2.COLOR_LAB2RGB
1020
+ ).astype(np.float32)
1021
+
1022
+ return corrected_rgb
1023
+
1024
+ def blend_inpainting_with_guided_filter(
1025
+ self,
1026
+ original: Image.Image,
1027
+ generated: Image.Image,
1028
+ mask: Image.Image,
1029
+ feather_radius: int = 8,
1030
+ guide_radius: int = 8,
1031
+ guide_eps: float = 0.01
1032
+ ) -> Image.Image:
1033
+ """
1034
+ Blend inpainted region using guided filter for edge-aware transitions.
1035
+
1036
+ Combines standard alpha blending with guided filtering to preserve
1037
+ edges in the original image while seamlessly integrating new content.
1038
+
1039
+ Parameters
1040
+ ----------
1041
+ original : PIL.Image
1042
+ Original image
1043
+ generated : PIL.Image
1044
+ Generated/inpainted result
1045
+ mask : PIL.Image
1046
+ Inpainting mask
1047
+ feather_radius : int
1048
+ Base feathering radius
1049
+ guide_radius : int
1050
+ Guided filter radius
1051
+ guide_eps : float
1052
+ Guided filter regularization
1053
+
1054
+ Returns
1055
+ -------
1056
+ PIL.Image
1057
+ Blended result with edge-aware transitions
1058
+ """
1059
+ logger.info("Applying guided filter inpainting blend")
1060
+
1061
+ # Ensure same size
1062
+ if generated.size != original.size:
1063
+ generated = generated.resize(original.size, Image.LANCZOS)
1064
+ if mask.size != original.size:
1065
+ mask = mask.resize(original.size, Image.LANCZOS)
1066
+
1067
+ # Convert to arrays
1068
+ orig_array = np.array(original.convert('RGB')).astype(np.float32)
1069
+ gen_array = np.array(generated.convert('RGB')).astype(np.float32)
1070
+ mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
1071
+
1072
+ # Apply base feathering
1073
+ if feather_radius > 0:
1074
+ kernel_size = feather_radius * 2 + 1
1075
+ mask_feathered = cv2.GaussianBlur(
1076
+ mask_array,
1077
+ (kernel_size, kernel_size),
1078
+ feather_radius / 2
1079
+ )
1080
+ else:
1081
+ mask_feathered = mask_array
1082
+
1083
+ # Use original image as guide for the filter
1084
+ guide = cv2.cvtColor(orig_array.astype(np.uint8), cv2.COLOR_RGB2GRAY)
1085
+ guide = guide.astype(np.float32) / 255.0
1086
+
1087
+ # Apply guided filter to the mask
1088
+ try:
1089
+ mask_guided = cv2.ximgproc.guidedFilter(
1090
+ guide=guide,
1091
+ src=mask_feathered,
1092
+ radius=guide_radius,
1093
+ eps=guide_eps
1094
+ )
1095
+ logger.debug("Guided filter applied successfully")
1096
+ except Exception as e:
1097
+ logger.warning(f"Guided filter failed: {e}, using standard feathering")
1098
+ mask_guided = mask_feathered
1099
+
1100
+ # Alpha blending
1101
+ alpha = mask_guided[:, :, np.newaxis]
1102
+ result = gen_array * alpha + orig_array * (1 - alpha)
1103
+ result = np.clip(result, 0, 255).astype(np.uint8)
1104
+
1105
+ return Image.fromarray(result)
inpainting_module.py ADDED
@@ -0,0 +1,1311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ import time
4
+ import traceback
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image, ImageFilter
12
+
13
+ from diffusers import ControlNetModel, DPMSolverMultistepScheduler
14
+ from diffusers import StableDiffusionXLControlNetInpaintPipeline
15
+ from diffusers import StableDiffusionXLInpaintPipeline
16
+ from transformers import AutoImageProcessor, AutoModelForDepthEstimation
17
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
18
+
19
+ logger = logging.getLogger(__name__)
20
+ logger.setLevel(logging.INFO)
21
+
22
+
23
+ @dataclass
24
+ class InpaintingConfig:
25
+ """Configuration for inpainting operations."""
26
+
27
+ # ControlNet settings
28
+ controlnet_conditioning_scale: float = 0.7
29
+ conditioning_type: str = "canny" # "canny" or "depth"
30
+
31
+ # Canny edge detection parameters
32
+ canny_low_threshold: int = 100
33
+ canny_high_threshold: int = 200
34
+
35
+ # Mask settings
36
+ feather_radius: int = 8
37
+ min_mask_coverage: float = 0.01
38
+ max_mask_coverage: float = 0.95
39
+
40
+ # Generation settings
41
+ num_inference_steps: int = 25
42
+ guidance_scale: float = 7.5
43
+ preview_steps: int = 15
44
+ preview_guidance_scale: float = 8.0
45
+
46
+ # Quality settings
47
+ enable_auto_optimization: bool = True
48
+ max_optimization_retries: int = 3
49
+ min_quality_score: float = 70.0
50
+
51
+ # Memory settings
52
+ enable_vae_tiling: bool = True
53
+ enable_attention_slicing: bool = True
54
+ max_resolution: int = 1024
55
+
56
+
57
+ @dataclass
58
+ class InpaintingResult:
59
+ """Result container for inpainting operations."""
60
+
61
+ success: bool
62
+ result_image: Optional[Image.Image] = None
63
+ preview_image: Optional[Image.Image] = None
64
+ control_image: Optional[Image.Image] = None
65
+ blended_image: Optional[Image.Image] = None
66
+ quality_score: float = 0.0
67
+ quality_details: Dict[str, Any] = field(default_factory=dict)
68
+ generation_time: float = 0.0
69
+ retries: int = 0
70
+ error_message: str = ""
71
+ metadata: Dict[str, Any] = field(default_factory=dict)
72
+
73
+
74
+ class InpaintingModule:
75
+ """
76
+ ControlNet-based Inpainting Module for SceneWeaver.
77
+
78
+ Implements StableDiffusionXLControlNetInpaintPipeline with support for
79
+ Canny edge and depth map conditioning. Features two-stage generation
80
+ (preview + full quality) and automatic quality optimization.
81
+
82
+ Attributes:
83
+ device: Computation device (cuda/mps/cpu)
84
+ config: InpaintingConfig instance
85
+ is_initialized: Whether pipeline is loaded
86
+
87
+ Example:
88
+ >>> module = InpaintingModule(device="cuda")
89
+ >>> module.load_inpainting_pipeline(progress_callback=my_callback)
90
+ >>> result = module.execute_inpainting(
91
+ ... image=my_image,
92
+ ... mask=my_mask,
93
+ ... prompt="a beautiful garden"
94
+ ... )
95
+ """
96
+
97
+ # Model identifiers
98
+ CONTROLNET_CANNY_MODEL = "diffusers/controlnet-canny-sdxl-1.0"
99
+ CONTROLNET_DEPTH_MODEL = "diffusers/controlnet-depth-sdxl-1.0"
100
+ DEPTH_MODEL_PRIMARY = "LiheYoung/depth-anything-small-hf"
101
+ DEPTH_MODEL_FALLBACK = "Intel/dpt-hybrid-midas"
102
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
103
+
104
+ def __init__(
105
+ self,
106
+ device: str = "auto",
107
+ config: Optional[InpaintingConfig] = None
108
+ ):
109
+ """
110
+ Initialize the InpaintingModule.
111
+
112
+ Parameters
113
+ ----------
114
+ device : str, optional
115
+ Computation device. "auto" for automatic detection.
116
+ config : InpaintingConfig, optional
117
+ Configuration object. Uses defaults if not provided.
118
+ """
119
+ self.device = self._setup_device(device)
120
+ self.config = config or InpaintingConfig()
121
+
122
+ # Pipeline instances (lazy loaded)
123
+ self._inpaint_pipeline = None
124
+ self._controlnet_canny = None
125
+ self._controlnet_depth = None
126
+ self._depth_estimator = None
127
+ self._depth_processor = None
128
+
129
+ # State tracking
130
+ self.is_initialized = False
131
+ self._current_conditioning_type = None
132
+ self._last_seed = None
133
+ self._cached_latents = None
134
+ self._use_controlnet = True # Track if ControlNet is available
135
+
136
+ # Reference to model manager (set by SceneWeaverCore)
137
+ self._model_manager = None
138
+
139
+ logger.info(f"InpaintingModule initialized on {self.device}")
140
+
141
+ def _setup_device(self, device: str) -> str:
142
+ """
143
+ Setup computation device.
144
+
145
+ Parameters
146
+ ----------
147
+ device : str
148
+ Device specification or "auto"
149
+
150
+ Returns
151
+ -------
152
+ str
153
+ Resolved device name
154
+ """
155
+ if device == "auto":
156
+ if torch.cuda.is_available():
157
+ return "cuda"
158
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
159
+ return "mps"
160
+ return "cpu"
161
+ return device
162
+
163
+ def set_model_manager(self, manager: Any) -> None:
164
+ """
165
+ Set reference to ModelManager for coordinated model lifecycle.
166
+
167
+ Parameters
168
+ ----------
169
+ manager : ModelManager
170
+ The global model manager instance
171
+ """
172
+ self._model_manager = manager
173
+ logger.info("ModelManager reference set for InpaintingModule")
174
+
175
+ def _memory_cleanup(self, aggressive: bool = False) -> None:
176
+ """
177
+ Perform memory cleanup.
178
+
179
+ Parameters
180
+ ----------
181
+ aggressive : bool
182
+ If True, perform multiple GC rounds and sync CUDA
183
+ """
184
+ rounds = 5 if aggressive else 2
185
+ for _ in range(rounds):
186
+ gc.collect()
187
+
188
+ if torch.cuda.is_available():
189
+ torch.cuda.empty_cache()
190
+ if aggressive:
191
+ torch.cuda.ipc_collect()
192
+ torch.cuda.synchronize()
193
+
194
+ logger.debug(f"Memory cleanup completed (aggressive={aggressive})")
195
+
196
+ def _check_memory_status(self) -> Dict[str, float]:
197
+ """
198
+ Check current GPU memory status.
199
+
200
+ Returns
201
+ -------
202
+ dict
203
+ Memory statistics including allocated, total, and usage ratio
204
+ """
205
+ if not torch.cuda.is_available():
206
+ return {"available": True, "usage_ratio": 0.0}
207
+
208
+ allocated = torch.cuda.memory_allocated() / 1024**3
209
+ total = torch.cuda.get_device_properties(0).total_memory / 1024**3
210
+ usage_ratio = allocated / total
211
+
212
+ return {
213
+ "allocated_gb": round(allocated, 2),
214
+ "total_gb": round(total, 2),
215
+ "free_gb": round(total - allocated, 2),
216
+ "usage_ratio": round(usage_ratio, 3),
217
+ "available": usage_ratio < 0.9
218
+ }
219
+
220
+ def load_inpainting_pipeline(
221
+ self,
222
+ conditioning_type: str = "canny",
223
+ progress_callback: Optional[Callable[[str, int], None]] = None
224
+ ) -> Tuple[bool, str]:
225
+ """
226
+ Load the ControlNet inpainting pipeline.
227
+
228
+ Implements mutual exclusion with background generation pipeline.
229
+ Only one pipeline can be loaded at a time.
230
+
231
+ Parameters
232
+ ----------
233
+ conditioning_type : str
234
+ Type of ControlNet conditioning: "canny" or "depth"
235
+ progress_callback : callable, optional
236
+ Function(message, percentage) for progress updates
237
+
238
+ Returns
239
+ -------
240
+ tuple
241
+ (success: bool, error_message: str)
242
+ """
243
+ if self.is_initialized and self._current_conditioning_type == conditioning_type:
244
+ logger.info(f"Inpainting pipeline already loaded with {conditioning_type}")
245
+ return True, ""
246
+
247
+ logger.info(f"Loading inpainting pipeline with {conditioning_type} conditioning...")
248
+
249
+ try:
250
+ self._memory_cleanup(aggressive=True)
251
+
252
+ if progress_callback:
253
+ progress_callback("Preparing to load inpainting models...", 5)
254
+
255
+ # Unload existing pipeline if different conditioning type
256
+ if self._inpaint_pipeline is not None:
257
+ self._unload_pipeline()
258
+
259
+ # Use ControlNet inpainting by default
260
+ use_controlnet_inpaint = True
261
+ logger.info("Using StableDiffusionXLControlNetInpaintPipeline")
262
+
263
+ if progress_callback:
264
+ progress_callback("Loading ControlNet model...", 20)
265
+
266
+ # Load appropriate ControlNet
267
+ dtype = torch.float16 if self.device == "cuda" else torch.float32
268
+ controlnet = None
269
+
270
+ if use_controlnet_inpaint:
271
+ if conditioning_type == "canny":
272
+ controlnet = ControlNetModel.from_pretrained(
273
+ self.CONTROLNET_CANNY_MODEL,
274
+ torch_dtype=dtype,
275
+ use_safetensors=True
276
+ )
277
+ self._controlnet_canny = controlnet
278
+ logger.info("Loaded ControlNet Canny model")
279
+
280
+ elif conditioning_type == "depth":
281
+ controlnet = ControlNetModel.from_pretrained(
282
+ self.CONTROLNET_DEPTH_MODEL,
283
+ torch_dtype=dtype,
284
+ use_safetensors=True
285
+ )
286
+ self._controlnet_depth = controlnet
287
+
288
+ # Load depth estimator
289
+ if progress_callback:
290
+ progress_callback("Loading depth estimation model...", 35)
291
+ self._load_depth_estimator()
292
+ logger.info("Loaded ControlNet Depth model")
293
+ else:
294
+ raise ValueError(f"Unknown conditioning type: {conditioning_type}")
295
+ else:
296
+ # Skip ControlNet loading for fallback mode
297
+ logger.info(f"Skipping ControlNet loading (fallback mode)")
298
+
299
+ if progress_callback:
300
+ progress_callback("Loading SDXL Inpainting pipeline...", 50)
301
+
302
+ # Load the inpainting pipeline
303
+ if use_controlnet_inpaint and controlnet is not None:
304
+ self._inpaint_pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
305
+ self.BASE_MODEL,
306
+ controlnet=controlnet,
307
+ torch_dtype=dtype,
308
+ use_safetensors=True,
309
+ variant="fp16" if dtype == torch.float16 else None
310
+ )
311
+ else:
312
+ # Fallback: Use dedicated inpainting model without ControlNet
313
+ self._inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
314
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
315
+ torch_dtype=dtype,
316
+ use_safetensors=True,
317
+ variant="fp16" if dtype == torch.float16 else None
318
+ )
319
+ self._use_controlnet = False
320
+
321
+ # Track ControlNet usage
322
+ self._use_controlnet = use_controlnet_inpaint and controlnet is not None
323
+
324
+ if progress_callback:
325
+ progress_callback("Configuring scheduler...", 70)
326
+
327
+ # Configure scheduler for faster generation
328
+ self._inpaint_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
329
+ self._inpaint_pipeline.scheduler.config
330
+ )
331
+
332
+ # Move to device
333
+ self._inpaint_pipeline = self._inpaint_pipeline.to(self.device)
334
+
335
+ if progress_callback:
336
+ progress_callback("Applying optimizations...", 85)
337
+
338
+ # Apply memory optimizations
339
+ self._apply_pipeline_optimizations()
340
+
341
+ # Set eval mode
342
+ self._inpaint_pipeline.unet.eval()
343
+ if hasattr(self._inpaint_pipeline, 'vae'):
344
+ self._inpaint_pipeline.vae.eval()
345
+
346
+ self.is_initialized = True
347
+ self._current_conditioning_type = conditioning_type if self._use_controlnet else "none"
348
+
349
+ if progress_callback:
350
+ progress_callback("Inpainting pipeline ready!", 100)
351
+
352
+ # Log memory status
353
+ mem_status = self._check_memory_status()
354
+ logger.info(f"Pipeline loaded. GPU memory: {mem_status.get('allocated_gb', 0):.1f}GB used")
355
+
356
+ return True, ""
357
+
358
+ except Exception as e:
359
+ error_msg = str(e)
360
+ logger.error(f"Failed to load inpainting pipeline: {error_msg}")
361
+ traceback.print_exc()
362
+ self._unload_pipeline()
363
+ return False, error_msg
364
+
365
+ def _load_depth_estimator(self) -> None:
366
+ """
367
+ Load depth estimation model with fallback strategy.
368
+
369
+ Tries Depth-Anything first, falls back to MiDaS if unavailable.
370
+ """
371
+ try:
372
+ logger.info(f"Attempting to load depth model: {self.DEPTH_MODEL_PRIMARY}")
373
+
374
+ self._depth_processor = AutoImageProcessor.from_pretrained(
375
+ self.DEPTH_MODEL_PRIMARY
376
+ )
377
+ self._depth_estimator = AutoModelForDepthEstimation.from_pretrained(
378
+ self.DEPTH_MODEL_PRIMARY,
379
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
380
+ )
381
+ self._depth_estimator.to(self.device)
382
+ self._depth_estimator.eval()
383
+
384
+ logger.info("Successfully loaded Depth-Anything model")
385
+
386
+ except Exception as e:
387
+ logger.warning(f"Primary depth model failed: {e}, trying fallback...")
388
+
389
+ try:
390
+ self._depth_processor = DPTImageProcessor.from_pretrained(
391
+ self.DEPTH_MODEL_FALLBACK
392
+ )
393
+ self._depth_estimator = DPTForDepthEstimation.from_pretrained(
394
+ self.DEPTH_MODEL_FALLBACK,
395
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
396
+ )
397
+ self._depth_estimator.to(self.device)
398
+ self._depth_estimator.eval()
399
+
400
+ logger.info("Successfully loaded MiDaS fallback model")
401
+
402
+ except Exception as fallback_e:
403
+ logger.error(f"Fallback depth model also failed: {fallback_e}")
404
+ raise RuntimeError("Unable to load any depth estimation model")
405
+
406
+ def _apply_pipeline_optimizations(self) -> None:
407
+ """Apply memory and performance optimizations to the pipeline."""
408
+ if self._inpaint_pipeline is None:
409
+ return
410
+
411
+ # Try xformers first
412
+ try:
413
+ self._inpaint_pipeline.enable_xformers_memory_efficient_attention()
414
+ logger.info("Enabled xformers memory efficient attention")
415
+ except Exception:
416
+ try:
417
+ self._inpaint_pipeline.enable_attention_slicing()
418
+ logger.info("Enabled attention slicing")
419
+ except Exception:
420
+ logger.warning("No attention optimization available")
421
+
422
+ # VAE optimizations
423
+ if self.config.enable_vae_tiling:
424
+ if hasattr(self._inpaint_pipeline, 'enable_vae_tiling'):
425
+ self._inpaint_pipeline.enable_vae_tiling()
426
+ logger.debug("Enabled VAE tiling")
427
+
428
+ if hasattr(self._inpaint_pipeline, 'enable_vae_slicing'):
429
+ self._inpaint_pipeline.enable_vae_slicing()
430
+ logger.debug("Enabled VAE slicing")
431
+
432
+ def _unload_pipeline(self) -> None:
433
+ """Unload the inpainting pipeline and free memory."""
434
+ logger.info("Unloading inpainting pipeline...")
435
+
436
+ if self._inpaint_pipeline is not None:
437
+ del self._inpaint_pipeline
438
+ self._inpaint_pipeline = None
439
+
440
+ if self._controlnet_canny is not None:
441
+ del self._controlnet_canny
442
+ self._controlnet_canny = None
443
+
444
+ if self._controlnet_depth is not None:
445
+ del self._controlnet_depth
446
+ self._controlnet_depth = None
447
+
448
+ if self._depth_estimator is not None:
449
+ del self._depth_estimator
450
+ self._depth_estimator = None
451
+
452
+ if self._depth_processor is not None:
453
+ del self._depth_processor
454
+ self._depth_processor = None
455
+
456
+ self.is_initialized = False
457
+ self._current_conditioning_type = None
458
+ self._cached_latents = None
459
+
460
+ self._memory_cleanup(aggressive=True)
461
+ logger.info("Inpainting pipeline unloaded")
462
+
463
+ def prepare_control_image(
464
+ self,
465
+ image: Image.Image,
466
+ mode: str = "canny"
467
+ ) -> Image.Image:
468
+ """
469
+ Generate ControlNet conditioning image.
470
+
471
+ Parameters
472
+ ----------
473
+ image : PIL.Image
474
+ Input image
475
+ mode : str
476
+ Conditioning mode: "canny" or "depth"
477
+
478
+ Returns
479
+ -------
480
+ PIL.Image
481
+ Generated control image (edges or depth map)
482
+ """
483
+ logger.info(f"Preparing control image with mode: {mode}")
484
+
485
+ # Convert to RGB if needed
486
+ if image.mode != 'RGB':
487
+ image = image.convert('RGB')
488
+
489
+ img_array = np.array(image)
490
+
491
+ if mode == "canny":
492
+ return self._generate_canny_edges(img_array)
493
+ elif mode == "depth":
494
+ return self._generate_depth_map(image)
495
+ else:
496
+ raise ValueError(f"Unknown control mode: {mode}")
497
+
498
+ def _generate_canny_edges(self, img_array: np.ndarray) -> Image.Image:
499
+ """
500
+ Generate Canny edge detection image.
501
+
502
+ Parameters
503
+ ----------
504
+ img_array : np.ndarray
505
+ Input image as RGB numpy array
506
+
507
+ Returns
508
+ -------
509
+ PIL.Image
510
+ Edge detection result as grayscale image
511
+ """
512
+ # Convert to grayscale
513
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
514
+
515
+ # Apply Gaussian blur to reduce noise
516
+ blurred = cv2.GaussianBlur(gray, (5, 5), 1.4)
517
+
518
+ # Canny edge detection
519
+ edges = cv2.Canny(
520
+ blurred,
521
+ self.config.canny_low_threshold,
522
+ self.config.canny_high_threshold
523
+ )
524
+
525
+ # Convert to 3-channel for ControlNet
526
+ edges_3ch = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
527
+
528
+ logger.debug(f"Generated Canny edges with thresholds "
529
+ f"{self.config.canny_low_threshold}/{self.config.canny_high_threshold}")
530
+
531
+ return Image.fromarray(edges_3ch)
532
+
533
+ def _generate_depth_map(self, image: Image.Image) -> Image.Image:
534
+ """
535
+ Generate depth map using depth estimation model.
536
+
537
+ Parameters
538
+ ----------
539
+ image : PIL.Image
540
+ Input RGB image
541
+
542
+ Returns
543
+ -------
544
+ PIL.Image
545
+ Depth map as grayscale image
546
+ """
547
+ if self._depth_estimator is None or self._depth_processor is None:
548
+ raise RuntimeError("Depth estimator not loaded")
549
+
550
+ # Preprocess
551
+ inputs = self._depth_processor(images=image, return_tensors="pt")
552
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
553
+
554
+ # Inference
555
+ with torch.no_grad():
556
+ outputs = self._depth_estimator(**inputs)
557
+ predicted_depth = outputs.predicted_depth
558
+
559
+ # Interpolate to original size
560
+ prediction = torch.nn.functional.interpolate(
561
+ predicted_depth.unsqueeze(1),
562
+ size=image.size[::-1], # (H, W)
563
+ mode="bicubic",
564
+ align_corners=False
565
+ )
566
+
567
+ # Normalize to 0-255
568
+ depth_array = prediction.squeeze().cpu().numpy()
569
+ depth_min = depth_array.min()
570
+ depth_max = depth_array.max()
571
+
572
+ if depth_max - depth_min > 0:
573
+ depth_normalized = ((depth_array - depth_min) / (depth_max - depth_min) * 255)
574
+ else:
575
+ depth_normalized = np.zeros_like(depth_array)
576
+
577
+ depth_normalized = depth_normalized.astype(np.uint8)
578
+
579
+ # Convert to 3-channel for ControlNet
580
+ depth_3ch = cv2.cvtColor(depth_normalized, cv2.COLOR_GRAY2RGB)
581
+
582
+ logger.debug(f"Generated depth map, range: {depth_min:.2f} - {depth_max:.2f}")
583
+
584
+ return Image.fromarray(depth_3ch)
585
+
586
+ def prepare_mask(
587
+ self,
588
+ mask: Image.Image,
589
+ target_size: Tuple[int, int],
590
+ feather_radius: Optional[int] = None
591
+ ) -> Tuple[Image.Image, Dict[str, Any]]:
592
+ """
593
+ Prepare and validate mask for inpainting.
594
+
595
+ Parameters
596
+ ----------
597
+ mask : PIL.Image
598
+ Input mask (white = inpaint area)
599
+ target_size : tuple
600
+ Target (width, height) to match input image
601
+ feather_radius : int, optional
602
+ Feathering radius in pixels. Uses config default if None.
603
+
604
+ Returns
605
+ -------
606
+ tuple
607
+ (processed_mask, validation_info)
608
+
609
+ Raises
610
+ ------
611
+ ValueError
612
+ If mask coverage is outside acceptable range
613
+ """
614
+ feather = feather_radius if feather_radius is not None else self.config.feather_radius
615
+
616
+ # Convert to grayscale
617
+ if mask.mode != 'L':
618
+ mask = mask.convert('L')
619
+
620
+ # Resize to match target
621
+ if mask.size != target_size:
622
+ mask = mask.resize(target_size, Image.LANCZOS)
623
+
624
+ # Convert to array for processing
625
+ mask_array = np.array(mask)
626
+
627
+ # Calculate coverage
628
+ total_pixels = mask_array.size
629
+ white_pixels = np.count_nonzero(mask_array > 127)
630
+ coverage = white_pixels / total_pixels
631
+
632
+ validation_info = {
633
+ "coverage": coverage,
634
+ "white_pixels": white_pixels,
635
+ "total_pixels": total_pixels,
636
+ "feather_radius": feather,
637
+ "valid": True,
638
+ "warning": ""
639
+ }
640
+
641
+ # Validate coverage
642
+ if coverage < self.config.min_mask_coverage:
643
+ validation_info["valid"] = False
644
+ validation_info["warning"] = (
645
+ f"Mask coverage too low ({coverage:.1%}). "
646
+ f"Please select a larger area to inpaint."
647
+ )
648
+ logger.warning(f"Mask coverage {coverage:.1%} below minimum {self.config.min_mask_coverage:.1%}")
649
+
650
+ elif coverage > self.config.max_mask_coverage:
651
+ validation_info["valid"] = False
652
+ validation_info["warning"] = (
653
+ f"Mask coverage too high ({coverage:.1%}). "
654
+ f"Consider using background generation instead."
655
+ )
656
+ logger.warning(f"Mask coverage {coverage:.1%} above maximum {self.config.max_mask_coverage:.1%}")
657
+
658
+ # Apply feathering
659
+ if feather > 0:
660
+ mask_array = cv2.GaussianBlur(
661
+ mask_array,
662
+ (feather * 2 + 1, feather * 2 + 1),
663
+ feather / 2
664
+ )
665
+ logger.debug(f"Applied {feather}px feathering to mask")
666
+
667
+ processed_mask = Image.fromarray(mask_array, mode='L')
668
+
669
+ return processed_mask, validation_info
670
+
671
+ def enhance_prompt_for_inpainting(
672
+ self,
673
+ prompt: str,
674
+ image: Image.Image,
675
+ mask: Image.Image
676
+ ) -> Tuple[str, str]:
677
+ """
678
+ Enhance prompt based on non-masked region analysis.
679
+
680
+ Analyzes the surrounding context to generate appropriate
681
+ lighting and color descriptors.
682
+
683
+ Parameters
684
+ ----------
685
+ prompt : str
686
+ User-provided prompt
687
+ image : PIL.Image
688
+ Original image
689
+ mask : PIL.Image
690
+ Inpainting mask
691
+
692
+ Returns
693
+ -------
694
+ tuple
695
+ (enhanced_prompt, negative_prompt)
696
+ """
697
+ logger.info("Enhancing prompt for inpainting context...")
698
+
699
+ # Convert to arrays
700
+ img_array = np.array(image.convert('RGB'))
701
+ mask_array = np.array(mask.convert('L'))
702
+
703
+ # Analyze non-masked regions
704
+ non_masked = mask_array < 127
705
+
706
+ if not np.any(non_masked):
707
+ # No context available
708
+ enhanced_prompt = f"{prompt}, high quality, detailed, photorealistic"
709
+ negative_prompt = self._get_inpainting_negative_prompt()
710
+ return enhanced_prompt, negative_prompt
711
+
712
+ # Extract context pixels
713
+ context_pixels = img_array[non_masked]
714
+
715
+ # Convert to Lab for analysis
716
+ context_lab = cv2.cvtColor(
717
+ context_pixels.reshape(-1, 1, 3),
718
+ cv2.COLOR_RGB2LAB
719
+ ).reshape(-1, 3)
720
+
721
+ # Use robust statistics (median) to avoid outlier influence
722
+ median_l = np.median(context_lab[:, 0])
723
+ median_a = np.median(context_lab[:, 1])
724
+ median_b = np.median(context_lab[:, 2])
725
+
726
+ # Analyze lighting conditions
727
+ lighting_descriptors = []
728
+
729
+ if median_l > 170:
730
+ lighting_descriptors.append("bright")
731
+ elif median_l > 130:
732
+ lighting_descriptors.append("well-lit")
733
+ elif median_l > 80:
734
+ lighting_descriptors.append("moderate lighting")
735
+ else:
736
+ lighting_descriptors.append("dim lighting")
737
+
738
+ # Analyze color temperature (b channel: blue(-) to yellow(+))
739
+ if median_b > 140:
740
+ lighting_descriptors.append("warm golden tones")
741
+ elif median_b > 120:
742
+ lighting_descriptors.append("warm afternoon light")
743
+ elif median_b < 110:
744
+ lighting_descriptors.append("cool neutral tones")
745
+
746
+ # Calculate saturation from context
747
+ hsv = cv2.cvtColor(context_pixels.reshape(-1, 1, 3), cv2.COLOR_RGB2HSV)
748
+ median_saturation = np.median(hsv[:, :, 1])
749
+
750
+ if median_saturation > 150:
751
+ lighting_descriptors.append("vibrant colors")
752
+ elif median_saturation < 80:
753
+ lighting_descriptors.append("subtle muted colors")
754
+
755
+ # Build enhanced prompt
756
+ lighting_desc = ", ".join(lighting_descriptors) if lighting_descriptors else ""
757
+ quality_suffix = "high quality, detailed, photorealistic, seamless integration"
758
+
759
+ if lighting_desc:
760
+ enhanced_prompt = f"{prompt}, {lighting_desc}, {quality_suffix}"
761
+ else:
762
+ enhanced_prompt = f"{prompt}, {quality_suffix}"
763
+
764
+ negative_prompt = self._get_inpainting_negative_prompt()
765
+
766
+ logger.info(f"Enhanced prompt with context: {lighting_desc}")
767
+
768
+ return enhanced_prompt, negative_prompt
769
+
770
+ def _get_inpainting_negative_prompt(self) -> str:
771
+ """Get standard negative prompt for inpainting."""
772
+ return (
773
+ "inconsistent lighting, wrong perspective, mismatched colors, "
774
+ "visible seams, blending artifacts, color bleeding, "
775
+ "blurry, low quality, distorted, deformed, "
776
+ "harsh edges, unnatural transition"
777
+ )
778
+
779
+ def execute_inpainting(
780
+ self,
781
+ image: Image.Image,
782
+ mask: Image.Image,
783
+ prompt: str,
784
+ preview_only: bool = False,
785
+ seed: Optional[int] = None,
786
+ progress_callback: Optional[Callable[[str, int], None]] = None,
787
+ **kwargs
788
+ ) -> InpaintingResult:
789
+ """
790
+ Execute the inpainting operation.
791
+
792
+ Implements two-stage generation: fast preview followed by
793
+ full quality generation if requested.
794
+
795
+ Parameters
796
+ ----------
797
+ image : PIL.Image
798
+ Original image to inpaint
799
+ mask : PIL.Image
800
+ Inpainting mask (white = area to regenerate)
801
+ prompt : str
802
+ Text description of desired content
803
+ preview_only : bool
804
+ If True, only generate preview (faster)
805
+ seed : int, optional
806
+ Random seed for reproducibility
807
+ progress_callback : callable, optional
808
+ Progress update function(message, percentage)
809
+ **kwargs
810
+ Additional parameters:
811
+ - controlnet_conditioning_scale: float
812
+ - feather_radius: int
813
+ - num_inference_steps: int
814
+ - guidance_scale: float
815
+
816
+ Returns
817
+ -------
818
+ InpaintingResult
819
+ Result container with generated images and metadata
820
+ """
821
+ start_time = time.time()
822
+
823
+ if not self.is_initialized:
824
+ return InpaintingResult(
825
+ success=False,
826
+ error_message="Inpainting pipeline not initialized. Call load_inpainting_pipeline() first."
827
+ )
828
+
829
+ logger.info(f"Starting inpainting: prompt='{prompt[:50]}...', preview_only={preview_only}")
830
+
831
+ try:
832
+ # Update config with kwargs
833
+ conditioning_scale = kwargs.get(
834
+ 'controlnet_conditioning_scale',
835
+ self.config.controlnet_conditioning_scale
836
+ )
837
+ feather_radius = kwargs.get('feather_radius', self.config.feather_radius)
838
+
839
+ if progress_callback:
840
+ progress_callback("Preparing images...", 5)
841
+
842
+ # Prepare image
843
+ if image.mode != 'RGB':
844
+ image = image.convert('RGB')
845
+
846
+ # Ensure dimensions are multiple of 8
847
+ width, height = image.size
848
+ new_width = (width // 8) * 8
849
+ new_height = (height // 8) * 8
850
+
851
+ if new_width != width or new_height != height:
852
+ image = image.resize((new_width, new_height), Image.LANCZOS)
853
+
854
+ # Check and potentially reduce resolution for memory
855
+ max_res = self.config.max_resolution
856
+ if max(new_width, new_height) > max_res:
857
+ scale = max_res / max(new_width, new_height)
858
+ new_width = int(new_width * scale) // 8 * 8
859
+ new_height = int(new_height * scale) // 8 * 8
860
+ image = image.resize((new_width, new_height), Image.LANCZOS)
861
+ logger.info(f"Reduced resolution to {new_width}x{new_height} for memory")
862
+
863
+ # Prepare mask
864
+ if progress_callback:
865
+ progress_callback("Processing mask...", 10)
866
+
867
+ processed_mask, mask_info = self.prepare_mask(
868
+ mask,
869
+ (new_width, new_height),
870
+ feather_radius
871
+ )
872
+
873
+ if not mask_info["valid"]:
874
+ return InpaintingResult(
875
+ success=False,
876
+ error_message=mask_info["warning"]
877
+ )
878
+
879
+ # Generate control image
880
+ if progress_callback:
881
+ progress_callback("Generating control image...", 20)
882
+
883
+ control_image = self.prepare_control_image(
884
+ image,
885
+ self._current_conditioning_type
886
+ )
887
+
888
+ # Enhance prompt
889
+ if progress_callback:
890
+ progress_callback("Enhancing prompt...", 25)
891
+
892
+ enhanced_prompt, negative_prompt = self.enhance_prompt_for_inpainting(
893
+ prompt, image, processed_mask
894
+ )
895
+
896
+ # Setup generator for reproducibility
897
+ if seed is None:
898
+ seed = int(time.time() * 1000) % (2**32)
899
+ self._last_seed = seed
900
+ generator = torch.Generator(device=self.device).manual_seed(seed)
901
+
902
+ # Stage 1: Preview generation
903
+ if progress_callback:
904
+ progress_callback("Generating preview...", 30)
905
+
906
+ preview_result = self._generate_inpaint(
907
+ image=image,
908
+ mask=processed_mask,
909
+ control_image=control_image,
910
+ prompt=enhanced_prompt,
911
+ negative_prompt=negative_prompt,
912
+ num_inference_steps=self.config.preview_steps,
913
+ guidance_scale=self.config.preview_guidance_scale,
914
+ controlnet_conditioning_scale=conditioning_scale,
915
+ generator=generator
916
+ )
917
+
918
+ if preview_only:
919
+ generation_time = time.time() - start_time
920
+
921
+ return InpaintingResult(
922
+ success=True,
923
+ preview_image=preview_result,
924
+ control_image=control_image,
925
+ generation_time=generation_time,
926
+ metadata={
927
+ "seed": seed,
928
+ "prompt": enhanced_prompt,
929
+ "conditioning_type": self._current_conditioning_type,
930
+ "conditioning_scale": conditioning_scale,
931
+ "preview_only": True
932
+ }
933
+ )
934
+
935
+ # Stage 2: Full quality generation
936
+ if progress_callback:
937
+ progress_callback("Generating full quality...", 60)
938
+
939
+ # Use same seed for reproducibility
940
+ generator = torch.Generator(device=self.device).manual_seed(seed)
941
+
942
+ num_steps = kwargs.get('num_inference_steps', self.config.num_inference_steps)
943
+ guidance = kwargs.get('guidance_scale', self.config.guidance_scale)
944
+
945
+ full_result = self._generate_inpaint(
946
+ image=image,
947
+ mask=processed_mask,
948
+ control_image=control_image,
949
+ prompt=enhanced_prompt,
950
+ negative_prompt=negative_prompt,
951
+ num_inference_steps=num_steps,
952
+ guidance_scale=guidance,
953
+ controlnet_conditioning_scale=conditioning_scale,
954
+ generator=generator
955
+ )
956
+
957
+ if progress_callback:
958
+ progress_callback("Blending result...", 90)
959
+
960
+ # Blend result
961
+ blended = self.blend_result(image, full_result, processed_mask)
962
+
963
+ generation_time = time.time() - start_time
964
+
965
+ if progress_callback:
966
+ progress_callback("Complete!", 100)
967
+
968
+ return InpaintingResult(
969
+ success=True,
970
+ result_image=full_result,
971
+ preview_image=preview_result,
972
+ control_image=control_image,
973
+ blended_image=blended,
974
+ generation_time=generation_time,
975
+ metadata={
976
+ "seed": seed,
977
+ "prompt": enhanced_prompt,
978
+ "negative_prompt": negative_prompt,
979
+ "conditioning_type": self._current_conditioning_type,
980
+ "conditioning_scale": conditioning_scale,
981
+ "num_inference_steps": num_steps,
982
+ "guidance_scale": guidance,
983
+ "feather_radius": feather_radius,
984
+ "mask_coverage": mask_info["coverage"],
985
+ "preview_only": False
986
+ }
987
+ )
988
+
989
+ except torch.cuda.OutOfMemoryError:
990
+ logger.error("CUDA out of memory during inpainting")
991
+ self._memory_cleanup(aggressive=True)
992
+ return InpaintingResult(
993
+ success=False,
994
+ error_message="GPU memory exhausted. Try reducing image size or closing other applications."
995
+ )
996
+
997
+ except Exception as e:
998
+ logger.error(f"Inpainting failed: {e}")
999
+ logger.error(traceback.format_exc())
1000
+ return InpaintingResult(
1001
+ success=False,
1002
+ error_message=f"Inpainting failed: {str(e)}"
1003
+ )
1004
+
1005
+ def _generate_inpaint(
1006
+ self,
1007
+ image: Image.Image,
1008
+ mask: Image.Image,
1009
+ control_image: Image.Image,
1010
+ prompt: str,
1011
+ negative_prompt: str,
1012
+ num_inference_steps: int,
1013
+ guidance_scale: float,
1014
+ controlnet_conditioning_scale: float,
1015
+ generator: torch.Generator
1016
+ ) -> Image.Image:
1017
+ """
1018
+ Internal method to run the inpainting pipeline.
1019
+
1020
+ Supports both ControlNet and non-ControlNet pipelines.
1021
+
1022
+ Parameters
1023
+ ----------
1024
+ image : PIL.Image
1025
+ Original image
1026
+ mask : PIL.Image
1027
+ Processed mask
1028
+ control_image : PIL.Image
1029
+ ControlNet conditioning image (ignored if ControlNet not available)
1030
+ prompt : str
1031
+ Enhanced prompt
1032
+ negative_prompt : str
1033
+ Negative prompt
1034
+ num_inference_steps : int
1035
+ Number of denoising steps
1036
+ guidance_scale : float
1037
+ Classifier-free guidance scale
1038
+ controlnet_conditioning_scale : float
1039
+ ControlNet influence strength (ignored if ControlNet not available)
1040
+ generator : torch.Generator
1041
+ Random generator for reproducibility
1042
+
1043
+ Returns
1044
+ -------
1045
+ PIL.Image
1046
+ Generated image
1047
+ """
1048
+ with torch.inference_mode():
1049
+ if self._use_controlnet:
1050
+ # Full ControlNet inpainting pipeline
1051
+ result = self._inpaint_pipeline(
1052
+ prompt=prompt,
1053
+ negative_prompt=negative_prompt,
1054
+ image=image,
1055
+ mask_image=mask,
1056
+ control_image=control_image,
1057
+ num_inference_steps=num_inference_steps,
1058
+ guidance_scale=guidance_scale,
1059
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
1060
+ generator=generator
1061
+ )
1062
+ else:
1063
+ # Fallback: Standard SDXL inpainting without ControlNet
1064
+ result = self._inpaint_pipeline(
1065
+ prompt=prompt,
1066
+ negative_prompt=negative_prompt,
1067
+ image=image,
1068
+ mask_image=mask,
1069
+ num_inference_steps=num_inference_steps,
1070
+ guidance_scale=guidance_scale,
1071
+ generator=generator
1072
+ )
1073
+
1074
+ return result.images[0]
1075
+
1076
+ def blend_result(
1077
+ self,
1078
+ original: Image.Image,
1079
+ generated: Image.Image,
1080
+ mask: Image.Image
1081
+ ) -> Image.Image:
1082
+ """
1083
+ Blend generated content with original image.
1084
+
1085
+ Uses linear color space blending for accurate results.
1086
+
1087
+ Parameters
1088
+ ----------
1089
+ original : PIL.Image
1090
+ Original image
1091
+ generated : PIL.Image
1092
+ Generated inpainted image
1093
+ mask : PIL.Image
1094
+ Blending mask (white = use generated)
1095
+
1096
+ Returns
1097
+ -------
1098
+ PIL.Image
1099
+ Blended result
1100
+ """
1101
+ logger.info("Blending inpainting result...")
1102
+
1103
+ # Ensure same size
1104
+ if generated.size != original.size:
1105
+ generated = generated.resize(original.size, Image.LANCZOS)
1106
+ if mask.size != original.size:
1107
+ mask = mask.resize(original.size, Image.LANCZOS)
1108
+
1109
+ # Convert to arrays
1110
+ orig_array = np.array(original.convert('RGB')).astype(np.float32)
1111
+ gen_array = np.array(generated.convert('RGB')).astype(np.float32)
1112
+ mask_array = np.array(mask.convert('L')).astype(np.float32) / 255.0
1113
+
1114
+ # sRGB to linear conversion
1115
+ def srgb_to_linear(img):
1116
+ img_norm = img / 255.0
1117
+ return np.where(
1118
+ img_norm <= 0.04045,
1119
+ img_norm / 12.92,
1120
+ np.power((img_norm + 0.055) / 1.055, 2.4)
1121
+ )
1122
+
1123
+ def linear_to_srgb(img):
1124
+ img_clipped = np.clip(img, 0, 1)
1125
+ return np.where(
1126
+ img_clipped <= 0.0031308,
1127
+ 12.92 * img_clipped,
1128
+ 1.055 * np.power(img_clipped, 1/2.4) - 0.055
1129
+ )
1130
+
1131
+ # Convert to linear space
1132
+ orig_linear = srgb_to_linear(orig_array)
1133
+ gen_linear = srgb_to_linear(gen_array)
1134
+
1135
+ # Alpha blending in linear space
1136
+ alpha = mask_array[:, :, np.newaxis]
1137
+ result_linear = gen_linear * alpha + orig_linear * (1 - alpha)
1138
+
1139
+ # Convert back to sRGB
1140
+ result_srgb = linear_to_srgb(result_linear)
1141
+ result_array = (result_srgb * 255).astype(np.uint8)
1142
+
1143
+ logger.debug("Blending completed in linear color space")
1144
+
1145
+ return Image.fromarray(result_array)
1146
+
1147
+ def execute_with_auto_optimization(
1148
+ self,
1149
+ image: Image.Image,
1150
+ mask: Image.Image,
1151
+ prompt: str,
1152
+ quality_checker: Any,
1153
+ progress_callback: Optional[Callable[[str, int], None]] = None,
1154
+ **kwargs
1155
+ ) -> InpaintingResult:
1156
+ """
1157
+ Execute inpainting with automatic quality-based optimization.
1158
+
1159
+ Retries with adjusted parameters if quality score is below threshold.
1160
+
1161
+ Parameters
1162
+ ----------
1163
+ image : PIL.Image
1164
+ Original image
1165
+ mask : PIL.Image
1166
+ Inpainting mask
1167
+ prompt : str
1168
+ Text prompt
1169
+ quality_checker : QualityChecker
1170
+ Quality assessment instance
1171
+ progress_callback : callable, optional
1172
+ Progress update function
1173
+ **kwargs
1174
+ Additional inpainting parameters
1175
+
1176
+ Returns
1177
+ -------
1178
+ InpaintingResult
1179
+ Best result achieved (may include retry information)
1180
+ """
1181
+ if not self.config.enable_auto_optimization:
1182
+ return self.execute_inpainting(
1183
+ image, mask, prompt,
1184
+ progress_callback=progress_callback,
1185
+ **kwargs
1186
+ )
1187
+
1188
+ best_result = None
1189
+ best_score = 0.0
1190
+ retry_count = 0
1191
+ prev_score = 0.0
1192
+
1193
+ # Mutable parameters for optimization
1194
+ current_feather = kwargs.get('feather_radius', self.config.feather_radius)
1195
+ current_scale = kwargs.get(
1196
+ 'controlnet_conditioning_scale',
1197
+ self.config.controlnet_conditioning_scale
1198
+ )
1199
+ current_guidance = kwargs.get('guidance_scale', self.config.guidance_scale)
1200
+ current_prompt = prompt
1201
+
1202
+ while retry_count <= self.config.max_optimization_retries:
1203
+ if progress_callback and retry_count > 0:
1204
+ progress_callback(f"Optimizing (attempt {retry_count + 1})...", 5)
1205
+
1206
+ # Execute inpainting
1207
+ result = self.execute_inpainting(
1208
+ image, mask, current_prompt,
1209
+ preview_only=False,
1210
+ feather_radius=current_feather,
1211
+ controlnet_conditioning_scale=current_scale,
1212
+ guidance_scale=current_guidance,
1213
+ progress_callback=progress_callback if retry_count == 0 else None,
1214
+ **{k: v for k, v in kwargs.items()
1215
+ if k not in ['feather_radius', 'controlnet_conditioning_scale',
1216
+ 'guidance_scale']}
1217
+ )
1218
+
1219
+ if not result.success:
1220
+ return result
1221
+
1222
+ # Evaluate quality
1223
+ if result.blended_image is not None:
1224
+ quality_results = quality_checker.run_all_checks(
1225
+ foreground=image,
1226
+ background=result.result_image,
1227
+ mask=mask,
1228
+ combined=result.blended_image
1229
+ )
1230
+ quality_score = quality_results.get("overall_score", 0)
1231
+ else:
1232
+ quality_score = 50.0 # Default if no blended image
1233
+
1234
+ result.quality_score = quality_score
1235
+ result.quality_details = quality_results if result.blended_image else {}
1236
+ result.retries = retry_count
1237
+
1238
+ logger.info(f"Quality score: {quality_score:.1f} (attempt {retry_count + 1})")
1239
+
1240
+ # Track best result
1241
+ if quality_score > best_score:
1242
+ best_score = quality_score
1243
+ best_result = result
1244
+
1245
+ # Check if quality is acceptable
1246
+ if quality_score >= self.config.min_quality_score:
1247
+ logger.info(f"Quality threshold met: {quality_score:.1f}")
1248
+ return best_result
1249
+
1250
+ # Check for minimal improvement (early termination)
1251
+ if retry_count > 0 and abs(quality_score - prev_score) < 5.0:
1252
+ logger.info("Minimal improvement, stopping optimization")
1253
+ return best_result
1254
+
1255
+ prev_score = quality_score
1256
+ retry_count += 1
1257
+
1258
+ if retry_count > self.config.max_optimization_retries:
1259
+ break
1260
+
1261
+ # Adjust parameters based on quality issues
1262
+ checks = quality_results.get("checks", {})
1263
+
1264
+ edge_score = checks.get("edge_continuity", {}).get("score", 100)
1265
+ harmony_score = checks.get("color_harmony", {}).get("score", 100)
1266
+
1267
+ if edge_score < 60:
1268
+ # Edge issues: increase feathering, decrease control strength
1269
+ current_feather = min(20, current_feather + 3)
1270
+ current_scale = max(0.5, current_scale - 0.1)
1271
+ logger.debug(f"Adjusting for edges: feather={current_feather}, scale={current_scale}")
1272
+
1273
+ if harmony_score < 60:
1274
+ # Color harmony issues: emphasize consistency in prompt
1275
+ if "color consistent" not in current_prompt.lower():
1276
+ current_prompt = f"{current_prompt}, color consistent with surroundings, matching lighting"
1277
+ current_guidance = min(12.0, current_guidance + 1.0)
1278
+ logger.debug(f"Adjusting for harmony: guidance={current_guidance}")
1279
+
1280
+ if edge_score < 60 and harmony_score < 60:
1281
+ # Both issues: stronger guidance
1282
+ current_guidance = min(12.0, current_guidance + 1.5)
1283
+
1284
+ logger.info(f"Optimization complete. Best score: {best_score:.1f}")
1285
+ return best_result
1286
+
1287
+ def get_status(self) -> Dict[str, Any]:
1288
+ """
1289
+ Get current module status.
1290
+
1291
+ Returns
1292
+ -------
1293
+ dict
1294
+ Status information including initialization state and memory usage
1295
+ """
1296
+ status = {
1297
+ "initialized": self.is_initialized,
1298
+ "device": self.device,
1299
+ "conditioning_type": self._current_conditioning_type,
1300
+ "last_seed": self._last_seed,
1301
+ "config": {
1302
+ "controlnet_conditioning_scale": self.config.controlnet_conditioning_scale,
1303
+ "feather_radius": self.config.feather_radius,
1304
+ "num_inference_steps": self.config.num_inference_steps,
1305
+ "guidance_scale": self.config.guidance_scale
1306
+ }
1307
+ }
1308
+
1309
+ status["memory"] = self._check_memory_status()
1310
+
1311
+ return status
inpainting_templates.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, List, Optional
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ @dataclass
9
+ class InpaintingTemplate:
10
+ """Data class representing an inpainting template."""
11
+
12
+ key: str
13
+ name: str
14
+ category: str
15
+ icon: str
16
+ description: str
17
+
18
+ # Prompt templates
19
+ prompt_template: str
20
+ negative_prompt: str
21
+
22
+ # Recommended parameters
23
+ controlnet_conditioning_scale: float = 0.7
24
+ feather_radius: int = 8
25
+ guidance_scale: float = 7.5
26
+ num_inference_steps: int = 25
27
+
28
+ # Conditioning type preference
29
+ preferred_conditioning: str = "canny" # "canny" or "depth"
30
+
31
+ # Tips for users
32
+ usage_tips: List[str] = field(default_factory=list)
33
+
34
+
35
+ class InpaintingTemplateManager:
36
+ """
37
+ Manages inpainting templates for various use cases.
38
+
39
+ Provides categorized presets optimized for different inpainting scenarios
40
+ including object replacement, removal, style transfer, and enhancement.
41
+
42
+ Attributes:
43
+ TEMPLATES: Dictionary of all available templates
44
+ CATEGORIES: List of category names in display order
45
+
46
+ Example:
47
+ >>> manager = InpaintingTemplateManager()
48
+ >>> template = manager.get_template("object_replacement")
49
+ >>> print(template.prompt_template)
50
+ """
51
+
52
+ TEMPLATES: Dict[str, InpaintingTemplate] = {
53
+ # Object Replacement Category
54
+ "object_replacement": InpaintingTemplate(
55
+ key="object_replacement",
56
+ name="Object Replacement",
57
+ category="Replacement",
58
+ icon="🔄",
59
+ description="Replace selected objects with new content while preserving context",
60
+ prompt_template="{content}, seamlessly integrated into scene, matching lighting and perspective, realistic placement",
61
+ negative_prompt=(
62
+ "inconsistent lighting, wrong perspective, mismatched colors, "
63
+ "visible seams, floating objects, unrealistic placement, original object, "
64
+ "poorly integrated, disconnected from scene, keeping original"
65
+ ),
66
+ controlnet_conditioning_scale=0.48,
67
+ feather_radius=14,
68
+ guidance_scale=10.0,
69
+ num_inference_steps=38,
70
+ preferred_conditioning="canny",
71
+ usage_tips=[
72
+ "Draw mask PRECISELY around the object to replace (include small margin)",
73
+ "Be specific: Instead of 'plant', use 'large green potted fern with detailed leaves'",
74
+ "📝 Example prompts:",
75
+ " • 'elegant white ceramic teacup with delicate gold rim and floral pattern'",
76
+ " • 'modern silver laptop computer with sleek metallic finish'",
77
+ " • 'vintage wooden desk lamp with warm brass details'"
78
+ ]
79
+ ),
80
+
81
+ "face_swap": InpaintingTemplate(
82
+ key="face_swap",
83
+ name="Face Enhancement",
84
+ category="Replacement",
85
+ icon="👤",
86
+ description="Enhance or modify facial features while maintaining identity cues",
87
+ prompt_template="{content}, natural skin texture, proper facial proportions, realistic lighting, detailed facial features",
88
+ negative_prompt=(
89
+ "deformed face, asymmetric features, unnatural skin, "
90
+ "plastic appearance, wrong eye direction, blurry features, "
91
+ "artificial smoothing, uncanny valley, distorted proportions"
92
+ ),
93
+ controlnet_conditioning_scale=0.88,
94
+ feather_radius=6,
95
+ guidance_scale=8.5,
96
+ num_inference_steps=35,
97
+ preferred_conditioning="canny",
98
+ usage_tips=[
99
+ "Draw mask CAREFULLY around face outline (avoid hair and neck)",
100
+ "High conditioning preserves facial structure - good for subtle enhancements",
101
+ "📝 Example prompts:",
102
+ " • 'warm friendly smile with bright eyes and natural expression'",
103
+ " • 'professional headshot with confident neutral expression'",
104
+ " • 'gentle smile with soft natural lighting on face'"
105
+ ]
106
+ ),
107
+
108
+ "clothing_change": InpaintingTemplate(
109
+ key="clothing_change",
110
+ name="Clothing Change",
111
+ category="Replacement",
112
+ icon="👕",
113
+ description="Change clothing color, pattern, style, and material (moderate changes)",
114
+ prompt_template="transform to {content}, vivid fabric texture, natural folds and wrinkles, correct fit, proper color saturation",
115
+ negative_prompt=(
116
+ "wrong body proportions, floating fabric, unrealistic wrinkles, "
117
+ "mismatched lighting, visible edges, original clothing style, "
118
+ "keeping same color, faded colors, unchanged appearance, partial change"
119
+ ),
120
+ controlnet_conditioning_scale=0.42,
121
+ feather_radius=14,
122
+ guidance_scale=10.5,
123
+ num_inference_steps=38,
124
+ preferred_conditioning="depth",
125
+ usage_tips=[
126
+ "Best for: Similar color changes (black→navy, white→cream) and pattern changes",
127
+ "For extreme changes (black↔white), use 'Dramatic Color Change' template instead",
128
+ "📝 Example prompts:",
129
+ " • 'deep navy blue formal blazer with fine texture and professional fit'",
130
+ " • 'bright red polo shirt with clean collar and soft fabric'",
131
+ " • 'charcoal gray sweater with ribbed knit texture'"
132
+ ]
133
+ ),
134
+
135
+ "dramatic_color_change": InpaintingTemplate(
136
+ key="dramatic_color_change",
137
+ name="Dramatic Color Change",
138
+ category="Replacement",
139
+ icon="🎨",
140
+ description="Extreme color transformations (dark↔light, black↔white)",
141
+ prompt_template="completely transform to {content}, vivid color saturation, high contrast, sharp color definition, clear distinct color",
142
+ negative_prompt=(
143
+ "original color, keeping same shade, partial change, color bleeding, "
144
+ "faded colors, mixed tones, subtle change, gradual transition, "
145
+ "original appearance, unchanged, dark remnants, light patches"
146
+ ),
147
+ controlnet_conditioning_scale=0.38,
148
+ feather_radius=16,
149
+ guidance_scale=12.0,
150
+ num_inference_steps=42,
151
+ preferred_conditioning="depth",
152
+ usage_tips=[
153
+ "Optimized for: black→white, white→black, dark→light extreme transformations",
154
+ "Use VERY vivid descriptors: 'PURE WHITE', 'JET BLACK', 'BRIGHT crimson red'",
155
+ "📝 Example prompts:",
156
+ " • 'pure white dress shirt with crisp texture and sharp collar'",
157
+ " • 'jet black leather jacket with smooth matte finish'",
158
+ " • 'bright crimson red blazer with vivid saturated color'"
159
+ ]
160
+ ),
161
+
162
+ "clothing_addition": InpaintingTemplate(
163
+ key="clothing_addition",
164
+ name="Add Accessories",
165
+ category="Replacement",
166
+ icon="👔",
167
+ description="Add ties, pockets, buttons, or accessories to clothing (Advanced)",
168
+ prompt_template="{content}, clearly visible, highly detailed accessory, seamlessly integrated into clothing, proper placement and perspective",
169
+ negative_prompt=(
170
+ "missing details, incomplete, floating objects, disconnected, "
171
+ "unrealistic placement, wrong perspective, blurry, poorly integrated, "
172
+ "invisible, faint, unclear, hidden, absent, not visible"
173
+ ),
174
+ controlnet_conditioning_scale=0.25,
175
+ feather_radius=12,
176
+ guidance_scale=14.0,
177
+ num_inference_steps=50,
178
+ preferred_conditioning="depth",
179
+ usage_tips=[
180
+ "⚡ Advanced feature: For adding neckties, pocket squares, badges, or buttons",
181
+ "Draw mask from NECK to CHEST (vertical strip) for ties, not just collar area",
182
+ "📝 Example prompts:",
183
+ " • 'burgundy silk necktie with diagonal stripes and Windsor knot, hanging down from collar to chest'",
184
+ " • 'white pocket square with neat fold, visible in breast pocket'",
185
+ " • 'silver lapel pin with detailed engraving on left collar'",
186
+ "⚠️ TIP: For ties, mask should cover the entire length where tie should appear"
187
+ ]
188
+ ),
189
+
190
+ # Object Removal Category
191
+ "object_removal": InpaintingTemplate(
192
+ key="object_removal",
193
+ name="Object Removal",
194
+ category="Removal",
195
+ icon="🗑️",
196
+ description="Remove unwanted objects and fill with matching background",
197
+ prompt_template="clean background, seamless continuation, {content}",
198
+ negative_prompt=(
199
+ "visible patches, color mismatch, texture inconsistency, "
200
+ "ghost artifacts, blur spots, repeated patterns, visible seams"
201
+ ),
202
+ controlnet_conditioning_scale=0.48,
203
+ feather_radius=14,
204
+ guidance_scale=8.0,
205
+ num_inference_steps=30,
206
+ preferred_conditioning="canny",
207
+ usage_tips=[
208
+ "Draw mask slightly BEYOND object edges for better blending",
209
+ "Describe what background SHOULD look like (e.g., 'grass lawn', 'wooden floor')",
210
+ "📝 Example prompts:",
211
+ " • 'clean grass lawn with natural green color'",
212
+ " • 'smooth wooden floor with consistent grain pattern'",
213
+ " • 'plain white wall with even texture'"
214
+ ]
215
+ ),
216
+
217
+ "watermark_removal": InpaintingTemplate(
218
+ key="watermark_removal",
219
+ name="Watermark Removal",
220
+ category="Removal",
221
+ icon="💧",
222
+ description="Remove watermarks and text overlays",
223
+ prompt_template="clean image, no text, seamless background, {content}",
224
+ negative_prompt=(
225
+ "text, watermark, logo, signature, letters, numbers, visible artifacts, "
226
+ "color inconsistency, blur, remnants, ghost text"
227
+ ),
228
+ controlnet_conditioning_scale=0.45,
229
+ feather_radius=12,
230
+ guidance_scale=8.5,
231
+ num_inference_steps=30,
232
+ preferred_conditioning="canny",
233
+ usage_tips=[
234
+ "Draw mask covering ALL watermark/text areas precisely",
235
+ "Describe what SHOULD be there instead (e.g., 'sky', 'fabric texture')",
236
+ "📝 Example prompts:",
237
+ " • 'clean blue sky with smooth gradient'",
238
+ " • 'natural skin texture without marks'",
239
+ " • 'smooth fabric surface with consistent color'"
240
+ ]
241
+ ),
242
+
243
+ "blemish_removal": InpaintingTemplate(
244
+ key="blemish_removal",
245
+ name="Blemish Removal",
246
+ category="Removal",
247
+ icon="✨",
248
+ description="Remove skin blemishes, scratches, or small imperfections",
249
+ prompt_template="clean smooth surface, natural texture, {content}",
250
+ negative_prompt=(
251
+ "artificial smoothing, plastic texture, visible editing, "
252
+ "color patches, unnatural appearance, over-processed"
253
+ ),
254
+ controlnet_conditioning_scale=0.6,
255
+ feather_radius=6,
256
+ guidance_scale=6.5,
257
+ num_inference_steps=20,
258
+ preferred_conditioning="canny",
259
+ usage_tips=[
260
+ "Draw small precise masks for EACH blemish/imperfection",
261
+ "Lower guidance (6.5) preserves natural skin texture",
262
+ "📝 Example prompts:",
263
+ " • 'natural clean skin with smooth texture'",
264
+ " • 'smooth surface without scratches or marks'",
265
+ " • 'clear skin with natural pores and texture'"
266
+ ]
267
+ ),
268
+
269
+ # Style Transfer Category
270
+ "style_artistic": InpaintingTemplate(
271
+ key="style_artistic",
272
+ name="Artistic Style",
273
+ category="Style",
274
+ icon="🎨",
275
+ description="Apply artistic style to selected region",
276
+ prompt_template="{content}, distinctive artistic style, strong painterly effect, creative interpretation, visible brushstrokes",
277
+ negative_prompt=(
278
+ "photorealistic, plain, boring, low contrast, unchanged, "
279
+ "inconsistent style, harsh transitions, original appearance, realistic photo"
280
+ ),
281
+ controlnet_conditioning_scale=0.52,
282
+ feather_radius=12,
283
+ guidance_scale=11.5,
284
+ num_inference_steps=38,
285
+ preferred_conditioning="canny",
286
+ usage_tips=[
287
+ "Works best on larger areas (faces, clothing, backgrounds) for visible transformation",
288
+ "Be VERY specific about art style for best results",
289
+ "📝 Example prompts:",
290
+ " • 'impressionist oil painting with visible thick brushstrokes and vibrant colors'",
291
+ " • 'watercolor painting with soft edges and delicate color washes'",
292
+ " • 'Van Gogh style with swirling brushstrokes and bold color contrasts'"
293
+ ]
294
+ ),
295
+
296
+ "style_vintage": InpaintingTemplate(
297
+ key="style_vintage",
298
+ name="Vintage Look",
299
+ category="Style",
300
+ icon="📻",
301
+ description="Apply vintage or retro aesthetic to selected area",
302
+ prompt_template="{content}, strong vintage aesthetic, warm sepia tones, film grain texture, nostalgic atmosphere",
303
+ negative_prompt=(
304
+ "modern, digital, cold colors, harsh contrast, "
305
+ "oversaturated, neon colors, contemporary look, clean digital, crisp"
306
+ ),
307
+ controlnet_conditioning_scale=0.55,
308
+ feather_radius=14,
309
+ guidance_scale=10.5,
310
+ num_inference_steps=35,
311
+ preferred_conditioning="canny",
312
+ usage_tips=[
313
+ "Works best on medium to large regions for visible aesthetic change",
314
+ "Specify era and style for best results",
315
+ "📝 Example prompts:",
316
+ " • '1920s sepia photograph with faded brown tones and soft grain'",
317
+ " • '1970s vintage photo with warm orange tones and slight film grain'",
318
+ " • '1950s Kodachrome with saturated warm colors and nostalgic feel'"
319
+ ]
320
+ ),
321
+
322
+ "style_anime": InpaintingTemplate(
323
+ key="style_anime",
324
+ name="Anime Style",
325
+ category="Style",
326
+ icon="🎌",
327
+ description="Transform selected region to anime/illustration style",
328
+ prompt_template="{content}, anime illustration style, clean sharp lines, vibrant saturated colors, cel-shaded with flat colors",
329
+ negative_prompt=(
330
+ "photorealistic, blurry lines, muddy colors, realistic photo, "
331
+ "3D render, uncanny valley, western cartoon, gradient shading, photographic"
332
+ ),
333
+ controlnet_conditioning_scale=0.48,
334
+ feather_radius=10,
335
+ guidance_scale=12.5,
336
+ num_inference_steps=40,
337
+ preferred_conditioning="canny",
338
+ usage_tips=[
339
+ "⚡ DRAMATIC transformation - best for portraits and characters",
340
+ "Expect significant stylistic changes from realistic to anime",
341
+ "📝 Example prompts:",
342
+ " • 'modern anime style with large expressive eyes and vibrant colors'",
343
+ " • 'Studio Ghibli style with soft features and warm color palette'",
344
+ " • 'manga style with clean black lines and cel-shaded coloring'"
345
+ ]
346
+ ),
347
+
348
+ # Detail Enhancement Category
349
+ "detail_enhance": InpaintingTemplate(
350
+ key="detail_enhance",
351
+ name="Detail Enhancement",
352
+ category="Enhancement",
353
+ icon="🔍",
354
+ description="Add fine details and textures to selected area",
355
+ prompt_template="{content}, highly detailed, intricate textures, fine details, sharp focus",
356
+ negative_prompt=(
357
+ "blurry, smooth, low detail, soft focus, "
358
+ "oversimplified, lacking texture"
359
+ ),
360
+ controlnet_conditioning_scale=0.85,
361
+ feather_radius=4,
362
+ guidance_scale=8.0,
363
+ num_inference_steps=30,
364
+ preferred_conditioning="depth",
365
+ usage_tips=[
366
+ "High conditioning (0.85) preserves overall structure while adding detail",
367
+ "Best for adding fine details to existing objects",
368
+ "📝 Example prompts:",
369
+ " • 'highly detailed fabric with visible weave and fine threads'",
370
+ " • 'intricate wood grain with natural knots and detailed texture'",
371
+ " • 'sharp facial features with fine skin pores and detail'"
372
+ ]
373
+ ),
374
+
375
+ "texture_add": InpaintingTemplate(
376
+ key="texture_add",
377
+ name="Texture Addition",
378
+ category="Enhancement",
379
+ icon="🧱",
380
+ description="Add or enhance surface textures",
381
+ prompt_template="{content} texture, realistic surface detail, natural material appearance",
382
+ negative_prompt=(
383
+ "flat, smooth, unrealistic, plastic, "
384
+ "wrong material, inconsistent texture"
385
+ ),
386
+ controlnet_conditioning_scale=0.8,
387
+ feather_radius=5,
388
+ guidance_scale=7.5,
389
+ num_inference_steps=25,
390
+ preferred_conditioning="depth",
391
+ usage_tips=[
392
+ "Specify material type clearly for best results",
393
+ "Depth conditioning preserves 3D form while changing texture",
394
+ "📝 Example prompts:",
395
+ " • 'rough wood texture with natural grain and knots'",
396
+ " • 'soft cotton fabric with gentle weave pattern'",
397
+ " • 'smooth marble surface with subtle veining'"
398
+ ]
399
+ ),
400
+
401
+ "lighting_fix": InpaintingTemplate(
402
+ key="lighting_fix",
403
+ name="Lighting Correction",
404
+ category="Enhancement",
405
+ icon="💡",
406
+ description="Correct or enhance lighting in selected area",
407
+ prompt_template="{content}, proper lighting, natural shadows, balanced exposure",
408
+ negative_prompt=(
409
+ "harsh shadows, overexposed, underexposed, "
410
+ "flat lighting, unnatural highlights"
411
+ ),
412
+ controlnet_conditioning_scale=0.65,
413
+ feather_radius=15,
414
+ guidance_scale=7.0,
415
+ num_inference_steps=25,
416
+ preferred_conditioning="depth",
417
+ usage_tips=[
418
+ "Use large feather (15px) for smooth lighting transitions",
419
+ "Best for fixing uneven lighting or adding natural light",
420
+ "📝 Example prompts:",
421
+ " • 'soft natural lighting from window, gentle shadows'",
422
+ " • 'balanced exposure with warm golden hour light'",
423
+ " • 'even studio lighting with soft diffused shadows'"
424
+ ]
425
+ ),
426
+
427
+ # Background Category
428
+ "background_extend": InpaintingTemplate(
429
+ key="background_extend",
430
+ name="Background Extension",
431
+ category="Background",
432
+ icon="📐",
433
+ description="Extend image background seamlessly",
434
+ prompt_template="seamless background extension, {content}, consistent style and lighting",
435
+ negative_prompt=(
436
+ "visible seams, style mismatch, lighting inconsistency, "
437
+ "repeated elements, unnatural continuation, abrupt changes"
438
+ ),
439
+ controlnet_conditioning_scale=0.55,
440
+ feather_radius=20,
441
+ guidance_scale=8.0,
442
+ num_inference_steps=32,
443
+ preferred_conditioning="canny",
444
+ usage_tips=[
445
+ "Draw mask on area to extend (edges of image)",
446
+ "Large feather (20px) ensures smooth blending with existing background",
447
+ "📝 Example prompts:",
448
+ " • 'continue the wooden floor with same grain pattern'",
449
+ " • 'extend blue sky with matching clouds and lighting'",
450
+ " • 'seamless continuation of brick wall texture'"
451
+ ]
452
+ ),
453
+
454
+ "background_replace": InpaintingTemplate(
455
+ key="background_replace",
456
+ name="Background Replacement",
457
+ category="Background",
458
+ icon="🖼️",
459
+ description="Replace background while keeping subject intact",
460
+ prompt_template="{content}, professional background scene, seamless integration with subject, matching lighting and atmosphere",
461
+ negative_prompt=(
462
+ "floating subject, inconsistent lighting, disconnected, "
463
+ "wrong perspective, visible edges, color mismatch, original background, "
464
+ "poor integration, obvious composite"
465
+ ),
466
+ controlnet_conditioning_scale=0.60,
467
+ feather_radius=12,
468
+ guidance_scale=9.5,
469
+ num_inference_steps=35,
470
+ preferred_conditioning="depth",
471
+ usage_tips=[
472
+ "Draw mask around ENTIRE background (leave subject unmasked with small margin)",
473
+ "Include lighting description to match subject for natural results",
474
+ "📝 Example prompts:",
475
+ " • 'professional photography studio with white backdrop and soft lighting'",
476
+ " • 'modern minimalist office with white walls and bright natural lighting'",
477
+ " • 'sunny beach with blue ocean and golden hour lighting'"
478
+ ]
479
+ ),
480
+ }
481
+
482
+ # Category display order
483
+ CATEGORIES = ["Replacement", "Removal", "Style", "Enhancement", "Background"]
484
+
485
+ def __init__(self):
486
+ """Initialize the InpaintingTemplateManager."""
487
+ logger.info(f"InpaintingTemplateManager initialized with {len(self.TEMPLATES)} templates")
488
+
489
+ def get_all_templates(self) -> Dict[str, InpaintingTemplate]:
490
+ """
491
+ Get all available templates.
492
+
493
+ Returns
494
+ -------
495
+ dict
496
+ Dictionary of all templates keyed by template key
497
+ """
498
+ return self.TEMPLATES
499
+
500
+ def get_template(self, key: str) -> Optional[InpaintingTemplate]:
501
+ """
502
+ Get a specific template by key.
503
+
504
+ Parameters
505
+ ----------
506
+ key : str
507
+ Template identifier
508
+
509
+ Returns
510
+ -------
511
+ InpaintingTemplate or None
512
+ Template if found, None otherwise
513
+ """
514
+ return self.TEMPLATES.get(key)
515
+
516
+ def get_templates_by_category(self, category: str) -> List[InpaintingTemplate]:
517
+ """
518
+ Get all templates in a specific category.
519
+
520
+ Parameters
521
+ ----------
522
+ category : str
523
+ Category name
524
+
525
+ Returns
526
+ -------
527
+ list
528
+ List of templates in the category
529
+ """
530
+ return [t for t in self.TEMPLATES.values() if t.category == category]
531
+
532
+ def get_categories(self) -> List[str]:
533
+ """
534
+ Get list of all categories in display order.
535
+
536
+ Returns
537
+ -------
538
+ list
539
+ Category names
540
+ """
541
+ return self.CATEGORIES
542
+
543
+ def get_template_choices_sorted(self) -> List[str]:
544
+ """
545
+ Get template choices formatted for Gradio dropdown.
546
+
547
+ Returns list of display strings sorted by category then A-Z.
548
+ Format: "icon Name"
549
+
550
+ Returns
551
+ -------
552
+ list
553
+ Formatted display strings for dropdown
554
+ """
555
+ display_list = []
556
+
557
+ for category in self.CATEGORIES:
558
+ templates = self.get_templates_by_category(category)
559
+ for template in sorted(templates, key=lambda t: t.name):
560
+ display_name = f"{template.icon} {template.name}"
561
+ display_list.append(display_name)
562
+
563
+ return display_list
564
+
565
+ def get_template_key_from_display(self, display_name: str) -> Optional[str]:
566
+ """
567
+ Get template key from display name.
568
+
569
+ Parameters
570
+ ----------
571
+ display_name : str
572
+ Display string like "🔄 Object Replacement"
573
+
574
+ Returns
575
+ -------
576
+ str or None
577
+ Template key if found
578
+ """
579
+ if not display_name:
580
+ return None
581
+
582
+ for key, template in self.TEMPLATES.items():
583
+ if f"{template.icon} {template.name}" == display_name:
584
+ return key
585
+ return None
586
+
587
+ def get_parameters_for_template(self, key: str) -> Dict[str, any]:
588
+ """
589
+ Get recommended parameters for a template.
590
+
591
+ Parameters
592
+ ----------
593
+ key : str
594
+ Template key
595
+
596
+ Returns
597
+ -------
598
+ dict
599
+ Dictionary of parameter names and values
600
+ """
601
+ template = self.get_template(key)
602
+ if not template:
603
+ return {}
604
+
605
+ return {
606
+ "controlnet_conditioning_scale": template.controlnet_conditioning_scale,
607
+ "feather_radius": template.feather_radius,
608
+ "guidance_scale": template.guidance_scale,
609
+ "num_inference_steps": template.num_inference_steps,
610
+ "preferred_conditioning": template.preferred_conditioning
611
+ }
612
+
613
+ def build_prompt(self, key: str, content: str) -> str:
614
+ """
615
+ Build complete prompt from template and user content.
616
+
617
+ Parameters
618
+ ----------
619
+ key : str
620
+ Template key
621
+ content : str
622
+ User-provided content description
623
+
624
+ Returns
625
+ -------
626
+ str
627
+ Formatted prompt with content inserted
628
+ """
629
+ template = self.get_template(key)
630
+ if not template:
631
+ return content
632
+
633
+ return template.prompt_template.format(content=content)
634
+
635
+ def get_negative_prompt(self, key: str) -> str:
636
+ """
637
+ Get negative prompt for a template.
638
+
639
+ Parameters
640
+ ----------
641
+ key : str
642
+ Template key
643
+
644
+ Returns
645
+ -------
646
+ str
647
+ Negative prompt string
648
+ """
649
+ template = self.get_template(key)
650
+ if not template:
651
+ return ""
652
+ return template.negative_prompt
653
+
654
+ def get_usage_tips(self, key: str) -> List[str]:
655
+ """
656
+ Get usage tips for a template.
657
+
658
+ Parameters
659
+ ----------
660
+ key : str
661
+ Template key
662
+
663
+ Returns
664
+ -------
665
+ list
666
+ List of tip strings
667
+ """
668
+ template = self.get_template(key)
669
+ if not template:
670
+ return []
671
+ return template.usage_tips
672
+
673
+ def build_gallery_html(self) -> str:
674
+ """
675
+ Build HTML for template gallery display.
676
+
677
+ Returns
678
+ -------
679
+ str
680
+ HTML string for Gradio display
681
+ """
682
+ html_parts = ['<div class="inpainting-gallery">']
683
+
684
+ for category in self.CATEGORIES:
685
+ templates = self.get_templates_by_category(category)
686
+ if not templates:
687
+ continue
688
+
689
+ html_parts.append(f'''
690
+ <div class="inpainting-category">
691
+ <h4 class="inpainting-category-title">{category}</h4>
692
+ <div class="inpainting-grid">
693
+ ''')
694
+
695
+ for template in sorted(templates, key=lambda t: t.name):
696
+ html_parts.append(f'''
697
+ <div class="inpainting-card" data-template="{template.key}">
698
+ <span class="inpainting-icon">{template.icon}</span>
699
+ <span class="inpainting-name">{template.name}</span>
700
+ <span class="inpainting-desc">{template.description[:50]}...</span>
701
+ </div>
702
+ ''')
703
+
704
+ html_parts.append('</div></div>')
705
+
706
+ html_parts.append('</div>')
707
+ return ''.join(html_parts)
mask_generator.py CHANGED
@@ -1,5 +1,6 @@
1
  import cv2
2
  import numpy as np
 
3
  from PIL import Image, ImageFilter, ImageDraw
4
  import logging
5
  from typing import Optional, Tuple
@@ -222,7 +223,6 @@ class MaskGenerator:
222
 
223
  except Exception as e:
224
  logger.error(f"❌ BiRefNet mask generation failed: {e}")
225
- import traceback
226
  logger.error(f"📍 Traceback: {traceback.format_exc()}")
227
  return None
228
 
@@ -380,7 +380,6 @@ class MaskGenerator:
380
  return is_cartoon
381
 
382
  except Exception as e:
383
- import traceback
384
  logger.error(f"❌ Cartoon detection failed: {e}")
385
  logger.error(f"📍 Traceback: {traceback.format_exc()}")
386
  print(f"❌ CARTOON DETECTION ERROR: {e}")
@@ -437,7 +436,6 @@ class MaskGenerator:
437
  return enhanced_alpha
438
 
439
  except Exception as e:
440
- import traceback
441
  logger.error(f"❌ Cartoon mask enhancement failed: {e}")
442
  logger.error(f"📍 Traceback: {traceback.format_exc()}")
443
  print(f"❌ CARTOON MASK ENHANCEMENT ERROR: {e}")
 
1
  import cv2
2
  import numpy as np
3
+ import traceback
4
  from PIL import Image, ImageFilter, ImageDraw
5
  import logging
6
  from typing import Optional, Tuple
 
223
 
224
  except Exception as e:
225
  logger.error(f"❌ BiRefNet mask generation failed: {e}")
 
226
  logger.error(f"📍 Traceback: {traceback.format_exc()}")
227
  return None
228
 
 
380
  return is_cartoon
381
 
382
  except Exception as e:
 
383
  logger.error(f"❌ Cartoon detection failed: {e}")
384
  logger.error(f"📍 Traceback: {traceback.format_exc()}")
385
  print(f"❌ CARTOON DETECTION ERROR: {e}")
 
436
  return enhanced_alpha
437
 
438
  except Exception as e:
 
439
  logger.error(f"❌ Cartoon mask enhancement failed: {e}")
440
  logger.error(f"📍 Traceback: {traceback.format_exc()}")
441
  print(f"❌ CARTOON MASK ENHANCEMENT ERROR: {e}")
model_manager.py CHANGED
@@ -1,7 +1,8 @@
1
  import logging
2
  import gc
3
  import time
4
- from typing import Dict, Any, Optional, Callable
 
5
  from dataclasses import dataclass, field
6
  from threading import Lock
7
  import torch
@@ -10,13 +11,41 @@ logger = logging.getLogger(__name__)
10
  logger.setLevel(logging.INFO)
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @dataclass
14
  class ModelInfo:
15
- """Information about a registered model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  name: str
17
  loader: Callable[[], Any]
18
- is_critical: bool = False # Critical models are not unloaded under memory pressure
 
19
  estimated_memory_gb: float = 0.0
 
20
  is_loaded: bool = False
21
  last_used: float = 0.0
22
  model_instance: Any = None
@@ -25,12 +54,34 @@ class ModelInfo:
25
  class ModelManager:
26
  """
27
  Singleton model manager for unified model lifecycle management.
28
- Handles lazy loading, caching, and intelligent memory management.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  """
30
 
31
  _instance = None
32
  _lock = Lock()
33
 
 
 
 
34
  def __new__(cls):
35
  if cls._instance is None:
36
  with cls._lock:
@@ -45,9 +96,11 @@ class ModelManager:
45
 
46
  self._models: Dict[str, ModelInfo] = {}
47
  self._memory_threshold = 0.80 # Trigger cleanup at 80% GPU memory usage
 
48
  self._device = self._detect_device()
 
49
 
50
- logger.info(f"🔧 ModelManager initialized on {self._device}")
51
  self._initialized = True
52
 
53
  def _detect_device(self) -> str:
@@ -63,44 +116,73 @@ class ModelManager:
63
  name: str,
64
  loader: Callable[[], Any],
65
  is_critical: bool = False,
66
- estimated_memory_gb: float = 0.0
 
 
67
  ):
68
  """
69
  Register a model for managed loading.
70
 
71
- Args:
72
- name: Unique model identifier
73
- loader: Callable that returns the loaded model
74
- is_critical: If True, model won't be unloaded under memory pressure
75
- estimated_memory_gb: Estimated GPU memory usage in GB
 
 
 
 
 
 
 
 
 
76
  """
77
  if name in self._models:
78
- logger.warning(f"⚠️ Model '{name}' already registered, updating")
 
 
 
 
79
 
80
  self._models[name] = ModelInfo(
81
  name=name,
82
  loader=loader,
83
  is_critical=is_critical,
 
84
  estimated_memory_gb=estimated_memory_gb,
 
85
  is_loaded=False,
86
  last_used=0.0,
87
  model_instance=None
88
  )
89
- logger.info(f"📝 Registered model: {name} (critical={is_critical}, ~{estimated_memory_gb:.1f}GB)")
90
 
91
- def load_model(self, name: str) -> Any:
92
  """
93
  Load a model by name. Returns cached instance if already loaded.
94
 
95
- Args:
96
- name: Model identifier
97
 
98
- Returns:
 
 
 
 
 
 
 
 
 
99
  Loaded model instance
100
 
101
- Raises:
102
- KeyError: If model not registered
103
- RuntimeError: If loading fails
 
 
 
104
  """
105
  if name not in self._models:
106
  raise KeyError(f"Model '{name}' not registered")
@@ -110,15 +192,21 @@ class ModelManager:
110
  # Return cached instance
111
  if model_info.is_loaded and model_info.model_instance is not None:
112
  model_info.last_used = time.time()
113
- logger.debug(f"📦 Using cached model: {name}")
 
 
114
  return model_info.model_instance
115
 
 
 
 
 
116
  # Check memory pressure before loading
117
  self.check_memory_pressure()
118
 
119
  # Load the model
120
  try:
121
- logger.info(f"📥 Loading model: {name}")
122
  start_time = time.time()
123
 
124
  model_instance = model_info.loader()
@@ -127,32 +215,64 @@ class ModelManager:
127
  model_info.is_loaded = True
128
  model_info.last_used = time.time()
129
 
 
 
 
 
 
 
 
130
  load_time = time.time() - start_time
131
- logger.info(f"Model '{name}' loaded in {load_time:.1f}s")
132
 
133
  return model_instance
134
 
135
  except Exception as e:
136
- logger.error(f"Failed to load model '{name}': {e}")
137
  raise RuntimeError(f"Model loading failed: {e}")
138
 
139
- def unload_model(self, name: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  """
141
  Unload a specific model to free memory.
142
 
143
- Args:
144
- name: Model identifier
 
 
 
 
 
 
 
145
  """
146
  if name not in self._models:
147
- return
148
 
149
  model_info = self._models[name]
150
 
151
  if not model_info.is_loaded:
152
- return
153
 
154
  try:
155
- logger.info(f"🗑️ Unloading model: {name}")
156
 
157
  # Delete model instance
158
  if model_info.model_instance is not None:
@@ -161,21 +281,33 @@ class ModelManager:
161
  model_info.model_instance = None
162
  model_info.is_loaded = False
163
 
 
 
 
 
164
  # Cleanup
165
  gc.collect()
166
  if torch.cuda.is_available():
167
  torch.cuda.empty_cache()
 
168
 
169
- logger.info(f"Model '{name}' unloaded")
 
170
 
171
  except Exception as e:
172
- logger.error(f"Error unloading model '{name}': {e}")
 
173
 
174
  def check_memory_pressure(self) -> bool:
175
  """
176
- Check GPU memory usage and unload least-used non-critical models if needed.
177
 
178
- Returns:
 
 
 
 
 
179
  True if cleanup was performed
180
  """
181
  if not torch.cuda.is_available():
@@ -188,18 +320,19 @@ class ModelManager:
188
  if usage_ratio < self._memory_threshold:
189
  return False
190
 
191
- logger.warning(f"⚠️ Memory pressure detected: {usage_ratio:.1%} used")
192
 
193
- # Find non-critical models sorted by last used time
194
- unloadable = [
 
195
  (name, info) for name, info in self._models.items()
196
- if info.is_loaded and not info.is_critical
197
  ]
198
- unloadable.sort(key=lambda x: x[1].last_used)
199
 
200
- # Unload oldest non-critical models
201
  cleaned = False
202
- for name, info in unloadable:
203
  self.unload_model(name)
204
  cleaned = True
205
 
@@ -210,13 +343,21 @@ class ModelManager:
210
 
211
  return cleaned
212
 
213
- def force_cleanup(self):
214
- """Force cleanup all non-critical models and clear caches."""
215
- logger.info("🧹 Force cleanup initiated")
 
 
 
 
 
 
 
216
 
217
- # Unload all non-critical models
218
- for name, info in self._models.items():
219
- if info.is_loaded and not info.is_critical:
 
220
  self.unload_model(name)
221
 
222
  # Aggressive garbage collection
@@ -228,7 +369,81 @@ class ModelManager:
228
  torch.cuda.ipc_collect()
229
  torch.cuda.synchronize()
230
 
231
- logger.info("Force cleanup completed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  def get_memory_status(self) -> Dict[str, Any]:
234
  """
 
1
  import logging
2
  import gc
3
  import time
4
+ from enum import IntEnum
5
+ from typing import Dict, Any, Optional, Callable, List
6
  from dataclasses import dataclass, field
7
  from threading import Lock
8
  import torch
 
11
  logger.setLevel(logging.INFO)
12
 
13
 
14
+ class ModelPriority(IntEnum):
15
+ """
16
+ Model priority levels for memory management.
17
+
18
+ Higher priority models are kept loaded longer under memory pressure.
19
+ """
20
+ CRITICAL = 100 # Never unload (e.g., OpenCLIP for analysis)
21
+ HIGH = 80 # Currently active pipeline
22
+ MEDIUM = 50 # Recently used models
23
+ LOW = 20 # Inactive pipelines, can be evicted
24
+ DISPOSABLE = 0 # Temporary models, evict first
25
+
26
+
27
  @dataclass
28
  class ModelInfo:
29
+ """
30
+ Information about a registered model.
31
+
32
+ Attributes:
33
+ name: Unique model identifier
34
+ loader: Callable that returns the loaded model
35
+ is_critical: If True, model won't be unloaded under memory pressure
36
+ priority: ModelPriority level for eviction decisions
37
+ estimated_memory_gb: Estimated GPU memory usage
38
+ model_group: Group name for mutual exclusion (e.g., "pipeline")
39
+ is_loaded: Whether model is currently loaded
40
+ last_used: Timestamp of last use
41
+ model_instance: The actual model object
42
+ """
43
  name: str
44
  loader: Callable[[], Any]
45
+ is_critical: bool = False
46
+ priority: int = ModelPriority.MEDIUM
47
  estimated_memory_gb: float = 0.0
48
+ model_group: str = "" # For mutual exclusion (e.g., "pipeline")
49
  is_loaded: bool = False
50
  last_used: float = 0.0
51
  model_instance: Any = None
 
54
  class ModelManager:
55
  """
56
  Singleton model manager for unified model lifecycle management.
57
+
58
+ Handles lazy loading, caching, priority-based eviction, and mutual
59
+ exclusion for pipeline models. Designed for memory-constrained
60
+ environments like Google Colab and HuggingFace Spaces.
61
+
62
+ Features:
63
+ - Priority-based model eviction under memory pressure
64
+ - Mutual exclusion for pipeline models (only one active at a time)
65
+ - Automatic memory monitoring and cleanup
66
+ - Support for model groups and dependencies
67
+
68
+ Example:
69
+ >>> manager = get_model_manager()
70
+ >>> manager.register_model(
71
+ ... name="sdxl_pipeline",
72
+ ... loader=load_sdxl,
73
+ ... priority=ModelPriority.HIGH,
74
+ ... model_group="pipeline"
75
+ ... )
76
+ >>> pipeline = manager.load_model("sdxl_pipeline")
77
  """
78
 
79
  _instance = None
80
  _lock = Lock()
81
 
82
+ # Known model groups for mutual exclusion
83
+ PIPELINE_GROUP = "pipeline" # Only one pipeline can be loaded at a time
84
+
85
  def __new__(cls):
86
  if cls._instance is None:
87
  with cls._lock:
 
96
 
97
  self._models: Dict[str, ModelInfo] = {}
98
  self._memory_threshold = 0.80 # Trigger cleanup at 80% GPU memory usage
99
+ self._high_memory_threshold = 0.90 # Critical threshold for aggressive cleanup
100
  self._device = self._detect_device()
101
+ self._active_pipeline: Optional[str] = None # Track currently active pipeline
102
 
103
+ logger.info(f"ModelManager initialized on {self._device}")
104
  self._initialized = True
105
 
106
  def _detect_device(self) -> str:
 
116
  name: str,
117
  loader: Callable[[], Any],
118
  is_critical: bool = False,
119
+ priority: int = ModelPriority.MEDIUM,
120
+ estimated_memory_gb: float = 0.0,
121
+ model_group: str = ""
122
  ):
123
  """
124
  Register a model for managed loading.
125
 
126
+ Parameters
127
+ ----------
128
+ name : str
129
+ Unique model identifier
130
+ loader : callable
131
+ Function that returns the loaded model
132
+ is_critical : bool
133
+ If True, model won't be unloaded under memory pressure
134
+ priority : int
135
+ ModelPriority level for eviction decisions
136
+ estimated_memory_gb : float
137
+ Estimated GPU memory usage in GB
138
+ model_group : str
139
+ Group name for mutual exclusion (e.g., "pipeline")
140
  """
141
  if name in self._models:
142
+ logger.warning(f"Model '{name}' already registered, updating")
143
+
144
+ # Critical models always have highest priority
145
+ if is_critical:
146
+ priority = ModelPriority.CRITICAL
147
 
148
  self._models[name] = ModelInfo(
149
  name=name,
150
  loader=loader,
151
  is_critical=is_critical,
152
+ priority=priority,
153
  estimated_memory_gb=estimated_memory_gb,
154
+ model_group=model_group,
155
  is_loaded=False,
156
  last_used=0.0,
157
  model_instance=None
158
  )
159
+ logger.info(f"Registered model: {name} (priority={priority}, group={model_group}, ~{estimated_memory_gb:.1f}GB)")
160
 
161
+ def load_model(self, name: str, update_priority: Optional[int] = None) -> Any:
162
  """
163
  Load a model by name. Returns cached instance if already loaded.
164
 
165
+ Implements mutual exclusion for pipeline models - loading a new
166
+ pipeline will unload any existing pipeline first.
167
 
168
+ Parameters
169
+ ----------
170
+ name : str
171
+ Model identifier
172
+ update_priority : int, optional
173
+ If provided, update the model's priority after loading
174
+
175
+ Returns
176
+ -------
177
+ Any
178
  Loaded model instance
179
 
180
+ Raises
181
+ ------
182
+ KeyError
183
+ If model not registered
184
+ RuntimeError
185
+ If loading fails
186
  """
187
  if name not in self._models:
188
  raise KeyError(f"Model '{name}' not registered")
 
192
  # Return cached instance
193
  if model_info.is_loaded and model_info.model_instance is not None:
194
  model_info.last_used = time.time()
195
+ if update_priority is not None:
196
+ model_info.priority = update_priority
197
+ logger.debug(f"Using cached model: {name}")
198
  return model_info.model_instance
199
 
200
+ # Handle mutual exclusion for pipeline group
201
+ if model_info.model_group == self.PIPELINE_GROUP:
202
+ self._ensure_pipeline_exclusion(name)
203
+
204
  # Check memory pressure before loading
205
  self.check_memory_pressure()
206
 
207
  # Load the model
208
  try:
209
+ logger.info(f"Loading model: {name}")
210
  start_time = time.time()
211
 
212
  model_instance = model_info.loader()
 
215
  model_info.is_loaded = True
216
  model_info.last_used = time.time()
217
 
218
+ if update_priority is not None:
219
+ model_info.priority = update_priority
220
+
221
+ # Track active pipeline
222
+ if model_info.model_group == self.PIPELINE_GROUP:
223
+ self._active_pipeline = name
224
+
225
  load_time = time.time() - start_time
226
+ logger.info(f"Model '{name}' loaded in {load_time:.1f}s")
227
 
228
  return model_instance
229
 
230
  except Exception as e:
231
+ logger.error(f"Failed to load model '{name}': {e}")
232
  raise RuntimeError(f"Model loading failed: {e}")
233
 
234
+ def _ensure_pipeline_exclusion(self, new_pipeline: str) -> None:
235
+ """
236
+ Ensure only one pipeline is loaded at a time.
237
+
238
+ Unloads any existing pipeline before loading a new one.
239
+
240
+ Parameters
241
+ ----------
242
+ new_pipeline : str
243
+ Name of the pipeline about to be loaded
244
+ """
245
+ for name, info in self._models.items():
246
+ if (info.model_group == self.PIPELINE_GROUP and
247
+ info.is_loaded and
248
+ name != new_pipeline):
249
+ logger.info(f"Unloading {name} to make room for {new_pipeline}")
250
+ self.unload_model(name)
251
+
252
+ def unload_model(self, name: str) -> bool:
253
  """
254
  Unload a specific model to free memory.
255
 
256
+ Parameters
257
+ ----------
258
+ name : str
259
+ Model identifier
260
+
261
+ Returns
262
+ -------
263
+ bool
264
+ True if model was unloaded successfully
265
  """
266
  if name not in self._models:
267
+ return False
268
 
269
  model_info = self._models[name]
270
 
271
  if not model_info.is_loaded:
272
+ return True
273
 
274
  try:
275
+ logger.info(f"Unloading model: {name}")
276
 
277
  # Delete model instance
278
  if model_info.model_instance is not None:
 
281
  model_info.model_instance = None
282
  model_info.is_loaded = False
283
 
284
+ # Update active pipeline tracking
285
+ if self._active_pipeline == name:
286
+ self._active_pipeline = None
287
+
288
  # Cleanup
289
  gc.collect()
290
  if torch.cuda.is_available():
291
  torch.cuda.empty_cache()
292
+ torch.cuda.ipc_collect()
293
 
294
+ logger.info(f"Model '{name}' unloaded")
295
+ return True
296
 
297
  except Exception as e:
298
+ logger.error(f"Error unloading model '{name}': {e}")
299
+ return False
300
 
301
  def check_memory_pressure(self) -> bool:
302
  """
303
+ Check GPU memory usage and unload low-priority models if needed.
304
 
305
+ Uses priority-based eviction: lower priority models are unloaded first,
306
+ then falls back to least-recently-used within same priority tier.
307
+
308
+ Returns
309
+ -------
310
+ bool
311
  True if cleanup was performed
312
  """
313
  if not torch.cuda.is_available():
 
320
  if usage_ratio < self._memory_threshold:
321
  return False
322
 
323
+ logger.warning(f"Memory pressure detected: {usage_ratio:.1%} used")
324
 
325
+ # Find evictable models (not critical, loaded)
326
+ # Sort by priority (ascending) then by last_used (ascending)
327
+ evictable = [
328
  (name, info) for name, info in self._models.items()
329
+ if info.is_loaded and info.priority < ModelPriority.CRITICAL
330
  ]
331
+ evictable.sort(key=lambda x: (x[1].priority, x[1].last_used))
332
 
333
+ # Unload models starting from lowest priority
334
  cleaned = False
335
+ for name, info in evictable:
336
  self.unload_model(name)
337
  cleaned = True
338
 
 
343
 
344
  return cleaned
345
 
346
+ def force_cleanup(self, keep_critical_only: bool = True):
347
+ """
348
+ Force cleanup models and clear caches.
349
+
350
+ Parameters
351
+ ----------
352
+ keep_critical_only : bool
353
+ If True, only keep CRITICAL priority models loaded
354
+ """
355
+ logger.info("Force cleanup initiated")
356
 
357
+ # Unload models based on priority
358
+ threshold = ModelPriority.CRITICAL if keep_critical_only else ModelPriority.HIGH
359
+ for name, info in list(self._models.items()):
360
+ if info.is_loaded and info.priority < threshold:
361
  self.unload_model(name)
362
 
363
  # Aggressive garbage collection
 
369
  torch.cuda.ipc_collect()
370
  torch.cuda.synchronize()
371
 
372
+ logger.info("Force cleanup completed")
373
+
374
+ def update_priority(self, name: str, priority: int) -> bool:
375
+ """
376
+ Update a model's priority level.
377
+
378
+ Parameters
379
+ ----------
380
+ name : str
381
+ Model identifier
382
+ priority : int
383
+ New priority level
384
+
385
+ Returns
386
+ -------
387
+ bool
388
+ True if priority was updated
389
+ """
390
+ if name not in self._models:
391
+ return False
392
+
393
+ self._models[name].priority = priority
394
+ logger.debug(f"Updated priority for {name} to {priority}")
395
+ return True
396
+
397
+ def get_active_pipeline(self) -> Optional[str]:
398
+ """
399
+ Get the name of currently active pipeline.
400
+
401
+ Returns
402
+ -------
403
+ str or None
404
+ Name of active pipeline, or None if no pipeline is loaded
405
+ """
406
+ return self._active_pipeline
407
+
408
+ def switch_to_pipeline(
409
+ self,
410
+ name: str,
411
+ loader: Optional[Callable[[], Any]] = None
412
+ ) -> Any:
413
+ """
414
+ Switch to a different pipeline, unloading current one.
415
+
416
+ This is a convenience method for pipeline switching that handles
417
+ mutual exclusion automatically.
418
+
419
+ Parameters
420
+ ----------
421
+ name : str
422
+ Pipeline name to switch to
423
+ loader : callable, optional
424
+ Loader function if pipeline not already registered
425
+
426
+ Returns
427
+ -------
428
+ Any
429
+ The loaded pipeline instance
430
+
431
+ Raises
432
+ ------
433
+ KeyError
434
+ If pipeline not registered and no loader provided
435
+ """
436
+ # Register if needed
437
+ if name not in self._models and loader is not None:
438
+ self.register_model(
439
+ name=name,
440
+ loader=loader,
441
+ priority=ModelPriority.HIGH,
442
+ model_group=self.PIPELINE_GROUP
443
+ )
444
+
445
+ # Load will handle unloading of current pipeline
446
+ return self.load_model(name, update_priority=ModelPriority.HIGH)
447
 
448
  def get_memory_status(self) -> Dict[str, Any]:
449
  """
quality_checker.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
  import numpy as np
3
  import cv2
4
  from PIL import Image
5
- from typing import Dict, Any, Tuple, Optional
6
  from dataclasses import dataclass
7
 
8
  logger = logging.getLogger(__name__)
@@ -407,3 +407,452 @@ class QualityChecker:
407
  summary += f"\nNotes: {'; '.join(results['warnings'])}"
408
 
409
  return summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import cv2
4
  from PIL import Image
5
+ from typing import Dict, Any, Tuple, Optional, List
6
  from dataclasses import dataclass
7
 
8
  logger = logging.getLogger(__name__)
 
407
  summary += f"\nNotes: {'; '.join(results['warnings'])}"
408
 
409
  return summary
410
+
411
+ # =========================================================================
412
+ # INPAINTING-SPECIFIC QUALITY CHECKS
413
+ # =========================================================================
414
+
415
+ def check_inpainting_edge_continuity(
416
+ self,
417
+ original: Image.Image,
418
+ inpainted: Image.Image,
419
+ mask: Image.Image,
420
+ ring_width: int = 5
421
+ ) -> QualityResult:
422
+ """
423
+ Check edge continuity at inpainting boundary.
424
+
425
+ Calculates color distribution similarity between the ring zones
426
+ on each side of the mask boundary in Lab color space.
427
+
428
+ Parameters
429
+ ----------
430
+ original : PIL.Image
431
+ Original image before inpainting
432
+ inpainted : PIL.Image
433
+ Result after inpainting
434
+ mask : PIL.Image
435
+ Inpainting mask (white = inpainted area)
436
+ ring_width : int
437
+ Width in pixels for the ring zones on each side
438
+
439
+ Returns
440
+ -------
441
+ QualityResult
442
+ Edge continuity assessment
443
+ """
444
+ try:
445
+ # Convert to arrays
446
+ orig_array = np.array(original.convert('RGB'))
447
+ inpaint_array = np.array(inpainted.convert('RGB'))
448
+ mask_array = np.array(mask.convert('L'))
449
+
450
+ # Find boundary using morphological gradient
451
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
452
+ dilated = cv2.dilate(mask_array, kernel, iterations=ring_width)
453
+ eroded = cv2.erode(mask_array, kernel, iterations=ring_width)
454
+
455
+ # Inner ring (inside inpainted region, near boundary)
456
+ inner_ring = (mask_array > 127) & (eroded <= 127)
457
+
458
+ # Outer ring (outside inpainted region, near boundary)
459
+ outer_ring = (mask_array <= 127) & (dilated > 127)
460
+
461
+ if not np.any(inner_ring) or not np.any(outer_ring):
462
+ return QualityResult(
463
+ score=50,
464
+ passed=True,
465
+ issue="Unable to detect boundary rings",
466
+ details={"ring_width": ring_width}
467
+ )
468
+
469
+ # Convert to Lab for perceptual comparison
470
+ inpaint_lab = cv2.cvtColor(inpaint_array, cv2.COLOR_RGB2LAB).astype(np.float32)
471
+
472
+ # Get Lab values for each ring from the inpainted image
473
+ inner_lab = inpaint_lab[inner_ring]
474
+ outer_lab = inpaint_lab[outer_ring]
475
+
476
+ # Calculate statistics for each channel
477
+ inner_mean = np.mean(inner_lab, axis=0)
478
+ outer_mean = np.mean(outer_lab, axis=0)
479
+ inner_std = np.std(inner_lab, axis=0)
480
+ outer_std = np.std(outer_lab, axis=0)
481
+
482
+ # Calculate differences
483
+ mean_diff = np.abs(inner_mean - outer_mean)
484
+ std_diff = np.abs(inner_std - outer_std)
485
+
486
+ # Calculate Delta E (simplified)
487
+ delta_e = np.sqrt(np.sum(mean_diff ** 2))
488
+
489
+ # Score calculation
490
+ # Low Delta E = good continuity
491
+ # Target: Delta E < 10 is excellent, < 20 is good
492
+ if delta_e < 5:
493
+ continuity_score = 100
494
+ elif delta_e < 10:
495
+ continuity_score = 90
496
+ elif delta_e < 20:
497
+ continuity_score = 75
498
+ elif delta_e < 30:
499
+ continuity_score = 60
500
+ elif delta_e < 50:
501
+ continuity_score = 40
502
+ else:
503
+ continuity_score = max(20, 100 - delta_e)
504
+
505
+ # Penalize for large std differences (inconsistent textures)
506
+ std_penalty = min(20, np.mean(std_diff) * 0.5)
507
+ final_score = max(0, continuity_score - std_penalty)
508
+
509
+ passed = final_score >= 60
510
+ issue = ""
511
+
512
+ if final_score < 60:
513
+ if delta_e > 30:
514
+ issue = f"Visible color discontinuity at boundary (Delta E: {delta_e:.1f})"
515
+ elif np.mean(std_diff) > 20:
516
+ issue = "Texture mismatch at boundary"
517
+ else:
518
+ issue = "Poor edge blending"
519
+
520
+ return QualityResult(
521
+ score=final_score,
522
+ passed=passed,
523
+ issue=issue,
524
+ details={
525
+ "delta_e": delta_e,
526
+ "mean_diff_l": mean_diff[0],
527
+ "mean_diff_a": mean_diff[1],
528
+ "mean_diff_b": mean_diff[2],
529
+ "std_diff_avg": np.mean(std_diff),
530
+ "inner_pixels": np.count_nonzero(inner_ring),
531
+ "outer_pixels": np.count_nonzero(outer_ring)
532
+ }
533
+ )
534
+
535
+ except Exception as e:
536
+ logger.error(f"Inpainting edge continuity check failed: {e}")
537
+ return QualityResult(score=0, passed=False, issue=str(e), details={})
538
+
539
+ def check_inpainting_color_harmony(
540
+ self,
541
+ original: Image.Image,
542
+ inpainted: Image.Image,
543
+ mask: Image.Image
544
+ ) -> QualityResult:
545
+ """
546
+ Check color harmony between inpainted region and surrounding area.
547
+
548
+ Compares color statistics of the inpainted region with adjacent
549
+ non-inpainted regions to assess visual coherence.
550
+
551
+ Parameters
552
+ ----------
553
+ original : PIL.Image
554
+ Original image
555
+ inpainted : PIL.Image
556
+ Inpainted result
557
+ mask : PIL.Image
558
+ Inpainting mask
559
+
560
+ Returns
561
+ -------
562
+ QualityResult
563
+ Color harmony assessment
564
+ """
565
+ try:
566
+ inpaint_array = np.array(inpainted.convert('RGB'))
567
+ mask_array = np.array(mask.convert('L'))
568
+
569
+ # Define regions
570
+ inpaint_region = mask_array > 127
571
+
572
+ # Get adjacent region (dilated mask minus original mask)
573
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
574
+ dilated = cv2.dilate(mask_array, kernel, iterations=2)
575
+ adjacent_region = (dilated > 127) & (mask_array <= 127)
576
+
577
+ if not np.any(inpaint_region) or not np.any(adjacent_region):
578
+ return QualityResult(
579
+ score=50,
580
+ passed=True,
581
+ issue="Insufficient regions for comparison",
582
+ details={}
583
+ )
584
+
585
+ # Convert to Lab
586
+ inpaint_lab = cv2.cvtColor(inpaint_array, cv2.COLOR_RGB2LAB).astype(np.float32)
587
+
588
+ # Extract region colors
589
+ inpaint_colors = inpaint_lab[inpaint_region]
590
+ adjacent_colors = inpaint_lab[adjacent_region]
591
+
592
+ # Calculate color statistics
593
+ inpaint_mean = np.mean(inpaint_colors, axis=0)
594
+ adjacent_mean = np.mean(adjacent_colors, axis=0)
595
+
596
+ inpaint_std = np.std(inpaint_colors, axis=0)
597
+ adjacent_std = np.std(adjacent_colors, axis=0)
598
+
599
+ # Color histogram comparison
600
+ hist_scores = []
601
+ for i in range(3): # L, a, b channels
602
+ hist_inpaint, _ = np.histogram(
603
+ inpaint_colors[:, i], bins=32, range=(0, 255)
604
+ )
605
+ hist_adjacent, _ = np.histogram(
606
+ adjacent_colors[:, i], bins=32, range=(0, 255)
607
+ )
608
+
609
+ # Normalize
610
+ hist_inpaint = hist_inpaint.astype(np.float32) / (np.sum(hist_inpaint) + 1e-6)
611
+ hist_adjacent = hist_adjacent.astype(np.float32) / (np.sum(hist_adjacent) + 1e-6)
612
+
613
+ # Bhattacharyya coefficient (1 = identical, 0 = completely different)
614
+ bc = np.sum(np.sqrt(hist_inpaint * hist_adjacent))
615
+ hist_scores.append(bc)
616
+
617
+ avg_hist_score = np.mean(hist_scores)
618
+
619
+ # Calculate harmony score
620
+ mean_diff = np.linalg.norm(inpaint_mean - adjacent_mean)
621
+
622
+ if mean_diff < 10 and avg_hist_score > 0.8:
623
+ harmony_score = 100
624
+ elif mean_diff < 20 and avg_hist_score > 0.7:
625
+ harmony_score = 85
626
+ elif mean_diff < 30 and avg_hist_score > 0.6:
627
+ harmony_score = 70
628
+ elif mean_diff < 50:
629
+ harmony_score = 55
630
+ else:
631
+ harmony_score = max(30, 100 - mean_diff)
632
+
633
+ # Boost score if histogram similarity is high
634
+ histogram_bonus = (avg_hist_score - 0.5) * 20 # -10 to +10
635
+ final_score = max(0, min(100, harmony_score + histogram_bonus))
636
+
637
+ passed = final_score >= 60
638
+ issue = ""
639
+
640
+ if final_score < 60:
641
+ if mean_diff > 40:
642
+ issue = "Significant color mismatch with surrounding area"
643
+ elif avg_hist_score < 0.5:
644
+ issue = "Color distribution differs from context"
645
+ else:
646
+ issue = "Poor color integration"
647
+
648
+ return QualityResult(
649
+ score=final_score,
650
+ passed=passed,
651
+ issue=issue,
652
+ details={
653
+ "mean_color_diff": mean_diff,
654
+ "histogram_similarity": avg_hist_score,
655
+ "inpaint_luminance": inpaint_mean[0],
656
+ "adjacent_luminance": adjacent_mean[0]
657
+ }
658
+ )
659
+
660
+ except Exception as e:
661
+ logger.error(f"Inpainting color harmony check failed: {e}")
662
+ return QualityResult(score=0, passed=False, issue=str(e), details={})
663
+
664
+ def check_inpainting_artifact_detection(
665
+ self,
666
+ inpainted: Image.Image,
667
+ mask: Image.Image
668
+ ) -> QualityResult:
669
+ """
670
+ Detect common inpainting artifacts like blurriness or color bleeding.
671
+
672
+ Parameters
673
+ ----------
674
+ inpainted : PIL.Image
675
+ Inpainted result
676
+ mask : PIL.Image
677
+ Inpainting mask
678
+
679
+ Returns
680
+ -------
681
+ QualityResult
682
+ Artifact detection results
683
+ """
684
+ try:
685
+ inpaint_array = np.array(inpainted.convert('RGB'))
686
+ mask_array = np.array(mask.convert('L'))
687
+
688
+ inpaint_region = mask_array > 127
689
+
690
+ if not np.any(inpaint_region):
691
+ return QualityResult(
692
+ score=50,
693
+ passed=True,
694
+ issue="No inpainted region detected",
695
+ details={}
696
+ )
697
+
698
+ # Extract inpainted region pixels
699
+ gray = cv2.cvtColor(inpaint_array, cv2.COLOR_RGB2GRAY)
700
+
701
+ # Calculate sharpness (Laplacian variance)
702
+ laplacian = cv2.Laplacian(gray, cv2.CV_64F)
703
+ inpaint_laplacian = laplacian[inpaint_region]
704
+ sharpness = np.var(inpaint_laplacian)
705
+
706
+ # Get surrounding region for comparison
707
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
708
+ dilated = cv2.dilate(mask_array, kernel, iterations=1)
709
+ surrounding = (dilated > 127) & (mask_array <= 127)
710
+
711
+ if np.any(surrounding):
712
+ surrounding_laplacian = laplacian[surrounding]
713
+ surrounding_sharpness = np.var(surrounding_laplacian)
714
+ sharpness_ratio = sharpness / (surrounding_sharpness + 1e-6)
715
+ else:
716
+ sharpness_ratio = 1.0
717
+
718
+ # Check for color bleeding (abnormal saturation at edges)
719
+ hsv = cv2.cvtColor(inpaint_array, cv2.COLOR_RGB2HSV)
720
+ saturation = hsv[:, :, 1]
721
+
722
+ # Find boundary pixels
723
+ boundary_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
724
+ boundary = cv2.morphologyEx(mask_array, cv2.MORPH_GRADIENT, boundary_kernel) > 0
725
+
726
+ if np.any(boundary):
727
+ boundary_saturation = saturation[boundary]
728
+ saturation_std = np.std(boundary_saturation)
729
+ else:
730
+ saturation_std = 0
731
+
732
+ # Calculate score
733
+ sharpness_score = 100
734
+ if sharpness_ratio < 0.3:
735
+ sharpness_score = 40 # Much blurrier than surroundings
736
+ elif sharpness_ratio < 0.6:
737
+ sharpness_score = 60
738
+ elif sharpness_ratio < 0.8:
739
+ sharpness_score = 80
740
+
741
+ bleeding_penalty = min(20, saturation_std * 0.5)
742
+
743
+ final_score = max(0, sharpness_score - bleeding_penalty)
744
+ passed = final_score >= 60
745
+
746
+ issue = ""
747
+ if sharpness_ratio < 0.5:
748
+ issue = "Inpainted region appears blurry"
749
+ elif saturation_std > 40:
750
+ issue = "Possible color bleeding at edges"
751
+ elif final_score < 60:
752
+ issue = "Detected visual artifacts"
753
+
754
+ return QualityResult(
755
+ score=final_score,
756
+ passed=passed,
757
+ issue=issue,
758
+ details={
759
+ "sharpness": sharpness,
760
+ "sharpness_ratio": sharpness_ratio,
761
+ "boundary_saturation_std": saturation_std
762
+ }
763
+ )
764
+
765
+ except Exception as e:
766
+ logger.error(f"Inpainting artifact detection failed: {e}")
767
+ return QualityResult(score=0, passed=False, issue=str(e), details={})
768
+
769
+ def run_inpainting_checks(
770
+ self,
771
+ original: Image.Image,
772
+ inpainted: Image.Image,
773
+ mask: Image.Image
774
+ ) -> Dict[str, Any]:
775
+ """
776
+ Run all inpainting-specific quality checks.
777
+
778
+ Parameters
779
+ ----------
780
+ original : PIL.Image
781
+ Original image before inpainting
782
+ inpainted : PIL.Image
783
+ Result after inpainting
784
+ mask : PIL.Image
785
+ Inpainting mask
786
+
787
+ Returns
788
+ -------
789
+ dict
790
+ Comprehensive quality assessment for inpainting
791
+ """
792
+ logger.info("Running inpainting quality checks...")
793
+
794
+ results = {
795
+ "checks": {},
796
+ "overall_score": 0,
797
+ "passed": True,
798
+ "warnings": [],
799
+ "errors": []
800
+ }
801
+
802
+ # Run inpainting-specific checks
803
+ edge_result = self.check_inpainting_edge_continuity(original, inpainted, mask)
804
+ results["checks"]["edge_continuity"] = {
805
+ "score": edge_result.score,
806
+ "passed": edge_result.passed,
807
+ "issue": edge_result.issue,
808
+ "details": edge_result.details
809
+ }
810
+
811
+ harmony_result = self.check_inpainting_color_harmony(original, inpainted, mask)
812
+ results["checks"]["color_harmony"] = {
813
+ "score": harmony_result.score,
814
+ "passed": harmony_result.passed,
815
+ "issue": harmony_result.issue,
816
+ "details": harmony_result.details
817
+ }
818
+
819
+ artifact_result = self.check_inpainting_artifact_detection(inpainted, mask)
820
+ results["checks"]["artifact_detection"] = {
821
+ "score": artifact_result.score,
822
+ "passed": artifact_result.passed,
823
+ "issue": artifact_result.issue,
824
+ "details": artifact_result.details
825
+ }
826
+
827
+ # Calculate overall score (weighted)
828
+ weights = {
829
+ "edge_continuity": 0.4,
830
+ "color_harmony": 0.35,
831
+ "artifact_detection": 0.25
832
+ }
833
+
834
+ total_score = (
835
+ edge_result.score * weights["edge_continuity"] +
836
+ harmony_result.score * weights["color_harmony"] +
837
+ artifact_result.score * weights["artifact_detection"]
838
+ )
839
+ results["overall_score"] = round(total_score, 1)
840
+
841
+ # Determine overall pass/fail
842
+ results["passed"] = all([
843
+ edge_result.passed,
844
+ harmony_result.passed,
845
+ artifact_result.passed
846
+ ])
847
+
848
+ # Collect issues
849
+ for check_name, check_data in results["checks"].items():
850
+ if check_data["issue"]:
851
+ if check_data["passed"]:
852
+ results["warnings"].append(f"{check_name}: {check_data['issue']}")
853
+ else:
854
+ results["errors"].append(f"{check_name}: {check_data['issue']}")
855
+
856
+ logger.info(f"Inpainting quality: {results['overall_score']:.1f}, Passed: {results['passed']}")
857
+
858
+ return results
scene_templates.py CHANGED
@@ -4,7 +4,6 @@ from dataclasses import dataclass
4
 
5
  logger = logging.getLogger(__name__)
6
 
7
-
8
  @dataclass
9
  class SceneTemplate:
10
  """Data class representing a scene template"""
@@ -25,7 +24,7 @@ class SceneTemplateManager:
25
 
26
  # Scene template definitions
27
  TEMPLATES: Dict[str, SceneTemplate] = {
28
- # === Professional Category ===
29
  "office_modern": SceneTemplate(
30
  key="office_modern",
31
  name="Modern Office",
@@ -72,7 +71,7 @@ class SceneTemplateManager:
72
  guidance_scale=7.5
73
  ),
74
 
75
- # === Nature Category ===
76
  "beach_sunset": SceneTemplate(
77
  key="beach_sunset",
78
  name="Sunset Beach",
@@ -128,7 +127,7 @@ class SceneTemplateManager:
128
  guidance_scale=7.0
129
  ),
130
 
131
- # === Urban Category ===
132
  "city_skyline": SceneTemplate(
133
  key="city_skyline",
134
  name="City Skyline",
@@ -175,7 +174,7 @@ class SceneTemplateManager:
175
  guidance_scale=7.5
176
  ),
177
 
178
- # === Artistic Category ===
179
  "gradient_soft": SceneTemplate(
180
  key="gradient_soft",
181
  name="Soft Gradient",
@@ -213,7 +212,7 @@ class SceneTemplateManager:
213
  guidance_scale=6.5
214
  ),
215
 
216
- # === Seasonal Category ===
217
  "autumn_foliage": SceneTemplate(
218
  key="autumn_foliage",
219
  name="Autumn Foliage",
 
4
 
5
  logger = logging.getLogger(__name__)
6
 
 
7
  @dataclass
8
  class SceneTemplate:
9
  """Data class representing a scene template"""
 
24
 
25
  # Scene template definitions
26
  TEMPLATES: Dict[str, SceneTemplate] = {
27
+ # Professional Category
28
  "office_modern": SceneTemplate(
29
  key="office_modern",
30
  name="Modern Office",
 
71
  guidance_scale=7.5
72
  ),
73
 
74
+ # Nature Category
75
  "beach_sunset": SceneTemplate(
76
  key="beach_sunset",
77
  name="Sunset Beach",
 
127
  guidance_scale=7.0
128
  ),
129
 
130
+ # Urban Category
131
  "city_skyline": SceneTemplate(
132
  key="city_skyline",
133
  name="City Skyline",
 
174
  guidance_scale=7.5
175
  ),
176
 
177
+ # Artistic Category
178
  "gradient_soft": SceneTemplate(
179
  key="gradient_soft",
180
  name="Soft Gradient",
 
212
  guidance_scale=6.5
213
  ),
214
 
215
+ # Seasonal Category
216
  "autumn_foliage": SceneTemplate(
217
  key="autumn_foliage",
218
  name="Autumn Foliage",
scene_weaver_core.py CHANGED
@@ -5,25 +5,47 @@ from PIL import Image
5
  import logging
6
  import gc
7
  import time
8
- from typing import Optional, Dict, Any, Tuple, List
9
  from pathlib import Path
10
  import warnings
11
  warnings.filterwarnings("ignore")
12
 
13
  from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
14
  import open_clip
 
15
  from mask_generator import MaskGenerator
16
  from image_blender import ImageBlender
17
  from quality_checker import QualityChecker
 
 
 
18
 
19
  logger = logging.getLogger(__name__)
20
  logger.setLevel(logging.INFO)
21
 
22
  class SceneWeaverCore:
23
  """
24
- SceneWeaver with perfect background generation + fixed blending + memory optimization
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  """
26
 
 
 
 
 
 
27
  # Style presets for diversity generation mode
28
  STYLE_PRESETS = {
29
  "professional": {
@@ -82,7 +104,17 @@ class SceneWeaverCore:
82
  self.image_blender = ImageBlender()
83
  self.quality_checker = QualityChecker()
84
 
85
- logger.info(f"OptimizedSceneWeaver initialized on {self.device}")
 
 
 
 
 
 
 
 
 
 
86
 
87
  def _setup_device(self, device: str) -> str:
88
  """Setup computation device"""
@@ -277,7 +309,7 @@ class SceneWeaverCore:
277
  # Analyze image characteristics
278
  img_array = np.array(foreground_image.convert('RGB'))
279
 
280
- # === Analyze color temperature ===
281
  # Convert to LAB to analyze color temperature
282
  lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
283
  avg_a = np.mean(lab[:, :, 1]) # a channel: green(-) to red(+)
@@ -286,12 +318,12 @@ class SceneWeaverCore:
286
  # Determine warm/cool tone
287
  is_warm = avg_b > 128 # b > 128 means more yellow/warm
288
 
289
- # === Analyze brightness ===
290
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
291
  avg_brightness = np.mean(gray)
292
  is_bright = avg_brightness > 127
293
 
294
- # === Get subject type from CLIP ===
295
  clip_analysis = self.analyze_image_with_clip(foreground_image)
296
  subject_type = "unknown"
297
 
@@ -306,7 +338,7 @@ class SceneWeaverCore:
306
  elif "nature" in clip_analysis.lower() or "landscape" in clip_analysis.lower():
307
  subject_type = "nature"
308
 
309
- # === Build prompt fragments library ===
310
  lighting_options = {
311
  "warm_bright": "warm golden hour lighting, soft natural light",
312
  "warm_dark": "warm ambient lighting, cozy atmosphere",
@@ -325,7 +357,7 @@ class SceneWeaverCore:
325
 
326
  quality_modifiers = "high quality, detailed, sharp focus, photorealistic"
327
 
328
- # === Select appropriate fragments ===
329
  # Lighting based on color temperature and brightness
330
  if is_warm and is_bright:
331
  lighting = lighting_options["warm_bright"]
@@ -339,7 +371,7 @@ class SceneWeaverCore:
339
  # Atmosphere based on subject type
340
  atmosphere = atmosphere_options.get(subject_type, atmosphere_options["unknown"])
341
 
342
- # === Check for conflicts in user prompt ===
343
  user_prompt_lower = user_prompt.lower()
344
 
345
  # Avoid adding conflicting descriptions
@@ -348,7 +380,7 @@ class SceneWeaverCore:
348
  if "dark" in user_prompt_lower or "night" in user_prompt_lower:
349
  lighting = lighting.replace("bright", "").replace("daylight", "")
350
 
351
- # === Combine enhanced prompt ===
352
  fragments = [user_prompt]
353
 
354
  if lighting:
@@ -592,7 +624,6 @@ class SceneWeaverCore:
592
  }
593
 
594
  except Exception as e:
595
- import traceback
596
  error_traceback = traceback.format_exc()
597
  logger.error(f"❌ Generation and combination failed: {str(e)}")
598
  logger.error(f"📍 Full traceback:\n{error_traceback}")
@@ -806,3 +837,341 @@ class SceneWeaverCore:
806
  })
807
 
808
  return status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import logging
6
  import gc
7
  import time
8
+ from typing import Optional, Dict, Any, Tuple, List, Callable
9
  from pathlib import Path
10
  import warnings
11
  warnings.filterwarnings("ignore")
12
 
13
  from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
14
  import open_clip
15
+ import traceback
16
  from mask_generator import MaskGenerator
17
  from image_blender import ImageBlender
18
  from quality_checker import QualityChecker
19
+ from model_manager import get_model_manager, ModelPriority
20
+ from inpainting_module import InpaintingModule
21
+ from inpainting_templates import InpaintingTemplateManager
22
 
23
  logger = logging.getLogger(__name__)
24
  logger.setLevel(logging.INFO)
25
 
26
  class SceneWeaverCore:
27
  """
28
+ SceneWeaver Core Engine - Facade for all AI generation subsystems.
29
+
30
+ Integrates SDXL pipeline, OpenCLIP analysis, mask generation, image blending,
31
+ and inpainting functionality into a unified interface.
32
+
33
+ Attributes:
34
+ device: Computation device (cuda/mps/cpu)
35
+ is_initialized: Whether models are loaded
36
+ inpainting_module: Optional InpaintingModule instance
37
+
38
+ Example:
39
+ >>> core = SceneWeaverCore()
40
+ >>> core.load_models()
41
+ >>> result = core.generate_and_combine(image, prompt="sunset beach")
42
  """
43
 
44
+ # Model registry names
45
+ MODEL_SDXL_PIPELINE = "sdxl_background_pipeline"
46
+ MODEL_OPENCLIP = "openclip_analyzer"
47
+ MODEL_INPAINTING_PIPELINE = "inpainting_pipeline"
48
+
49
  # Style presets for diversity generation mode
50
  STYLE_PRESETS = {
51
  "professional": {
 
104
  self.image_blender = ImageBlender()
105
  self.quality_checker = QualityChecker()
106
 
107
+ # Model manager reference
108
+ self._model_manager = get_model_manager()
109
+
110
+ # Inpainting module (lazy loaded)
111
+ self._inpainting_module = None
112
+ self._inpainting_initialized = False
113
+
114
+ # Current mode tracking
115
+ self._current_mode = "background" # "background" or "inpainting"
116
+
117
+ logger.info(f"SceneWeaverCore initialized on {self.device}")
118
 
119
  def _setup_device(self, device: str) -> str:
120
  """Setup computation device"""
 
309
  # Analyze image characteristics
310
  img_array = np.array(foreground_image.convert('RGB'))
311
 
312
+ # Analyze color temperature
313
  # Convert to LAB to analyze color temperature
314
  lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
315
  avg_a = np.mean(lab[:, :, 1]) # a channel: green(-) to red(+)
 
318
  # Determine warm/cool tone
319
  is_warm = avg_b > 128 # b > 128 means more yellow/warm
320
 
321
+ # Analyze brightness
322
  gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
323
  avg_brightness = np.mean(gray)
324
  is_bright = avg_brightness > 127
325
 
326
+ # Get subject type from CLIP
327
  clip_analysis = self.analyze_image_with_clip(foreground_image)
328
  subject_type = "unknown"
329
 
 
338
  elif "nature" in clip_analysis.lower() or "landscape" in clip_analysis.lower():
339
  subject_type = "nature"
340
 
341
+ # Build prompt fragments library
342
  lighting_options = {
343
  "warm_bright": "warm golden hour lighting, soft natural light",
344
  "warm_dark": "warm ambient lighting, cozy atmosphere",
 
357
 
358
  quality_modifiers = "high quality, detailed, sharp focus, photorealistic"
359
 
360
+ # Select appropriate fragments
361
  # Lighting based on color temperature and brightness
362
  if is_warm and is_bright:
363
  lighting = lighting_options["warm_bright"]
 
371
  # Atmosphere based on subject type
372
  atmosphere = atmosphere_options.get(subject_type, atmosphere_options["unknown"])
373
 
374
+ # Check for conflicts in user prompt
375
  user_prompt_lower = user_prompt.lower()
376
 
377
  # Avoid adding conflicting descriptions
 
380
  if "dark" in user_prompt_lower or "night" in user_prompt_lower:
381
  lighting = lighting.replace("bright", "").replace("daylight", "")
382
 
383
+ # Combine enhanced prompt
384
  fragments = [user_prompt]
385
 
386
  if lighting:
 
624
  }
625
 
626
  except Exception as e:
 
627
  error_traceback = traceback.format_exc()
628
  logger.error(f"❌ Generation and combination failed: {str(e)}")
629
  logger.error(f"📍 Full traceback:\n{error_traceback}")
 
837
  })
838
 
839
  return status
840
+
841
+ # INPAINTING FACADE METHODS
842
+ def get_inpainting_module(self):
843
+ """
844
+ Get or create the InpaintingModule instance.
845
+
846
+ Implements lazy loading - module is only created when first accessed.
847
+
848
+ Returns
849
+ -------
850
+ InpaintingModule
851
+ The inpainting module instance
852
+ """
853
+ if self._inpainting_module is None:
854
+ self._inpainting_module = InpaintingModule(device=self.device)
855
+ self._inpainting_module.set_model_manager(self._model_manager)
856
+ logger.info("InpaintingModule created (lazy load)")
857
+
858
+ return self._inpainting_module
859
+
860
+ def switch_to_inpainting_mode(
861
+ self,
862
+ conditioning_type: str = "canny",
863
+ progress_callback: Optional[Callable[[str, int], None]] = None
864
+ ) -> bool:
865
+ """
866
+ Switch to inpainting mode, unloading background pipeline.
867
+
868
+ Implements mutual exclusion between pipelines to conserve memory.
869
+
870
+ Parameters
871
+ ----------
872
+ conditioning_type : str
873
+ ControlNet conditioning type: "canny" or "depth"
874
+ progress_callback : callable, optional
875
+ Progress update function(message, percentage)
876
+
877
+ Returns
878
+ -------
879
+ bool
880
+ True if switch was successful
881
+ """
882
+ logger.info(f"Switching to inpainting mode (conditioning: {conditioning_type})")
883
+
884
+ try:
885
+ # Unload background pipeline first
886
+ if self.pipeline is not None:
887
+ if progress_callback:
888
+ progress_callback("Unloading background pipeline...", 10)
889
+
890
+ del self.pipeline
891
+ self.pipeline = None
892
+ self._ultra_memory_cleanup()
893
+ logger.info("Background pipeline unloaded")
894
+
895
+ # Load inpainting pipeline
896
+ if progress_callback:
897
+ progress_callback("Loading inpainting pipeline...", 20)
898
+
899
+ inpaint_module = self.get_inpainting_module()
900
+
901
+ def inpaint_progress(msg, pct):
902
+ if progress_callback:
903
+ # Map inpainting progress (0-100) to (20-90)
904
+ mapped_pct = 20 + int(pct * 0.7)
905
+ progress_callback(msg, mapped_pct)
906
+
907
+ success, error_msg = inpaint_module.load_inpainting_pipeline(
908
+ conditioning_type=conditioning_type,
909
+ progress_callback=inpaint_progress
910
+ )
911
+
912
+ if success:
913
+ self._current_mode = "inpainting"
914
+ self._inpainting_initialized = True
915
+
916
+ if progress_callback:
917
+ progress_callback("Inpainting mode ready!", 100)
918
+
919
+ logger.info("Successfully switched to inpainting mode")
920
+ else:
921
+ self._last_inpainting_error = error_msg
922
+ logger.error(f"Failed to load inpainting pipeline: {error_msg}")
923
+
924
+ return success
925
+
926
+ except Exception as e:
927
+ traceback.print_exc()
928
+ self._last_inpainting_error = str(e)
929
+ logger.error(f"Failed to switch to inpainting mode: {e}")
930
+ if progress_callback:
931
+ progress_callback(f"Error: {str(e)}", 0)
932
+ return False
933
+
934
+ def switch_to_background_mode(
935
+ self,
936
+ progress_callback: Optional[Callable[[str, int], None]] = None
937
+ ) -> bool:
938
+ """
939
+ Switch back to background generation mode.
940
+
941
+ Parameters
942
+ ----------
943
+ progress_callback : callable, optional
944
+ Progress update function
945
+
946
+ Returns
947
+ -------
948
+ bool
949
+ True if switch was successful
950
+ """
951
+ logger.info("Switching to background generation mode")
952
+
953
+ try:
954
+ # Unload inpainting pipeline
955
+ if self._inpainting_module is not None and self._inpainting_module.is_initialized:
956
+ if progress_callback:
957
+ progress_callback("Unloading inpainting pipeline...", 10)
958
+
959
+ self._inpainting_module._unload_pipeline()
960
+ self._ultra_memory_cleanup()
961
+
962
+ # Reload background pipeline
963
+ if progress_callback:
964
+ progress_callback("Loading background pipeline...", 30)
965
+
966
+ # Reset initialization flag to force reload
967
+ self.is_initialized = False
968
+ self.load_models(progress_callback=progress_callback)
969
+
970
+ self._current_mode = "background"
971
+
972
+ if progress_callback:
973
+ progress_callback("Background mode ready!", 100)
974
+
975
+ return True
976
+
977
+ except Exception as e:
978
+ logger.error(f"Failed to switch to background mode: {e}")
979
+ return False
980
+
981
+ def execute_inpainting(
982
+ self,
983
+ image: Image.Image,
984
+ mask: Image.Image,
985
+ prompt: str,
986
+ preview_only: bool = False,
987
+ template_key: Optional[str] = None,
988
+ progress_callback: Optional[Callable[[str, int], None]] = None,
989
+ **kwargs
990
+ ) -> Dict[str, Any]:
991
+ """
992
+ Execute inpainting operation through the Facade.
993
+
994
+ This is the main entry point for inpainting functionality.
995
+
996
+ Parameters
997
+ ----------
998
+ image : PIL.Image
999
+ Original image to inpaint
1000
+ mask : PIL.Image
1001
+ Inpainting mask (white = area to regenerate)
1002
+ prompt : str
1003
+ Text description of desired content
1004
+ preview_only : bool
1005
+ If True, generate quick preview only
1006
+ template_key : str, optional
1007
+ Inpainting template key to use
1008
+ progress_callback : callable, optional
1009
+ Progress update function
1010
+ **kwargs
1011
+ Additional inpainting parameters
1012
+
1013
+ Returns
1014
+ -------
1015
+ dict
1016
+ Result dictionary with images and metadata
1017
+ """
1018
+ # Ensure inpainting mode is active
1019
+ if self._current_mode != "inpainting" or not self._inpainting_initialized:
1020
+ conditioning = kwargs.get('conditioning_type', 'canny')
1021
+ if not self.switch_to_inpainting_mode(conditioning, progress_callback):
1022
+ error_detail = getattr(self, '_last_inpainting_error', 'Unknown error')
1023
+ return {
1024
+ "success": False,
1025
+ "error": f"Failed to initialize inpainting mode: {error_detail}"
1026
+ }
1027
+
1028
+ inpaint_module = self.get_inpainting_module()
1029
+
1030
+ # Apply template if specified
1031
+ if template_key:
1032
+ template_mgr = InpaintingTemplateManager()
1033
+ template = template_mgr.get_template(template_key)
1034
+
1035
+ if template:
1036
+ # Build prompt from template
1037
+ prompt = template_mgr.build_prompt(template_key, prompt)
1038
+ # Apply template parameters as defaults
1039
+ params = template_mgr.get_parameters_for_template(template_key)
1040
+ for key, value in params.items():
1041
+ if key not in kwargs:
1042
+ kwargs[key] = value
1043
+
1044
+ # Execute inpainting
1045
+ result = inpaint_module.execute_inpainting(
1046
+ image=image,
1047
+ mask=mask,
1048
+ prompt=prompt,
1049
+ preview_only=preview_only,
1050
+ progress_callback=progress_callback,
1051
+ **kwargs
1052
+ )
1053
+
1054
+ # Convert InpaintingResult to dictionary format
1055
+ return {
1056
+ "success": result.success,
1057
+ "combined_image": result.blended_image or result.result_image,
1058
+ "generated_image": result.result_image,
1059
+ "preview_image": result.preview_image,
1060
+ "control_image": result.control_image,
1061
+ "original_image": image,
1062
+ "mask": mask,
1063
+ "quality_score": result.quality_score,
1064
+ "generation_time": result.generation_time,
1065
+ "metadata": result.metadata,
1066
+ "error": result.error_message if not result.success else None
1067
+ }
1068
+
1069
+ def execute_inpainting_with_optimization(
1070
+ self,
1071
+ image: Image.Image,
1072
+ mask: Image.Image,
1073
+ prompt: str,
1074
+ progress_callback: Optional[Callable[[str, int], None]] = None,
1075
+ **kwargs
1076
+ ) -> Dict[str, Any]:
1077
+ """
1078
+ Execute inpainting with automatic quality optimization.
1079
+
1080
+ Retries with adjusted parameters if quality is below threshold.
1081
+
1082
+ Parameters
1083
+ ----------
1084
+ image : PIL.Image
1085
+ Original image
1086
+ mask : PIL.Image
1087
+ Inpainting mask
1088
+ prompt : str
1089
+ Text prompt
1090
+ progress_callback : callable, optional
1091
+ Progress callback
1092
+ **kwargs
1093
+ Additional parameters
1094
+
1095
+ Returns
1096
+ -------
1097
+ dict
1098
+ Optimized result dictionary
1099
+ """
1100
+ # Ensure inpainting mode
1101
+ if self._current_mode != "inpainting" or not self._inpainting_initialized:
1102
+ conditioning = kwargs.get('conditioning_type', 'canny')
1103
+ if not self.switch_to_inpainting_mode(conditioning, progress_callback):
1104
+ error_detail = getattr(self, '_last_inpainting_error', 'Unknown error')
1105
+ return {
1106
+ "success": False,
1107
+ "error": f"Failed to initialize inpainting mode: {error_detail}"
1108
+ }
1109
+
1110
+ inpaint_module = self.get_inpainting_module()
1111
+
1112
+ result = inpaint_module.execute_with_auto_optimization(
1113
+ image=image,
1114
+ mask=mask,
1115
+ prompt=prompt,
1116
+ quality_checker=self.quality_checker,
1117
+ progress_callback=progress_callback,
1118
+ **kwargs
1119
+ )
1120
+
1121
+ return {
1122
+ "success": result.success,
1123
+ "combined_image": result.blended_image or result.result_image,
1124
+ "generated_image": result.result_image,
1125
+ "preview_image": result.preview_image,
1126
+ "control_image": result.control_image,
1127
+ "quality_score": result.quality_score,
1128
+ "quality_details": result.quality_details,
1129
+ "retries": result.retries,
1130
+ "generation_time": result.generation_time,
1131
+ "metadata": result.metadata,
1132
+ "error": result.error_message if not result.success else None
1133
+ }
1134
+
1135
+ def get_current_mode(self) -> str:
1136
+ """
1137
+ Get current operation mode.
1138
+
1139
+ Returns
1140
+ -------
1141
+ str
1142
+ "background" or "inpainting"
1143
+ """
1144
+ return self._current_mode
1145
+
1146
+ def is_inpainting_ready(self) -> bool:
1147
+ """
1148
+ Check if inpainting is ready to use.
1149
+
1150
+ Returns
1151
+ -------
1152
+ bool
1153
+ True if inpainting module is loaded and ready
1154
+ """
1155
+ return (
1156
+ self._inpainting_module is not None and
1157
+ self._inpainting_module.is_initialized
1158
+ )
1159
+
1160
+ def get_inpainting_status(self) -> Dict[str, Any]:
1161
+ """
1162
+ Get inpainting module status.
1163
+
1164
+ Returns
1165
+ -------
1166
+ dict
1167
+ Status information
1168
+ """
1169
+ if self._inpainting_module is None:
1170
+ return {
1171
+ "initialized": False,
1172
+ "mode": self._current_mode
1173
+ }
1174
+
1175
+ status = self._inpainting_module.get_status()
1176
+ status["mode"] = self._current_mode
1177
+ return status
ui_manager.py CHANGED
@@ -1,7 +1,8 @@
1
  import logging
2
  import time
 
3
  from pathlib import Path
4
- from typing import Optional, Tuple
5
  from PIL import Image
6
  import numpy as np
7
  import cv2
@@ -11,6 +12,7 @@ import spaces
11
  from scene_weaver_core import SceneWeaverCore
12
  from css_styles import CSSStyles
13
  from scene_templates import SceneTemplateManager
 
14
 
15
  logger = logging.getLogger(__name__)
16
  logger.setLevel(logging.INFO)
@@ -23,13 +25,26 @@ logging.basicConfig(
23
 
24
 
25
  class UIManager:
26
- """Gradio UI with enhanced memory management and professional design"""
 
 
 
 
 
 
 
 
 
 
27
 
28
  def __init__(self):
29
  self.sceneweaver = SceneWeaverCore()
30
  self.template_manager = SceneTemplateManager()
 
31
  self.generation_history = []
 
32
  self._preview_sensitivity = 0.5
 
33
 
34
  def apply_template(self, display_name: str, current_negative: str) -> Tuple[str, str, float]:
35
  """
@@ -234,7 +249,6 @@ class UIManager:
234
  return None, None, None, f"Error: {error_msg}", gr.update(visible=False)
235
 
236
  except Exception as e:
237
- import traceback
238
  error_traceback = traceback.format_exc()
239
  logger.error(f"Generation handler error: {str(e)}")
240
  logger.error(f"Traceback:\n{error_traceback}")
@@ -249,18 +263,12 @@ class UIManager:
249
  self._gradio_version = gr.__version__
250
  self._gradio_major = int(self._gradio_version.split('.')[0])
251
 
252
- # Gradio 5.x: css/theme moved to launch(), Blocks() has minimal params
253
- # Gradio 4.x: css/theme are in Blocks() constructor
254
- if self._gradio_major >= 5:
255
- blocks_kwargs = {
256
- "title": "SceneWeaver - AI Background Generator"
257
- }
258
- else:
259
- blocks_kwargs = {
260
- "css": self._css,
261
- "title": "SceneWeaver - AI Background Generator",
262
- "theme": gr.themes.Soft()
263
- }
264
 
265
  with gr.Blocks(**blocks_kwargs) as interface:
266
 
@@ -271,177 +279,232 @@ class UIManager:
271
  <span class="title-emoji">🎨</span>
272
  SceneWeaver
273
  </h1>
274
- <p class="main-subtitle">AI-powered background generation with professional edge processing</p>
275
  </div>
276
  """)
277
 
278
- with gr.Row():
279
- # Left Column - Input controls
280
- with gr.Column(scale=1, min_width=350, elem_classes=["feature-card"]):
281
- gr.HTML("""
282
- <div class="card-content">
283
- <h3 class="card-title">
284
- <span class="section-emoji">📸</span>
285
- Upload & Generate
286
- </h3>
287
- </div>
288
- """)
289
 
290
- uploaded_image = gr.Image(
291
- label="Upload Your Image",
292
- type="pil",
293
- height=280,
294
- elem_classes=["input-field"]
295
- )
296
 
297
- # Scene Template Selector (without Accordion to fix dropdown positioning in Gradio 5.x)
298
- template_dropdown = gr.Dropdown(
299
- label="Scene Templates",
300
- choices=[""] + self.template_manager.get_template_choices_sorted(),
301
- value="",
302
- info="24 curated scenes sorted A-Z (optional)",
303
- elem_classes=["template-dropdown"]
304
- )
 
 
 
 
 
 
 
 
 
 
305
 
306
- prompt_input = gr.Textbox(
307
- label="Background Scene Description",
308
- placeholder="Select a template above or describe your own scene...",
309
- lines=3,
310
- elem_classes=["input-field"]
311
- )
 
 
312
 
313
- combination_mode = gr.Dropdown(
314
- label="Composition Mode",
315
- choices=["center", "left_half", "right_half", "full"],
316
- value="center",
317
- info="center=Smart Center | left_half=Left Half | right_half=Right Half | full=Full Image",
318
- elem_classes=["input-field"]
319
- )
320
 
321
- focus_mode = gr.Dropdown(
322
- label="Focus Mode",
323
- choices=["person", "scene"],
324
- value="person",
325
- info="person=Tight Crop | scene=Include Surrounding Objects",
326
- elem_classes=["input-field"]
327
- )
328
 
329
- with gr.Accordion("Advanced Options", open=False):
330
- negative_prompt = gr.Textbox(
331
- label="Negative Prompt",
332
- value="blurry, low quality, distorted, people, characters",
333
- lines=2,
334
- elem_classes=["input-field"]
335
- )
336
 
337
- steps_slider = gr.Slider(
338
- label="Quality Steps",
339
- minimum=15,
340
- maximum=50,
341
- value=25,
342
- step=5,
343
- elem_classes=["input-field"]
344
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
- guidance_slider = gr.Slider(
347
- label="Guidance Scale",
348
- minimum=5.0,
349
- maximum=15.0,
350
- value=7.5,
351
- step=0.5,
352
- elem_classes=["input-field"]
353
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
- generate_btn = gr.Button(
356
- "Generate Background",
357
- variant="primary",
358
- size="lg",
359
- elem_classes=["primary-button"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  )
361
 
362
- # Right Column - Results display
363
- with gr.Column(scale=2, elem_classes=["feature-card"], elem_id="results-gallery-centered"):
364
- gr.HTML("""
365
- <div class="card-content">
366
- <h3 class="card-title">
367
- <span class="section-emoji">🎭</span>
368
- Results Gallery
369
- </h3>
370
- </div>
371
- """)
372
-
373
- # Loading notice
374
- gr.HTML("""
375
- <div class="loading-notice">
376
- <span class="loading-notice-icon">⏱️</span>
377
- <span class="loading-notice-text">
378
- <strong>First-time users:</strong> Initial model loading takes 1-2 minutes.
379
- Subsequent generations are much faster (~30s).
380
- </span>
381
- </div>
382
- """)
383
-
384
- # Quick start guide
385
- gr.HTML("""
386
- <details class="user-guidance-panel">
387
- <summary class="guidance-summary">
388
- <span class="emoji-enhanced">💡</span>
389
- Quick Start Guide
390
- </summary>
391
- <div class="guidance-content">
392
- <p><strong>Step 1:</strong> Upload any image with a clear subject</p>
393
- <p><strong>Step 2:</strong> Describe or Choose your desired background scene</p>
394
- <p><strong>Step 3:</strong> Choose composition mode (center works best)</p>
395
- <p><strong>Step 4:</strong> Click Generate and wait for the magic!</p>
396
- <p><strong>Tip:</strong> For dark clothing, ensure good lighting in original photo.</p>
397
- </div>
398
- </details>
399
- """)
400
 
401
- with gr.Tabs():
402
- with gr.TabItem("Final Result"):
403
- combined_output = gr.Image(
404
- label="Your Generated Image",
405
- elem_classes=["result-gallery"],
406
- show_label=False
407
- )
408
- with gr.TabItem("Background"):
409
- generated_output = gr.Image(
410
- label="Generated Background",
411
- elem_classes=["result-gallery"],
412
- show_label=False
413
- )
414
- with gr.TabItem("Original"):
415
- original_output = gr.Image(
416
- label="Processed Original",
417
- elem_classes=["result-gallery"],
418
- show_label=False
419
- )
420
 
421
- status_output = gr.Textbox(
422
- label="Status",
423
- value="Ready to create! Upload an image and describe your vision.",
424
- interactive=False,
425
- elem_classes=["status-panel", "status-ready"]
426
  )
427
 
428
- with gr.Row():
429
- download_btn = gr.DownloadButton(
430
- "Download Result",
431
- value=None,
432
- visible=False,
433
- elem_classes=["secondary-button"]
434
- )
435
- clear_btn = gr.Button(
436
- "Clear All",
437
- elem_classes=["secondary-button"]
438
- )
439
- memory_btn = gr.Button(
440
- "Clean Memory",
441
- elem_classes=["secondary-button"]
442
- )
443
 
444
- # Footer with tech credits
445
  gr.HTML("""
446
  <div class="app-footer">
447
  <div class="footer-powered">
@@ -464,57 +527,13 @@ class UIManager:
464
  </div>
465
  """)
466
 
467
- # Event handlers
468
- # Template selection handler
469
- template_dropdown.change(
470
- fn=self.apply_template,
471
- inputs=[template_dropdown, negative_prompt],
472
- outputs=[prompt_input, negative_prompt, guidance_slider]
473
- )
474
-
475
- generate_btn.click(
476
- fn=self.generate_handler,
477
- inputs=[
478
- uploaded_image,
479
- prompt_input,
480
- combination_mode,
481
- focus_mode,
482
- negative_prompt,
483
- steps_slider,
484
- guidance_slider
485
- ],
486
- outputs=[
487
- combined_output,
488
- generated_output,
489
- original_output,
490
- status_output,
491
- download_btn
492
- ]
493
- )
494
-
495
- clear_btn.click(
496
- fn=lambda: (None, None, None, "Ready to create!", gr.update(visible=False)),
497
- outputs=[combined_output, generated_output, original_output, status_output, download_btn]
498
- )
499
-
500
- memory_btn.click(
501
- fn=lambda: self.sceneweaver._ultra_memory_cleanup() or "Memory cleaned!",
502
- outputs=[status_output]
503
- )
504
-
505
- combined_output.change(
506
- fn=lambda img: gr.update(value="outputs/latest_combined.png", visible=True) if (img is not None) else gr.update(visible=False),
507
- inputs=[combined_output],
508
- outputs=[download_btn]
509
- )
510
-
511
  return interface
512
 
513
  def launch(self, share: bool = True, debug: bool = False):
514
  """Launch the UI interface"""
515
  interface = self.create_interface()
516
 
517
- # Build launch kwargs based on Gradio version
518
  launch_kwargs = {
519
  "share": share,
520
  "debug": debug,
@@ -522,10 +541,450 @@ class UIManager:
522
  "quiet": False
523
  }
524
 
525
- # Gradio 5.x: css/theme are passed to launch()
526
- # Gradio 4.x: these were already set in Blocks()
527
- if self._gradio_major >= 5:
528
- launch_kwargs["css"] = self._css
529
- launch_kwargs["theme"] = gr.themes.Soft()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
- return interface.launch(**launch_kwargs)
 
 
 
 
1
  import logging
2
  import time
3
+ import traceback
4
  from pathlib import Path
5
+ from typing import Optional, Tuple, Dict, Any, List
6
  from PIL import Image
7
  import numpy as np
8
  import cv2
 
12
  from scene_weaver_core import SceneWeaverCore
13
  from css_styles import CSSStyles
14
  from scene_templates import SceneTemplateManager
15
+ from inpainting_templates import InpaintingTemplateManager
16
 
17
  logger = logging.getLogger(__name__)
18
  logger.setLevel(logging.INFO)
 
25
 
26
 
27
  class UIManager:
28
+ """
29
+ Gradio UI Manager with support for background generation and inpainting.
30
+
31
+ Provides a professional interface with mode switching, template selection,
32
+ and advanced parameter controls.
33
+
34
+ Attributes:
35
+ sceneweaver: SceneWeaverCore instance
36
+ template_manager: Scene template manager
37
+ inpainting_template_manager: Inpainting template manager
38
+ """
39
 
40
  def __init__(self):
41
  self.sceneweaver = SceneWeaverCore()
42
  self.template_manager = SceneTemplateManager()
43
+ self.inpainting_template_manager = InpaintingTemplateManager()
44
  self.generation_history = []
45
+ self.inpainting_history = []
46
  self._preview_sensitivity = 0.5
47
+ self._current_mode = "background" # "background" or "inpainting"
48
 
49
  def apply_template(self, display_name: str, current_negative: str) -> Tuple[str, str, float]:
50
  """
 
249
  return None, None, None, f"Error: {error_msg}", gr.update(visible=False)
250
 
251
  except Exception as e:
 
252
  error_traceback = traceback.format_exc()
253
  logger.error(f"Generation handler error: {str(e)}")
254
  logger.error(f"Traceback:\n{error_traceback}")
 
263
  self._gradio_version = gr.__version__
264
  self._gradio_major = int(self._gradio_version.split('.')[0])
265
 
266
+ # Both Gradio 4.x and 5.x use css/theme in Blocks() constructor
267
+ blocks_kwargs = {
268
+ "css": self._css,
269
+ "title": "SceneWeaver - AI Background Generator",
270
+ "theme": gr.themes.Soft()
271
+ }
 
 
 
 
 
 
272
 
273
  with gr.Blocks(**blocks_kwargs) as interface:
274
 
 
279
  <span class="title-emoji">🎨</span>
280
  SceneWeaver
281
  </h1>
282
+ <p class="main-subtitle">AI-powered background generation and inpainting with professional edge processing</p>
283
  </div>
284
  """)
285
 
286
+ # Main Tabs for Mode Selection
287
+ with gr.Tabs(elem_id="main-mode-tabs") as main_tabs:
 
 
 
 
 
 
 
 
 
288
 
289
+ # Background Generation Tab
290
+ with gr.Tab("Background Generation", elem_id="bg-gen-tab"):
 
 
 
 
291
 
292
+ with gr.Row():
293
+ # Left Column - Input controls
294
+ with gr.Column(scale=1, min_width=350, elem_classes=["feature-card"]):
295
+ gr.HTML("""
296
+ <div class="card-content">
297
+ <h3 class="card-title">
298
+ <span class="section-emoji">📸</span>
299
+ Upload & Generate
300
+ </h3>
301
+ </div>
302
+ """)
303
+
304
+ uploaded_image = gr.Image(
305
+ label="Upload Your Image",
306
+ type="pil",
307
+ height=280,
308
+ elem_classes=["input-field"]
309
+ )
310
 
311
+ # Scene Template Selector (without Accordion to fix dropdown positioning in Gradio 5.x)
312
+ template_dropdown = gr.Dropdown(
313
+ label="Scene Templates",
314
+ choices=[""] + self.template_manager.get_template_choices_sorted(),
315
+ value="",
316
+ info="24 curated scenes sorted A-Z (optional)",
317
+ elem_classes=["template-dropdown"]
318
+ )
319
 
320
+ prompt_input = gr.Textbox(
321
+ label="Background Scene Description",
322
+ placeholder="Select a template above or describe your own scene...",
323
+ lines=3,
324
+ elem_classes=["input-field"]
325
+ )
 
326
 
327
+ combination_mode = gr.Dropdown(
328
+ label="Composition Mode",
329
+ choices=["center", "left_half", "right_half", "full"],
330
+ value="center",
331
+ info="center=Smart Center | left_half=Left Half | right_half=Right Half | full=Full Image",
332
+ elem_classes=["input-field"]
333
+ )
334
 
335
+ focus_mode = gr.Dropdown(
336
+ label="Focus Mode",
337
+ choices=["person", "scene"],
338
+ value="person",
339
+ info="person=Tight Crop | scene=Include Surrounding Objects",
340
+ elem_classes=["input-field"]
341
+ )
342
 
343
+ with gr.Accordion("Advanced Options", open=False):
344
+ negative_prompt = gr.Textbox(
345
+ label="Negative Prompt",
346
+ value="blurry, low quality, distorted, people, characters",
347
+ lines=2,
348
+ elem_classes=["input-field"]
349
+ )
350
+
351
+ steps_slider = gr.Slider(
352
+ label="Quality Steps",
353
+ minimum=15,
354
+ maximum=50,
355
+ value=25,
356
+ step=5,
357
+ elem_classes=["input-field"]
358
+ )
359
+
360
+ guidance_slider = gr.Slider(
361
+ label="Guidance Scale",
362
+ minimum=5.0,
363
+ maximum=15.0,
364
+ value=7.5,
365
+ step=0.5,
366
+ elem_classes=["input-field"]
367
+ )
368
+
369
+ generate_btn = gr.Button(
370
+ "Generate Background",
371
+ variant="primary",
372
+ size="lg",
373
+ elem_classes=["primary-button"]
374
+ )
375
 
376
+ # Right Column - Results display
377
+ with gr.Column(scale=2, elem_classes=["feature-card"], elem_id="results-gallery-centered"):
378
+ gr.HTML("""
379
+ <div class="card-content">
380
+ <h3 class="card-title">
381
+ <span class="section-emoji">🎭</span>
382
+ Results Gallery
383
+ </h3>
384
+ </div>
385
+ """)
386
+
387
+ # Loading notice
388
+ gr.HTML("""
389
+ <div class="loading-notice">
390
+ <span class="loading-notice-icon">⏱️</span>
391
+ <span class="loading-notice-text">
392
+ <strong>First-time users:</strong> Initial model loading takes 1-2 minutes.
393
+ Subsequent generations are much faster (~30s).
394
+ </span>
395
+ </div>
396
+ """)
397
+
398
+ # Quick start guide
399
+ gr.HTML("""
400
+ <details class="user-guidance-panel">
401
+ <summary class="guidance-summary">
402
+ <span class="emoji-enhanced">💡</span>
403
+ Quick Start Guide
404
+ </summary>
405
+ <div class="guidance-content">
406
+ <p><strong>Step 1:</strong> Upload any image with a clear subject</p>
407
+ <p><strong>Step 2:</strong> Describe or Choose your desired background scene</p>
408
+ <p><strong>Step 3:</strong> Choose composition mode (center works best)</p>
409
+ <p><strong>Step 4:</strong> Click Generate and wait for the magic!</p>
410
+ <p><strong>Tip:</strong> For dark clothing, ensure good lighting in original photo.</p>
411
+ </div>
412
+ </details>
413
+ """)
414
+
415
+ with gr.Tabs():
416
+ with gr.TabItem("Final Result"):
417
+ combined_output = gr.Image(
418
+ label="Your Generated Image",
419
+ elem_classes=["result-gallery"],
420
+ show_label=False
421
+ )
422
+ with gr.TabItem("Background"):
423
+ generated_output = gr.Image(
424
+ label="Generated Background",
425
+ elem_classes=["result-gallery"],
426
+ show_label=False
427
+ )
428
+ with gr.TabItem("Original"):
429
+ original_output = gr.Image(
430
+ label="Processed Original",
431
+ elem_classes=["result-gallery"],
432
+ show_label=False
433
+ )
434
+
435
+ status_output = gr.Textbox(
436
+ label="Status",
437
+ value="Ready to create! Upload an image and describe your vision.",
438
+ interactive=False,
439
+ elem_classes=["status-panel", "status-ready"]
440
+ )
441
 
442
+ with gr.Row():
443
+ download_btn = gr.DownloadButton(
444
+ "Download Result",
445
+ value=None,
446
+ visible=False,
447
+ elem_classes=["secondary-button"]
448
+ )
449
+ clear_btn = gr.Button(
450
+ "Clear All",
451
+ elem_classes=["secondary-button"]
452
+ )
453
+ memory_btn = gr.Button(
454
+ "Clean Memory",
455
+ elem_classes=["secondary-button"]
456
+ )
457
+
458
+ # Event handlers for Background Generation Tab
459
+ # Template selection handler
460
+ template_dropdown.change(
461
+ fn=self.apply_template,
462
+ inputs=[template_dropdown, negative_prompt],
463
+ outputs=[prompt_input, negative_prompt, guidance_slider]
464
  )
465
 
466
+ generate_btn.click(
467
+ fn=self.generate_handler,
468
+ inputs=[
469
+ uploaded_image,
470
+ prompt_input,
471
+ combination_mode,
472
+ focus_mode,
473
+ negative_prompt,
474
+ steps_slider,
475
+ guidance_slider
476
+ ],
477
+ outputs=[
478
+ combined_output,
479
+ generated_output,
480
+ original_output,
481
+ status_output,
482
+ download_btn
483
+ ]
484
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
+ clear_btn.click(
487
+ fn=lambda: (None, None, None, "Ready to create!", gr.update(visible=False)),
488
+ outputs=[combined_output, generated_output, original_output, status_output, download_btn]
489
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
 
491
+ memory_btn.click(
492
+ fn=lambda: self.sceneweaver._ultra_memory_cleanup() or "Memory cleaned!",
493
+ outputs=[status_output]
 
 
494
  )
495
 
496
+ combined_output.change(
497
+ fn=lambda img: gr.update(value="outputs/latest_combined.png", visible=True) if (img is not None) else gr.update(visible=False),
498
+ inputs=[combined_output],
499
+ outputs=[download_btn]
500
+ )
501
+
502
+ # End of Background Generation Tab
503
+
504
+ # Inpainting Tab
505
+ self.create_inpainting_tab()
 
 
 
 
 
506
 
507
+ # Footer with tech credits (outside tabs)
508
  gr.HTML("""
509
  <div class="app-footer">
510
  <div class="footer-powered">
 
527
  </div>
528
  """)
529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  return interface
531
 
532
  def launch(self, share: bool = True, debug: bool = False):
533
  """Launch the UI interface"""
534
  interface = self.create_interface()
535
 
536
+ # Launch kwargs compatible with both Gradio 4.x and 5.x
537
  launch_kwargs = {
538
  "share": share,
539
  "debug": debug,
 
541
  "quiet": False
542
  }
543
 
544
+ return interface.launch(**launch_kwargs)
545
+
546
+ # INPAINTING UI METHODS
547
+ def apply_inpainting_template(
548
+ self,
549
+ display_name: str,
550
+ current_prompt: str
551
+ ) -> Tuple[str, float, int, str]:
552
+ """
553
+ Apply an inpainting template to the UI fields.
554
+
555
+ Parameters
556
+ ----------
557
+ display_name : str
558
+ Template display name from dropdown
559
+ current_prompt : str
560
+ Current prompt content
561
+
562
+ Returns
563
+ -------
564
+ tuple
565
+ (prompt, conditioning_scale, feather_radius, conditioning_type)
566
+ """
567
+ if not display_name:
568
+ return current_prompt, 0.7, 8, "canny"
569
+
570
+ template_key = self.inpainting_template_manager.get_template_key_from_display(display_name)
571
+ if not template_key:
572
+ return current_prompt, 0.7, 8, "canny"
573
+
574
+ template = self.inpainting_template_manager.get_template(template_key)
575
+ if template:
576
+ params = self.inpainting_template_manager.get_parameters_for_template(template_key)
577
+ return (
578
+ current_prompt,
579
+ params.get('controlnet_conditioning_scale', 0.7),
580
+ params.get('feather_radius', 8),
581
+ params.get('preferred_conditioning', 'canny')
582
+ )
583
+
584
+ return current_prompt, 0.7, 8, "canny"
585
+
586
+ def extract_mask_from_editor(self, editor_output: Dict[str, Any]) -> Optional[Image.Image]:
587
+ """
588
+ Extract mask from Gradio ImageEditor output.
589
+
590
+ Handles different Gradio versions' output formats.
591
+
592
+ Parameters
593
+ ----------
594
+ editor_output : dict
595
+ Output from gr.ImageEditor component
596
+
597
+ Returns
598
+ -------
599
+ PIL.Image or None
600
+ Extracted mask as grayscale image
601
+ """
602
+ if editor_output is None:
603
+ return None
604
+
605
+ try:
606
+ # Gradio 5.x format
607
+ if isinstance(editor_output, dict):
608
+ # Check for 'layers' key (Gradio 5.x ImageEditor)
609
+ if 'layers' in editor_output and editor_output['layers']:
610
+ # Get the first layer as mask
611
+ layer = editor_output['layers'][0]
612
+ if isinstance(layer, np.ndarray):
613
+ mask_array = layer
614
+ elif isinstance(layer, Image.Image):
615
+ mask_array = np.array(layer)
616
+ else:
617
+ return None
618
+
619
+ # Check for 'composite' key
620
+ elif 'composite' in editor_output:
621
+ composite = editor_output['composite']
622
+ if isinstance(composite, np.ndarray):
623
+ mask_array = composite
624
+ elif isinstance(composite, Image.Image):
625
+ mask_array = np.array(composite)
626
+ else:
627
+ return None
628
+ else:
629
+ return None
630
+
631
+ elif isinstance(editor_output, np.ndarray):
632
+ mask_array = editor_output
633
+ elif isinstance(editor_output, Image.Image):
634
+ mask_array = np.array(editor_output)
635
+ else:
636
+ logger.warning(f"Unexpected editor output type: {type(editor_output)}")
637
+ return None
638
+
639
+ # Convert to grayscale if needed
640
+ if len(mask_array.shape) == 3:
641
+ if mask_array.shape[2] == 4:
642
+ # RGBA - use alpha channel
643
+ mask_gray = mask_array[:, :, 3]
644
+ else:
645
+ # RGB - convert to grayscale
646
+ mask_gray = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
647
+ else:
648
+ mask_gray = mask_array
649
+
650
+ return Image.fromarray(mask_gray.astype(np.uint8), mode='L')
651
+
652
+ except Exception as e:
653
+ logger.error(f"Failed to extract mask from editor: {e}")
654
+ return None
655
+
656
+ def inpainting_handler(
657
+ self,
658
+ image: Optional[Image.Image],
659
+ mask_editor: Dict[str, Any],
660
+ prompt: str,
661
+ template_dropdown: str,
662
+ conditioning_type: str,
663
+ conditioning_scale: float,
664
+ feather_radius: int,
665
+ guidance_scale: float,
666
+ num_steps: int,
667
+ progress: gr.Progress = gr.Progress()
668
+ ) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image], str]:
669
+ """
670
+ Handle inpainting generation request.
671
+
672
+ Parameters
673
+ ----------
674
+ image : PIL.Image
675
+ Original image to inpaint
676
+ mask_editor : dict
677
+ Mask editor output
678
+ prompt : str
679
+ Text description of desired content
680
+ template_dropdown : str
681
+ Selected template (optional)
682
+ conditioning_type : str
683
+ ControlNet conditioning type
684
+ conditioning_scale : float
685
+ ControlNet influence strength
686
+ feather_radius : int
687
+ Mask feathering radius
688
+ guidance_scale : float
689
+ Guidance scale for generation
690
+ num_steps : int
691
+ Number of inference steps
692
+ progress : gr.Progress
693
+ Progress callback
694
+
695
+ Returns
696
+ -------
697
+ tuple
698
+ (result_image, preview_image, control_image, status_message)
699
+ """
700
+ if image is None:
701
+ return None, None, None, "Please upload an image first"
702
+
703
+ # Extract mask
704
+ mask = self.extract_mask_from_editor(mask_editor)
705
+ if mask is None:
706
+ return None, None, None, "Please draw a mask on the image"
707
+
708
+ # Validate mask
709
+ mask_array = np.array(mask)
710
+ coverage = np.count_nonzero(mask_array > 127) / mask_array.size
711
+ if coverage < 0.01:
712
+ return None, None, None, "Mask too small - please select a larger area"
713
+ if coverage > 0.95:
714
+ return None, None, None, "Mask too large - consider using background generation instead"
715
+
716
+ def progress_callback(msg: str, pct: int):
717
+ progress(pct / 100, desc=msg)
718
+
719
+ try:
720
+ start_time = time.time()
721
+
722
+ # Get template key if selected
723
+ template_key = None
724
+ if template_dropdown:
725
+ template_key = self.inpainting_template_manager.get_template_key_from_display(
726
+ template_dropdown
727
+ )
728
+
729
+ # Execute inpainting through SceneWeaverCore facade
730
+ result = self.sceneweaver.execute_inpainting(
731
+ image=image,
732
+ mask=mask,
733
+ prompt=prompt,
734
+ preview_only=False,
735
+ template_key=template_key,
736
+ conditioning_type=conditioning_type,
737
+ controlnet_conditioning_scale=conditioning_scale,
738
+ feather_radius=feather_radius,
739
+ guidance_scale=guidance_scale,
740
+ num_inference_steps=num_steps,
741
+ progress_callback=progress_callback
742
+ )
743
+
744
+ elapsed = time.time() - start_time
745
+
746
+ if result.get('success'):
747
+ # Store in history
748
+ self.inpainting_history.append({
749
+ 'result': result.get('combined_image'),
750
+ 'prompt': prompt,
751
+ 'time': elapsed
752
+ })
753
+ if len(self.inpainting_history) > 3:
754
+ self.inpainting_history.pop(0)
755
+
756
+ quality_score = result.get('quality_score', 0)
757
+ status = f"Inpainting complete in {elapsed:.1f}s"
758
+ if quality_score > 0:
759
+ status += f" | Quality: {quality_score:.0f}/100"
760
+
761
+ return (
762
+ result.get('combined_image'),
763
+ result.get('preview_image'),
764
+ result.get('control_image'),
765
+ status
766
+ )
767
+ else:
768
+ error_msg = result.get('error', 'Unknown error')
769
+ return None, None, None, f"Inpainting failed: {error_msg}"
770
+
771
+ except Exception as e:
772
+ logger.error(f"Inpainting handler error: {e}")
773
+ logger.error(traceback.format_exc())
774
+ return None, None, None, f"Error: {str(e)}"
775
+
776
+ def create_inpainting_tab(self) -> gr.Tab:
777
+ """
778
+ Create the inpainting tab UI.
779
+
780
+ Returns
781
+ -------
782
+ gr.Tab
783
+ Configured inpainting tab component
784
+ """
785
+ with gr.Tab("Inpainting", elem_id="inpainting-tab") as tab:
786
+ gr.HTML("""
787
+ <div class="inpainting-header">
788
+ <h3 style="display: flex; align-items: center; gap: 10px; margin-bottom: 8px;">
789
+ ControlNet Inpainting
790
+ <span style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
791
+ color: white;
792
+ padding: 3px 10px;
793
+ border-radius: 12px;
794
+ font-size: 0.65em;
795
+ font-weight: 700;
796
+ letter-spacing: 0.5px;
797
+ box-shadow: 0 2px 4px rgba(102, 126, 234, 0.3);">
798
+ BETA
799
+ </span>
800
+ </h3>
801
+ <p style="color: #666; margin-bottom: 12px;">Draw a mask to select the area you want to regenerate</p>
802
+ <div style="background: linear-gradient(to right, #FFF4E6, #FFE8CC);
803
+ border-left: 4px solid #FF9500;
804
+ padding: 12px 15px;
805
+ border-radius: 6px;
806
+ margin-top: 10px;
807
+ box-shadow: 0 2px 4px rgba(255, 149, 0, 0.1);">
808
+ <p style="color: #8B4513; font-size: 0.9em; margin: 0; line-height: 1.5;">
809
+ <strong>⚠️ Beta Feature - Continuously Optimizing</strong><br>
810
+ Results may vary depending on complexity. Use templates and detailed prompts for best results.
811
+ Advanced features (like Add Accessories) may require multiple attempts.
812
+ </p>
813
+ </div>
814
+ </div>
815
+ """)
816
+
817
+ with gr.Row():
818
+ # Left column - Input
819
+ with gr.Column(scale=1):
820
+ # Image upload
821
+ inpaint_image = gr.Image(
822
+ label="Upload Image",
823
+ type="pil",
824
+ height=300
825
+ )
826
+
827
+ # Mask editor
828
+ mask_editor = gr.ImageEditor(
829
+ label="Draw Mask (white = area to inpaint)",
830
+ type="pil",
831
+ height=300,
832
+ brush=gr.Brush(colors=["#FFFFFF"], default_size=20),
833
+ eraser=gr.Eraser(default_size=20),
834
+ layers=False
835
+ )
836
+
837
+ # Template selection
838
+ with gr.Accordion("Inpainting Templates", open=False):
839
+ inpaint_template = gr.Dropdown(
840
+ choices=[""] + self.inpainting_template_manager.get_template_choices_sorted(),
841
+ value="",
842
+ label="Select Template",
843
+ elem_classes=["template-dropdown"]
844
+ )
845
+ template_tips = gr.Markdown("")
846
+
847
+ # Prompt
848
+ inpaint_prompt = gr.Textbox(
849
+ label="Prompt",
850
+ placeholder="Describe what you want to generate in the masked area...",
851
+ lines=2
852
+ )
853
+
854
+ # Right column - Settings and Output
855
+ with gr.Column(scale=1):
856
+ # Settings
857
+ with gr.Accordion("Generation Settings", open=True):
858
+ conditioning_type = gr.Radio(
859
+ choices=["canny", "depth"],
860
+ value="canny",
861
+ label="ControlNet Mode"
862
+ )
863
+
864
+ conditioning_scale = gr.Slider(
865
+ minimum=0.5,
866
+ maximum=1.0,
867
+ value=0.7,
868
+ step=0.05,
869
+ label="ControlNet Strength"
870
+ )
871
+
872
+ feather_radius = gr.Slider(
873
+ minimum=0,
874
+ maximum=20,
875
+ value=8,
876
+ step=1,
877
+ label="Feather Radius (px)"
878
+ )
879
+
880
+ with gr.Accordion("Advanced Settings", open=False):
881
+ inpaint_guidance = gr.Slider(
882
+ minimum=5.0,
883
+ maximum=15.0,
884
+ value=7.5,
885
+ step=0.5,
886
+ label="Guidance Scale"
887
+ )
888
+
889
+ inpaint_steps = gr.Slider(
890
+ minimum=15,
891
+ maximum=50,
892
+ value=25,
893
+ step=5,
894
+ label="Inference Steps"
895
+ )
896
+
897
+ # Generate button
898
+ inpaint_btn = gr.Button(
899
+ "Generate Inpainting",
900
+ variant="primary",
901
+ elem_classes=["primary-button"]
902
+ )
903
+
904
+ # Status
905
+ inpaint_status = gr.Textbox(
906
+ label="Status",
907
+ value="Ready for inpainting",
908
+ interactive=False
909
+ )
910
+
911
+ # Output row
912
+ with gr.Row():
913
+ with gr.Column(scale=1):
914
+ inpaint_result = gr.Image(
915
+ label="Result",
916
+ type="pil",
917
+ height=400
918
+ )
919
+
920
+ with gr.Column(scale=1):
921
+ with gr.Tabs():
922
+ with gr.Tab("Preview"):
923
+ inpaint_preview = gr.Image(
924
+ label="Preview",
925
+ type="pil",
926
+ height=350
927
+ )
928
+ with gr.Tab("Control Image"):
929
+ inpaint_control = gr.Image(
930
+ label="Control Image",
931
+ type="pil",
932
+ height=350
933
+ )
934
+
935
+ # Event handlers
936
+ inpaint_template.change(
937
+ fn=self.apply_inpainting_template,
938
+ inputs=[inpaint_template, inpaint_prompt],
939
+ outputs=[inpaint_prompt, conditioning_scale, feather_radius, conditioning_type]
940
+ )
941
+
942
+ inpaint_template.change(
943
+ fn=lambda x: self._get_template_tips(x),
944
+ inputs=[inpaint_template],
945
+ outputs=[template_tips]
946
+ )
947
+
948
+ # Copy uploaded image to mask editor
949
+ inpaint_image.change(
950
+ fn=lambda x: x,
951
+ inputs=[inpaint_image],
952
+ outputs=[mask_editor]
953
+ )
954
+
955
+ inpaint_btn.click(
956
+ fn=self.inpainting_handler,
957
+ inputs=[
958
+ inpaint_image,
959
+ mask_editor,
960
+ inpaint_prompt,
961
+ inpaint_template,
962
+ conditioning_type,
963
+ conditioning_scale,
964
+ feather_radius,
965
+ inpaint_guidance,
966
+ inpaint_steps
967
+ ],
968
+ outputs=[
969
+ inpaint_result,
970
+ inpaint_preview,
971
+ inpaint_control,
972
+ inpaint_status
973
+ ]
974
+ )
975
+
976
+ return tab
977
+
978
+ def _get_template_tips(self, display_name: str) -> str:
979
+ """Get usage tips for selected template."""
980
+ if not display_name:
981
+ return ""
982
+
983
+ template_key = self.inpainting_template_manager.get_template_key_from_display(display_name)
984
+ if not template_key:
985
+ return ""
986
 
987
+ tips = self.inpainting_template_manager.get_usage_tips(template_key)
988
+ if tips:
989
+ return "**Tips:**\n" + "\n".join(f"- {tip}" for tip in tips)
990
+ return ""