File size: 38,893 Bytes
695fbf0
 
 
 
 
 
 
 
 
9068525
695fbf0
 
 
 
 
42c372b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695fbf0
 
 
 
42c372b
 
 
 
695fbf0
 
 
 
 
 
 
 
 
 
 
42c372b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695fbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42c372b
 
 
 
 
695fbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9068525
 
 
 
 
695fbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
import os
import sys
import math
import torch
import torch.nn.functional as F
import numpy as np

import comfy.model_management as model_management
from .preset_loader import get as load_preset
from ..hard.mg_upscale_module import clear_gpu_and_ram_cache


_DEPTH_INIT = False
_DEPTH_MODEL = None
_DEPTH_PROC = None
_DEPTH_WARNED = False


def _find_custom_nodes_root() -> str | None:
    try:
        here = os.path.abspath(os.path.dirname(__file__))
        cur = here
        for _ in range(6):
            if os.path.basename(cur).lower() == 'custom_nodes':
                return cur
            parent = os.path.dirname(cur)
            if parent == cur:
                break
            cur = parent
    except Exception:
        return None
    return None


def _insert_aux_path():
    try:
        base = _find_custom_nodes_root()
        if base is None:
            # heuristic fallback: go up four levels from .../mod/easy
            base = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..', '..'))
        aux_root = os.path.join(base, 'comfyui_controlnet_aux')
        aux_src = os.path.join(aux_root, 'src')
        for p in (aux_src, aux_root):
            if os.path.isdir(p) and p not in sys.path:
                sys.path.insert(0, p)
    except Exception:
        pass


def _try_init_depth_anything(model_path: str):
    global _DEPTH_INIT, _DEPTH_MODEL, _DEPTH_PROC
    # If already loaded, reuse
    if _DEPTH_MODEL is not None:
        return True
    # Resolve model path: allow 'auto' or directory; scan vendor and comfyui_controlnet_aux ckpts tree
    try:
        def _prefer_order(paths):
            order = ["vitl", "vitb", "vits", "vitg"]
            scored = []
            for p in paths:
                name = os.path.basename(p).lower()
                score = 100
                for i, tag in enumerate(order):
                    if tag in name:
                        score = i
                        break
                scored.append((score, p))
            scored.sort(key=lambda x: x[0])
            return [p for _, p in scored]

        def _resolve_path(mp: str) -> str:
            if isinstance(mp, str) and mp.strip().lower() == "auto":
                mp = ""
            if mp and os.path.isfile(mp):
                return mp
            search_dirs = []
            if mp and os.path.isdir(mp):
                search_dirs.append(mp)
            # vendor depth-anything folder
            base_dir = os.path.join(os.path.dirname(__file__), '..', 'depth-anything')
            search_dirs.append(base_dir)
            # comfyui_controlnet_aux ckpts/depth-anything (and deeper)
            base_custom = _find_custom_nodes_root()
            if base_custom is None:
                base_custom = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..', '..'))
            aux_ckpts = os.path.join(base_custom, 'comfyui_controlnet_aux', 'ckpts', 'depth-anything')
            search_dirs.append(aux_ckpts)
            cand = []
            for d in search_dirs:
                try:
                    if not os.path.isdir(d):
                        continue
                    for root, _dirs, files in os.walk(d):
                        for fn in files:
                            fnl = fn.lower()
                            key = fnl.replace('-', '_')
                            if fnl.endswith('.pth') and ('depth_anything' in key) and ('v2' in key):
                                cand.append(os.path.join(root, fn))
                except Exception:
                    pass
            if cand:
                return _prefer_order(cand)[0]
            return mp

        model_path = _resolve_path(model_path)
    except Exception:
        pass
    # If no local weights resolved, bail out to pseudo-depth to avoid heavy auto-download/memory spikes
    try:
        if not (isinstance(model_path, str) and os.path.isfile(model_path)):
            global _DEPTH_WARNED
            if not _DEPTH_WARNED:
                try:
                    print("[ControlFusion][Depth/Easy] no local Depth Anything v2 weights found; using pseudo-depth fallback.")
                except Exception:
                    pass
                _DEPTH_WARNED = True
            _DEPTH_MODEL = None
            _DEPTH_PROC = False
            return False
    except Exception:
        _DEPTH_MODEL = None
        _DEPTH_PROC = False
        return False
    # Prefer our vendored implementation first
    try:
        from ...vendor.depth_anything_v2.dpt import DepthAnythingV2  # type: ignore
        # Guess config from filename
        fname = os.path.basename(model_path or '')
        cfgs = {
            'depth_anything_v2_vits.pth': dict(encoder='vits', features=64, out_channels=[48,96,192,384]),
            'depth_anything_v2_vitb.pth': dict(encoder='vitb', features=128, out_channels=[96,192,384,768]),
            'depth_anything_v2_vitl.pth': dict(encoder='vitl', features=256, out_channels=[256,512,1024,1024]),
            'depth_anything_v2_vitg.pth': dict(encoder='vitg', features=384, out_channels=[1536,1536,1536,1536]),
            'depth_anything_v2_metric_vkitti_vitl.pth': dict(encoder='vitl', features=256, out_channels=[256,512,1024,1024]),
            'depth_anything_v2_metric_hypersim_vitl.pth': dict(encoder='vitl', features=256, out_channels=[256,512,1024,1024]),
        }
        # fallback to vitl if unknown
        cfg = cfgs.get(fname, cfgs['depth_anything_v2_vitl.pth'])
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        m = DepthAnythingV2(**cfg)
        sd = torch.load(model_path, map_location='cpu')
        m.load_state_dict(sd)
        _DEPTH_MODEL = m.to(device).eval()
        _DEPTH_PROC = True
        return True
    except Exception:
        # Try local checkout of comfyui_controlnet_aux (if present)
        _insert_aux_path()
        try:
            from custom_controlnet_aux.depth_anything_v2.dpt import DepthAnythingV2  # type: ignore
            fname = os.path.basename(model_path or '')
            cfgs = {
                'depth_anything_v2_vits.pth': dict(encoder='vits', features=64, out_channels=[48,96,192,384]),
                'depth_anything_v2_vitb.pth': dict(encoder='vitb', features=128, out_channels=[96,192,384,768]),
                'depth_anything_v2_vitl.pth': dict(encoder='vitl', features=256, out_channels=[256,512,1024,1024]),
                'depth_anything_v2_vitg.pth': dict(encoder='vitg', features=384, out_channels=[1536,1536,1536,1536]),
                'depth_anything_v2_metric_vkitti_vitl.pth': dict(encoder='vitl', features=256, out_channels=[256,512,1024,1024]),
                'depth_anything_v2_metric_hypersim_vitl.pth': dict(encoder='vitl', features=256, out_channels=[256,512,1024,1024]),
            }
            cfg = cfgs.get(fname, cfgs['depth_anything_v2_vitl.pth'])
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            m = DepthAnythingV2(**cfg)
            sd = torch.load(model_path, map_location='cpu')
            m.load_state_dict(sd)
            _DEPTH_MODEL = m.to(device).eval()
            _DEPTH_PROC = True
            return True
        except Exception:
            # Fallback: packaged auxiliary API
            try:
                from controlnet_aux.depth_anything import DepthAnythingDetector, DepthAnythingV2  # type: ignore
                device = 'cuda' if torch.cuda.is_available() else 'cpu'
                _DEPTH_MODEL = DepthAnythingV2(model_path=model_path, device=device)
                _DEPTH_PROC = True
                return True
            except Exception:
                _DEPTH_MODEL = None
                _DEPTH_PROC = False
                return False


