FlashAttention: Algorithm and Pseudocode

How Tiling and Streaming Softmax Enable this GPU Kernel

January 27, 2026

FlashAttention computes O by tiling Q, K, and V.

This blog post will explain the FlashAttention algorithm, showing how it builds upon the concepts from two prior blog posts: tiling, and streaming softmax. More specifically, this post will focus on a "whiteboard-level" understanding of the algorithm, and, where helpful, will link out to a Python implementation that mimics how the algorithm would be coded in CUDA; we'll save a full CUDA walkthrough for a future blog post (though readers may find it a good exercise to implement FlashAttention in CUDA after reading this blog post, using the Python implementation as a starting point).

What is FlashAttention? Background and historical context

FlashAttention is a drop-in replacement for many of the steps of the Multi-Head Attention operation, which itself is the foundational building block for modeling sequences introduced in Attention is All You Need and is described in more detail (with a couple reference implementations) in a prior blog post here. It was released in mid-2022 by a team led by Tri Dao, then at Stanford. This release, now known as "FlashAttention V1", already provided a13% speedup over highly-optimized CUDA kernels written by NVIDIA for their own GPUs! These impressive speedups were achieved because FA V1 was the first to apply the concepts covered in this series of posts - tiling and streaming softmax - to the core Multi-Head Attention operation. Subsequent versions have employed much more sophisticated software features of GPUs, some of which only work on the most recent NVIDIA GPU generations, and have had varying degrees of involvement from Prof. Dao1, but the basic ideas introduced in V1 continue to be core and are what produced the original "great leap forward" in Multi-Head Attention performance on GPUs.

How FlashAttention fits into Multi-Head Attention

FlashAttention is sometimes described as "a drop-in replacement for attention". That is not entirely accurate; it replaces many of the steps involved in Multi-Head Attention, but not all. Below is a reminder of all the steps involved in Multi-Head Attention:

Standard Multi-Head Attention requires computing entire attention matrices within each head.

FlashAttention fuses a particular key sequence of these operations into a single CUDA kernel; in particular, standard Multi-Head Attention requires "materializing" - defining and holding as one object in memory - entire "sequence_length × sequence_length" attention matrices within each head, whereas FlashAttention is able to compute the same final result while avoiding this memory-intensive materialization.

FlashAttention fuses several operations, avoiding ever materializing an attention matrix.

Note that, for example, the final step of multiplying the concatenated outputs \(W_O\) is not part of the FlashAttention kernel.

So how does it do it?

FlashAttention v1: Preliminaries

We'll start with three high level comments on FlashAttention that can keep you oriented when going through the specific steps of the algorithm:

1. The "tiling + accumulation" approach used in FlashAttention is like the one used in matrix multiplication, but with three input matrices rather than two.

In this blog post, we describe a technique for computing the output of an operation involving two input matrices by:

  1. Dividing the input matrices into "tiles".
  2. Successively feeding these into fast "shared" memory on the GPU
  3. For each set of tiles loaded in, do some computation to add this to an "accumulation" storing the results of all the prior tile computations in output

The structure of what we do in FlashAttention is very similar (which is why reading through and making sure you understand that blog post is great scaffolding for understanding FlashAttention), except now we have three input matrices - \(Q\), \(K\), and \(V\) - and the operation we're trying to "emulate in a tiled fashion" is the attention operation:

\[ \operatorname{Attention}(Q, K, V) = \operatorname{softmax}\!\left(\frac{QK^T}{\sqrt{d}}\right)V \]

rather than simply a matrix multiplication.

2. FlashAttention exploits that the same exact computation needs to happen within each batch and head element

That raises a question: where do these "two dimensional Tensors" we deal with in FlashAttention come from? \(Q\), \(K\), and \(V\), after all, are typically four-dimensional Tensors of shape [batch_size, sequence_length, head_dim, num_heads] in attention implementations; even if you discount the batch dimension--operations are always the same within each batch element in neural networks, so that they are actually "batch_size" identical operations happening in parallel--you are still left with three dimensional Tensors of shape: [sequence_length, head_dim, num_heads]. Here we take advantage of a further parallel structure in Multi-Head Attention: the computations within each head are identical. Naturally, then, in FlashAttention v1, we launch a GPU kernel with a grid of "batch_size * num_heads" thread blocks, each of which operates on a single (batch_index, head_index) tuple. Each thread block will then operate on the three "sequence_length × head_dim" Tensors:

and produce the "sequence_length × head_dim" Tensor O[batch_index, head_index, :, :]. So, the algorithm we'll describe starting in the next section really is a "three 2D matrix" analogue of the "two 2D matrix" tiling algorithm described in the prior blog post; "the magic of CUDA" (specifically being able to launch "batch_size * num_heads" blocks of threads at once) scales this up to be the 4D Tensor operation we need.

Now that we understand that we can operate on two dimensional slices of \(Q\), \(K\), and, \(V\), we have to get to the core of the problem: how we are going to "tile" these to compute \(O\) without ever computing the full attention matrix.

