import torch import numpy as np import gradio as gr import matplotlib.pyplot as plt from model import Generator # Load the model device = torch.device("cpu") generator = Generator() generator.load_state_dict(torch.load("generator_digit.pth", map_location=device)) generator.eval() def generate_images(digit): noise = torch.randn(5, 100) labels = torch.tensor([digit] * 5) with torch.no_grad(): images = generator(noise, labels).squeeze().numpy() # Plot the 5 images in one figure fig, axs = plt.subplots(1, 5, figsize=(10, 2)) for i in range(5): axs[i].imshow(images[i], cmap='gray') axs[i].axis('off') return fig # Gradio Interface using modern syntax demo = gr.Interface( fn=generate_images, inputs=gr.Slider(0, 9, step=1, label="Digit (0–9)"), outputs=gr.Plot(label="Generated Images"), title="MNIST Digit Generator", description="Generates 5 handwritten images of the selected digit using a trained GAN." ) demo.launch()