DZRobo commited on
Commit
7dd5fc1
·
1 Parent(s): df38cb0

Add configurable CFG scheduling to guidance wrapper

Browse files

Introduces per-step CFG scheduling options (cosine, warmup, U-shape) to _wrap_model_with_guidance and exposes related parameters in ComfyAdaptiveDetailEnhancer25. Updates mg_cade25.cfg presets to utilize new scheduling controls and adjusts detail and denoise settings for improved flexibility.

mod/easy/mg_cade25_easy.py CHANGED
@@ -1320,7 +1320,9 @@ def _fdg_split_three(delta: torch.Tensor,
1320
  def _wrap_model_with_guidance(model, guidance_mode: str, rescale_multiplier: float, momentum_beta: float, cfg_curve: float, perp_damp: float, use_zero_init: bool=False, zero_init_steps: int=0, fdg_low: float = 0.6, fdg_high: float = 1.3, fdg_sigma: float = 1.0, ze_zero_steps: int = 0, ze_adaptive: bool = False, ze_r_switch_hi: float = 0.6, ze_r_switch_lo: float = 0.45, fdg_low_adaptive: bool = False, fdg_low_min: float = 0.45, fdg_low_max: float = 0.7, fdg_ema_beta: float = 0.8, use_local_mask: bool = False, mask_inside: float = 1.0, mask_outside: float = 1.0,
1321
  midfreq_enable: bool = False, midfreq_gain: float = 0.0, midfreq_sigma_lo: float = 0.8, midfreq_sigma_hi: float = 2.0,
1322
  mahiro_plus_enable: bool = False, mahiro_plus_strength: float = 0.5,
1323
- eps_scale_enable: bool = False, eps_scale: float = 0.0):
 
 
1324
 
1325
  """Clone model and attach a cfg mixing function implementing RescaleCFG/FDG, CFGZero*/FD, or hybrid ZeResFDG.
1326
  guidance_mode: 'default' | 'RescaleCFG' | 'RescaleFDG' | 'CFGZero*' | 'CFGZeroFD' | 'ZeResFDG'
@@ -1495,6 +1497,7 @@ def _wrap_model_with_guidance(model, guidance_mode: str, rescale_multiplier: flo
1495
  cond = uncond + delta
1496
 
1497
  cond_scale_eff = cond_scale
 
1498
  if cfg_curve > 0.0 and (sigma is not None):
1499
  s = sigma
1500
  if s.ndim > 1:
@@ -1513,10 +1516,56 @@ def _wrap_model_with_guidance(model, guidance_mode: str, rescale_multiplier: flo
1513
  t = t.clamp(0.0, 1.0)
1514
  k = 6.0 * float(cfg_curve)
1515
  s_curve = torch.tanh((t - 0.5) * k)
1516
- gain = 1.0 + 0.15 * float(cfg_curve) * s_curve
1517
- if gain.ndim > 0:
1518
- gain = gain.mean().item()
1519
- cond_scale_eff = cond_scale * float(gain)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1520
 
1521
  # Epsilon scaling (exposure bias correction): early steps get multiplier closer to (1 + eps_scale)
1522
  eps_mult = 1.0
@@ -2145,6 +2194,12 @@ class ComfyAdaptiveDetailEnhancer25:
2145
  clipseg_blend = str(pv("clipseg_blend", clipseg_blend))
2146
  clipseg_ref_gate = bool(pv("clipseg_ref_gate", clipseg_ref_gate))
2147
  clipseg_ref_threshold = float(pv("clipseg_ref_threshold", clipseg_ref_threshold))
 
 
 
 
 
 
2148
  # Latent buffer (internal-only; configured via presets)
2149
  latent_buffer = bool(pv("latent_buffer", True))
2150
  lb_inject = float(pv("lb_inject", 0.25))
@@ -2386,7 +2441,9 @@ class ComfyAdaptiveDetailEnhancer25:
2386
  fdg_low_adaptive=bool(fdg_low_adaptive), fdg_low_min=float(fdg_low_min), fdg_low_max=float(fdg_low_max), fdg_ema_beta=float(fdg_ema_beta),
2387
  use_local_mask=bool(onnx_local_guidance), mask_inside=float(onnx_mask_inside), mask_outside=float(onnx_mask_outside),
2388
  mahiro_plus_enable=bool(muse_blend), mahiro_plus_strength=float(muse_blend_strength),
2389
- eps_scale_enable=bool(eps_scale_enable), eps_scale=float(eps_scale)
 
 
2390
  )
2391
  # check once more right before the loop starts
2392
  model_management.throw_exception_if_processing_interrupted()
 
1320
  def _wrap_model_with_guidance(model, guidance_mode: str, rescale_multiplier: float, momentum_beta: float, cfg_curve: float, perp_damp: float, use_zero_init: bool=False, zero_init_steps: int=0, fdg_low: float = 0.6, fdg_high: float = 1.3, fdg_sigma: float = 1.0, ze_zero_steps: int = 0, ze_adaptive: bool = False, ze_r_switch_hi: float = 0.6, ze_r_switch_lo: float = 0.45, fdg_low_adaptive: bool = False, fdg_low_min: float = 0.45, fdg_low_max: float = 0.7, fdg_ema_beta: float = 0.8, use_local_mask: bool = False, mask_inside: float = 1.0, mask_outside: float = 1.0,
1321
  midfreq_enable: bool = False, midfreq_gain: float = 0.0, midfreq_sigma_lo: float = 0.8, midfreq_sigma_hi: float = 2.0,
1322
  mahiro_plus_enable: bool = False, mahiro_plus_strength: float = 0.5,
1323
+ eps_scale_enable: bool = False, eps_scale: float = 0.0,
1324
+ cfg_sched_type: str = "off", cfg_sched_min: float = 0.0, cfg_sched_max: float = 0.0,
1325
+ cfg_sched_gamma: float = 1.5, cfg_sched_u_pow: float = 1.0):
1326
 
1327
  """Clone model and attach a cfg mixing function implementing RescaleCFG/FDG, CFGZero*/FD, or hybrid ZeResFDG.
1328
  guidance_mode: 'default' | 'RescaleCFG' | 'RescaleFDG' | 'CFGZero*' | 'CFGZeroFD' | 'ZeResFDG'
 
1497
  cond = uncond + delta
1498
 
1499
  cond_scale_eff = cond_scale
1500
+ curve_gain = 1.0
1501
  if cfg_curve > 0.0 and (sigma is not None):
1502
  s = sigma
1503
  if s.ndim > 1:
 
1516
  t = t.clamp(0.0, 1.0)
1517
  k = 6.0 * float(cfg_curve)
1518
  s_curve = torch.tanh((t - 0.5) * k)
1519
+ g = 1.0 + 0.15 * float(cfg_curve) * s_curve
1520
+ if g.ndim > 0:
1521
+ g = g.mean().item()
1522
+ curve_gain = float(g)
1523
+ cond_scale_eff = cond_scale * curve_gain
1524
+
1525
+ # Per-step CFG schedule (cosine/warmup/U) using normalized sigma progress
1526
+ if isinstance(cfg_sched_type, str) and cfg_sched_type.lower() != "off" and (sigma is not None):
1527
+ try:
1528
+ s = sigma
1529
+ if s.ndim > 1:
1530
+ s = s.flatten()
1531
+ s_max = float(torch.max(s).item())
1532
+ s_min = float(torch.min(s).item())
1533
+ if sigma_seen["max"] is None:
1534
+ sigma_seen["max"] = s_max
1535
+ sigma_seen["min"] = s_min
1536
+ else:
1537
+ sigma_seen["max"] = max(sigma_seen["max"], s_max)
1538
+ sigma_seen["min"] = min(sigma_seen["min"], s_min)
1539
+ lo = max(1e-6, sigma_seen["min"])
1540
+ hi = max(lo * (1.0 + 1e-6), sigma_seen["max"])
1541
+ t = (torch.log(s + 1e-6) - torch.log(torch.tensor(lo, device=sigma.device))) / (torch.log(torch.tensor(hi, device=sigma.device)) - torch.log(torch.tensor(lo, device=sigma.device)) + 1e-6)
1542
+ t = t.clamp(0.0, 1.0)
1543
+ if t.ndim > 0:
1544
+ t_val = float(t.mean().item())
1545
+ else:
1546
+ t_val = float(t.item())
1547
+ cmin = float(max(0.0, cfg_sched_min))
1548
+ cmax = float(max(cmin, cfg_sched_max))
1549
+ tp = cfg_sched_type.lower()
1550
+ if tp == "cosine":
1551
+ import math
1552
+ cfg_val = cmax - (cmax - cmin) * 0.5 * (1.0 + math.cos(math.pi * t_val))
1553
+ elif tp in ("warmup", "warm-up", "linear"):
1554
+ g = float(max(0.0, min(1.0, t_val))) ** float(max(0.1, cfg_sched_gamma))
1555
+ cfg_val = cmin + (cmax - cmin) * g
1556
+ elif tp in ("u", "u-shape", "ushape"):
1557
+ # edges high, middle low; power to control concavity
1558
+ e = 4.0 * (t_val - 0.5) * (t_val - 0.5)
1559
+ e = float(min(1.0, max(0.0, e)))
1560
+ e = e ** float(max(0.1, cfg_sched_u_pow))
1561
+ cfg_val = cmin + (cmax - cmin) * e
1562
+ else:
1563
+ cfg_val = cond_scale_eff
1564
+ # Keep curve shaping as a multiplier on top of scheduled absolute value
1565
+ shape = (cond_scale_eff / float(cond_scale)) if float(cond_scale) != 0.0 else 1.0
1566
+ cond_scale_eff = float(cfg_val) * float(shape)
1567
+ except Exception:
1568
+ pass
1569
 
1570
  # Epsilon scaling (exposure bias correction): early steps get multiplier closer to (1 + eps_scale)
1571
  eps_mult = 1.0
 
2194
  clipseg_blend = str(pv("clipseg_blend", clipseg_blend))
2195
  clipseg_ref_gate = bool(pv("clipseg_ref_gate", clipseg_ref_gate))
2196
  clipseg_ref_threshold = float(pv("clipseg_ref_threshold", clipseg_ref_threshold))
2197
+ # CFG scheduling (internal-only; configured via presets)
2198
+ cfg_sched = str(pv("cfg_sched", "off"))
2199
+ cfg_sched_min = float(pv("cfg_sched_min", max(0.0, cfg * 0.5)))
2200
+ cfg_sched_max = float(pv("cfg_sched_max", cfg))
2201
+ cfg_sched_gamma = float(pv("cfg_sched_gamma", 1.5))
2202
+ cfg_sched_u_pow = float(pv("cfg_sched_u_pow", 1.0))
2203
  # Latent buffer (internal-only; configured via presets)
2204
  latent_buffer = bool(pv("latent_buffer", True))
2205
  lb_inject = float(pv("lb_inject", 0.25))
 
2441
  fdg_low_adaptive=bool(fdg_low_adaptive), fdg_low_min=float(fdg_low_min), fdg_low_max=float(fdg_low_max), fdg_ema_beta=float(fdg_ema_beta),
2442
  use_local_mask=bool(onnx_local_guidance), mask_inside=float(onnx_mask_inside), mask_outside=float(onnx_mask_outside),
2443
  mahiro_plus_enable=bool(muse_blend), mahiro_plus_strength=float(muse_blend_strength),
2444
+ eps_scale_enable=bool(eps_scale_enable), eps_scale=float(eps_scale),
2445
+ cfg_sched_type=str(cfg_sched), cfg_sched_min=float(cfg_sched_min), cfg_sched_max=float(cfg_sched_max),
2446
+ cfg_sched_gamma=float(cfg_sched_gamma), cfg_sched_u_pow=float(cfg_sched_u_pow)
2447
  )