def _build_depth_map(image_bhwc: torch.Tensor, res: int, model_path: str, hires_mode: bool = True) -> torch.Tensor:
    B, H, W, C = image_bhwc.shape
    dev = image_bhwc.device
    dtype = image_bhwc.dtype
    # Choose target min-side for processing. In hires mode we allow higher caps and keep aspect.
    # DepthAnything v2 can be memory-hungry on large inputs; cap min-side at 1024
    cap = 1024
    target = int(max(16, min(cap, res)))
    if _try_init_depth_anything(model_path):
        try:
            # to CPU uint8
            img = image_bhwc.detach().to('cpu')
            x = img[0].movedim(-1, 0).unsqueeze(0)
            # keep aspect ratio: scale so that min(H,W) == target
            _, Cc, Ht, Wt = x.shape
            min_side = max(1, min(Ht, Wt))
            scale = float(target) / float(min_side)
            out_h = max(1, int(round(Ht * scale)))
            out_w = max(1, int(round(Wt * scale)))
            x = F.interpolate(x, size=(out_h, out_w), mode='bilinear', align_corners=False)
            # make channels-last and ensure contiguous layout for OpenCV
            arr = (x[0].movedim(0, -1).contiguous().numpy() * 255.0).astype('uint8')
            # Prefer direct DepthAnythingV2 inference if model has infer_image
            if hasattr(_DEPTH_MODEL, 'infer_image'):
                import cv2
                # Drive input_size from desired depth resolution (min side), let DA keep aspect
                input_sz = int(max(224, min(cap, res)))
                depth = _DEPTH_MODEL.infer_image(cv2.cvtColor(arr, cv2.COLOR_RGB2BGR), input_size=input_sz, max_depth=20.0)
                d = np.asarray(depth, dtype=np.float32)
                # Normalize DepthAnythingV2 output (0..max_depth) to 0..1
                d = d / 20.0
            else:
                depth = _DEPTH_MODEL(arr)
                d = np.asarray(depth, dtype=np.float32)
            if d.max() > 1.0:
                d = d / 255.0
            d = torch.from_numpy(d)[None, None]  # 1,1,h,w
            d = F.interpolate(d, size=(H, W), mode='bilinear', align_corners=False)
            d = d[0, 0].to(device=dev, dtype=dtype)
            d = d.clamp(0, 1)
            return d
        except Exception:
            pass
    # Fallback depth: constant mid-gray (0.5) for lightweight, uniform guidance
    try:
        return torch.full((H, W), 0.5, device=dev, dtype=dtype)
    except Exception:
        return torch.full((H, W), 0.5, device='cpu', dtype=torch.float32).to(device=dev, dtype=dtype)


