rahul7star commited on
Commit
5479a3e
Β·
verified Β·
1 Parent(s): 88e95a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -35
app.py CHANGED
@@ -8,7 +8,12 @@ import contextlib
8
  # βš™οΈ Setup
9
  # ============================================================
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- flash_attn = get_kernel("kernels-community/vllm-flash-attn3")
 
 
 
 
 
12
 
13
  # ============================================================
14
  # 🧠 Reference attention (PyTorch SDPA)
@@ -19,20 +24,6 @@ def reference_attention(query, key, value, causal=False):
19
  out = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=causal)
20
  return out.transpose(1, 2).contiguous()
21
 
22
- def var_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=False):
23
- batch_size = cu_seqlens_q.shape[0] - 1
24
- total_tokens_q = q.shape[0]
25
- out = torch.zeros((total_tokens_q, q.shape[1], q.shape[2]), device=q.device, dtype=q.dtype)
26
- for b in range(batch_size):
27
- start_q, end_q = cu_seqlens_q[b], cu_seqlens_q[b + 1]
28
- start_k, end_k = cu_seqlens_k[b], cu_seqlens_k[b + 1]
29
- q_slice = q[start_q:end_q].unsqueeze(0)
30
- k_slice = k[start_k:end_k].unsqueeze(0)
31
- v_slice = v[start_k:end_k].unsqueeze(0)
32
- attn_out = reference_attention(q_slice, k_slice, v_slice, causal=causal)
33
- out[start_q:end_q] = attn_out.squeeze(0)
34
- return out
35
-
36
  # ============================================================
37
  # πŸŽ›οΈ Test function
38
  # ============================================================
@@ -48,25 +39,30 @@ def run_flash_attention(B=2, S=5, H=4, D=8, seed=42):
48
  print(f"Running FlashAttention Tests on device: {device}")
49
  print(f"Input shape: B={B}, S={S}, H={H}, D={D}\n")
50
 
51
- # Standard attention
52
  out_ref = reference_attention(q, k, v)
53
- try:
54
- out_flash = flash_attn["flash_attn_func"](q, k, v, causal=False)
55
- except TypeError:
56
- out_flash, _ = flash_attn["flash_attn_func"](q, k, v, causal=False)
57
- print("1. Standard attention:")
58
- print(f" Reference: {out_ref.shape}, Flash: {out_flash.shape}")
59
- print(f" Outputs close: {torch.allclose(out_flash, out_ref, atol=1e-2, rtol=1e-3)}\n")
60
-
61
- # Causal attention
62
- out_ref_causal = reference_attention(q, k, v, causal=True)
63
- try:
64
- out_causal = flash_attn["flash_attn_func"](q, k, v, causal=True)
65
- except TypeError:
66
- out_causal, _ = flash_attn["flash_attn_func"](q, k, v, causal=True)
67
- print("2. Causal attention:")
68
- print(f" Reference: {out_ref_causal.shape}, Flash: {out_causal.shape}")
69
- print(f" Outputs close: {torch.allclose(out_causal, out_ref_causal, atol=1e-2, rtol=1e-3)}\n")
 
 
 
 
 
70
 
71
  return log.getvalue()
72
 
@@ -75,7 +71,7 @@ def run_flash_attention(B=2, S=5, H=4, D=8, seed=42):
75
  # ============================================================
76
  with gr.Blocks(title="Flash Attention Kernel Tester") as demo:
77
  gr.Markdown("## ⚑ Flash Attention Kernel Tester")
78
- gr.Markdown("Compare PyTorch SDPA vs FlashAttention implementations interactively.")
79
 
80
  with gr.Row():
81
  B = gr.Slider(1, 8, value=2, step=1, label="Batch Size (B)")
@@ -90,4 +86,4 @@ with gr.Blocks(title="Flash Attention Kernel Tester") as demo:
90
 
91
  run_btn.click(run_flash_attention, inputs=[B, S, H, D, seed], outputs=output)
92
 
93
- demo.launch(share=True)
 
8
  # βš™οΈ Setup
9
  # ============================================================
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ if torch.cuda.is_available():
13
+ flash_attn = get_kernel("kernels-community/vllm-flash-attn3")
14
+ flash_attn_func = flash_attn.flash_attn_func
15
+ else:
16
+ flash_attn_func = None # CPU fallback
17
 
18
  # ============================================================
19
  # 🧠 Reference attention (PyTorch SDPA)
 
24
  out = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=causal)
25
  return out.transpose(1, 2).contiguous()
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # ============================================================
28
  # πŸŽ›οΈ Test function
29
  # ============================================================
 
39
  print(f"Running FlashAttention Tests on device: {device}")
40
  print(f"Input shape: B={B}, S={S}, H={H}, D={D}\n")
41
 
42
+ # Always run PyTorch reference attention
43
  out_ref = reference_attention(q, k, v)
44
+ print(f"βœ… Reference attention OK: {out_ref.shape}\n")
45
+
46
+ # Run FlashAttention if CUDA available
47
+ if flash_attn_func is not None:
48
+ try:
49
+ out_flash, _ = flash_attn_func(q, k, v, causal=False)
50
+ print("⚑ FlashAttention (non-causal):")
51
+ print(f" Output: {out_flash.shape}")
52
+ print(f" Close to reference: {torch.allclose(out_flash, out_ref, atol=1e-2, rtol=1e-3)}\n")
53
+
54
+ out_flash_causal, _ = flash_attn_func(q, k, v, causal=True)
55
+ out_ref_causal = reference_attention(q, k, v, causal=True)
56
+ print("⚑ FlashAttention (causal):")
57
+ print(f" Output: {out_flash_causal.shape}")
58
+ print(f" Close to reference: {torch.allclose(out_flash_causal, out_ref_causal, atol=1e-2, rtol=1e-3)}\n")
59
+
60
+ except Exception as e:
61
+ print("❌ FlashAttention test failed:")
62
+ print(str(e))
63
+ else:
64
+ print("⚠️ CUDA not available β€” FlashAttention test skipped.")
65
+ print("Using PyTorch SDPA (reference) only.\n")
66
 
67
  return log.getvalue()
68
 
 
71
  # ============================================================
72
  with gr.Blocks(title="Flash Attention Kernel Tester") as demo:
73
  gr.Markdown("## ⚑ Flash Attention Kernel Tester")
74
+ gr.Markdown("Compare PyTorch SDPA vs FlashAttention implementations interactively. Works on both CPU and CUDA.")
75
 
76
  with gr.Row():
77
  B = gr.Slider(1, 8, value=2, step=1, label="Batch Size (B)")
 
86
 
87
  run_btn.click(run_flash_attention, inputs=[B, S, H, D, seed], outputs=output)
88
 
89
+ demo.launch()