2448
  # check once more right before the loop starts
2449
  model_management.throw_exception_if_processing_interrupted()
pressets/mg_cade25.cfg CHANGED
@@ -49,6 +49,14 @@ ref_preview: 512
49
  ref_threshold: 0.02
50
  ref_cooldown: 2
51
 
 
 
 
 
 
 
 
 
52
 
53
  # guidance
54
  guidance_mode: ZeResFDG
@@ -174,6 +182,14 @@ ref_preview: 512
174
  ref_threshold: 0.020
175
  ref_cooldown: 2
176
 
 
 
 
 
 
 
 
 
177
 
178
  # guidance
179
  guidance_mode: ZeResFDG
@@ -257,7 +273,7 @@ seed: 0
257
  control_after_generate: randomize
258
  steps: 10
259
  cfg: 5.0
260
- denoise: 0.55
261
  sampler_name: ddim
262
  scheduler: MGHybrid
263
  iterations: 2
@@ -303,6 +319,14 @@ ref_preview: 512
303
  ref_threshold: 0.020
304
  ref_cooldown: 2
305
 
 
 
 
 
 
 
 
 
306
 
307
  # guidance
308
  guidance_mode: ZeResFDG
@@ -414,13 +438,13 @@ lb_rebase_thresh: 0.10
414
  lb_rebase_rate: 0.25
