Spaces:
Runtime error
Runtime error
| import cv2 | |
| import streamlit as st | |
| import time | |
| from huggingface_sb3 import load_from_hub | |
| from stable_baselines3 import PPO | |
| from stable_baselines3.common.env_util import make_atari_env | |
| from stable_baselines3.common.vec_env import VecFrameStack | |
| from stable_baselines3.common.env_util import make_atari_env | |
| st.title("Atari Environments Live Model") | |
| # @st.cache This is not cachable :( | |
| def load_env(env_name): | |
| env = make_atari_env(env_name, n_envs=1) | |
| env = VecFrameStack(env, n_stack=4) | |
| return env | |
| # @st.cache This is not cachable :( | |
| def load_model(env_name): | |
| custom_objects = { | |
| "learning_rate": 0.0, | |
| "lr_schedule": lambda _: 0.0, | |
| "clip_range": lambda _: 0.0, | |
| } | |
| checkpoint = load_from_hub( | |
| f"ThomasSimonini/ppo-{env_name}", | |
| f"ppo-{env_name}.zip", | |
| ) | |
| model = PPO.load(checkpoint, custom_objects=custom_objects) | |
| return model | |
| env_name = st.selectbox( | |
| "Select environment", | |
| ( | |
| "SpaceInvadersNoFrameskip-v4", | |
| "PongNoFrameskip-v4", | |
| "SeaquestNoFrameskip-v4", | |
| "QbertNoFrameskip-v4", | |
| ), | |
| ) | |
| num_episodes = st.slider("Number of Episodes", 1, 20, 5) | |
| env = load_env(env_name) | |
| model = load_model(env_name) | |
| obs = env.reset() | |
| with st.empty(): | |
| for i in range(num_episodes): | |
| obs = env.reset() | |
| done = False | |
| while not done: | |
| frame = env.render(mode="rgb_array") | |
| im = st.image(frame, width=400) | |
| action, _states = model.predict(obs) | |
| obs, reward, done, info = env.step([action]) | |
| time.sleep(0.1) | |