def _pyracanny(image_bhwc: torch.Tensor,
               low: int,
               high: int,
               res: int,
               thin_iter: int = 0,
               edge_boost: float = 0.0,
               smart_tune: bool = False,
               smart_boost: float = 0.2,
               preserve_aspect: bool = True) -> torch.Tensor:
    try:
        import cv2
    except Exception:
        # Fallback: simple Sobel magnitude
        x = image_bhwc.movedim(-1, 1)
        xg = x.mean(dim=1, keepdim=True)
        gx = F.conv2d(xg, torch.tensor([[[-1, 0, 1],[-2,0,2],[-1,0,1]]], dtype=x.dtype, device=x.device).unsqueeze(1), padding=1)
        gy = F.conv2d(xg, torch.tensor([[[-1,-2,-1],[0,0,0],[1,2,1]]], dtype=x.dtype, device=x.device).unsqueeze(1), padding=1)
        mag = torch.sqrt(gx*gx + gy*gy)
        mag = (mag - mag.amin())/(mag.amax()-mag.amin()+1e-6)
        return mag[0,0].clamp(0,1)
    B,H,W,C = image_bhwc.shape
    img = (image_bhwc.detach().to('cpu')[0].contiguous().numpy()*255.0).astype('uint8')
    cap = 4096
    target = int(max(64, min(cap, res)))
    if preserve_aspect:
        scale = float(target) / float(max(1, min(H, W)))
        out_h = max(8, int(round(H * scale)))
        out_w = max(8, int(round(W * scale)))
        img_res = cv2.resize(img, (out_w, out_h), interpolation=cv2.INTER_LINEAR)
    else:
        img_res = cv2.resize(img, (target, target), interpolation=cv2.INTER_LINEAR)
    gray = cv2.cvtColor(img_res, cv2.COLOR_RGB2GRAY)
    pyr_scales = [1.0, 0.5, 0.25]
    acc = None
    for s in pyr_scales:
        if preserve_aspect:
            sz = (max(8, int(round(img_res.shape[1]*s))), max(8, int(round(img_res.shape[0]*s))))
        else:
            sz = (max(8, int(target*s)), max(8, int(target*s)))
        g = cv2.resize(gray, sz, interpolation=cv2.INTER_AREA)
        g = cv2.GaussianBlur(g, (5,5), 0)
        e = cv2.Canny(g, threshold1=int(low*s), threshold2=int(high*s))
        e = cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR)
        e = (e.astype(np.float32)/255.0)
        acc = e if acc is None else np.maximum(acc, e)
    # Estimate density and sharpness for smart tuning
    edensity_pre = None
    try:
        edensity_pre = float(np.mean(acc)) if acc is not None else None
    except Exception:
        edensity_pre = None
    lap_var = None
    try:
        g32 = gray.astype(np.float32) / 255.0
        lap = cv2.Laplacian(g32, cv2.CV_32F)
        lap_var = float(lap.var())
    except Exception:
        lap_var = None

    # optional thinning
    try:
        thin_iter_eff = int(thin_iter)
        if smart_tune:
            # simple heuristic: more thinning on high res and dense edges
            auto = 0
            if target >= 1024:
                auto += 1
            if target >= 1400:
                auto += 1
            if edensity_pre is not None and edensity_pre > 0.12:
                auto += 1
            if edensity_pre is not None and edensity_pre < 0.05:
                auto = max(0, auto - 1)
            thin_iter_eff = max(thin_iter_eff, min(3, auto))
        if thin_iter_eff > 0:
            import cv2
            if hasattr(cv2, 'ximgproc') and hasattr(cv2.ximgproc, 'thinning'):
                th = acc.copy()
                th = (th*255).astype('uint8')
                th = cv2.ximgproc.thinning(th)
                acc = th.astype(np.float32)/255.0
            else:
                # simple erosion-based thinning approximation
                kernel = np.ones((3,3), np.uint8)
                t = (acc*255).astype('uint8')
                for _ in range(int(thin_iter_eff)):
                    t = cv2.erode(t, kernel, iterations=1)
                acc = t.astype(np.float32)/255.0
    except Exception:
        pass
    # optional edge boost (unsharp on edge map)
    # We fix a gentle boost for micro‑contrast; smart_tune may nudge it slightly
    boost_eff = 0.10
    if smart_tune:
        try:
            lv = 0.0 if lap_var is None else max(0.0, min(1.0, lap_var / 2.0))
            dens = 0.0 if edensity_pre is None else float(max(0.0, min(1.0, edensity_pre)))
            boost_eff = max(0.05, min(0.20, boost_eff + (1.0 - dens) * 0.05 + (1.0 - lv) * 0.02))
        except Exception:
            pass
    if boost_eff and boost_eff != 0.0:
        try:
            import cv2
            blur = cv2.GaussianBlur(acc, (0,0), sigmaX=1.0)
            acc = np.clip(acc + float(boost_eff)*(acc - blur), 0.0, 1.0)
        except Exception:
            pass
    ed = torch.from_numpy(acc).to(device=image_bhwc.device, dtype=image_bhwc.dtype)
    return ed.clamp(0,1)


