|
|
from typing import NamedTuple |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def get_mean_shifted_latents( |
|
|
latents: torch.Tensor, |
|
|
shift: float = 0.11, |
|
|
delta_shift: float = 0.1, |
|
|
channels: list[int] = [0, 1, 1, 0], |
|
|
) -> torch.Tensor: |
|
|
shifted_latents = latents.clone() |
|
|
|
|
|
print("channels", channels) |
|
|
|
|
|
for idx, sign in enumerate(channels): |
|
|
if sign == 0: |
|
|
|
|
|
continue |
|
|
|
|
|
latent_channel = shifted_latents[:, idx, :, :] |
|
|
|
|
|
positive_ratio = (latent_channel > 0).float().mean() |
|
|
target_ratio = positive_ratio + shift * sign |
|
|
|
|
|
|
|
|
while True: |
|
|
latent_channel += delta_shift * sign |
|
|
new_positive_ratio = (latent_channel > 0).float().mean() |
|
|
if new_positive_ratio >= target_ratio: |
|
|
break |
|
|
|
|
|
|
|
|
shifted_latents[:, idx, :, :] = latent_channel |
|
|
|
|
|
return shifted_latents |
|
|
|
|
|
|
|
|
def get_2d_gaussian( |
|
|
latent_height: int, |
|
|
latent_width: int, |
|
|
std_dev: float, |
|
|
device: torch.device, |
|
|
center_x: float = 0.0, |
|
|
center_y: float = 0.0, |
|
|
factor: int = 8, |
|
|
): |
|
|
y = torch.linspace(-1, 1, steps=latent_height // factor, device=device) |
|
|
x = torch.linspace(-1, 1, steps=latent_width // factor, device=device) |
|
|
|
|
|
y_grid, x_grid = torch.meshgrid(y, x, indexing="ij") |
|
|
|
|
|
x_grid = x_grid - center_x |
|
|
y_grid = y_grid - center_y |
|
|
|
|
|
gauss = torch.exp(-((x_grid**2 + y_grid**2) / (2 * std_dev**2))) |
|
|
gauss = gauss[None, None, :, :] |
|
|
|
|
|
return gauss |
|
|
|
|
|
|
|
|
def apply_tkg_noise( |
|
|
latents: torch.Tensor, |
|
|
shift: float = 0.11, |
|
|
delta_shift: float = 0.1, |
|
|
std_dev: float = 0.5, |
|
|
factor: int = 8, |
|
|
channels: list[int] = [0, 1, 1, 0], |
|
|
): |
|
|
batch_size, num_channels, latent_height, latent_width = latents.shape |
|
|
|
|
|
shifted_latents = get_mean_shifted_latents( |
|
|
latents, |
|
|
shift=shift, |
|
|
delta_shift=delta_shift, |
|
|
channels=channels, |
|
|
) |
|
|
gauss_mask = get_2d_gaussian( |
|
|
latent_height=latent_height, |
|
|
latent_width=latent_width, |
|
|
std_dev=std_dev, |
|
|
center_x=0.0, |
|
|
center_y=0.0, |
|
|
factor=factor, |
|
|
device=latents.device, |
|
|
) |
|
|
gauss_mask = F.interpolate( |
|
|
gauss_mask, |
|
|
size=(latent_height, latent_width), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
|
|
|
gauss_mask = gauss_mask.expand(batch_size, num_channels, -1, -1) |
|
|
|
|
|
noised_latents = shifted_latents * (1 - gauss_mask) + latents * gauss_mask |
|
|
|
|
|
return noised_latents |
|
|
|
|
|
|
|
|
class ColorSet(NamedTuple): |
|
|
name: str |
|
|
channels: list[int] |
|
|
|
|
|
|
|
|
|
|
|
COLOR_SETS: list[ColorSet] = [ |
|
|
ColorSet("green", [0, 1, 1, 0]), |
|
|
ColorSet("cyan", [0, 1, 0, 0]), |
|
|
ColorSet("magenta", [0, -1, -1, -1]), |
|
|
ColorSet("purple", [0, 0, -1, -1]), |
|
|
ColorSet("black", [-1, 0, 0, 1]), |
|
|
ColorSet("orange", [-1, -1, 1, 0]), |
|
|
ColorSet("white", [0, 0, 0, -1]), |
|
|
ColorSet("yellow", [0, -1, 1, -1]), |
|
|
] |
|
|
|
|
|
COLOR_SET_MAP: dict[str, ColorSet] = {c.name: c for c in COLOR_SETS} |
|
|
|