415
 
416
  # detail controls
417
- ids_strength: 0.30
418
  upscale_method: lanczos
419
  scale_by: 1.5
420
  scale_delta: 0.1
421
  noise_offset: 0.0035
422
  threshold: 1.000
423
- Sharpnes_strenght: 0.3
424
  accumulation: fp32+fp32
425
 
426
  # reference clean
@@ -429,6 +453,14 @@ ref_preview: 512
429
  ref_threshold: 0.200
430
  ref_cooldown: 2
431
 
 
 
 
 
 
 
 
 
432
 
433
  # guidance
434
  guidance_mode: ZeResFDG
@@ -449,7 +481,7 @@ zero_init_steps: 0
449
 
450
  # FDG / ZE thresholds
451
  fdg_low: 0.35
452
- fdg_high: 0.7
453
  fdg_sigma: 1.20
454
  ze_res_zero_steps: 10
455
  ze_adaptive: true
 
49
  ref_threshold: 0.02
50
  ref_cooldown: 2
51
 
52
+ # cfg schedule (internal)
53
+ #cfg_sched: off | cosine | warmup | u
54
+ cfg_sched: warmup
55
+ #cfg_sched_min: 4.0
56
+ #cfg_sched_max: 8.0
57
+ cfg_sched_gamma: 1.5
58
+ #cfg_sched_u_pow: 1.0
59
+
60
 