def _blend(depth: torch.Tensor, edges: torch.Tensor, mode: str, factor: float) -> torch.Tensor:
    depth = depth.clamp(0,1)
    edges = edges.clamp(0,1)
    if mode == 'max':
        return torch.maximum(depth, edges)
    if mode == 'edge_over_depth':
        # edges override depth (edge=1) while preserving depth elsewhere
        return (depth * (1.0 - edges) + edges).clamp(0,1)
    # normal
    f = float(max(0.0, min(1.0, factor)))
    return (depth*(1.0-f) + edges*f).clamp(0,1)


def _apply_controlnet_separate(positive, negative, control_net, image_bhwc: torch.Tensor,
                               strength_pos: float, strength_neg: float,
                               start_percent: float, end_percent: float, vae=None,
                               apply_to_uncond: bool = False,
                               stack_prev_control: bool = False):
    control_hint = image_bhwc.movedim(-1,1)
    out_pos = []
    out_neg = []
    # POS
    for t in positive:
        d = t[1].copy()
        prev = d.get('control', None) if stack_prev_control else None
        c_net = control_net.copy().set_cond_hint(control_hint, float(strength_pos), (start_percent, end_percent), vae=vae, extra_concat=[])
        c_net.set_previous_controlnet(prev)
        d['control'] = c_net
        d['control_apply_to_uncond'] = bool(apply_to_uncond)
        out_pos.append([t[0], d])
    # NEG
    for t in negative:
        d = t[1].copy()
        prev = d.get('control', None) if stack_prev_control else None
        c_net = control_net.copy().set_cond_hint(control_hint, float(strength_neg), (start_percent, end_percent), vae=vae, extra_concat=[])
        c_net.set_previous_controlnet(prev)
        d['control'] = c_net
        d['control_apply_to_uncond'] = bool(apply_to_uncond)
        out_neg.append([t[0], d])
    return out_pos, out_neg


