Flash-Attn-Demo / app.py
rahul7star's picture
Update app.py
5479a3e verified
import gradio as gr
import torch
from kernels import get_kernel
import io
import contextlib
# ============================================================
# βš™οΈ Setup
# ============================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
flash_attn = get_kernel("kernels-community/vllm-flash-attn3")
flash_attn_func = flash_attn.flash_attn_func
else:
flash_attn_func = None # CPU fallback
# ============================================================
# 🧠 Reference attention (PyTorch SDPA)
# ============================================================
def reference_attention(query, key, value, causal=False):
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
out = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=causal)
return out.transpose(1, 2).contiguous()
# ============================================================
# πŸŽ›οΈ Test function
# ============================================================
def run_flash_attention(B=2, S=5, H=4, D=8, seed=42):
B, S, H, D = int(B), int(S), int(H), int(D)
torch.manual_seed(int(seed))
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
q = k = v = torch.randn(B, S, H, D, device=device, dtype=dtype)
log = io.StringIO()
with contextlib.redirect_stdout(log):
print(f"Running FlashAttention Tests on device: {device}")
print(f"Input shape: B={B}, S={S}, H={H}, D={D}\n")
# Always run PyTorch reference attention
out_ref = reference_attention(q, k, v)
print(f"βœ… Reference attention OK: {out_ref.shape}\n")
# Run FlashAttention if CUDA available
if flash_attn_func is not None:
try:
out_flash, _ = flash_attn_func(q, k, v, causal=False)
print("⚑ FlashAttention (non-causal):")
print(f" Output: {out_flash.shape}")
print(f" Close to reference: {torch.allclose(out_flash, out_ref, atol=1e-2, rtol=1e-3)}\n")
out_flash_causal, _ = flash_attn_func(q, k, v, causal=True)
out_ref_causal = reference_attention(q, k, v, causal=True)
print("⚑ FlashAttention (causal):")
print(f" Output: {out_flash_causal.shape}")
print(f" Close to reference: {torch.allclose(out_flash_causal, out_ref_causal, atol=1e-2, rtol=1e-3)}\n")
except Exception as e:
print("❌ FlashAttention test failed:")
print(str(e))
else:
print("⚠️ CUDA not available β€” FlashAttention test skipped.")
print("Using PyTorch SDPA (reference) only.\n")
return log.getvalue()
# ============================================================
# 🧩 Gradio UI
# ============================================================
with gr.Blocks(title="Flash Attention Kernel Tester") as demo:
gr.Markdown("## ⚑ Flash Attention Kernel Tester")
gr.Markdown("Compare PyTorch SDPA vs FlashAttention implementations interactively. Works on both CPU and CUDA.")
with gr.Row():
B = gr.Slider(1, 8, value=2, step=1, label="Batch Size (B)")
S = gr.Slider(2, 10, value=5, step=1, label="Sequence Length (S)")
with gr.Row():
H = gr.Slider(1, 8, value=4, step=1, label="Number of Heads (H)")
D = gr.Slider(4, 64, value=8, step=4, label="Head Dim (D)")
seed = gr.Number(value=42, label="Random Seed")
run_btn = gr.Button("πŸš€ Run Tests")
output = gr.Textbox(label="Console Output", lines=25, show_copy_button=True)
run_btn.click(run_flash_attention, inputs=[B, S, H, D, seed], outputs=output)
demo.launch()