File size: 2,499 Bytes
695fbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import comfy.utils
import torch
import gc
import logging
import comfy.model_management as model_management


def clear_gpu_and_ram_cache():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()


def _smart_decode(vae, latent, tile_size=512):
    try:
        images = vae.decode(latent["samples"])
    except model_management.OOM_EXCEPTION:
        logging.warning("VAE decode OOM, using tiled decode")
        compression = vae.spacial_compression_decode()
        images = vae.decode_tiled(
            latent["samples"],
            tile_x=tile_size // compression,
            tile_y=tile_size // compression,
            overlap=(tile_size // 4) // compression,
        )
    if len(images.shape) == 5:
        images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
    return images


class MagicUpscaleModule:
    """Moved into mod/ as mg_upscale_module keeping class/key name."""
    upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "samples": ("LATENT", {}),
                "vae": ("VAE", {}),
                "upscale_method": (cls.upscale_methods, {"default": "bilinear"}),
                "scale_by": ("FLOAT", {"default": 1.2, "min": 0.01, "max": 8.0, "step": 0.01}),
            }
        }

    RETURN_TYPES = ("LATENT", "IMAGE")
    RETURN_NAMES = ("LATENT", "Upscaled Image")
    FUNCTION = "process_upscale"
    CATEGORY = "MagicNodes"

    def process_upscale(self, samples, vae, upscale_method, scale_by):
        clear_gpu_and_ram_cache()
        images = _smart_decode(vae, samples)
        samples_t = images.movedim(-1, 1)
        width = round(samples_t.shape[3] * scale_by)
        height = round(samples_t.shape[2] * scale_by)
        # Align to VAE stride to avoid border artifacts/shape drift
        try:
            stride = int(vae.spacial_compression_decode())
        except Exception:
            stride = 8
        if stride <= 0:
            stride = 8
        def _align_up(x, s):
            return int(((x + s - 1) // s) * s)
        width_al = _align_up(width, stride)
        height_al = _align_up(height, stride)
        up = comfy.utils.common_upscale(samples_t, width_al, height_al, upscale_method, "disabled")
        up = up.movedim(1, -1)
        encoded = vae.encode(up[:, :, :, :3])
        return ({"samples": encoded}, up)