Evgueni Poloukarov Claude commited on
Commit
2d135b5
·
1 Parent(s): dc9b9db

fix: implement sub-batching to avoid CUDA OOM on T4 GPU

Browse files

Problem:
- Batch of 38 borders requires 762 MB GPU memory
- T4 GPU has only 534 MB free after model load (14.22 GB used)
- CUDA out of memory error

Solution:
- Process borders in sub-batches of 10 (4 sub-batches total)
- Clear GPU cache between sub-batches
- Still much faster than sequential (4x10 vs 38x1)

Implementation:
- Split contexts into sub-batches of SUB_BATCH_SIZE=10
- Process each sub-batch independently
- Store all forecasts and process quantiles after
- Expected time: ~8-10 seconds (vs 60 min sequential)

This balances GPU memory constraints with batch processing speedup.

Co-Authored-By: Claude <[email protected]>

Files changed (1) hide show
  1. src/forecasting/chronos_inference.py +113 -77
src/forecasting/chronos_inference.py CHANGED
@@ -159,10 +159,13 @@ class ChronosInferencePipeline:
159
 
160
  total_start = time.time()
161
 
162
- # BATCH INFERENCE: Collect all contexts first
 
 
 
163
  print(f"\n[BATCH] Preparing contexts for {len(forecast_borders)} borders...")
164
- batch_contexts = []
165
- border_names = []
166
 
167
  for i, border in enumerate(forecast_borders, 1):
168
  print(f" [{i}/{len(forecast_borders)}] Extracting context for {border}...", flush=True)
@@ -178,8 +181,8 @@ class ChronosInferencePipeline:
178
 
179
  # Extract context values and convert to PyTorch tensor
180
  context = torch.from_numpy(context_data[target_col].values).float()
181
- batch_contexts.append(context)
182
- border_names.append(border)
183
 
184
  except Exception as e:
185
  import traceback
@@ -188,83 +191,116 @@ class ChronosInferencePipeline:
188
  print(f" [ERROR] {border}: {error_msg}", flush=True)
189
  results['borders'][border] = {'error': error_msg, 'traceback': traceback_str}
190
 
191
- # Stack all contexts into a batch
192
- if batch_contexts:
193
- batch_tensor = torch.stack(batch_contexts) # Shape: (num_borders, context_hours)
194
- print(f"\n[BATCH] Running inference on batch of {batch_tensor.shape[0]} borders...")
195
- print(f"[BATCH] Batch shape: {batch_tensor.shape}", flush=True)
196
 
197
- inference_start = time.time()
198
 