3. The specific tiling of \(Q\), \(K\), and, \(V\) that works in FlashAttention comes from the dependencies within attention itself.

If you dive into a single row of the output matrix \(O\), you'll see that it only depends on the values from the corresponding row in \(Q\)! Incidentally, it turns out each element of this row will actually depend on all elements of \(K\) (due to the softmax operation), and all of the elements in the same column as that element within \(V\). Independently of this latter detail about \(K\) and \(V\), the point is that collectively the row needs to "see" all rows of \(K\) and \(V\) whereas it only needs to see the corresponding row of \(Q\). This leads to the tiling strategy described in the next section.

FlashAttention: Attention on Tiles

In the diagram and the text below, we'll use

Moreover, we refer to a Python re-write of the CUDA kernel that may be easier for readers to grok and/or hack on themselves than the CUDA C++ code.

As with the blog post where we covered matrix multiplication, the picture showing the geometry of what is going on is critical. The key picture is below:

FlashAttention operates on tiles of Q, K, and V to produce accumulations of O_tile
  1. Outer loop (over query tiles): We load in a tile of \(Tr\) rows of \(Q\) — call it \(Q_{tile}\).

  2. Inner loop (stream over key/value tiles while \(Q_{tile}\) is loaded into shared memory):

    1. While this is loaded in, we loop through \(K\) and \(V\), \(Tc\) rows at a time (the c is for columns; you’ll see why shortly).

    2. For each \(K_{tile}\), we take the dot product of \(Q_{tile}\) with \(K_{tile}^T\); this creates a "\(Tr\) rows by \(Tc\) columns" matrix of sums that, in the \(N \times N\) attention matrix, sits in the same set of rows as \(Q_{tile}\) and the same set of columns as \(K_{tile}\). We call this sh_S--sums stored in shared memory--in the code. We record the maximum value seen for each row in this block.

    3. We "rescale these sums", motivated by the ideas from the streaming softmax blog post. Remember that since softmax is itself a fraction, the attention operation (ignoring the scaling factor) can be written:

      \[ \frac{\operatorname{scaled\_exponentials}(QK^T)\cdot V} {\operatorname{sum\_of\_scaled\_exponentials}(QK^T)} \]

      We know how to compute this in a "streaming" fashion: we update the running max, and use this updated max to compute the "scaled exponentials" themselves as well as adding to the running sum and rescaling it if necessary. These two operations happen here and here respectively. For an individual row_in_q_tile row, we store the running denominator in row_sumexp[row_in_q_tile] and the running max in row_max[row_in_q_tile] to be used later.

    4. We multiply these now-scaled rows by the corresponding \(V_{tile}\) to get a "Tr × D" tile contribution in the same location as \(O_{tile}\). In code we call this O_tile_accum.

    5. As in tiled matrix multiplication, we increment these "Tr × D" elements of \(O_{tile}\), then move on to the next \(K\) and \(V\) tiles.

    6. After looping through all tiles in \(K\) and \(V\), we’ve accumulated the numerator of softmax-dot for each row.

      So the final step is to divide each row of the accumulated numerator by its stored denominator, which we've stored as row_sumexp[row_in_q_tile].

  3. Back to the outer loop: Move on to the next \(Q_{tile}\) and repeat until all rows of \(O\) are filled.

Computational and memory complexity

As should be clear from the diagrams and description of the algorithm above: FlashAttention actually consumes constant memory in the forward pass, whereas a naive Multi-Head Attention implementation would require \(O(N^2)\) memory! Though we won't go into the backward pass in detail in this blog post, it turns out the standard way of computing it, that offers a very good tradeoff between computation and memory, involves saving the per-row maxima and sums of scaled exponentials; this takes up \(O(N)\) memory. Still a huge improvement over \(O(N^2)\)!

Moreover, since this is the section on compute complexity: astute readers will notice that we're clearly "leaving some parallelism on the table": since each \(Q_{tile}\) is independent, we could clearly parallelize so that we running through the algorithm above on multiple \(Q_{tile}\)s at once! This is done in FlashAttention V2; this quote from the abstract, listing the key improvements in FA V2:

...(2) parallelize the attention computation, even for a single head, across different thread blocks to increase occupancy...

describes exactly this.

Conclusion

This post covered how tiling and streaming softmax enable us to compute large parts of the Multi-Head Attention operation without ever materializing the entire attention matrix, along with pseudocode that should make the algorithm more concrete. In a future blog post we'll walk through the CUDA side of the implementation in more detail; for now, this should provide a clean implementation you can look through.

Footnotes

1 FlashAttention v3 (published on July 2024) was led by a team at Colfax Research--especially Jay Shah--and NVIDIA, whereas FlashAttention v4 (which has only been semi-published, with Tri Dao giving a talk on an early version of it in August 2025) appears to be nearly 2,000 lines of Python CuTe DSL code written by Prof. Dao himself.