File size: 2,499 Bytes
695fbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import comfy.utils
import torch
import gc
import logging
import comfy.model_management as model_management
def clear_gpu_and_ram_cache():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def _smart_decode(vae, latent, tile_size=512):
try:
images = vae.decode(latent["samples"])
except model_management.OOM_EXCEPTION:
logging.warning("VAE decode OOM, using tiled decode")
compression = vae.spacial_compression_decode()
images = vae.decode_tiled(
latent["samples"],
tile_x=tile_size // compression,
tile_y=tile_size // compression,
overlap=(tile_size // 4) // compression,
)
if len(images.shape) == 5:
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return images
class MagicUpscaleModule:
"""Moved into mod/ as mg_upscale_module keeping class/key name."""
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"samples": ("LATENT", {}),
"vae": ("VAE", {}),
"upscale_method": (cls.upscale_methods, {"default": "bilinear"}),
"scale_by": ("FLOAT", {"default": 1.2, "min": 0.01, "max": 8.0, "step": 0.01}),
}
}
RETURN_TYPES = ("LATENT", "IMAGE")
RETURN_NAMES = ("LATENT", "Upscaled Image")
FUNCTION = "process_upscale"
CATEGORY = "MagicNodes"
def process_upscale(self, samples, vae, upscale_method, scale_by):
clear_gpu_and_ram_cache()
images = _smart_decode(vae, samples)
samples_t = images.movedim(-1, 1)
width = round(samples_t.shape[3] * scale_by)
height = round(samples_t.shape[2] * scale_by)
# Align to VAE stride to avoid border artifacts/shape drift
try:
stride = int(vae.spacial_compression_decode())
except Exception:
stride = 8
if stride <= 0:
stride = 8
def _align_up(x, s):
return int(((x + s - 1) // s) * s)
width_al = _align_up(width, stride)
height_al = _align_up(height, stride)
up = comfy.utils.common_upscale(samples_t, width_al, height_al, upscale_method, "disabled")
up = up.movedim(1, -1)
encoded = vae.encode(up[:, :, :, :3])
return ({"samples": encoded}, up)
|