199
- # Run batch inference
200
- batch_forecasts = pipeline.predict(
201
- inputs=batch_tensor, # Chronos API uses 'inputs'
202
- prediction_length=prediction_hours,
203
- num_samples=num_samples
204
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- inference_time = time.time() - inference_start
207
- print(f"[BATCH] Inference complete in {inference_time:.1f}s ({inference_time/len(border_names):.2f}s per border)")
208
- print(f"[BATCH] Forecast shape: {batch_forecasts.shape}", flush=True)
209
 
210
  # Process each border's forecast
211
- for i, border in enumerate(border_names):
212
- print(f"\n[{i+1}/{len(border_names)}] Processing forecast for {border}...", flush=True)
213
- border_start = time.time()
214
-
215
- try:
216
- # Extract this border's forecast from batch
217
- forecast = batch_forecasts[i] # Extract from batch dimension
218
-
219
- # Calculate quantiles
220
- forecast_numpy = forecast.numpy()
221
- print(f"[DEBUG] Raw forecast shape: {forecast_numpy.shape}", flush=True)
222
-
223
- # Chronos may return (batch, num_samples, time) or (num_samples, time)
224
- # Squeeze any batch dimension (if present)
225
- if forecast_numpy.ndim == 3:
226
- print(f"[DEBUG] 3D forecast detected, squeezing batch dimension", flush=True)
227
- forecast_numpy = forecast_numpy.squeeze(axis=0) # Remove batch dim
228
-
229
- print(f"[DEBUG] Forecast shape after squeeze: {forecast_numpy.shape}, Expected: ({num_samples}, {prediction_hours}) or ({prediction_hours}, {num_samples})", flush=True)
230
-
231
- # Now forecast should be 2D: either (num_samples, time) or (time, num_samples)
232
- # Compute median along samples axis to get (time,) shape
233
- if forecast_numpy.shape[0] == num_samples and forecast_numpy.shape[1] == prediction_hours:
234
- # Shape is (num_samples, time) - use axis=0
235
- print(f"[DEBUG] Using axis=0 for shape (num_samples={num_samples}, time={prediction_hours})", flush=True)
236
- median = np.median(forecast_numpy, axis=0)
237
- q10 = np.quantile(forecast_numpy, 0.1, axis=0)
238
- q90 = np.quantile(forecast_numpy, 0.9, axis=0)
239
- elif forecast_numpy.shape[0] == prediction_hours and forecast_numpy.shape[1] == num_samples:
240
- # Shape is (time, num_samples) - use axis=1
241
- print(f"[DEBUG] Using axis=1 for shape (time={prediction_hours}, num_samples={num_samples})", flush=True)
242
- median = np.median(forecast_numpy, axis=1)
243
- q10 = np.quantile(forecast_numpy, 0.1, axis=1)
244
- q90 = np.quantile(forecast_numpy, 0.9, axis=1)
245
- else:
246
- raise ValueError(f"Unexpected forecast shape: {forecast_numpy.shape}, expected ({num_samples}, {prediction_hours}) or ({prediction_hours}, {num_samples})")
247
-
248
- print(f"[DEBUG] Final median shape: {median.shape}, Expected: ({prediction_hours},)", flush=True)
249
- assert median.shape == (prediction_hours,), f"Median shape {median.shape} != expected ({prediction_hours},)"
250
-
251
- # Store results
252
- results['borders'][border] = {
253
- 'median': median.tolist(),
254
- 'q10': q10.tolist(),
255
- 'q90': q90.tolist(),
256
- 'inference_time_s': time.time() - border_start
257
- }
258
-
259
- print(f" [OK] Complete in {time.time() - border_start:.1f}s")
260
-
261
- except Exception as e:
262
- import traceback
263
- error_msg = f"{type(e).__name__}: {str(e)}"
264
- traceback_str = traceback.format_exc()
265
- print(f" [ERROR] {error_msg}", flush=True)
266
- print(f"Traceback:\n{traceback_str}", flush=True)
267
- results['borders'][border] = {'error': error_msg, 'traceback': traceback_str}
 
 
 
 
 
 
 
268
 
269
  # Add summary metadata
270
  results['metadata']['total_time_s'] = time.time() - total_start
 
159
 
160
  total_start = time.time()
161
 
162
+ # SUB-BATCH INFERENCE: Process borders in chunks to fit GPU memory
163
+ # T4 GPU has 14.74 GB total, model uses ~14 GB, so we need small batches
164
+ SUB_BATCH_SIZE = 10 # Process 10 borders at a time
165
+
166
  print(f"\n[BATCH] Preparing contexts for {len(forecast_borders)} borders...")
167
+ all_contexts = []
168
+ all_border_names = []
169
 
170
  for i, border in enumerate(forecast_borders, 1):
171
  print(f" [{i}/{len(forecast_borders)}] Extracting context for {border}...", flush=True)
 
181
 
182
  # Extract context values and convert to PyTorch tensor
183
  context = torch.from_numpy(context_data[target_col].values).float()
184
+ all_contexts.append(context)
185
+ all_border_names.append(border)
186
 
187
  except Exception as e:
188
  import traceback
 
191
  print(f" [ERROR] {border}: {error_msg}", flush=True)
192
  results['borders'][border] = {'error': error_msg, 'traceback': traceback_str}
193
 
194
+ # Process contexts in sub-batches
195
+ if all_contexts:
196
+ num_contexts = len(all_contexts)
197
+ num_sub_batches = (num_contexts + SUB_BATCH_SIZE - 1) // SUB_BATCH_SIZE
 
198
 
199
+ print(f"\n[BATCH] Running inference in {num_sub_batches} sub-batches of {SUB_BATCH_SIZE} borders...")
200
 
201
+ all_forecasts = []
202
+ total_inference_time = 0
203
+
204
+ for batch_idx in range(num_sub_batches):
205
+ start_idx = batch_idx * SUB_BATCH_SIZE
206
+ end_idx = min(start_idx + SUB_BATCH_SIZE, num_contexts)
207
+
208
+ # Get sub-batch
209
+ sub_batch_contexts = all_contexts[start_idx:end_idx]
210
+ sub_batch_names = all_border_names[start_idx:end_idx]
211
+
212
+ batch_tensor = torch.stack(sub_batch_contexts)
213
+ print(f"[BATCH {batch_idx+1}/{num_sub_batches}] Processing {len(sub_batch_names)} borders: {sub_batch_names[0]} ... {sub_batch_names[-1]}", flush=True)
214
+ print(f"[BATCH {batch_idx+1}/{num_sub_batches}] Batch shape: {batch_tensor.shape}", flush=True)
215
+
216
+ inference_start = time.time()
217
+
218
+ # Run batch inference
219
+ batch_forecasts = pipeline.predict(
220
+ inputs=batch_tensor,
221
+ prediction_length=prediction_hours,
222
+ num_samples=num_samples
223
+ )
224
+
225
+ inference_time = time.time() - inference_start
226
+ total_inference_time += inference_time
227
+ print(f"[BATCH {batch_idx+1}/{num_sub_batches}] Complete in {inference_time:.1f}s ({inference_time/len(sub_batch_names):.2f}s per border)", flush=True)
228
+
229
+ # Store forecasts
230
+ all_forecasts.append(batch_forecasts)
231
+
232
+ # Clear GPU cache between sub-batches
233
+ if torch.cuda.is_available():
234
+ torch.cuda.empty_cache()
235
 
236
+ print(f"\n[BATCH] All inference complete in {total_inference_time:.1f}s total")
237
+ print(f"[BATCH] Average: {total_inference_time/num_contexts:.2f}s per border")
 
238
 
239
  # Process each border's forecast
240
+ forecast_idx = 0
241
+ for batch_idx, batch_forecasts in enumerate(all_forecasts):
242
+ start_idx = batch_idx * SUB_BATCH_SIZE
243
+ end_idx = min(start_idx + SUB_BATCH_SIZE, num_contexts)
244
+ sub_batch_names = all_border_names[start_idx:end_idx]
245
+
246
+ for i, border in enumerate(sub_batch_names):
247
+ forecast_idx += 1
248
+ print(f"\n[{forecast_idx}/{num_contexts}] Processing forecast for {border}...", flush=True)
249
+ border_start = time.time()
250
+
251
+ try:
252
+ # Extract this border's forecast from batch
253
+ forecast = batch_forecasts[i] # Extract from batch dimension
254
+
255
+ # Calculate quantiles
256
+ forecast_numpy = forecast.numpy()
257
+ print(f"[DEBUG] Raw forecast shape: {forecast_numpy.shape}", flush=True)
258
+
259
+ # Chronos may return (batch, num_samples, time) or (num_samples, time)
260
+ # Squeeze any batch dimension (if present)
261
+ if forecast_numpy.ndim == 3:
262
+ print(f"[DEBUG] 3D forecast detected, squeezing batch dimension", flush=True)
263
+ forecast_numpy = forecast_numpy.squeeze(axis=0) # Remove batch dim
264
+
265
+ print(f"[DEBUG] Forecast shape after squeeze: {forecast_numpy.shape}, Expected: ({num_samples}, {prediction_hours}) or ({prediction_hours}, {num_samples})", flush=True)
266
+
267
+ # Now forecast should be 2D: either (num_samples, time) or (time, num_samples)
268
+ # Compute median along samples axis to get (time,) shape
269
+ if forecast_numpy.shape[0] == num_samples and forecast_numpy.shape[1] == prediction_hours:
270
+ # Shape is (num_samples, time) - use axis=0
271
+ print(f"[DEBUG] Using axis=0 for shape (num_samples={num_samples}, time={prediction_hours})", flush=True)
272
+ median = np.median(forecast_numpy, axis=0)
273
+ q10 = np.quantile(forecast_numpy, 0.1, axis=0)
274
+ q90 = np.quantile(forecast_numpy, 0.9, axis=0)
275
+ elif forecast_numpy.shape[0] == prediction_hours and forecast_numpy.shape[1] == num_samples:
276
+ # Shape is (time, num_samples) - use axis=1
277
+ print(f"[DEBUG] Using axis=1 for shape (time={prediction_hours}, num_samples={num_samples})", flush=True)
278
+ median = np.median(forecast_numpy, axis=1)
279
+ q10 = np.quantile(forecast_numpy, 0.1, axis=1)
280
+ q90 = np.quantile(forecast_numpy, 0.9, axis=1)
281
+ else:
282
+ raise ValueError(f"Unexpected forecast shape: {forecast_numpy.shape}, expected ({num_samples}, {prediction_hours}) or ({prediction_hours}, {num_samples})")
283
+
284
+ print(f"[DEBUG] Final median shape: {median.shape}, Expected: ({prediction_hours},)", flush=True)
285
+ assert median.shape == (prediction_hours,), f"Median shape {median.shape} != expected ({prediction_hours},)"
286
+
287
+ # Store results
288
+ results['borders'][border] = {
289
+ 'median': median.tolist(),
290
+ 'q10': q10.tolist(),
291
+ 'q90': q90.tolist(),
292
+ 'inference_time_s': time.time() - border_start
293
+ }
294
+
295
+ print(f" [OK] Complete in {time.time() - border_start:.1f}s")
296
+
297
+ except Exception as e:
298
+ import traceback
299
+ error_msg = f"{type(e).__name__}: {str(e)}"
300
+ traceback_str = traceback.format_exc()
301
+ print(f" [ERROR] {error_msg}", flush=True)
302
+ print(f"Traceback:\n{traceback_str}", flush=True)
303
+ results['borders'][border] = {'error': error_msg, 'traceback': traceback_str}
304
 
305
  # Add summary metadata
306
  results['metadata']['total_time_s'] = time.time() - total_start