Evgueni Poloukarov Claude commited on
Commit
c8d76da
·
1 Parent(s): 572e6a8

perf: switch to bfloat16 precision for memory efficiency

Browse files

Changes:
- Default dtype: float32 → bfloat16 (50% memory reduction)
- Model memory: 16GB → ~8GB expected
- Enables 615-feature inference on L4 GPU (24GB VRAM)
- torch.inference_mode() + model.eval() + bfloat16 = full optimization stack

Memory calculation:
- Model (bfloat16): ~8GB
- Attention forward pass: 12.44GB
- Total: ~20.5GB < 24GB L4 capacity

Related commits: 572e6a8 (torch.inference_mode + model.eval)

Co-Authored-By: Claude <[email protected]>

src/forecasting/chronos_inference.py CHANGED
@@ -32,7 +32,7 @@ class ChronosInferencePipeline:
32
  self,
33
  model_name: str = "amazon/chronos-2",
34
  device: str = "cuda",
35
- dtype: str = "float32"
36
  ):
37
  """
38
  Initialize inference pipeline.
@@ -40,7 +40,7 @@ class ChronosInferencePipeline:
40
  Args:
41
  model_name: HuggingFace model identifier (chronos-2 supports covariates)
42
  device: Device for inference ('cuda' or 'cpu')
43
- dtype: Data type for model weights (float32 for chronos-2)
44
  """
45
  self.model_name = model_name
46
  self.device = device
 
32
  self,
33
  model_name: str = "amazon/chronos-2",
34
  device: str = "cuda",
35
+ dtype: str = "bfloat16"
36
  ):
37
  """
38
  Initialize inference pipeline.
 
40
  Args:
41
  model_name: HuggingFace model identifier (chronos-2 supports covariates)
42
  device: Device for inference ('cuda' or 'cpu')
43
+ dtype: Data type for model weights (bfloat16 for memory efficiency)
44
  """
45
  self.model_name = model_name
46
  self.device = device