class MG_ControlFusion:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "preset_step": (["Custom", "Step 2", "Step 3", "Step 4"], {"default": "Custom", "tooltip": "Apply preset values from pressets/mg_controlfusion.cfg. UI values override."}),
                "image": ("IMAGE", {"tooltip": "Input RGB image (B,H,W,3) in 0..1."}),
                "positive": ("CONDITIONING", {"tooltip": "Positive conditioning to apply ControlNet to."}),
                "negative": ("CONDITIONING", {"tooltip": "Negative conditioning to apply ControlNet to."}),
                "control_net": ("CONTROL_NET", {"tooltip": "ControlNet module receiving the fused mask as hint."}),
                "vae": ("VAE", {"tooltip": "VAE used by ControlNet when encoding the hint."}),
            },
            "optional": {
                "enable_depth": ("BOOLEAN", {"default": True, "tooltip": "Enable depth map fusion (Depth Anything v2 if available)."}),
                "depth_model_path": ("STRING", {"default": os.path.join(os.path.dirname(os.path.dirname(__file__)), 'MagicNodes','depth-anything','depth_anything_v2_vitl.pth') if False else os.path.join(os.path.dirname(__file__), '..','depth-anything','depth_anything_v2_vitl.pth'), "tooltip": "Path to Depth Anything v2 .pth weights (vits/vitb/vitl/vitg)."}),
                "depth_resolution": ("INT", {"default": 768, "min": 64, "max": 1024, "step": 64, "tooltip": "Depth min-side resolution (cap 1024). In Hi‑Res mode drives DepthAnything input_size."}),
                "enable_pyra": ("BOOLEAN", {"default": True, "tooltip": "Enable PyraCanny edge detector."}),
                "pyra_low": ("INT", {"default": 109, "min": 0, "max": 255, "tooltip": "Canny low threshold (0..255)."}),
                "pyra_high": ("INT", {"default": 147, "min": 0, "max": 255, "tooltip": "Canny high threshold (0..255)."}),
                "pyra_resolution": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 64, "tooltip": "Working resolution for edges (min side, keeps aspect)."}),
                "edge_thin_iter": ("INT", {"default": 0, "min": 0, "max": 10, "step": 1, "tooltip": "Thinning iterations for edges (skeletonize). 0 = off."}),
                "edge_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Opacity for edges before blending (0..1)."}),
                "edge_boost": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Deprecated: internal boost fixed (~0.10); use edge_alpha instead."}),
                "smart_tune": ("BOOLEAN", {"default": False, "tooltip": "Auto-adjust thinning/boost from image edge density and sharpness."}),
                "smart_boost": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Scale for auto edge boost when Smart Tune is on."}),
                "blend_mode": (["normal","max","edge_over_depth"], {"default": "normal", "tooltip": "Depth+edges merge: normal (mix), max (strongest), edge_over_depth (edges overlay)."}),
                "blend_factor": ("FLOAT", {"default": 0.02, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Blend strength for edges into depth (depends on mode)."}),
                "strength_pos": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "ControlNet strength for positive branch."}),
                "strength_neg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "ControlNet strength for negative branch."}),
                "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Start percentage along the sampling schedule."}),
                "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "End percentage along the sampling schedule."}),
                "preview_res": ("INT", {"default": 1024, "min": 256, "max": 2048, "step": 64, "tooltip": "Preview minimum side (keeps aspect ratio)."}),
                "mask_brightness": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Preview brightness multiplier (visualization only)."}),
                "preview_show_strength": ("BOOLEAN", {"default": True, "tooltip": "Multiply preview by ControlNet strength for visualization."}),
                "preview_strength_branch": (["positive","negative","max","avg"], {"default": "max", "tooltip": "Which strength to reflect in preview (display only)."}),
                "hires_mask_auto": ("BOOLEAN", {"default": True, "tooltip": "High‑res mask: keep aspect ratio, scale by minimal side for depth/edges, and drive DepthAnything with your depth_resolution (no 2K cap)."}),
                "apply_to_uncond": ("BOOLEAN", {"default": False, "tooltip": "Apply ControlNet hint to the unconditional branch as well (stronger global hold on very large images)."}),
                "stack_prev_control": ("BOOLEAN", {"default": False, "tooltip": "Chain with any previously attached ControlNet in the conditioning (advanced). Off = replace to avoid memory bloat."}),
                # Split apply: chain Depth and Edges with separate schedules/strengths (fixed order: depth -> edges)
                "split_apply": ("BOOLEAN", {"default": False, "tooltip": "Apply Depth and Edges as two chained ControlNets (fixed order: depth then edges)."}),
                "edge_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Edges start percent (when split is enabled)."}),
                "edge_end_percent": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Edges end percent (when split is enabled)."}),
                "depth_start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Depth start percent (when split is enabled)."}),
                "depth_end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Depth end percent (when split is enabled)."}),
                "edge_strength_mul": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 3.0, "step": 0.01, "tooltip": "Multiply global strength for Edges when split is enabled."}),
                "depth_strength_mul": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 3.0, "step": 0.01, "tooltip": "Multiply global strength for Depth when split is enabled."}),
                # Extra edge controls (bottom)
                "edge_width": ("FLOAT", {"default": 0.0, "min": -0.5, "max": 1.5, "step": 0.05, "tooltip": "Edge thickness adjust: negative thins, positive thickens."}),
                "edge_smooth": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.05, "tooltip": "Small smooth on edges to reduce pixelation (0..1)."}),
                "edge_single_line": ("BOOLEAN", {"default": False, "tooltip": "Try to collapse double outlines into a single centerline."}),
                "edge_single_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Strength of single-line collapse (0..1). 0 = off, 1 = strong."}),
                "edge_depth_gate": ("BOOLEAN", {"default": False, "tooltip": "Weigh edges by depth so distant lines are fainter."}),
                "edge_depth_gamma": ("FLOAT", {"default": 1.5, "min": 0.2, "max": 4.0, "step": 0.1, "tooltip": "Gamma for depth gating: edges *= (1−depth)^gamma."}),
            }
        }

    RETURN_TYPES = ("CONDITIONING","CONDITIONING","IMAGE")
    RETURN_NAMES = ("positive","negative","Mask_Preview")
    FUNCTION = "apply"
    CATEGORY = "MagicNodes"

    def apply(self, image, positive, negative, control_net, vae,
              enable_depth=True, depth_model_path="", depth_resolution=1024,
              enable_pyra=True, pyra_low=109, pyra_high=147, pyra_resolution=1024,
              edge_thin_iter=0, edge_alpha=1.0, edge_boost=0.0,
              smart_tune=False, smart_boost=0.2,
              blend_mode="normal", blend_factor=0.02,
              strength_pos=1.0, strength_neg=1.0, start_percent=0.0, end_percent=1.0,
              preview_res=1024, mask_brightness=1.0,
              preview_show_strength=True, preview_strength_branch="max",
              hires_mask_auto=True, apply_to_uncond=False, stack_prev_control=False,
              edge_width=0.0, edge_smooth=0.0, edge_single_line=False, edge_single_strength=0.0,
              edge_depth_gate=False, edge_depth_gamma=1.5,
              split_apply=False, edge_start_percent=0.0, edge_end_percent=0.6,
              depth_start_percent=0.0, depth_end_percent=1.0,
              edge_strength_mul=1.0, depth_strength_mul=1.0,
              preset_step="Step 2", custom_override=False):

        # Merge preset values (if selected) with UI values; UI overrides preset
        try:
            if isinstance(preset_step, str) and preset_step.lower() != "custom":
                p = load_preset("mg_controlfusion", preset_step)
            else:
                p = {}
        except Exception:
            p = {}
        def pv(name, cur):
            return p.get(name, cur)
        enable_depth = bool(pv('enable_depth', enable_depth))
        depth_model_path = str(pv('depth_model_path', depth_model_path))
        depth_resolution = int(pv('depth_resolution', depth_resolution))
        enable_pyra = bool(pv('enable_pyra', enable_pyra))
        pyra_low = int(pv('pyra_low', pyra_low))
        pyra_high = int(pv('pyra_high', pyra_high))
        pyra_resolution = int(pv('pyra_resolution', pyra_resolution))
        edge_thin_iter = int(pv('edge_thin_iter', edge_thin_iter))
        edge_alpha = float(pv('edge_alpha', edge_alpha))
        edge_boost = float(pv('edge_boost', edge_boost))
        smart_tune = bool(pv('smart_tune', smart_tune))
        smart_boost = float(pv('smart_boost', smart_boost))
        blend_mode = str(pv('blend_mode', blend_mode))
        blend_factor = float(pv('blend_factor', blend_factor))
        strength_pos = float(pv('strength_pos', strength_pos))
        strength_neg = float(pv('strength_neg', strength_neg))
        start_percent = float(pv('start_percent', start_percent))
        end_percent = float(pv('end_percent', end_percent))
        preview_res = int(pv('preview_res', preview_res))
        mask_brightness = float(pv('mask_brightness', mask_brightness))
        preview_show_strength = bool(pv('preview_show_strength', preview_show_strength))
        preview_strength_branch = str(pv('preview_strength_branch', preview_strength_branch))
        hires_mask_auto = bool(pv('hires_mask_auto', hires_mask_auto))
        apply_to_uncond = bool(pv('apply_to_uncond', apply_to_uncond))
        stack_prev_control = bool(pv('stack_prev_control', stack_prev_control))
        split_apply = bool(pv('split_apply', split_apply))
        edge_start_percent = float(pv('edge_start_percent', edge_start_percent))
        edge_end_percent = float(pv('edge_end_percent', edge_end_percent))
        depth_start_percent = float(pv('depth_start_percent', depth_start_percent))
        depth_end_percent = float(pv('depth_end_percent', depth_end_percent))
        edge_strength_mul = float(pv('edge_strength_mul', edge_strength_mul))
        depth_strength_mul = float(pv('depth_strength_mul', depth_strength_mul))
        edge_width = float(pv('edge_width', edge_width))
        edge_smooth = float(pv('edge_smooth', edge_smooth))
        edge_single_line = bool(pv('edge_single_line', edge_single_line))
        edge_single_strength = float(pv('edge_single_strength', edge_single_strength))
        edge_depth_gate = bool(pv('edge_depth_gate', edge_depth_gate))
        edge_depth_gamma = float(pv('edge_depth_gamma', edge_depth_gamma))

        dev = image.device
        dtype = image.dtype
        B,H,W,C = image.shape
        # Build depth/edges
        depth = None
        edges = None
        if enable_depth:
            model_path = depth_model_path or os.path.join(os.path.dirname(__file__), '..','depth-anything','depth_anything_v2_vitl.pth')
            depth = _build_depth_map(image, int(depth_resolution), model_path, bool(hires_mask_auto))
        if enable_pyra:
            edges = _pyracanny(image,
                               int(pyra_low), int(pyra_high), int(pyra_resolution),
                               int(edge_thin_iter), float(edge_boost),
                               bool(smart_tune), float(smart_boost), bool(hires_mask_auto))
        if depth is None and edges is None:
            # Nothing to do: return inputs and zero preview
            prev = torch.zeros((B, max(H,1), max(W,1), 3), device=dev, dtype=dtype)
            return positive, negative, prev

        if depth is None:
            depth = torch.zeros_like(edges)
        if edges is None:
            edges = torch.zeros_like(depth)

        # Edge post-process: width/single-line/smooth
        def _edges_post(acc_t: torch.Tensor) -> torch.Tensor:
            try:
                import cv2, numpy as _np
                acc = acc_t.detach().to('cpu').numpy()
                img = (acc*255.0).astype(_np.uint8)
                k = _np.ones((3,3), _np.uint8)
                # Adjust thickness
                w = float(edge_width)
                if abs(w) > 1e-6:
                    it = int(abs(w))
                    frac = abs(w) - it
                    op = cv2.dilate if w > 0 else cv2.erode
                    y = img.copy()
                    for _ in range(max(0, it)):
                        y = op(y, k, iterations=1)
                    if frac > 1e-6:
                        y2 = op(y, k, iterations=1)
                        y = ((1.0-frac)*y.astype(_np.float32) + frac*y2.astype(_np.float32)).astype(_np.uint8)
                    img = y
                # Collapse double lines to single centerline
                if bool(edge_single_line) and float(edge_single_strength) > 1e-6:
                    try:
                        s = float(edge_single_strength)
                        close = cv2.morphologyEx(img, cv2.MORPH_CLOSE, k, iterations=1)
                        if hasattr(cv2, 'ximgproc') and hasattr(cv2.ximgproc, 'thinning'):
                            sk = cv2.ximgproc.thinning(close)
                        else:
                            # limited-iteration morphological skeletonization
                            iters = max(1, int(round(2 + 6*s)))
                            sk = _np.zeros_like(close)
                            src = close.copy()
                            elem = cv2.getStructuringElement(cv2.MORPH_CROSS, (3,3))
                            for _ in range(iters):
                                er = cv2.erode(src, elem, iterations=1)
                                op = cv2.morphologyEx(er, cv2.MORPH_OPEN, elem)
                                tmp = cv2.subtract(er, op)
                                sk = cv2.bitwise_or(sk, tmp)
                                src = er
                                if not _np.any(src):
                                    break
                        # Blend skeleton back with original according to strength
                        img = ((_np.float32(1.0 - s) * img.astype(_np.float32)) + (_np.float32(s) * sk.astype(_np.float32))).astype(_np.uint8)
                    except Exception:
                        pass
                # Smooth
                if float(edge_smooth) > 1e-6:
                    sigma = max(0.1, min(2.0, float(edge_smooth) * 1.2))
                    img = cv2.GaussianBlur(img, (0,0), sigmaX=sigma)
                out = torch.from_numpy((img.astype(_np.float32)/255.0)).to(device=acc_t.device, dtype=acc_t.dtype)
                return out.clamp(0,1)
            except Exception:
                # Torch fallback: light blur-only
                if float(edge_smooth) > 1e-6:
                    s = max(1, int(round(float(edge_smooth)*2)))
                    return F.avg_pool2d(acc_t.unsqueeze(0).unsqueeze(0), kernel_size=2*s+1, stride=1, padding=s)[0,0].clamp(0,1)
                return acc_t

        edges = _edges_post(edges)

        # Depth gating of edges
        if bool(edge_depth_gate):
            # Inverted gating per feedback: use depth^gamma (nearer = stronger if depth is larger)
            g = (depth.clamp(0,1)) ** float(edge_depth_gamma)
            edges = (edges * g).clamp(0,1)

        # Apply edge alpha before blending
        edges = (edges * float(edge_alpha)).clamp(0,1)

        fused = _blend(depth, edges, str(blend_mode), float(blend_factor))
        # Apply as split (Edges then Depth) or single fused hint
        if bool(split_apply):
            # Fixed order for determinism: Depth first, then Edges
            hint_edges = edges.unsqueeze(-1).repeat(1,1,1,3)
            hint_depth = depth.unsqueeze(-1).repeat(1,1,1,3)
            # Depth first
            pos_mid, neg_mid = _apply_controlnet_separate(
                positive, negative, control_net, hint_depth,
                float(strength_pos) * float(depth_strength_mul),
                float(strength_neg) * float(depth_strength_mul),
                float(depth_start_percent), float(depth_end_percent), vae,
                bool(apply_to_uncond), True
            )
            # Then edges
            pos_out, neg_out = _apply_controlnet_separate(
                pos_mid, neg_mid, control_net, hint_edges,
                float(strength_pos) * float(edge_strength_mul),
                float(strength_neg) * float(edge_strength_mul),
                float(edge_start_percent), float(edge_end_percent), vae,
                bool(apply_to_uncond), True
            )
        else:
            hint = fused.unsqueeze(-1).repeat(1,1,1,3)
            pos_out, neg_out = _apply_controlnet_separate(
                positive, negative, control_net, hint,
                float(strength_pos), float(strength_neg),
                float(start_percent), float(end_percent), vae,
                bool(apply_to_uncond), bool(stack_prev_control)
            )
        # Build preview: keep aspect ratio, set minimal side
        prev_res = int(max(256, min(2048, preview_res)))
        scale = prev_res / float(min(H, W))
        out_h = max(1, int(round(H * scale)))
        out_w = max(1, int(round(W * scale)))
        prev = F.interpolate(fused.unsqueeze(0).unsqueeze(0), size=(out_h, out_w), mode='bilinear', align_corners=False)[0,0]
        # Optionally reflect ControlNet strength in preview (display only)
        if bool(preview_show_strength):
            br = str(preview_strength_branch)
            sp = float(strength_pos)
            sn = float(strength_neg)
            if br == 'negative':
                s_vis = sn
            elif br == 'max':
                s_vis = max(sp, sn)
            elif br == 'avg':
                s_vis = 0.5 * (sp + sn)
            else:
                s_vis = sp
            # clamp for display range
            s_vis = max(0.0, min(1.0, s_vis))
            prev = prev * s_vis
        # Apply visualization brightness only for preview
        prev = (prev * float(mask_brightness)).clamp(0.0, 1.0)
        prev = prev.unsqueeze(-1).repeat(1,1,3).to(device=dev, dtype=dtype).unsqueeze(0)
        # Best-effort cleanup of caches to avoid sticky VRAM/RAM after node finishes
        try:
            clear_gpu_and_ram_cache()
        except Exception:
            pass
        return (pos_out, neg_out, prev)


