Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| # Load class names | |
| with open("classes.txt", "r", encoding="utf-8") as f: | |
| CLASSES = [line.strip() for line in f if line.strip()] | |
| device = torch.device("cpu") | |
| # Load trained model | |
| model_path = "best_model.pth" | |
| model = torch.load(model_path, map_location=device) | |
| model.eval() | |
| # Preprocessing for input images | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ), | |
| ]) | |
| def predict(image: Image.Image): | |
| if image is None: | |
| return {"error": 1.0} | |
| image = image.convert("RGB") | |
| x = preprocess(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(x) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| # Build {label: prob} dict for Gradio Label output | |
| return { | |
| CLASSES[i]: float(probs[i]) | |
| for i in range(len(CLASSES)) | |
| } | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload a game cover"), | |
| outputs=gr.Label(num_top_classes=3, label="Predicted genre"), | |
| title="Video Game Genre Predictor", | |
| description="Upload a video game cover to predict its genre." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |