| import sys |
| from argparse import ArgumentParser |
| from pathlib import Path |
| from comet_ml import Experiment |
|
|
| import numpy as np |
| import torch |
| import yaml |
| from PIL import Image |
| from skimage.color import gray2rgb |
| from skimage.io import imread |
| from skimage.transform import resize |
| from skimage.util import img_as_ubyte |
| from tqdm import tqdm |
|
|
| sys.path.append(str(Path(__file__).resolve().parent.parent)) |
|
|
| import climategan |
|
|
| GROUND_MODEL = "/miniscratch/_groups/ccai/experiments/runs/ablation-v1/out--ground" |
|
|
|
|
| def uint8(array): |
| return array.astype(np.uint8) |
|
|
|
|
| def crop_and_resize(image_path, label_path): |
| """ |
| Resizes an image so that it keeps the aspect ratio and the smallest dimensions |
| is 640, then crops this resized image in its center so that the output is 640x640 |
| without aspect ratio distortion |
| |
| Args: |
| image_path (Path or str): Path to an image |
| label_path (Path or str): Path to the image's associated label |
| |
| Returns: |
| tuple((np.ndarray, np.ndarray)): (new image, new label) |
| """ |
|
|
| img = imread(image_path) |
| lab = imread(label_path) |
|
|
| |
| |
|
|
| |
| if img.shape[:2] != lab.shape[:2]: |
| print( |
| "\nWARNING: shape mismatch: im -> {}, lab -> {}".format( |
| image_path.name, label_path.name |
| ) |
| ) |
| |
|
|
| |
| h, w = img.shape[:2] |
| if h < w: |
| size = (640, int(640 * w / h)) |
| else: |
| size = (int(640 * h / w), 640) |
|
|
| r_img = resize(img, size, preserve_range=True, anti_aliasing=True) |
| r_img = uint8(r_img) |
|
|
| r_lab = resize(lab, size, preserve_range=True, anti_aliasing=False, order=0) |
| r_lab = uint8(r_lab) |
|
|
| |
| H, W = r_img.shape[:2] |
|
|
| top = (H - 640) // 2 |
| left = (W - 640) // 2 |
|
|
| rc_img = r_img[top : top + 640, left : left + 640, :] |
| rc_lab = ( |
| r_lab[top : top + 640, left : left + 640, :] |
| if r_lab.ndim == 3 |
| else r_lab[top : top + 640, left : left + 640] |
| ) |
|
|
| return rc_img, rc_lab |
|
|
|
|
| def load_ground(ground_output_path, ref_image_path): |
| gop = Path(ground_output_path) |
| rip = Path(ref_image_path) |
|
|
| ground_paths = list((gop / "eval-metrics" / "pred").glob(f"{rip.stem}.jpg")) + list( |
| (gop / "eval-metrics" / "pred").glob(f"{rip.stem}.png") |
| ) |
| if len(ground_paths) == 0: |
| raise ValueError( |
| f"Could not find a ground match in {str(gop)} for image {str(rip)}" |
| ) |
| elif len(ground_paths) > 1: |
| raise ValueError( |
| f"Found more than 1 ground match in {str(gop)} for image {str(rip)}:" |
| + f" {list(map(str, ground_paths))}" |
| ) |
| ground_path = ground_paths[0] |
| _, ground = crop_and_resize(rip, ground_path) |
| ground = (ground > 0).astype(np.float32) |
| return torch.from_numpy(ground).unsqueeze(0).unsqueeze(0).cuda() |
|
|
|
|
| def parse_args(): |
| parser = ArgumentParser() |
| parser.add_argument("-y", "--yaml", help="Path to a list of models") |
| parser.add_argument( |
| "--disable_loading", |
| action="store_true", |
| default=False, |
| help="Disable loading of existing inferences", |
| ) |
| parser.add_argument( |
| "-t", "--tags", nargs="*", help="Comet.ml tags", default=[], type=str |
| ) |
| parser.add_argument( |
| "--tasks", |
| nargs="*", |
| help="Comet.ml tags", |
| default=["x", "d", "s", "m", "mx", "p"], |
| type=str, |
| ) |
| args = parser.parse_args() |
|
|
| print("Received args:") |
| print(vars(args)) |
|
|
| return args |
|
|
|
|
| def load_images_and_labels( |
| path="/miniscratch/_groups/ccai/data/omnigan/masker-test-set", |
| ): |
| p = Path(path) |
| ims_path = p / "imgs" |
| lab_path = p / "labels" |
|
|
| ims = sorted(climategan.utils.find_images(ims_path), key=lambda x: x.name) |
| labs = sorted( |
| climategan.utils.find_images(lab_path), |
| key=lambda x: x.name.replace("_labeled.", "."), |
| ) |
|
|
| xs = climategan.transforms.PrepareInference()(ims) |
| ys = climategan.transforms.PrepareInference(is_label=True)(labs) |
|
|
| return xs, ys, ims, labs |
|
|
|
|
| def load_inferences(inf_path, im_paths): |
| try: |
| assert inf_path.exists() |
| assert sorted([i.stem for i in im_paths]) == sorted( |
| [i.stem for i in inf_path.glob("*.pt")] |
| ) |
| return [torch.load(str(i)) for i in tqdm(list(inf_path.glob("*.pt")))] |
| except Exception as e: |
| print() |
| print(e) |
| print("Aborting Loading") |
| print() |
| return None |
|
|
|
|
| def get_or_load_inferences( |
| m_path, device, xs, is_ground, im_paths, ground_model, try_load=True |
| ): |
| inf_path = Path(m_path) / "inferences" |
| if try_load: |
| print("Trying to load existing inferences:") |
| outputs = load_inferences(inf_path, im_paths) |
| if outputs is not None: |
| print("Successfully loaded existing inferences") |
| return outputs |
|
|
| trainer = climategan.trainer.Trainer.resume_from_path( |
| m_path if not is_ground else ground_model, |
| inference=True, |
| new_exp=None, |
| device=device, |
| ) |
|
|
| inf_path.mkdir(exist_ok=True) |
| outputs = [] |
| for i, x in enumerate(tqdm(xs)): |
| x = x.to(trainer.device) |
| if not is_ground: |
| out = trainer.G.decode(x=x) |
| else: |
| out = {"m": load_ground(GROUND_MODEL, im_paths[i])} |
| out["p"] = trainer.G.paint(out["m"] > 0.5, x) |
| out["x"] = x |
| inference = {k: v.cpu() for k, v in out.items()} |
| outputs.append(inference) |
| torch.save(inference, inf_path / f"{im_paths[i].stem}.pt") |
| print() |
|
|
| return outputs |
|
|
|
|
| def numpify(outputs): |
| nps = [] |
| print("Numpifying...") |
| for o in tqdm(outputs): |
| x = (o["x"][0].permute(1, 2, 0).numpy() + 1) / 2 |
| m = o["m"] |
| m = (m[0, 0, :, :].numpy() > 0.5).astype(np.uint8) |
| p = (o["p"][0].permute(1, 2, 0).numpy() + 1) / 2 |
| data = {"m": m, "p": p, "x": x} |
| if "s" in o: |
| s = climategan.data.decode_segmap_merged_labels(o["s"], "r", False) / 255.0 |
| data["s"] = s[0].permute(1, 2, 0).numpy() |
| if "d" in o: |
| d = climategan.tutils.normalize_tensor(o["d"]).squeeze().numpy() |
| data["d"] = d |
| nps.append({k: img_as_ubyte(v) for k, v in data.items()}) |
| return nps |
|
|
|
|
| def concat_npy_for_model(data, tasks): |
| assert "m" in data |
| assert "x" in data |
| assert "p" in data |
|
|
| x = mask = depth = seg = painted = masked = None |
|
|
| x = data["x"] |
| painted = data["p"] |
| mask = (gray2rgb(data["m"]) * 255).astype(np.uint8) |
| painted = data["p"] |
| masked = (1 - gray2rgb(data["m"])) * x |
|
|
| concats = [] |
|
|
| if "d" in data: |
| depth = img_as_ubyte( |
| gray2rgb( |
| resize(data["d"], data["x"].shape[:2], anti_aliasing=True, order=1) |
| ) |
| ) |
| else: |
| depth = np.ones_like(data["x"]) * 255 |
|
|
| if "s" in data: |
| seg = img_as_ubyte( |
| resize(data["s"], data["x"].shape[:2], anti_aliasing=False, order=0) |
| ) |
| else: |
| seg = np.ones_like(data["x"]) * 255 |
|
|
| for t in tasks: |
| if t == "x": |
| concats.append(x) |
| if t == "m": |
| concats.append(mask) |
| elif t == "mx": |
| concats.append(masked) |
| elif t == "d": |
| concats.append(depth) |
| elif t == "s": |
| concats.append(seg) |
| elif t == "p": |
| concats.append(painted) |
|
|
| row = np.concatenate(concats, axis=1) |
|
|
| return row |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
|
|
| with open(args.yaml, "r") as f: |
| maskers = yaml.safe_load(f) |
| if "models" in maskers: |
| maskers = maskers["models"] |
|
|
| load = not args.disable_loading |
| tags = args.tags |
| tasks = args.tasks |
|
|
| ground_model = None |
| for m in maskers: |
| if "ground" not in maskers: |
| ground_model = m |
| break |
| if ground_model is None: |
| raise ValueError("Could not find a non-ground model to get a painter") |
|
|
| device = torch.device("cuda:0") |
| torch.set_grad_enabled(False) |
|
|
| xs, ys, im_paths, lab_paths = load_images_and_labels() |
|
|
| np_outs = {} |
| names = [] |
|
|
| for m_path in maskers: |
|
|
| opt_path = Path(m_path) / "opts.yaml" |
| with opt_path.open("r") as f: |
| opt = yaml.safe_load(f) |
|
|
| name = ( |
| ", ".join( |
| [ |
| t |
| for t in sorted(opt["comet"]["tags"]) |
| if "branch" not in t and "ablation" not in t and "trash" not in t |
| ] |
| ) |
| if "--ground" not in m_path |
| else "ground" |
| ) |
| names.append(name) |
|
|
| is_ground = name == "ground" |
|
|
| print("#" * 100) |
| print("\n>>> Processing", name) |
| print() |
|
|
| outputs = get_or_load_inferences( |
| m_path, device, xs, is_ground, im_paths, ground_model, load |
| ) |
| nps = numpify(outputs) |
|
|
| np_outs[name] = nps |
|
|
| exp = Experiment(project_name="climategan-inferences", display_summary_level=0) |
| exp.log_parameter("names", names) |
| exp.add_tags(tags) |
|
|
| for i in tqdm(range(len(xs))): |
| all_models_for_image = [] |
| for name in names: |
| xpmds = concat_npy_for_model(np_outs[name][i], tasks) |
| all_models_for_image.append(xpmds) |
| full_im = np.concatenate(all_models_for_image, axis=0) |
| pil_im = Image.fromarray(full_im) |
| exp.log_image(pil_im, name=im_paths[i].stem.replace(".", "_"), step=i) |
|
|