Spaces:
Runtime error
Runtime error
Changed git commands over to Hugging Face
Browse files- cartpole.py +37 -25
cartpole.py
CHANGED
|
@@ -11,7 +11,7 @@
|
|
| 11 |
# name: python3
|
| 12 |
# ---
|
| 13 |
|
| 14 |
-
# + id="QAY_RQOLcRtA" executionInfo={"status": "ok", "timestamp":
|
| 15 |
MAIN = __name__ == "__main__"
|
| 16 |
if MAIN:
|
| 17 |
print('Mounting drive...')
|
|
@@ -19,23 +19,32 @@ if MAIN:
|
|
| 19 |
drive.mount('/content/drive')
|
| 20 |
# %cd /content/drive/MyDrive/Colab Notebooks/cartpole-demo
|
| 21 |
|
| 22 |
-
# + colab={"base_uri": "https://localhost:8080/"} id="GgSNZRJh4EjV" executionInfo={"status": "ok", "timestamp":
|
| 23 |
# !pip install einops
|
| 24 |
# !pip install wandb
|
| 25 |
# !pip install jupytext
|
| 26 |
# !pip install pygame
|
| 27 |
# !pip install torchtyping
|
| 28 |
# !pip install gradio
|
|
|
|
| 29 |
|
| 30 |
-
# + colab={"base_uri": "https://localhost:8080/"} id="1g58HZUb8Ltl" executionInfo={"status": "ok", "timestamp":
|
|
|
|
|
|
|
| 31 |
# !git config --global user.email "[email protected]"
|
| 32 |
-
# !
|
| 33 |
-
# !cat pat.txt | xargs git remote set-url origin
|
| 34 |
# !jupytext --to py cartpole.ipynb
|
| 35 |
# !git fetch
|
|
|
|
| 36 |
# !git status
|
| 37 |
|
| 38 |
-
# + id="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
import os
|
| 40 |
import glob
|
| 41 |
import sys
|
|
@@ -66,7 +75,7 @@ from typeguard import typechecked
|
|
| 66 |
# + id="K7T8bs1Y76ZK" executionInfo={"status": "ok", "timestamp": 1677942330521, "user_tz": 0, "elapsed": 8, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} colab={"base_uri": "https://localhost:8080/"} outputId="f59ffef0-7156-4f27-d992-a392d59a1c73"
|
| 67 |
# %env "WANDB_NOTEBOOK_NAME" "cartpole.py"
|
| 68 |
|
| 69 |
-
# + id="Q5E93-BGRjuy"
|
| 70 |
def make_env(
|
| 71 |
env_id: str, seed: int, idx: int, capture_video: bool, run_name: str
|
| 72 |
):
|
|
@@ -93,7 +102,7 @@ def make_env(
|
|
| 93 |
return thunk
|
| 94 |
|
| 95 |
|
| 96 |
-
# + id="Kf152ROwHjM_"
|
| 97 |
def test_minibatch_indexes(minibatch_indexes):
|
| 98 |
for n in range(5):
|
| 99 |
frac, minibatch_size = np.random.randint(1, 8, size=(2,))
|
|
@@ -105,7 +114,7 @@ def test_minibatch_indexes(minibatch_indexes):
|
|
| 105 |
np.testing.assert_equal(np.sort(np.stack(indices).flatten()), np.arange(batch_size))
|
| 106 |
|
| 107 |
|
| 108 |
-
# + id="mhvduVeOHkln"
|
| 109 |
def test_calc_entropy_bonus(calc_entropy_bonus):
|
| 110 |
probs = Categorical(logits=t.randn((3, 4)))
|
| 111 |
ent_coef = 0.5
|
|
@@ -114,7 +123,7 @@ def test_calc_entropy_bonus(calc_entropy_bonus):
|
|
| 114 |
t.testing.assert_close(expected, actual)
|
| 115 |
|
| 116 |
|
| 117 |
-
# + id="Aya60GeCGA5X"
|
| 118 |
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
|
| 119 |
t.nn.init.orthogonal_(layer.weight, std)
|
| 120 |
t.nn.init.constant_(layer.bias, bias_const)
|
|
@@ -146,7 +155,7 @@ class Agent(nn.Module):
|
|
| 146 |
|
| 147 |
|
| 148 |
|
| 149 |
-
# + id="6PwPZHlLGDYu"
|
| 150 |
# %%
|
| 151 |
@t.inference_mode()
|
| 152 |
def compute_advantages(
|
|
@@ -190,7 +199,7 @@ def compute_advantages(
|
|
| 190 |
|
| 191 |
|
| 192 |
|
| 193 |
-
# + id="uYSSMnF-GPvm"
|
| 194 |
# %%
|
| 195 |
@dataclass
|
| 196 |
class Minibatch:
|
|
@@ -252,7 +261,7 @@ def make_minibatches(
|
|
| 252 |
|
| 253 |
|
| 254 |
|
| 255 |
-
# + id="K7wXDJ9MGOWu"
|
| 256 |
# %%
|
| 257 |
def calc_policy_loss(
|
| 258 |
probs: Categorical, mb_action: t.Tensor, mb_advantages: t.Tensor,
|
|
@@ -277,7 +286,7 @@ def calc_policy_loss(
|
|
| 277 |
|
| 278 |
|
| 279 |
|
| 280 |
-
# + id="CmyxU6JWGMsG"
|
| 281 |
# %%
|
| 282 |
def calc_value_function_loss(
|
| 283 |
critic: nn.Sequential, mb_obs: t.Tensor, mb_returns: t.Tensor, v_coef: float
|
|
@@ -294,7 +303,7 @@ def calc_value_function_loss(
|
|
| 294 |
|
| 295 |
|
| 296 |
|
| 297 |
-
# + id="npyWs6xjGLkP"
|
| 298 |
# %%
|
| 299 |
def calc_entropy_loss(probs: Categorical, ent_coef: float):
|
| 300 |
'''Return the entropy loss term.
|
|
@@ -310,7 +319,7 @@ if MAIN:
|
|
| 310 |
test_calc_entropy_bonus(calc_entropy_loss)
|
| 311 |
|
| 312 |
|
| 313 |
-
# + id="nqJeg1kZGKSG"
|
| 314 |
# %%
|
| 315 |
class PPOScheduler:
|
| 316 |
def __init__(self, optimizer: optim.Adam, initial_lr: float, end_lr: float, num_updates: int):
|
|
@@ -345,7 +354,7 @@ def make_optimizer(
|
|
| 345 |
|
| 346 |
|
| 347 |
|
| 348 |
-
# + id="mgZ7-wsRCxJW"
|
| 349 |
@dataclass
|
| 350 |
class PPOArgs:
|
| 351 |
exp_name: str = 'cartpole.py'
|
|
@@ -373,7 +382,7 @@ class PPOArgs:
|
|
| 373 |
minibatch_size: int = 128
|
| 374 |
|
| 375 |
|
| 376 |
-
# + id="xeIu-J3ZwGyq"
|
| 377 |
def wandb_init(name: str, args: PPOArgs):
|
| 378 |
wandb.init(
|
| 379 |
project=args.wandb_project_name,
|
|
@@ -387,14 +396,14 @@ def wandb_init(name: str, args: PPOArgs):
|
|
| 387 |
)
|
| 388 |
|
| 389 |
|
| 390 |
-
# + id="gMYWqhsryYHy"
|
| 391 |
def set_seed(seed: int):
|
| 392 |
random.seed(seed)
|
| 393 |
np.random.seed(seed)
|
| 394 |
torch.manual_seed(seed)
|
| 395 |
|
| 396 |
|
| 397 |
-
# + id="T9j_L0Wpyrgz"
|
| 398 |
@typechecked
|
| 399 |
def rollout_phase(
|
| 400 |
next_obs: t.Tensor, next_done: t.Tensor,
|
|
@@ -472,14 +481,14 @@ def rollout_phase(
|
|
| 472 |
)
|
| 473 |
|
| 474 |
|
| 475 |
-
# + id="xdDhABIk5jyb"
|
| 476 |
def reset_env(envs, device):
|
| 477 |
next_obs = torch.Tensor(envs.reset()).to(device)
|
| 478 |
next_done = torch.zeros(envs.num_envs).to(device)
|
| 479 |
return next_obs, next_done
|
| 480 |
|
| 481 |
|
| 482 |
-
# + id="5CoMpUVU7rFT"
|
| 483 |
def get_action_shape(envs: gym.vector.SyncVectorEnv):
|
| 484 |
action_shape = envs.single_action_space.shape
|
| 485 |
assert action_shape is not None
|
|
@@ -489,7 +498,7 @@ def get_action_shape(envs: gym.vector.SyncVectorEnv):
|
|
| 489 |
return action_shape
|
| 490 |
|
| 491 |
|
| 492 |
-
# + id="FHmn5kSUGFFu"
|
| 493 |
# %%
|
| 494 |
def train_ppo(args: PPOArgs):
|
| 495 |
t0 = int(time.time())
|
|
@@ -628,8 +637,11 @@ if MAIN:
|
|
| 628 |
args = PPOArgs()
|
| 629 |
train_ppo(args)
|
| 630 |
|
| 631 |
-
# + colab={"base_uri": "https://localhost:8080/"} id="xJW6KL7QIj4s" outputId="7c529849-6d46-4a6a-def5-e1c0ef652c64"
|
| 632 |
# !python demo.py
|
| 633 |
|
| 634 |
-
# + id="P7ZfUlAqImIr"
|
|
|
|
|
|
|
|
|
|
| 635 |
|
|
|
|
| 11 |
# name: python3
|
| 12 |
# ---
|
| 13 |
|
| 14 |
+
# + id="QAY_RQOLcRtA" executionInfo={"status": "ok", "timestamp": 1677945244865, "user_tz": 0, "elapsed": 19712, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} colab={"base_uri": "https://localhost:8080/"} outputId="be179435-1667-40af-8a80-7bc63a472715"
|
| 15 |
MAIN = __name__ == "__main__"
|
| 16 |
if MAIN:
|
| 17 |
print('Mounting drive...')
|
|
|
|
| 19 |
drive.mount('/content/drive')
|
| 20 |
# %cd /content/drive/MyDrive/Colab Notebooks/cartpole-demo
|
| 21 |
|
| 22 |
+
# + colab={"base_uri": "https://localhost:8080/"} id="GgSNZRJh4EjV" executionInfo={"status": "ok", "timestamp": 1677945316689, "user_tz": 0, "elapsed": 57846, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} outputId="6aeb7bf3-e186-449d-cdc4-c66f778244b2"
|
| 23 |
# !pip install einops
|
| 24 |
# !pip install wandb
|
| 25 |
# !pip install jupytext
|
| 26 |
# !pip install pygame
|
| 27 |
# !pip install torchtyping
|
| 28 |
# !pip install gradio
|
| 29 |
+
# !pip install huggingface_hub
|
| 30 |
|
| 31 |
+
# + colab={"base_uri": "https://localhost:8080/"} id="1g58HZUb8Ltl" executionInfo={"status": "ok", "timestamp": 1677945458077, "user_tz": 0, "elapsed": 16862, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} outputId="62ffc9cd-ff0b-4473-c17a-4593a14526cf"
|
| 32 |
+
# !git config --global credential.helper store
|
| 33 |
+
# !git config --global user.name "skar0"
|
| 34 |
# !git config --global user.email "[email protected]"
|
| 35 |
+
# !huggingface-cli login
|
|
|
|
| 36 |
# !jupytext --to py cartpole.ipynb
|
| 37 |
# !git fetch
|
| 38 |
+
# # !chmod +x .git/hooks/pre-push
|
| 39 |
# !git status
|
| 40 |
|
| 41 |
+
# + id="dYeFdxVIWOqc" executionInfo={"status": "ok", "timestamp": 1677945546175, "user_tz": 0, "elapsed": 318, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# + colab={"base_uri": "https://localhost:8080/"} id="5xFqBnKzVN60" executionInfo={"status": "ok", "timestamp": 1677945556589, "user_tz": 0, "elapsed": 7558, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} outputId="535e6c5e-17f6-4342-8a9d-ff54f4c82187"
|
| 45 |
+
# !git push
|
| 46 |
+
|
| 47 |
+
# + id="vEczQ48wC40O"
|
| 48 |
import os
|
| 49 |
import glob
|
| 50 |
import sys
|
|
|
|
| 75 |
# + id="K7T8bs1Y76ZK" executionInfo={"status": "ok", "timestamp": 1677942330521, "user_tz": 0, "elapsed": 8, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} colab={"base_uri": "https://localhost:8080/"} outputId="f59ffef0-7156-4f27-d992-a392d59a1c73"
|
| 76 |
# %env "WANDB_NOTEBOOK_NAME" "cartpole.py"
|
| 77 |
|
| 78 |
+
# + id="Q5E93-BGRjuy"
|
| 79 |
def make_env(
|
| 80 |
env_id: str, seed: int, idx: int, capture_video: bool, run_name: str
|
| 81 |
):
|
|
|
|
| 102 |
return thunk
|
| 103 |
|
| 104 |
|
| 105 |
+
# + id="Kf152ROwHjM_"
|
| 106 |
def test_minibatch_indexes(minibatch_indexes):
|
| 107 |
for n in range(5):
|
| 108 |
frac, minibatch_size = np.random.randint(1, 8, size=(2,))
|
|
|
|
| 114 |
np.testing.assert_equal(np.sort(np.stack(indices).flatten()), np.arange(batch_size))
|
| 115 |
|
| 116 |
|
| 117 |
+
# + id="mhvduVeOHkln"
|
| 118 |
def test_calc_entropy_bonus(calc_entropy_bonus):
|
| 119 |
probs = Categorical(logits=t.randn((3, 4)))
|
| 120 |
ent_coef = 0.5
|
|
|
|
| 123 |
t.testing.assert_close(expected, actual)
|
| 124 |
|
| 125 |
|
| 126 |
+
# + id="Aya60GeCGA5X"
|
| 127 |
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
|
| 128 |
t.nn.init.orthogonal_(layer.weight, std)
|
| 129 |
t.nn.init.constant_(layer.bias, bias_const)
|
|
|
|
| 155 |
|
| 156 |
|
| 157 |
|
| 158 |
+
# + id="6PwPZHlLGDYu"
|
| 159 |
# %%
|
| 160 |
@t.inference_mode()
|
| 161 |
def compute_advantages(
|
|
|
|
| 199 |
|
| 200 |
|
| 201 |
|
| 202 |
+
# + id="uYSSMnF-GPvm"
|
| 203 |
# %%
|
| 204 |
@dataclass
|
| 205 |
class Minibatch:
|
|
|
|
| 261 |
|
| 262 |
|
| 263 |
|
| 264 |
+
# + id="K7wXDJ9MGOWu"
|
| 265 |
# %%
|
| 266 |
def calc_policy_loss(
|
| 267 |
probs: Categorical, mb_action: t.Tensor, mb_advantages: t.Tensor,
|
|
|
|
| 286 |
|
| 287 |
|
| 288 |
|
| 289 |
+
# + id="CmyxU6JWGMsG"
|
| 290 |
# %%
|
| 291 |
def calc_value_function_loss(
|
| 292 |
critic: nn.Sequential, mb_obs: t.Tensor, mb_returns: t.Tensor, v_coef: float
|
|
|
|
| 303 |
|
| 304 |
|
| 305 |
|
| 306 |
+
# + id="npyWs6xjGLkP"
|
| 307 |
# %%
|
| 308 |
def calc_entropy_loss(probs: Categorical, ent_coef: float):
|
| 309 |
'''Return the entropy loss term.
|
|
|
|
| 319 |
test_calc_entropy_bonus(calc_entropy_loss)
|
| 320 |
|
| 321 |
|
| 322 |
+
# + id="nqJeg1kZGKSG"
|
| 323 |
# %%
|
| 324 |
class PPOScheduler:
|
| 325 |
def __init__(self, optimizer: optim.Adam, initial_lr: float, end_lr: float, num_updates: int):
|
|
|
|
| 354 |
|
| 355 |
|
| 356 |
|
| 357 |
+
# + id="mgZ7-wsRCxJW"
|
| 358 |
@dataclass
|
| 359 |
class PPOArgs:
|
| 360 |
exp_name: str = 'cartpole.py'
|
|
|
|
| 382 |
minibatch_size: int = 128
|
| 383 |
|
| 384 |
|
| 385 |
+
# + id="xeIu-J3ZwGyq"
|
| 386 |
def wandb_init(name: str, args: PPOArgs):
|
| 387 |
wandb.init(
|
| 388 |
project=args.wandb_project_name,
|
|
|
|
| 396 |
)
|
| 397 |
|
| 398 |
|
| 399 |
+
# + id="gMYWqhsryYHy"
|
| 400 |
def set_seed(seed: int):
|
| 401 |
random.seed(seed)
|
| 402 |
np.random.seed(seed)
|
| 403 |
torch.manual_seed(seed)
|
| 404 |
|
| 405 |
|
| 406 |
+
# + id="T9j_L0Wpyrgz"
|
| 407 |
@typechecked
|
| 408 |
def rollout_phase(
|
| 409 |
next_obs: t.Tensor, next_done: t.Tensor,
|
|
|
|
| 481 |
)
|
| 482 |
|
| 483 |
|
| 484 |
+
# + id="xdDhABIk5jyb"
|
| 485 |
def reset_env(envs, device):
|
| 486 |
next_obs = torch.Tensor(envs.reset()).to(device)
|
| 487 |
next_done = torch.zeros(envs.num_envs).to(device)
|
| 488 |
return next_obs, next_done
|
| 489 |
|
| 490 |
|
| 491 |
+
# + id="5CoMpUVU7rFT"
|
| 492 |
def get_action_shape(envs: gym.vector.SyncVectorEnv):
|
| 493 |
action_shape = envs.single_action_space.shape
|
| 494 |
assert action_shape is not None
|
|
|
|
| 498 |
return action_shape
|
| 499 |
|
| 500 |
|
| 501 |
+
# + id="FHmn5kSUGFFu"
|
| 502 |
# %%
|
| 503 |
def train_ppo(args: PPOArgs):
|
| 504 |
t0 = int(time.time())
|
|
|
|
| 637 |
args = PPOArgs()
|
| 638 |
train_ppo(args)
|
| 639 |
|
| 640 |
+
# + colab={"base_uri": "https://localhost:8080/"} id="xJW6KL7QIj4s" executionInfo={"status": "ok", "timestamp": 1677942639015, "user_tz": 0, "elapsed": 105286, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} outputId="7c529849-6d46-4a6a-def5-e1c0ef652c64"
|
| 641 |
# !python demo.py
|
| 642 |
|
| 643 |
+
# + id="P7ZfUlAqImIr"
|
| 644 |
+
# !pip freeze > requirements.txt
|
| 645 |
+
|
| 646 |
+
# + id="x_bhyL3GLnhr"
|
| 647 |
|