MagicNodes / mod /hard /mg_zesmart_sampler_v1_1.py
DZRobo
Add Z_image support and Improve latent/channel handling
e69f3b7
raw
history blame
10.4 kB
from __future__ import annotations
import math
import torch
import torch.nn.functional as F # noqa: F401
import comfy.utils as _utils
import comfy.sample as _sample
import comfy.samplers as _samplers
from comfy.k_diffusion import sampling as _kds
import nodes # latent preview callback
def _smoothstep01(x: torch.Tensor) -> torch.Tensor:
return x * x * (3.0 - 2.0 * x)
def _build_hybrid_sigmas(model, steps: int, base_sampler: str, mode: str,
mix: float, denoise: float, jitter: float, seed: int,
_debug: bool = False, tail_smooth: float = 0.0,
auto_hybrid_tail: bool = True, auto_tail_strength: float = 0.35):
"""Return 1D tensor of sigmas (len == steps+1), strictly descending and ending with 0.
mode: 'karras' | 'beta' | 'hybrid'. If 'hybrid', blend tail toward beta by `mix`.
We DO NOT apply 'drop penultimate' until the very end to preserve denoise math.
"""
ms = model.get_model_object("model_sampling")
steps = int(steps)
assert steps >= 1
# --- base tracks ---
sig_k = _samplers.calculate_sigmas(ms, "karras", steps)
sig_b = _samplers.calculate_sigmas(ms, "beta", steps)
def _align_len(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Align two sigma schedules to the same length (use tail of longer)."""
if a.shape[0] == b.shape[0]:
return a, b
m = min(a.shape[0], b.shape[0])
return a[-m:], b[-m:]
mode = str(mode).lower()
sig_k, sig_b = _align_len(sig_k, sig_b)
if mode == "karras":
sig = sig_k
elif mode == "beta":
sig = sig_b
else:
n = sig_k.shape[0]
t = torch.linspace(0.0, 1.0, n, device=sig_k.device, dtype=sig_k.dtype)
m = float(max(0.0, min(1.0, mix)))
eps = 1e-6 if m < 1e-6 else m
w = torch.clamp((t - (1.0 - m)) / eps, 0.0, 1.0)
w = _smoothstep01(w)
sig = sig_k * (1.0 - w) + sig_b * w
# --- Comfy denoise semantics: recompute a "full" track and take the tail of desired length ---
sig_k_base = sig_k
sig_b_base = sig_b
if denoise is not None and 0.0 < float(denoise) < 0.999999:
new_steps = max(1, int(steps / max(1e-6, float(denoise))))
sk = _samplers.calculate_sigmas(ms, "karras", new_steps)
sb = _samplers.calculate_sigmas(ms, "beta", new_steps)
sk, sb = _align_len(sk, sb)
if mode == "karras":
sig_full = sk
elif mode == "beta":
sig_full = sb
else:
n2 = sk.shape[0]
t2 = torch.linspace(0.0, 1.0, n2, device=sk.device, dtype=sk.dtype)
m = float(max(0.0, min(1.0, mix)))
eps = 1e-6 if m < 1e-6 else m
w2 = torch.clamp((t2 - (1.0 - m)) / eps, 0.0, 1.0)
w2 = _smoothstep01(w2)
sig_full = sk * (1.0 - w2) + sb * w2
need = steps + 1
if sig_full.shape[0] >= need:
sig = sig_full[-need:]
sig_k_base = sk[-need:]
sig_b_base = sb[-need:]
else:
# Worst case: trust what we got; we will still guarantee the last sigma is zero later
sig = sig_full
tail = min(need, sk.shape[0])
sig_k_base = sk[-tail:]
sig_b_base = sb[-tail:]
# --- auto-hybrid tail: blend beta into the tail when the steps become brittle ---
if bool(auto_hybrid_tail) and sig.numel() > 2:
n = sig.shape[0]
t = torch.linspace(0.0, 1.0, n, device=sig.device, dtype=sig.dtype)
m = float(max(0.0, min(1.0, mix)))
if mode == "hybrid":
eps = 1e-6 if m < 1e-6 else m
w_m = torch.clamp((t - (1.0 - m)) / eps, 0.0, 1.0)
w_m = _smoothstep01(w_m)
elif mode == "beta":
w_m = torch.ones_like(t)
else:
w_m = torch.zeros_like(t)
dif = (sig[1:] - sig[:-1]).abs() / sig[:-1].abs().clamp_min(1e-8)
dif = torch.cat([dif, dif[-1:]], dim=0)
dif = (dif - dif.min()) / (dif.max() - dif.min() + 1e-8)
ramp = _smoothstep01(torch.clamp((t - 0.7) / 0.3, 0.0, 1.0))
w_a = dif * ramp
g = float(max(0.0, min(1.0, auto_tail_strength)))
u = w_m + g * w_a - w_m * g * w_a
sig = sig_k_base * (1.0 - u) + sig_b_base * u
# --- tiny schedule jitter ---
j = float(max(0.0, min(0.1, float(jitter))))
if j > 0.0 and sig.numel() > 1:
gen = torch.Generator(device='cpu')
gen.manual_seed(int(seed) ^ 0x5EEDCAFE)
noise = torch.randn(sig.shape, generator=gen, device='cpu').to(sig.device, sig.dtype)
amp = j * float(sig[0].item() - sig[-1].item()) * 1e-3
sig = sig + noise * amp
sig, _ = torch.sort(sig, descending=True)
# --- hard guarantee of ending with exact zero ---
if sig[-1].abs() > 1e-12:
sig = torch.cat([sig[:-1], sig.new_zeros(1)], dim=0)
# --- and only now drop-penultimate for respective samplers ---
# --- gentle smoothing of sigma tail (adaptive, safe for monotonic decrease) ---
ts = float(max(0.0, min(1.0, tail_smooth)))
if ts > 0.0 and sig.numel() > 2:
s = sig.clone()
n = int(s.shape[0])
t = torch.linspace(0.0, 1.0, n, device=s.device, dtype=s.dtype)
w = (t.pow(2) * ts).clamp(0.0, 1.0)
for i in range(n - 2, -1, -1):
a = float(min(0.5, 0.5 * w[i].item()))
s[i] = (1.0 - a) * s[i] + a * s[i + 1]
sig = s
if base_sampler in _samplers.KSampler.DISCARD_PENULTIMATE_SIGMA_SAMPLERS and sig.numel() >= 2:
sig = torch.cat([sig[:-2], sig[-1:]], dim=0)
sig = sig.to(model.load_device)
# Lightweight debug: schedule summary
if _debug:
try:
desc_ok = bool((sig[:-1] > sig[1:]).all().item()) if sig.numel() > 1 else True
head = ", ".join(f"{float(v):.4g}" for v in sig[:3].tolist()) if sig.numel() >= 3 else \
", ".join(f"{float(v):.4g}" for v in sig.tolist())
tail = ", ".join(f"{float(v):.4g}" for v in sig[-3:].tolist()) if sig.numel() >= 3 else head
print(f"[ZeSmart][dbg] sigmas len={sig.numel()} desc={desc_ok} first={float(sig[0]):.6g} last={float(sig[-1]):.6g}")
print(f"[ZeSmart][dbg] head: [{head}] tail: [{tail}]")
except Exception:
pass
return sig
class MG_ZeSmartSampler:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL", {}),
"seed": ("INT", {"default": 0, "min": 0, "max": 2**63-1, "control_after_generate": True}),
"steps": ("INT", {"default": 20, "min": 1, "max": 4096}),
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 50.0, "step": 0.1}),
"base_sampler": (_samplers.KSampler.SAMPLERS, {"default": "dpmpp_2m"}),
"schedule": (["karras", "beta", "hybrid"], {"default": "hybrid", "tooltip": "Sigma curve: karras — soft start; beta — stable tail; hybrid — their mix."}),
"positive": ("CONDITIONING", {}),
"negative": ("CONDITIONING", {}),
"latent": ("LATENT", {}),
},
"optional": {
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Path shortening: 1.0 = full; <1.0 = take the last steps only. Useful for inpaint/mixing."}),
"hybrid_mix": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For schedule=hybrid: tail fraction blended toward beta (0=karras, 1=beta)."}),
"jitter_sigma": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 0.1, "step": 0.001, "tooltip": "Tiny sigma jitter to kill moiré/banding on backgrounds. 0–0.02 is usually enough."}),
"tail_smooth": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Smooth the sigma tail — reduces wobble/banding. Too high may soften details."}),
"auto_hybrid_tail": ("BOOLEAN", {"default": True, "tooltip": "Auto‑blend beta on the tail when steps become brittle."}),
"auto_tail_strength": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Strength of auto beta‑mix on the tail (0=off, 1=max)."}),
"debug_probe": ("BOOLEAN", {"default": False, "tooltip": "Print sigma summary (length, first/last, head/tail)."}),
}
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("LATENT",)
FUNCTION = "apply"
CATEGORY = "MagicNodes/Experimental"
def apply(self, model, seed, steps, cfg, base_sampler, schedule,
positive, negative, latent, denoise=1.0, hybrid_mix=0.5,
jitter_sigma=0.02, tail_smooth=0.07,
auto_hybrid_tail=True, auto_tail_strength=0.3,
debug_probe=False):
# Prepare latent + noise
lat_img = latent["samples"]
lat_img = _sample.fix_empty_latent_channels(model, lat_img)
batch_inds = latent.get("batch_index", None)
noise = _sample.prepare_noise(lat_img, seed, batch_inds)
noise_mask = latent.get("noise_mask", None)
# Custom sigmas
sigmas = _build_hybrid_sigmas(model, int(steps), str(base_sampler), str(schedule),
float(hybrid_mix), float(denoise), float(jitter_sigma), int(seed),
_debug=bool(debug_probe), tail_smooth=float(tail_smooth),
auto_hybrid_tail=bool(auto_hybrid_tail),
auto_tail_strength=float(auto_tail_strength))
# Use native sampler; all tweaks happen in sigma schedule only.
sampler_obj = _samplers.sampler_object(str(base_sampler))
callback = nodes.latent_preview.prepare_callback(model, int(steps))
disable_pbar = not _utils.PROGRESS_BAR_ENABLED
samples = _sample.sample_custom(model, noise, float(cfg), sampler_obj, sigmas,
positive, negative, lat_img,
noise_mask=noise_mask, callback=callback,
disable_pbar=disable_pbar, seed=seed)
out = {**latent}
out["samples"] = samples
return (out,)