Spaces:
Runtime error
Runtime error
| import glob | |
| import gradio as gr | |
| import gym | |
| import sys | |
| from torch.utils.tensorboard import SummaryWriter | |
| import yaml | |
| import torch | |
| from cartpole import ( | |
| make_env, reset_env, Agent, rollout_phase, get_action_shape | |
| ) | |
| MAIN = __name__ == "__main__" | |
| examples = [0, 1, 31415, 'Hello, World!', 'This is a seed...'] | |
| def generate_video( | |
| string: str, wandb_path='wandb/run-20230303_211416-ox4d1p0u/files' | |
| ): | |
| with open(f'{wandb_path}/config.yaml') as f_cfg: | |
| config = yaml.safe_load(f_cfg) | |
| seed = hash(string) % ((sys.maxsize + 1) * 2) | |
| num_envs = config['num_envs']['value'] | |
| num_steps = config['num_steps']['value'] | |
| assert seed >= 0 | |
| assert isinstance(seed, int) | |
| run_name = f'seed{seed}' | |
| log_dir = f'generate/{run_name}' | |
| writer = SummaryWriter(log_dir) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| envs = gym.vector.SyncVectorEnv([ | |
| make_env("CartPole-v1", seed, i, True, run_name) | |
| for i in range(num_envs) | |
| ]) | |
| action_shape = get_action_shape(envs) | |
| next_obs, next_done = reset_env(envs, device) | |
| global_step = 0 | |
| agent = Agent(envs).to(device) | |
| agent.load_state_dict(torch.load(f'{wandb_path}/model_state_dict.pt')) | |
| rollout_phase( | |
| next_obs, next_done, agent, envs, writer, device, | |
| global_step, action_shape, num_envs, num_steps, | |
| ) | |
| video_path = glob.glob(f'videos/{run_name}/*.mp4')[0] | |
| return video_path | |
| if MAIN: | |
| demo = gr.Interface( | |
| fn=generate_video, | |
| inputs=[ | |
| gr.components.Textbox(lines=1, label="Seed"), | |
| ], | |
| outputs=gr.components.Video(label="Generated Video"), | |
| examples=examples, | |
| ) | |
| demo.launch() |