import gradio as gr import torch from torchvision import transforms from PIL import Image # Load class names with open("classes.txt", "r", encoding="utf-8") as f: CLASSES = [line.strip() for line in f if line.strip()] device = torch.device("cpu") # Load trained model model_path = "best_model.pth" model = torch.load(model_path, map_location=device) model.eval() # Preprocessing for input images preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ]) def predict(image: Image.Image): if image is None: return {"error": 1.0} image = image.convert("RGB") x = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): logits = model(x) probs = torch.softmax(logits, dim=1)[0] # Build {label: prob} dict for Gradio Label output return { CLASSES[i]: float(probs[i]) for i in range(len(CLASSES)) } demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload a game cover"), outputs=gr.Label(num_top_classes=3, label="Predicted genre"), title="Video Game Genre Predictor", description="Upload a video game cover to predict its genre." ) if __name__ == "__main__": demo.launch()