Evgueni Poloukarov Claude commited on
Commit
c85b8a5
·
1 Parent(s): ef3410d

fix: add PyTorch memory allocator config to prevent fragmentation

Browse files

Root Cause Analysis:
- 256h context reduction DID work: PyTorch allocation dropped from 20.42 GB to 7.98 GB
- OOM persisted due to memory fragmentation: 12.61 GB reserved but unallocated
- Error: Could not allocate 12.44 GB contiguous block for attention matrix
- PyTorch error message: "try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"

Fix Applied:
- Set os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' BEFORE importing torch
- This enables PyTorch's expandable memory segments feature
- Prevents fragmentation by allowing PyTorch to use non-contiguous memory blocks

Expected Outcome:
- The 12.61 GB of reserved memory can now be used for the 12.44 GB attention allocation
- Should fit comfortably within L4 24GB VRAM: 7.98 GB + 12.44 GB = 20.42 GB

Testing Plan:
1. Deploy to HF Space
2. Run smoke test (AT_CZ border, 7-day forecast)
3. If successful: Run full Oct 1-7 forecast (7 borders)
4. Calculate D+1 MAE and verify <150 MW threshold

References:
- PyTorch memory management: https://pytorch.org/docs/stable/notes/cuda.html#environment-variables
- Chronos-2 paper: Context window directly affects attention memory (O(n²))

Co-Authored-By: Claude <[email protected]>

src/forecasting/chronos_inference.py CHANGED
@@ -10,6 +10,12 @@ import os
10
  import time
11
  from typing import List, Dict, Optional
12
  from datetime import datetime, timedelta
 
 
 
 
 
 
13
  import polars as pl
14
  import pandas as pd
15
  import numpy as np
 
10
  import time
11
  from typing import List, Dict, Optional
12
  from datetime import datetime, timedelta
13
+
14
+ # CRITICAL: Set PyTorch memory allocator config BEFORE importing torch
15
+ # This prevents memory fragmentation issues that cause OOM even with sufficient free memory
16
+ # See: https://pytorch.org/docs/stable/notes/cuda.html#environment-variables
17
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
18
+
19
  import polars as pl
20
  import pandas as pd
21
  import numpy as np