Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
29a0aca
1
Parent(s):
dba2df7
online_data_generation.py
CHANGED
|
@@ -541,8 +541,8 @@ def main():
|
|
| 541 |
posterior = autoencoder.encode(padding_tensor)
|
| 542 |
latent = posterior.sample()
|
| 543 |
latent = torch.zeros_like(latent).squeeze(0)
|
| 544 |
-
np.save(os.path.join(OUTPUT_DIR, 'padding.npy
|
| 545 |
-
os.rename(os.path.join(OUTPUT_DIR, 'padding.npy
|
| 546 |
# Initialize database
|
| 547 |
initialize_database()
|
| 548 |
|
|
|
|
| 541 |
posterior = autoencoder.encode(padding_tensor)
|
| 542 |
latent = posterior.sample()
|
| 543 |
latent = torch.zeros_like(latent).squeeze(0)
|
| 544 |
+
np.save(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), latent.cpu().numpy())
|
| 545 |
+
os.rename(os.path.join(OUTPUT_DIR, 'padding.tmp.npy'), os.path.join(OUTPUT_DIR, 'padding.npy'))
|
| 546 |
# Initialize database
|
| 547 |
initialize_database()
|
| 548 |
|