rahul7star commited on
Commit
88e95a7
Β·
verified Β·
1 Parent(s): 53ed132

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -37,24 +37,33 @@ def var_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, m
37
  # πŸŽ›οΈ Test function
38
  # ============================================================
39
  def run_flash_attention(B=2, S=5, H=4, D=8, seed=42):
40
- torch.manual_seed(seed)
41
- q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.bfloat16)
42
- log = io.StringIO()
 
 
43
 
 
44
  with contextlib.redirect_stdout(log):
45
  print(f"Running FlashAttention Tests on device: {device}")
46
  print(f"Input shape: B={B}, S={S}, H={H}, D={D}\n")
47
 
48
  # Standard attention
49
  out_ref = reference_attention(q, k, v)
50
- out_flash, _ = flash_attn.flash_attn_func(q, k, v, causal=False)
 
 
 
51
  print("1. Standard attention:")
52
  print(f" Reference: {out_ref.shape}, Flash: {out_flash.shape}")
53
  print(f" Outputs close: {torch.allclose(out_flash, out_ref, atol=1e-2, rtol=1e-3)}\n")
54
 
55
  # Causal attention
56
  out_ref_causal = reference_attention(q, k, v, causal=True)
57
- out_causal, _ = flash_attn.flash_attn_func(q, k, v, causal=True)
 
 
 
58
  print("2. Causal attention:")
59
  print(f" Reference: {out_ref_causal.shape}, Flash: {out_causal.shape}")
60
  print(f" Outputs close: {torch.allclose(out_causal, out_ref_causal, atol=1e-2, rtol=1e-3)}\n")
@@ -66,7 +75,7 @@ def run_flash_attention(B=2, S=5, H=4, D=8, seed=42):
66
  # ============================================================
67
  with gr.Blocks(title="Flash Attention Kernel Tester") as demo:
68
  gr.Markdown("## ⚑ Flash Attention Kernel Tester")
69
- gr.Markdown("Run reference vs FlashAttention comparisons interactively.")
70
 
71
  with gr.Row():
72
  B = gr.Slider(1, 8, value=2, step=1, label="Batch Size (B)")
@@ -77,7 +86,7 @@ with gr.Blocks(title="Flash Attention Kernel Tester") as demo:
77
  seed = gr.Number(value=42, label="Random Seed")
78
 
79
  run_btn = gr.Button("πŸš€ Run Tests")
80
- output = gr.Textbox(label="Console Output", lines=25)
81
 
82
  run_btn.click(run_flash_attention, inputs=[B, S, H, D, seed], outputs=output)
83
 
 
37
  # πŸŽ›οΈ Test function
38
  # ============================================================
39
  def run_flash_attention(B=2, S=5, H=4, D=8, seed=42):
40
+ B, S, H, D = int(B), int(S), int(H), int(D)
41
+ torch.manual_seed(int(seed))
42
+
43
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
44
+ q = k = v = torch.randn(B, S, H, D, device=device, dtype=dtype)
45
 
46
+ log = io.StringIO()
47
  with contextlib.redirect_stdout(log):
48
  print(f"Running FlashAttention Tests on device: {device}")
49
  print(f"Input shape: B={B}, S={S}, H={H}, D={D}\n")
50
 
51
  # Standard attention
52
  out_ref = reference_attention(q, k, v)
53
+ try:
54
+ out_flash = flash_attn["flash_attn_func"](q, k, v, causal=False)
55
+ except TypeError:
56
+ out_flash, _ = flash_attn["flash_attn_func"](q, k, v, causal=False)
57
  print("1. Standard attention:")
58
  print(f" Reference: {out_ref.shape}, Flash: {out_flash.shape}")
59
  print(f" Outputs close: {torch.allclose(out_flash, out_ref, atol=1e-2, rtol=1e-3)}\n")
60
 
61
  # Causal attention
62
  out_ref_causal = reference_attention(q, k, v, causal=True)
63
+ try:
64
+ out_causal = flash_attn["flash_attn_func"](q, k, v, causal=True)
65
+ except TypeError:
66
+ out_causal, _ = flash_attn["flash_attn_func"](q, k, v, causal=True)
67
  print("2. Causal attention:")
68
  print(f" Reference: {out_ref_causal.shape}, Flash: {out_causal.shape}")
69
  print(f" Outputs close: {torch.allclose(out_causal, out_ref_causal, atol=1e-2, rtol=1e-3)}\n")
 
75
  # ============================================================
76
  with gr.Blocks(title="Flash Attention Kernel Tester") as demo:
77
  gr.Markdown("## ⚑ Flash Attention Kernel Tester")
78
+ gr.Markdown("Compare PyTorch SDPA vs FlashAttention implementations interactively.")
79
 
80
  with gr.Row():
81
  B = gr.Slider(1, 8, value=2, step=1, label="Batch Size (B)")
 
86
  seed = gr.Number(value=42, label="Random Seed")
87
 
88
  run_btn = gr.Button("πŸš€ Run Tests")
89
+ output = gr.Textbox(label="Console Output", lines=25, show_copy_button=True)
90
 
91
  run_btn.click(run_flash_attention, inputs=[B, S, H, D, seed], outputs=output)
92