61
  # guidance
62
  guidance_mode: ZeResFDG
 
182
  ref_threshold: 0.020
183
  ref_cooldown: 2
184
 
185
+ # cfg schedule (internal)
186
+ #cfg_sched: off | cosine | warmup | u
187
+ cfg_sched: cosine
188
+ #cfg_sched_min: 3.0
189
+ #cfg_sched_max: 6.5
190
+ #cfg_sched_gamma: 1.5
191
+ #cfg_sched_u_pow: 1.0
192
+
193
 
194
  # guidance
195
  guidance_mode: ZeResFDG
 
273
  control_after_generate: randomize
274
  steps: 10
275
  cfg: 5.0
276
+ denoise: 0.40
277
  sampler_name: ddim
278
  scheduler: MGHybrid
279
  iterations: 2
 
319
  ref_threshold: 0.020
320
  ref_cooldown: 2
321
 
322
+ # cfg schedule (internal)
323
+ #cfg_sched: off | cosine | warmup | u
324
+ cfg_sched: warmup
325
+ cfg_sched_min: 4.5
326
+ cfg_sched_max: 5.0
327
+ cfg_sched_gamma: 1.5
328
+ cfg_sched_u_pow: 1.2
329
+
330
 
331
  # guidance
332
  guidance_mode: ZeResFDG
 
438
  lb_rebase_rate: 0.25
439
 
440
  # detail controls
441
+ ids_strength: 0.35
442
  upscale_method: lanczos
443
  scale_by: 1.5
444
  scale_delta: 0.1
445
  noise_offset: 0.0035
446
  threshold: 1.000
447
+ Sharpnes_strenght: 0.24
448
  accumulation: fp32+fp32
449
 
450
  # reference clean
 
453
  ref_threshold: 0.200
454
  ref_cooldown: 2
455
 
456
+ # cfg schedule (internal)
457
+ #cfg_sched: off | cosine | warmup | u
458
+ cfg_sched: cosine
459
+ cfg_sched_min: 3.2
460
+ cfg_sched_max: 5.6
461
+ cfg_sched_gamma: 1.5
462
+ cfg_sched_u_pow: 1.0
463
+
464
 
465
  # guidance
466
  guidance_mode: ZeResFDG
 
481
 
482
  # FDG / ZE thresholds
483
  fdg_low: 0.35
484
+ fdg_high: 1.15
485
  fdg_sigma: 1.20
486
  ze_res_zero_steps: 10
487
  ze_adaptive: true