Spaces:
Sleeping
Sleeping
File size: 6,204 Bytes
44b73f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
"""
Smoke test for zero-shot inference pipeline
Tests:
1. Data loading and preparation
2. Chronos 2 model loading
3. Inference on single border (7 days)
4. Output validation
5. Performance metrics
"""
import sys
from pathlib import Path
# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
from inference.data_fetcher import DataFetcher
from inference.chronos_pipeline import ChronosForecaster
from datetime import datetime, timedelta
import torch
import pandas as pd
def main():
print("="*60)
print("FBMC Chronos 2 Zero-Shot Inference - Smoke Test")
print("="*60)
# Step 1: Check environment
print("\n[1] Checking environment...")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
print("Running on CPU (inference will be slower)")
# Step 2: Initialize DataFetcher
print("\n[2] Initializing DataFetcher...")
fetcher = DataFetcher(
use_local=True, # Use local files for testing
context_length=512 # Use 512 hours context
)
# Step 3: Load data
print("\n[3] Loading unified features...")
fetcher.load_data()
# Get available date range
min_date, max_date = fetcher.get_available_dates()
print(f"Available data: {min_date} to {max_date}")
# Select forecast date (use last month as test)
forecast_date = max_date - timedelta(days=30)
print(f"Test forecast date: {forecast_date}")
# Step 4: Prepare inference data (single border, 7 days)
print("\n[4] Preparing inference data (1 border, 7 days)...")
test_border = fetcher.target_borders[0] # Use first border
print(f"Test border: {test_border}")
context_df, future_df = fetcher.prepare_inference_data(
forecast_date=forecast_date,
prediction_length=168, # 7 days
borders=[test_border]
)
print(f"Context shape: {context_df.shape}")
print(f"Future shape: {future_df.shape}")
# Validate data
print("\n[5] Validating prepared data...")
assert 'timestamp' in context_df.columns, "Missing timestamp column"
assert 'border' in context_df.columns, "Missing border column"
assert 'target' in context_df.columns, "Missing target column"
assert len(context_df) > 0, "Empty context data"
assert len(future_df) > 0, "Empty future data"
print("[+] Data validation passed!")
# Check for NaN values
context_nulls = context_df.isnull().sum().sum()
future_nulls = future_df.isnull().sum().sum()
print(f"Context NaN count: {context_nulls}")
print(f"Future NaN count: {future_nulls}")
if context_nulls > 0 or future_nulls > 0:
print("[!] Warning: Data contains NaN values (will be handled by model)")
# Step 6: Initialize Chronos 2 forecaster
print("\n[6] Initializing Chronos 2 forecaster...")
forecaster = ChronosForecaster(
model_name="amazon/chronos-2-large",
device="auto" # Will use GPU if available
)
# Step 7: Load model
print("\n[7] Loading Chronos 2 Large model...")
print("(This may take a few minutes on first load)")
forecaster.load_model()
print("[+] Model loaded successfully!")
# Step 8: Run inference
print("\n[8] Running zero-shot inference...")
print(f"Forecasting {test_border} for 7 days (168 hours)")
forecasts = forecaster.predict_single_border(
border=test_border,
context_df=context_df,
future_df=future_df,
prediction_length=168,
num_samples=100 # 100 samples for probabilistic forecast
)
print(f"[+] Inference complete! Forecast shape: {forecasts.shape}")
# Step 9: Validate forecasts
print("\n[9] Validating forecasts...")
assert len(forecasts) > 0, "Empty forecasts"
assert 'timestamp' in forecasts.columns or forecasts.index.name == 'timestamp', "Missing timestamp"
# Check for reasonable values
if 'mean' in forecasts.columns:
mean_forecast = forecasts['mean']
print(f"Forecast statistics:")
print(f" Mean: {mean_forecast.mean():.2f} MW")
print(f" Min: {mean_forecast.min():.2f} MW")
print(f" Max: {mean_forecast.max():.2f} MW")
print(f" Std: {mean_forecast.std():.2f} MW")
# Sanity check: values should be reasonable for power capacity
assert mean_forecast.min() >= 0, "Negative forecasts detected"
assert mean_forecast.max() < 20000, "Unreasonably high forecasts"
print("[+] Forecast validation passed!")
# Step 10: Benchmark performance
print("\n[10] Benchmarking inference performance...")
metrics = forecaster.benchmark_inference(
context_df=context_df,
future_df=future_df,
prediction_length=168
)
print(f"Performance metrics:")
for key, value in metrics.items():
print(f" {key}: {value}")
# Check if we meet the 5-minute target (for 14 days)
# Scale to 14-day estimate
estimated_14d_time = metrics['inference_time_sec'] * (336 / 168)
print(f"\nEstimated time for 14-day forecast: {estimated_14d_time:.1f}s ({estimated_14d_time/60:.1f} min)")
if estimated_14d_time < 300: # 5 minutes
print("[+] Performance target met! (<5 min for 14 days)")
else:
print("[!] Warning: May not meet 5-minute target for 14 days")
# Step 11: Save test forecasts
print("\n[11] Saving test forecasts...")
output_path = "data/evaluation/smoke_test_forecast.parquet"
forecaster.save_forecasts(forecasts, output_path)
print(f"[+] Saved to: {output_path}")
# Summary
print("\n" + "="*60)
print("SMOKE TEST SUMMARY")
print("="*60)
print("[+] All tests passed!")
print(f"[+] Border: {test_border}")
print(f"[+] Forecast length: 168 hours (7 days)")
print(f"[+] Inference time: {metrics['inference_time_sec']:.1f}s")
print(f"[+] Output shape: {forecasts.shape}")
print("\n[+] Ready for full inference run!")
print("="*60)
if __name__ == "__main__":
main()
|