NariLabs commited on
Commit
a087083
·
verified ·
1 Parent(s): 04d9952

Update dia2/runtime/context.py

Browse files
Files changed (1) hide show
  1. dia2/runtime/context.py +15 -9
dia2/runtime/context.py CHANGED
@@ -46,7 +46,6 @@ def build_runtime(
46
  device_obj = torch.device(device)
47
  if device_obj.type == "cuda":
48
  cuda_matmul = torch.backends.cuda.matmul
49
- cudnn_conv = torch.backends.cudnn.conv
50
  if hasattr(cuda_matmul, "fp32_precision"):
51
  cuda_matmul.fp32_precision = "tf32"
52
  with warnings.catch_warnings():
@@ -57,15 +56,22 @@ def build_runtime(
57
  torch.backends.cuda.matmul.allow_tf32 = True
58
  else: # pragma: no cover - compatibility with older PyTorch
59
  torch.backends.cuda.matmul.allow_tf32 = True
60
- if hasattr(cudnn_conv, "fp32_precision"):
61
- cudnn_conv.fp32_precision = "tf32"
62
- with warnings.catch_warnings():
63
- warnings.filterwarnings(
64
- "ignore",
65
- message="Please use the new API settings",
66
- )
 
 
 
 
 
 
67
  torch.backends.cudnn.allow_tf32 = True
68
- else: # pragma: no cover
 
69
  torch.backends.cudnn.allow_tf32 = True
70
  precision = resolve_precision(dtype_pref, device_obj)
71
  config = load_config(config_path)
 
46
  device_obj = torch.device(device)
47
  if device_obj.type == "cuda":
48
  cuda_matmul = torch.backends.cuda.matmul
 
49
  if hasattr(cuda_matmul, "fp32_precision"):
50
  cuda_matmul.fp32_precision = "tf32"
51
  with warnings.catch_warnings():
 
56
  torch.backends.cuda.matmul.allow_tf32 = True
57
  else: # pragma: no cover - compatibility with older PyTorch
58
  torch.backends.cuda.matmul.allow_tf32 = True
59
+
60
+ # Handle cuDNN conv TF32 settings (check if conv attribute exists first)
61
+ if hasattr(torch.backends.cudnn, "conv"):
62
+ cudnn_conv = torch.backends.cudnn.conv
63
+ if hasattr(cudnn_conv, "fp32_precision"):
64
+ cudnn_conv.fp32_precision = "tf32"
65
+ with warnings.catch_warnings():
66
+ warnings.filterwarnings(
67
+ "ignore",
68
+ message="Please use the new API settings",
69
+ )
70
+ torch.backends.cudnn.allow_tf32 = True
71
+ else:
72
  torch.backends.cudnn.allow_tf32 = True
73
+ else:
74
+ # For older PyTorch versions without the conv attribute
75
  torch.backends.cudnn.allow_tf32 = True
76
  precision = resolve_precision(dtype_pref, device_obj)
77
  config = load_config(config_path)