# === Easy UI wrapper: simplified controls + Step/Custom preset logic ===
class MG_ControlFusionEasyUI(MG_ControlFusion):
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                # Step preset first for emphasis
                "preset_step": (["Step 2", "Step 3", "Step 4"], {"default": "Step 2", "tooltip": "Choose the Step preset. Toggle Custom below to apply UI values; otherwise Step preset values are used."}),
                # Custom toggle: when enabled, UI values override the Step for visible controls
                "custom": ("BOOLEAN", {"default": False, "tooltip": "Custom override: when enabled, your UI values override the selected Step for visible controls; hidden parameters still come from the Step preset."}),
                # Connectors
                "image": ("IMAGE", {"tooltip": "Input RGB image (B,H,W,3) in 0..1."}),
                "positive": ("CONDITIONING", {"tooltip": "Positive conditioning to apply ControlNet to."}),
                "negative": ("CONDITIONING", {"tooltip": "Negative conditioning to apply ControlNet to."}),
                "control_net": ("CONTROL_NET", {"tooltip": "ControlNet module receiving the fused mask as hint."}),
                "vae": ("VAE", {"tooltip": "VAE used by ControlNet when encoding the hint."}),
                # Minimal surface controls
                "enable_depth": ("BOOLEAN", {"default": True, "tooltip": "Enable depth map fusion (Depth Anything v2 if available)."}),
                "enable_pyra": ("BOOLEAN", {"default": True, "tooltip": "Enable PyraCanny edge detector."}),
                "edge_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Opacity for edges before blending (0..1)."}),
                "blend_factor": ("FLOAT", {"default": 0.02, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Blend strength for edges into depth (depends on mode)."}),
            },
            "optional": {}
        }

    RETURN_TYPES = ("CONDITIONING","CONDITIONING","IMAGE")
    RETURN_NAMES = ("positive","negative","Mask_Preview")
    FUNCTION = "apply_easy"

    def apply_easy(self, preset_step, custom, image, positive, negative, control_net, vae,
                   enable_depth=True, enable_pyra=True, edge_alpha=1.0, blend_factor=0.02):
        # Use Step preset; if custom is True, allow visible UI values to override inside base impl via custom_override
        return super().apply(
            image=image, positive=positive, negative=negative, control_net=control_net, vae=vae,
            enable_depth=bool(enable_depth), enable_pyra=bool(enable_pyra), edge_alpha=float(edge_alpha), blend_factor=float(blend_factor),
            preset_step=str(preset_step) if not bool(custom) else "Custom",
            custom_override=bool(custom),
        )