| import gradio as gr |
| import spaces |
| import torch |
| import yaml |
| import numpy as np |
| from PIL import Image |
| from cdim.noise import get_noise |
| from cdim.operators import get_operator |
| from cdim.diffusion.scheduling_ddim import DDIMScheduler |
| from cdim.diffusion.diffusion_pipeline import run_diffusion |
| from diffusers import DiffusionPipeline |
|
|
| |
| model = None |
| ddim_scheduler = None |
| model_type = None |
| curr_model_name = None |
|
|
|
|
| def load_image(image_path): |
| """Process input image to tensor format.""" |
| image = Image.open(image_path) |
| original_image = np.array(image.resize((256, 256), Image.BICUBIC)) |
| original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2) |
| return (original_image / 127.5 - 1.0).to(torch.float)[:, :3] |
|
|
|
|
| def load_yaml(file_path: str) -> dict: |
| """Load configurations from a YAML file.""" |
| with open(file_path) as f: |
| config = yaml.load(f, Loader=yaml.FullLoader) |
| return config |
|
|
|
|
| def convert_to_np(torch_image): |
| return ((torch_image.detach().clamp(-1, 1).cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8) |
|
|
|
|
| @spaces.GPU |
| def process_image(image_choice, noise_sigma, operator_key, T, stopping_sigma): |
| """Combined function to handle both generation and restoration.""" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| global model, curr_model_name, ddim_scheduler, model_type |
| model_name = "google/ddpm-celebahq-256" if "CelebA" in image_choice else "google/ddpm-church-256" |
|
|
| if model is None or curr_model_name != model_name: |
| model_type = "diffusers" |
| model = DiffusionPipeline.from_pretrained(model_name).to(device).unet |
| curr_model_name = model_name |
| ddim_scheduler = DDIMScheduler( |
| num_train_timesteps=1000, |
| beta_start=0.0001, |
| beta_end=0.02, |
| beta_schedule="linear" |
| ) |
|
|
| image_paths = { |
| "CelebA HQ 1": "sample_images/celebhq_29999.jpg", |
| "CelebA HQ 2": "sample_images/celebhq_00001.jpg", |
| "CelebA HQ 3": "sample_images/celebhq_00000.jpg", |
| "LSUN Church": "sample_images/lsun_church.png" |
| } |
|
|
| config_paths = { |
| "Box Inpainting": "operator_configs/box_inpainting_config.yaml", |
| "Random Inpainting": "operator_configs/random_inpainting_config.yaml", |
| "Super Resolution": "operator_configs/super_resolution_config.yaml", |
| "Gaussian Deblur": "operator_configs/gaussian_blur_config.yaml" |
| } |
|
|
| |
| image_path = image_paths[image_choice] |
| original_image = load_image(image_path).to(device) |
| |
| noise_config = load_yaml("noise_configs/gaussian_noise_config.yaml") |
| noise_config["sigma"] = noise_sigma |
| noise_function = get_noise(**noise_config) |
| |
| operator_config = load_yaml(config_paths[operator_key]) |
| operator_config["device"] = device |
| operator = get_operator(**operator_config) |
| |
| noisy_measurement = noise_function(operator(original_image)) |
| noisy_image = Image.fromarray(convert_to_np(noisy_measurement[0])) |
|
|
| |
| output_image = run_diffusion( |
| model, ddim_scheduler, noisy_measurement, operator, noise_function, device, |
| stopping_sigma, num_inference_steps=T, model_type=model_type |
| ) |
| |
| output_image = Image.fromarray(convert_to_np(output_image[0])) |
| return noisy_image, output_image |
|
|
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# Noisy Image Restoration with Diffusion Models") |
| |
| with gr.Row(): |
| T = gr.Slider(4, 200, value=25, step=1, label="Number of Inference Steps (T)") |
| stopping_sigma = gr.Slider(0.1, 5.0, value=0.1, step=0.1, label="Stopping Sigma (c)") |
| noise_sigma = gr.Slider(0, 0.6, value=0.05, step=0.01, label="Measurement Noise Sigma (σ_y)") |
| |
| image_select = gr.Dropdown( |
| choices=["CelebA HQ 1", "CelebA HQ 2", "CelebA HQ 3", "LSUN Church"], |
| value="CelebA HQ 1", |
| label="Select Input Image" |
| ) |
| |
| operator_select = gr.Dropdown( |
| choices=["Box Inpainting", "Random Inpainting", "Super Resolution", "Gaussian Deblur"], |
| value="Random Inpainting", |
| label="Select Task" |
| ) |
| |
| run_button = gr.Button("Run Inference") |
| noisy_image = gr.Image(label="Noisy Image") |
| restored_image = gr.Image(label="Restored Image") |
|
|
| run_button.click( |
| fn=process_image, |
| inputs=[image_select, noise_sigma, operator_select, T, stopping_sigma], |
| outputs=[noisy_image, restored_image] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|