Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

DeFT: Decoding with Flash Tree-Attention for Efficient Tree-structured LLM Inference

Jinwei Yao1,4,  Kaiqi Chen2,∗
Kexun Zhang3  Jiaxuan You4  Binhang Yuan5  Zeke Wang2,†  Tao Lin1,

1Westlake University  2Zhejiang University  
3Carnegie Mellon University  4University of Illinois Urbana-Champaign  
5Hong Kong University of Science and Technology

jinwei.yao1114@gmail.com;  {chiaki_cage,wangzeke}@zju.edu.cn;
kexunz@andrew.cmu.edu;  jiaxuan@illinois.edu;
biyuan@ust.hk;  lintao@westlake.edu.cn
Equal contribution. Work was done during Jinwei’s visit to Westlake University.Corresponding author.
Abstract

Given the increasing demand for tree-structured interactions with LLMs, we introduce DeFT (Decoding with Flash Tree-Attention), an IO-aware tree attention algorithm tailored for tree-structured inference. Unlike traditional sequence-based decoding, tree-structured decoding better accommodates modern task requirements, including self-consistency, few-shot prompting, multi-step reasoning, and multi-model/head coordination. However, existing sequence-based inference systems are ill-suited for tree-structured decoding, resulting in redundancy in computation, memory footprints, and memory access, thereby undermining inference efficiency. To address this challenge, DeFT maintains memory-efficient attention calculation with low memory footprints through two key stages: (1) QKV Preparation: We propose a KV-Guided Grouping Strategy with Tree Split to intelligently group QKV, optimizing GPU resource utilization while minimizing memory reads/writes for KV cache between GPU global memory and on-chip shared memory; (2) Attention Calculation: We compute partial attention of each QKV group in a fused kernel and employ a Tree-topology-aware Global Reduction strategy to obtain final attention. By reducing 73-99%percent\%% KV cache IO and nearly 100%percent\%% IO for partial results during attention calculation (e.g., Softmax), DeFT achieves up to 2.52/3.82×\times× speedup in the end-to-end/attention latency across three practical tree-based workloads: namely, few-shot prompting, multi-step reasoning, and speculative decoding, over state-of-the-art attention algorithms.

1 Introduction

Large language models (LLMs) [1, 34, 35] are extensively utilized across a range of tasks like chatbot [31], code generation [26], reasoning [42, 4, 28], etc. To meet the increasing demand for service quality of wide-range applications, the interactions with LLMs are more and more complex: moving from simple sequence-structured patterns like multi-turn chats, to tree-structured patterns, including self-consistency [37], few-shot prompting [25], multi-step reasoning [42, 11, 41], and multi-model/heads coordination [27, 5], etc. Unfortunately, higher service quality is not a free lunch: we sacrifice efficiency—more tokens need to be generated to provide large space for tree search [10, 23, 21] or selection, as shown in Table 1.

The mismatch between the existing sequence-based inference systems [20, 29, 16] and tree-structured interactions exacerbates the efficiency problem. Most current inference systems are designed for sequence-based decoding, which samples a single sequence of tokens every time, while tree-based decoding maintains multiple sequences with common prefixes as a tree structure, as shown in Figure 1. Since nodes in the forms of the tree can be shared computationally and in memory while that of the sequence cannot, applying tree-structured tasks directly to sequence-based decoding causes three levels of redundancy: (1) memory storage, especially the KV cache [20, 45]; (2) computation, especially the computation for common prompts among sequences in a batch [45]; (3) memory access.

Refer to caption
Figure 1: An illustration of Sequence-based decoding and Tree-based decoding.

Existing work of tree-based inference systems [45, 9] focuses on the first two levels while largely ignoring the third yet the most important one–memory access, given the nature of memory-bounded LLM inference [32, 5, 19]. As for sequence-based decoding methods optimize the memory access for the aspects of partial results (i.e., 𝐐𝐊superscript𝐐𝐊top\mathbf{Q}\mathbf{K}^{\top}bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT) during attention calculations [6, 7, 15]. However, their effectiveness in tree-based decoding is limited. In particular, these optimizations are unable to address the potential bottleneck posed by the KV cache IO when dealing with a large number of tokens, as illustrated in Table 1.

As a remedy, in this paper, we resort to the key attention component during the decoding process. Orthogonal to the traditional attention mechanisms in sequence-based decoding, tree attention [27, 5]—specifically designed to handle hierarchical or tree-structured tokens in tasks such as parallel decoding—can reduce the kernel launching, computation and KV cache storage overheads for attention calculations. However, this line of research does not further leverage the tree topology to reduce IO when calculating attention, and thus still not fully IO-aware for both (i) partial result (i.e., 𝐐𝐊superscript𝐐𝐊top\mathbf{Q}\mathbf{K}^{\top}bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT[5] due to the lack of tiling and kernel fusion [6]; and (ii) KV cache in a tree structure [27]. These limitations hinder their effectiveness in optimizing memory access during tree-based decoding.

Table 1: Comparison of efficiency in sequence-based (CoT [38]) and tree-based (ToT [42]) decoding for a reasoning task. The task is sorting 128 numbers from [4]. The total generated tokens of CoT is only 525 while 38,315 in ToT, resulting in inefficiency in end-to-end latency (second) and IO (TB). IO mainly consists of two parts as follows. (i) KV cache: IO-KV; (ii) Partial results during attention calculation like QKT𝑄superscript𝐾𝑇QK^{T}italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT and softmax: IO-PA; Baselines: (i) Flash-Decoding [7]; (ii) Tree Attention: tree attention in Medusa [5].
Metrics
Latency IO-KV IO-PA
Flash-Decoding + CoT 21 0.6 0
Flash-Decoding + ToT 429.65 59.96 0
Tree Attention + ToT 380.87 12.40 3.69
DeFT-Flatten(ours) + ToT 94.61 12.40 0
Speed up over best baseline 4.02×4.02\times4.02 × - -

To bridge the above gap, we propose DeFT, an IO-aware tree attention algorithm with two key insights. First, the IO overhead for queries (Q) is negligible compared to that of KV cache, primarily because the maximum query length typically corresponds to numbers of root-to-leaf paths in the tree, resulting in relatively short queries (e.g. dozens of tokens) compared with KV cache length in each node (e.g. hundreds/thousands of tokens). Second, in sequence-based decoding, each KV cache entry corresponds to a unique query, whereas in tree-based decoding, multiple queries can share their common ancestor’s KV cache during attention calculation, benefiting not only in reducing KV cache storage but also in IOs.

Building upon these two insights, in the first phase of DeFTQKV Preparation, we split the KV cache of the decoding tree with two choices: (i) split by node (DeFT-Node), which is simple with no need for causal mask; (ii) flatten the tree KV then split evenly (DeFT-Flatten), which ensures more stable speedup due to balanced workloads in GPUs, with little cost of bit causal mask IO. Then we group the KV cache of each node with all queries that share it in the decoding tree, to minimize the IO of KV cache with negligible IO overhead of queries. In the second phase of DeFTAttention Calculation, we adopt a fused kernel to get partial attention with LogSumExp of QKV groups calculated in phase 1, and conduct tree-topology-aware global reduction inspired by Flash-Decoding [7]. We summarize our contributions as follows:

  • We propose a simple but hardware-efficient tree attention algorithm–DeFT, which is IO-aware for both KV cache in a tree structure and partial results (i.e., 𝐐𝐊superscript𝐐𝐊top\mathbf{Q}\mathbf{K}^{\top}bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT and Softmax). We offer two specific implementations: DeFT-Node is straightforward without a mask, while DeFT-Flatten ensures more stable speedup across various tree topologies, with minimal extra IO cost for masks.

  • We implement DeFT on OpenAI Triton [33] to gain precise management over memory access and fuse all attention operations into a single GPU kernel.

  • We theoretically justify the superiority of DeFT over the existing attention algorithms [40, 7, 5, 27] in terms of IO complexity.

  • We empirically verify its effectiveness on few-shot prompting, multi-step reasoning and speculative-decoding tasks. DeFT can achieve a walk-clock time speedup of 1.3×\times× for few-shot prompting, 2.5×\times× for speculative decoding, 1.1×\times× for multi-step reasoning, due to an up to 3.82×\times× faster attention calculation, with the baseline implementations [7, 5, 45].

2 Related Work

Tree-based Decoding. Tree-based decoding, exemplified by beam search [10], has been pivotal in NLP, handling lexical and logical constraints [2, 30, 13], mitigating gender bias [24], achieving communicative goals [14], and improving alignment [21]. Based on the structure feature of queries and KV cache, we can classify tree-based decoding into two patterns: (i) tree-structured past KV with parallel queries—usually in multi-step reasoning [42, 4, 28], using search trees with parallel hypothesis generation and selection based on scoring functions. Some score candidates per token [8, 24, 23], others per reasoning step [39, 36, 41]. (ii) past KV in sequence with tree-structured queries—usually in speculative decoding [5, 27]. A token tree as queries are generated from different draft models [27] or heads [5], then these tokens will be verified in parallel via tree-based decoding. Details of these two patterns are discussed in Appendix A.2. Efficiency in tree-based decoding remains underexplored despite various search algorithms’ application, such as A* [23] and Monte-Carlo Tree Search [21].

Memory-efficient Attention Algorithms. Existing memory-efficient attention algorithms target sequence-based decoding. FlashAttention [6] improves self-attention computation in LLM training via tiling and kernel fusion, reducing IOs. Flash-Decoding [7] extends this, enhancing parallelism by dividing K and V and introducing global reduction to gather partial attention results, enabling efficient decoding for long sequences. Unluckily, applying these memory-efficient algorithms to the tree-based decoding overlooks redundancy in IO of tree-structured KV cache, which is the focus of DeFT.

Tree Attention. Integrated into LLM inference, tree attention reduces computation, storage, and kernel launching overheads [27]. Tree-structured token candidates undergo parallel decoding, with SpecInfer [27] introducing a topology-aware causal masked tree attention algorithm, dynamically updating a causal mask to capture relationships among tokens. Medusa [5] uses a similar mechanism with a static causal mask, while other works [44, 22] adopt analogous approaches to enhance attention calculation efficiency. However, unlike DeFT, these existing works utilizing tree attention do not take memory access into consideration.

Storage Optimization of Tree-based Decoding. LLM frameworks optimized for tree-based decoding [20, 45] focus on memory storage efficiency. vLLM [20] enhances GPU memory utilization, allowing sequences from the same parent to share KV cache storage. SGLang [45] supports dynamic KV cache management during multi-round interactions with LLMs, improving memory efficiency.

Discussion on Concurrent Works. Some concurrent works [43, 18, 3] also recognize the importance of IO during LLM inference. However, these works have at least one of these flaws: i) they [43, 18, 3] cannot be easily extended to situations where the decoding tree has more than two levels—they target single-context batch sampling scenarios, a special case of general tree-based decoding with a system prompt as prefix and unique suffixes in the first depth; ii) they [18, 3] do not consider the efficiency issues caused by the lengths of different nodes in the decoding tree. Details of comparison for DeFT and concurrent works are discussed in Appendix A.3.

3 DeFT

In this section, we start by introducing the background knowledge of LLM inference, upon which we outline the overview of system support for DeFT. We then present DeFT including its algorithm and Attention Kernel design, which not only reduces memory access of tree KV but also adopts a fused kernel to eliminate the memory access of partial results like 𝐐𝐊superscript𝐐𝐊top\mathbf{Q}\mathbf{K}^{\top}bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT and Softmax operations. We further theoretically analyze DeFT’s IO with existing attention algorithms to justify its advances.

3.1 Preliminary

LLM inference and its bottleneck.

LLM inference involves two stages: (1) prefill and (2) decoding. During the prefill stage, a prompt is tokenized to initialize LLM. The output of the prefill stage becomes the input for the decoding stage. The decoding stage is auto-regressive, with each output token from the previous step serving as the input token for the next step. Due to the sequential process of auto-regressive decoding, LLM inference is memory-bound [32, 19, 5], wherein every forward pass requires transferring all model parameters and KV cache from slower but larger High-Bandwidth Memory (HBM) to the faster but much smaller shared memory of the GPU [17] 111 A100’s HBM has 1.5-2TB/s bandwidth and 40-80GB; its shared memory has 19TB/s bandwidth and 20MB. .

Motivation for DeFT.

To improve efficiency, boosting the arithmetic intensity—the ratio of total floating-point operations (FLOPs) to total memory access—of the decoding process is essential. Parallel decoding frameworks [5, 27] tend to achieve this goal by introducing more calculations to generate more tokens in each decoding step, while keeping memory access nearly the same222 Medusa [5] only introduces negligible memory access of KV cache for token candidates in the tree. in each decoding step. A sequence of tokens will be generated as token candidates by draft models [27] or fine-tuned heads [5], which is then refined by the LLM for acceptable continuation. This line of approach reduces the total number of decoding steps as well as the total amount of memory access.

In the meanwhile, tree-based decoding, leveraging the decoding tree defined below, enables efficient parallel decoding. The tree attention is further introduced to reduce redundant KV storage, calculation, and kernel launching overheads when calculating the attention.

Definition 3.1 (Decoding tree).

A decoding tree 𝒯𝒯\mathcal{T}caligraphic_T is a rooted tree where the root node corresponds to the prompt and each non-root node u𝑢uitalic_u represents a sequence of generated tokens 𝒮usubscript𝒮𝑢\mathcal{S}_{u}caligraphic_S start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT. For each node u𝑢uitalic_u, usubscript𝑢\mathcal{B}_{u}caligraphic_B start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT is the path from root node to u𝑢uitalic_u (without u𝑢uitalic_u) and Pusubscript𝑃subscript𝑢P_{\mathcal{B}_{u}}italic_P start_POSTSUBSCRIPT caligraphic_B start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT end_POSTSUBSCRIPT is the concatenation of tokens in sequences of nodes in path usubscript𝑢\mathcal{B}_{u}caligraphic_B start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT by the sequential order. For each token nu𝑛𝑢n\in uitalic_n ∈ italic_u, su,n𝒮usubscript𝑠𝑢𝑛subscript𝒮𝑢s_{u,n}\in\mathcal{S}_{u}italic_s start_POSTSUBSCRIPT italic_u , italic_n end_POSTSUBSCRIPT ∈ caligraphic_S start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT represents the sequence from the first token of node u𝑢uitalic_u to n𝑛nitalic_n (including n𝑛nitalic_n). The last token of each leaf node represents the input token for the next decoding iteration.

Definition 3.2 (Tree-Attention).

For each token nu𝑛𝑢n\in uitalic_n ∈ italic_u, where u𝑢uitalic_u is any non-root node in the decoding tree 𝒯𝒯\mathcal{T}caligraphic_T, its tree attention is defined as the output of original Transformer-based sequence attention (Attention()Attention\text{Attention}(\cdot)Attention ( ⋅ )) on Prootnsubscript𝑃root𝑛P_{\text{root}\rightarrow n}italic_P start_POSTSUBSCRIPT root → italic_n end_POSTSUBSCRIPT, where Prootnsubscript𝑃root𝑛P_{\text{root}\rightarrow n}italic_P start_POSTSUBSCRIPT root → italic_n end_POSTSUBSCRIPT is the concatenation of PBusubscript𝑃subscript𝐵𝑢P_{B_{u}}italic_P start_POSTSUBSCRIPT italic_B start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT end_POSTSUBSCRIPT and su,nsubscript𝑠𝑢𝑛s_{u,n}italic_s start_POSTSUBSCRIPT italic_u , italic_n end_POSTSUBSCRIPT:

Tree-Attention(n)=Attention(Prootn).Tree-Attention𝑛Attentionsubscript𝑃root𝑛\displaystyle\textstyle\text{Tree-Attention}(n)=\text{Attention}(P_{\text{root% }\rightarrow n})\,.Tree-Attention ( italic_n ) = Attention ( italic_P start_POSTSUBSCRIPT root → italic_n end_POSTSUBSCRIPT ) . (1)

The existing solution of tree attention [5, 27] omits the potential IO optimization brought by the tree topology itself, thus motivating the DeFT we will explore in this paper. DeFT optimizes LLM efficiency from another perspective: it leverages the characteristics of prefix sharing in decoding trees to reduce the redundancy of KV cache IO from HBM to on-chip shared memory, then the whole arithmetic intensity will be improved with less memory access and nearly the same FLOPs.

Refer to caption
Figure 2: Overview of DeFT. SMEM means shared memory of GPUs. Input Metadata consists of 1) Query (tokens), 2) KV (KV cache of decoding tree), and 3) Tree Topo (the topology of decoding tree to map Query and KV, which are prepared by Branch Controller, KV cache Manager, and Sequence Tree Manager in the system elaborated in Appendix A.1, respectively.

3.2 Overview of System Design for DeFT

We can separate the execution of attention algorithms into two main phases: (1) QKV Preparation Phase: group Query, Key, and Value (QKV) logically and map QKV groups to different streaming multiprocessors (SMs) of GPUs; (2) Attention Calculation Phase: load QKV groups to different SMs’ shared memory and apply attention algorithms to each group for final attention results.

Minimizing memory access between slow HBM and fast shared memory for memory-bound computations (e.g., attention) is crucial. DeFT aims to be a memory-efficient algorithm in both aforementioned phases to get attention for tree-based decoding. In detail, as shown in Figure 2:

  • In the QKV Preparation Phase, we introduce a KV-guided Grouping strategy with tree-topology awareness to minimize the IO of QKV.

  • During the Attention Calculation Phase, we propose the DeFT Attention Kernel333GPUs utilize a vast array of threads to execute operations known as kernels. This includes (1) a Tree-Topology-Aware Global Reduction strategy and (2) established techniques such as Kernel Fusion and Tiling to eliminate the IO of partial results (i.e., 𝐐𝐊superscript𝐐𝐊top\mathbf{Q}\mathbf{K}^{\top}bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT and Softmax).

Apart from efficient DeFT Attention Kernel, our system for DeFT has other two advantages: 1) efficient memory management of the KV cache in a tree structure, and 2) flexible control of the tree decoding process with arbitrary user-defined functions, to decide when and how to branch/prune. The details of key components and their coordinations in the system refer to Appendix A.1.

3.3 An Efficient Attention Algorithm with IO-awareness for Tree-structured KV Cache

Refer to caption
Figure 3: Comparison of memory access from HBM to shared memory for different attention algorithms in QKV Preparation Phase, where the amount of IO required by each is enclosed in red rectangles for each QKV group. (Left) From top to bottom, there are notations, the composition of the input metadata, and, most importantly, details of the DeFT-Flatten algorithm: 1) The Depth-first Flatten strategy aims to minimize the IOs of queries in each block obtained after splitting, as queries corresponding to child KV are a subset of those in the parent KV (e.g., Q1subscript𝑄1Q_{1}italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and Q2subscript𝑄2Q_{2}italic_Q start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT for KV0𝐾subscript𝑉0KV_{0}italic_K italic_V start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT contain Q1subscript𝑄1Q_{1}italic_Q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT for KV1𝐾subscript𝑉1KV_{1}italic_K italic_V start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT); 2) The Evenly blockwise strategy ensures equal lengths of KV in each QKV group for balanced workloads of streaming multiprocessors (SMs) in GPUs; 3) The Bitmask[27] is a set of 64-bit integers used to record causal information of tokens in the tree, but its IO overhead (e.g. two 64-bit integers in KV-BCM1subscript𝑀1M_{1}italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT) is negligible compared to the dense causal mask[5]; 4) To accommodate DeFT-Flatten’s KV-guided Tree Split method, we adopt the KV-guided bit causal mask (KV-BCM) instead of the Q-guided one (Q-BCM)[27]. (Right) Different split and grouping strategies result in different memory access. Q-guided grouping (e.g. sequence-based attention [7, 45] and Tree Attention-SpecInfer [27]) causes significant redundancy of KV cache; KV-guided grouping (e.g. DeFT) causes negligible additional IO of queries. The IO cost of BCM can be ignored, while DCM cannot. See more details in Table 2 and Remark 3.3.

In this section, we delve into the details of the QKV Preparation Phase, which is a key design aspect of DeFT, and defer the discussion of the Attention Calculation Phase to Appendix A.4.

QKV Preparation Phase of DeFT.

In sequence-based decoding, split strategy—namely splitting the inputs KV into blocks—is commonly deployed to generate enough QKV groups for full utilization of the GPU [7]. This technique is crucial when the parallelism (usually limited by the batch size [7]) is much smaller than the number of streaming multiprocessors (SMs) on the GPU (108 for an A100), where the operation will only utilize a small portion of the GPU. Similarly, for tree-based decoding—where a decoding tree consists of multiple nodes and each node is a sequence of tokens—the batch size of trees may also be insufficient to fully utilize the GPU when the number of tokens in the tree is large, due to memory capacity limitations.

Unfortunately, split the tree is not as easy as split the sequence [7]: it may introduce significant IOs during the QKV grouping after splits, as shown in Figure 3 and discussed in Remark 3.3.

Table 2: Comparison of grouping and split strategies of baselines and DeFT. For IO redundancy, these significant is in red, while these can be ignored is in blue. Detailed of IO complexity in Table 4.
Method Sequence-based [7, 45] Tree Attention-S [27] Tree Attention-M [5] DeFT-Node DeFT-Flatten
Grouping indicator Q-guided Q-guided tree-guided KV-guided KV-guided
Tree KV Split Granularity by branch(query) no split no split by tree node by block
IO redundancy KV KV and BCM DCM Q Q and BCM
Remark 3.3 (The effects of tree split and QKV grouping strategies in the QKV Preparation Phase).

In the QKV Preparation Phase, how decoding tree is split and QKVs are grouped logically results in different memory access of QKV from HBM to shared memory for tree decoding, as shown in the right of Figure 3 and Table 2.

  • Sequence-based decoding methods [7, 45] split the tree based on Q and group QKV based on Q without tree topology awareness, which bring redundant KV cache IO;

  • Tree Attention-Medusa [5] groups the QKV of the entire decoding tree together with a tree topology-aware causal mask for tree attention computation based on Pytorch primitives, resulting cost of additional IO for the causal mask;

  • Tree Attention-SpecInfer [27] groups each query with the KV of the entire tree with a causal mask for tree attention calculation, which has great redundancy in KV cache IO.

To bridge this gap, we propose KV-Guided Grouping Strategy with Tree Split, offering two levels of granularity: it splits the tree by sequence nodes or blocks of the same length, and then groups the KV of each node with all queries that share it based on tree topology. This grouping strategy, with KV as the indicator for grouping, eliminates redundant IO operations for KV with negligible query IO cost, as illustrated in the bottom right of Figure 3.

Remark 3.4 (Properties of KV-Guided Grouping Strategy with Tree Split).

The additional IO cost of Q caused by split tree KV in DeFT is negligible because the length of the KV often surpasses that of the Q during tree decoding, primarily due the fact that the auto-regressive decoding pattern dictates that each query in the decoding stage has a length of 1, which means the maximum query length of a decoding tree is determined by the number of branches.

Remark 3.5 (The effects of different split granularities).

We provide two algorithm choices for DeFT different splits granularity in KV-Guided Tree Split.

  • DeFT-Node: split by node, which is simple without a need for the causal mask. However, it may have potentially unbalanced workloads in different SMs. For example, node A could have the KV cache of 1000100010001000 tokens, while node B only has that of 2222 tokens. When nodes A and B are allocated to SM1𝑆subscript𝑀1SM_{1}italic_S italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and SM2𝑆subscript𝑀2SM_{2}italic_S italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT respectively, SM2𝑆subscript𝑀2SM_{2}italic_S italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT could finish the task much earlier and be idle.

  • DeFT-Flatten: flatten tree KV then evenly split it to blocks. The same length of KV cache in each QKV group ensures balanced workloads in IOs and calculations for different SMs, with negligible IO cost of Bit Causal Mask, as shown in the right bottom of Figure 3.

Attention Calculation Phase of DeFT.

In this phase, we design DeFT Attention kernel to load QKV splits in a memory efficient way, which is logically grouped by the QKV Preparation Phase, then to perform the attention calculation. Key techniques are as follows, whose details are discussed in Appendix A.4: 1) common Kernel Fusion and Tiling strategies avoid significant IO operations for partial results (i.e.. 𝐐𝐊superscript𝐐𝐊top\mathbf{Q}\mathbf{K}^{\top}bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT and Softmax), which Tree Attention-Medusa [5] lacks; 2) a novel Tree-Topology-Aware Global Reduction inspired by Flash-Decoding [15] retrieves the final attention of each query based on partial attention results from each QKV group with tree topology.

Implementation details.

We implement the DeFT attention kernel by OpenAI Triton [33], which enables us to control memory access from global memory to shared memory and attention calculations in a thread block granularity. DeFT-Node and DeFT-Flatten algorithms with two phases in a Python style can be found in Appendix A.7 and Appendix A.8, respectively.

3.4 Analysis: IO Complexity of DeFT

This section analyzes the IO complexity of DeFT, showing a significant reduction in HBM accesses compared to existing attention algorithms. Note that it is non-trivial to summarize the IO cost of the entire tree decoding process, thus we only compare IOs based on the decoding tree snapshot in a single iteration.

Table 3: Notations.
lnsubscript𝑙𝑛\displaystyle l_{n}italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT Number of leaf nodes in a decoding tree, which means how many queries are in this decoding iteration.
Nisubscript𝑁𝑖\displaystyle N_{i}italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT Total token length from the root node to leaf node i.
Ntreesubscript𝑁𝑡𝑟𝑒𝑒\displaystyle N_{tree}italic_N start_POSTSUBSCRIPT italic_t italic_r italic_e italic_e end_POSTSUBSCRIPT Total token length the entire tree.
#node#𝑛𝑜𝑑𝑒\displaystyle\#node# italic_n italic_o italic_d italic_e Total number of nodes in entire tree.
dheadsubscript𝑑𝑒𝑎𝑑\displaystyle d_{head}italic_d start_POSTSUBSCRIPT italic_h italic_e italic_a italic_d end_POSTSUBSCRIPT Head dimension of LLM.
scsubscript𝑠𝑐\displaystyle s_{c}italic_s start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT Scale factor for scaled dot-product attention, typically denoted as dheadsubscript𝑑head\sqrt{d_{\text{head}}}square-root start_ARG italic_d start_POSTSUBSCRIPT head end_POSTSUBSCRIPT end_ARG.
Fssubscript𝐹𝑠\displaystyle F_{s}italic_F start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT Shared factor of reusing prefixes in tree attention, which means to which extent we can reduce IOs of KV cache: Fs=(i=1lnNi)/Ntreesubscript𝐹𝑠superscriptsubscript𝑖1𝑙𝑛subscript𝑁𝑖subscript𝑁𝑡𝑟𝑒𝑒F_{s}=(\sum_{i=1}^{ln}N_{i})/N_{tree}italic_F start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT = ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l italic_n end_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) / italic_N start_POSTSUBSCRIPT italic_t italic_r italic_e italic_e end_POSTSUBSCRIPT.

Consider a decoding tree with the features outlined in Table 3, and we summarize the corresponding IO breakdown in Table 4. It can be observed that due to the lack of tree-topology awareness, sequence-based decoding methods, such as naive attention and Flash-Decoding, incur Fssubscript𝐹𝑠F_{s}italic_F start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT times more memory access overheads for KV cache compared to DeFT-Node/Flatten and Tree Attention-Medusa [5].

However, Tree Attention-Medusa entails higher IO overheads for partial results like 𝐐𝐊superscript𝐐𝐊top\mathbf{Q}\mathbf{K}^{\top}bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT and Softmax due to the lack of tiling and kernel fusion444 Note that 𝐐𝐊Tsuperscript𝐐𝐊𝑇\mathbf{Q}\mathbf{K}^{T}bold_QK start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT, 𝐐𝐊scsuperscript𝐐𝐊topsubscript𝑠𝑐\frac{\mathbf{Q}\mathbf{K}^{\top}}{s_{c}}divide start_ARG bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG italic_s start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG, 𝐌+𝐐𝐊sc𝐌superscript𝐐𝐊topsubscript𝑠𝑐\mathbf{M}+\frac{\mathbf{Q}\mathbf{K}^{\top}}{s_{c}}bold_M + divide start_ARG bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG italic_s start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG and Softmax will load and write, so the IO cost contains a round-trip of memory access between HBM and shared memory, as shown in Figure 9. . What’s more, a dense mask is introduced to record the causal information of tokens in the tree, with significant IO costs.

When the number of leaf nodes/queries ln𝑙𝑛lnitalic_l italic_n is sufficiently large, the IO cost of partial results might become comparable to that of the KV cache. For instance, in the Llama models [34, 35], where dhead=128subscript𝑑𝑒𝑎𝑑128d_{head}\!=\!128italic_d start_POSTSUBSCRIPT italic_h italic_e italic_a italic_d end_POSTSUBSCRIPT = 128, with ln=29subscript𝑙𝑛29l_{n}\!=\!29italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = 29, the total IO cost of 𝐐𝐊Tsuperscript𝐐𝐊𝑇\mathbf{Q}\mathbf{K}^{T}bold_QK start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT, 𝐌𝐌\mathbf{M}bold_M, 𝐐𝐊scsuperscript𝐐𝐊topsubscript𝑠𝑐\frac{\mathbf{Q}\mathbf{K}^{\top}}{s_{c}}divide start_ARG bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG italic_s start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG, 𝐌+𝐐𝐊sc𝐌superscript𝐐𝐊topsubscript𝑠𝑐\mathbf{M}+\frac{\mathbf{Q}\mathbf{K}^{\top}}{s_{c}}bold_M + divide start_ARG bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG italic_s start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG, and Softmax matches that of the KV cache.

Table 4: IO complexity breakdown for various methods. 𝒪(1)𝒪1\mathcal{O}(1)caligraphic_O ( 1 ) denotes the IO cost for a single data in the tensor across all layers and heads, which is equivalent to #heads#layerdtype_size#𝑒𝑎𝑑𝑠#𝑙𝑎𝑦𝑒𝑟𝑑𝑡𝑦𝑝𝑒_𝑠𝑖𝑧𝑒\#heads*\#layer*dtype\_size# italic_h italic_e italic_a italic_d italic_s ∗ # italic_l italic_a italic_y italic_e italic_r ∗ italic_d italic_t italic_y italic_p italic_e _ italic_s italic_i italic_z italic_e. The best among all methods in the table is in red, while the (potential) worst is in blue. Query IO is omitted as it is 𝒪(klndhead)𝒪𝑘subscript𝑙𝑛subscript𝑑𝑒𝑎𝑑\mathcal{O}(kl_{n}d_{head})caligraphic_O ( italic_k italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_h italic_e italic_a italic_d end_POSTSUBSCRIPT ) for all methods. Here, k𝑘kitalic_k is the number of QKV groups: for DeFT-Node k=#node𝑘#𝑛𝑜𝑑𝑒k=\#nodeitalic_k = # italic_n italic_o italic_d italic_e; for DeFT-Flatten, k=Ntree/bs𝑘𝑁𝑡𝑟𝑒𝑒subscript𝑏𝑠k=N{tree}/b_{s}italic_k = italic_N italic_t italic_r italic_e italic_e / italic_b start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT, where bssubscript𝑏𝑠b_{s}italic_b start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT is the block size of KV; for others, k=1𝑘1k=1italic_k = 1. M in Tree Attention-M is short for Medusa [5], while S in Tree Attention-S is short for SpecInfer [27].
Method KV cache 𝐐𝐊superscript𝐐𝐊top\mathbf{Q}\mathbf{K}^{\top}bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT 𝐐𝐊scsuperscript𝐐𝐊topsubscript𝑠𝑐\frac{\mathbf{Q}\mathbf{K}^{\top}}{s_{c}}divide start_ARG bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG italic_s start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG Mask(M) 𝐌+𝐐𝐊sc𝐌superscript𝐐𝐊topsubscript𝑠𝑐\mathbf{M}+\frac{\mathbf{Q}\mathbf{K}^{\top}}{s_{c}}bold_M + divide start_ARG bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_ARG italic_s start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG Softmax
Naive Attention 𝒪(2dheadi=1lnNi)𝒪2subscript𝑑𝑒𝑎𝑑superscriptsubscript𝑖1subscript𝑙𝑛subscript𝑁𝑖\mathcal{O}(2d_{head}\sum_{i=1}^{l_{n}}N_{i})caligraphic_O ( 2 italic_d start_POSTSUBSCRIPT italic_h italic_e italic_a italic_d end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) 𝒪(2i=1lnNi)𝒪2superscriptsubscript𝑖1subscript𝑙𝑛subscript𝑁𝑖\mathcal{O}(2\sum_{i=1}^{l_{n}}N_{i})caligraphic_O ( 2 ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) 𝒪(2i=1lnNi)𝒪2superscriptsubscript𝑖1subscript𝑙𝑛subscript𝑁𝑖\mathcal{O}(2\sum_{i=1}^{l_{n}}N_{i})caligraphic_O ( 2 ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) 0 0 𝒪(2i=1lnNi)𝒪2superscriptsubscript𝑖1subscript𝑙𝑛subscript𝑁𝑖\mathcal{O}(2\sum_{i=1}^{l_{n}}N_{i})caligraphic_O ( 2 ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
Flash-Decoding 𝒪(2dheadi=1lnNi)𝒪2subscript𝑑𝑒𝑎𝑑superscriptsubscript𝑖1subscript𝑙𝑛subscript𝑁𝑖\mathcal{O}(2d_{head}\sum_{i=1}^{l_{n}}N_{i})caligraphic_O ( 2 italic_d start_POSTSUBSCRIPT italic_h italic_e italic_a italic_d end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) 0 0 0 0 0
Tree Attention-M 𝒪(2dheadNtree)𝒪2subscript𝑑𝑒𝑎𝑑subscript𝑁𝑡𝑟𝑒𝑒\mathcal{O}(2d_{head}N_{tree})caligraphic_O ( 2 italic_d start_POSTSUBSCRIPT italic_h italic_e italic_a italic_d end_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT italic_t italic_r italic_e italic_e end_POSTSUBSCRIPT ) 𝒪(2lnNtree)𝒪2subscript𝑙𝑛subscript𝑁𝑡𝑟𝑒𝑒\mathcal{O}(2l_{n}N_{tree})caligraphic_O ( 2 italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT italic_t italic_r italic_e italic_e end_POSTSUBSCRIPT ) 𝒪(2lnNtree)𝒪2subscript𝑙𝑛subscript𝑁𝑡𝑟𝑒𝑒\mathcal{O}(2l_{n}N_{tree})caligraphic_O ( 2 italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT italic_t italic_r italic_e italic_e end_POSTSUBSCRIPT ) 𝒪(lnNtree)𝒪subscript𝑙𝑛subscript𝑁𝑡𝑟𝑒𝑒\mathcal{O}(l_{n}N_{tree})caligraphic_O ( italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT italic_t italic_r italic_e italic_e end_POSTSUBSCRIPT ) 𝒪(2lnNtree)𝒪2subscript𝑙𝑛subscript𝑁𝑡𝑟𝑒𝑒\mathcal{O}(2l_{n}N_{tree})caligraphic_O ( 2 italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT italic_t italic_r italic_e italic_e end_POSTSUBSCRIPT ) 𝒪(2lnNtree)𝒪2subscript𝑙𝑛subscript𝑁𝑡𝑟𝑒𝑒\mathcal{O}(2l_{n}N_{tree})caligraphic_O ( 2 italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT italic_t italic_r italic_e italic_e end_POSTSUBSCRIPT )
Tree Attention-S 𝒪(2dheadNtreeln)𝒪2subscript𝑑𝑒𝑎𝑑subscript𝑁𝑡𝑟𝑒𝑒subscript𝑙𝑛\mathcal{O}(2d_{head}N_{tree}l_{n})caligraphic_O ( 2 italic_d start_POSTSUBSCRIPT italic_h italic_e italic_a italic_d end_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT italic_t italic_r italic_e italic_e end_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) 00 00 𝒪(lnNtree/64)𝒪subscript𝑙𝑛subscript𝑁𝑡𝑟𝑒𝑒64\mathcal{O}(l_{n}N_{tree}/64)caligraphic_O ( italic_l start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT italic_t italic_r italic_e italic_e end_POSTSUBSCRIPT / 64 ) 00 00
DeFT-Node 𝒪(2dheadNtree)𝒪2subscript𝑑𝑒𝑎𝑑subscript𝑁𝑡𝑟𝑒𝑒\mathcal{O}(2d_{head}N_{tree})caligraphic_O ( 2 italic_d start_POSTSUBSCRIPT italic_h italic_e italic_a italic_d end_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT italic_t italic_r italic_e italic_e end_POSTSUBSCRIPT ) 0 0 0 0 0
DeFT-Flatten 𝒪(2dheadNtree)𝒪2subscript𝑑𝑒𝑎𝑑subscript𝑁𝑡𝑟𝑒𝑒\mathcal{O}(2d_{head}N_{tree})caligraphic_O ( 2 italic_d start_POSTSUBSCRIPT italic_h italic_e italic_a italic_d end_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT italic_t italic_r italic_e italic_e end_POSTSUBSCRIPT ) 0 0 𝒪(Ntree)𝒪subscript𝑁𝑡𝑟𝑒𝑒\mathcal{O}(N_{tree})caligraphic_O ( italic_N start_POSTSUBSCRIPT italic_t italic_r italic_e italic_e end_POSTSUBSCRIPT ) 0 0
Remark 3.6 (KV IO in SpecInfer).

Though similar to DeFT, SpecInfer [27] also employs a fused kernel for tree attention. No IO is sharing for KV cache among queries in SpecInfer: instead, each query will load the entire KV cache of the tree independently, bringing significant IOs of the KV cache as in Table 4.

Remark 3.7 (Causal mask IO).

DeFT-Node splits the decoding tree by nodes without the need for causal masks. For more balanced calculations among SMs in GPUs, DeFT-Flatten evenly splits the decoding tree into blocks, with minimal IO cost for masks inspired by SpecInfer. This design reduces the IO overhead of masks significantly compared to the dense mask design in Medusa, as shown in Table 4.

4 Experiments

In this section, to demonstrate the effectiveness of DeFT under different tree topologies, we comprehensively conduct experiments on three types of tree-based decoding tasks, including: (1) few-shot prompting [25]: a typical case study of tree-structured interactions with two levels–a prefix and several suffixes; (2) multi-step reasoning [42, 41, 11]: tasks characterized by tree-structured past KV with parallel queries; (3) speculative decoding [5, 27]: tasks involving past KV in sequence with tree-structured queries.

4.1 Experimental Setup

Baselines.

Table 5: Comparison of baselines and DeFT. Attention kernels of baselines are implemented to fit its memory management. Therefore, for a fair comparison with baselines, we implement DeFT-Node and DeFT-Flatten that fit both paged [20]/unpaged memory management.
Method Flash-Decoding [15] Tree Attention-Medusa [5] Radix Attention [45] DeFT
Memory unpaged unpaged paged unpaged/paged
Implementation Triton Pytorch Triton Triton
Table 6: Workloads generation. ToT-BFS is short for tree-of-thoughts [42] with breath-first-search. See more details in Table 10.
Task Prompt Dataset Decoding Tree Source Decoding Tree Collection Method Stopping Criteria
Few-shot prompting APPS [12] - - 400 iterations
Multi-step reasoning 4 tasks in  [4] ToT-BFS in  [4] Reconstruct from interaction records with GPT 3.5 in  [4] End of task
Speculative decoding APPS [12] Medusa [5] Record token tree shape and accepted token length per step similar-to\sim 1000 steps(max length=6000)

We evaluate the performance of DeFT in NVIDIA A100 (80GB) in Llama3-8B model [35] with the SOTA attention algorithms in sequence-based and tree-based decoding, as shown in Table 5. Note that we did not include the tree attention operator of SpecInfer [27] to our baselines as its kernel only supports at most 64 tokens in the token tree (the decoding tree except for the past seq KV part), which is unsuitable for tree-based decoding with tree-structured KV (c.f. details in Appendix A.2).

Workloads generation.

To ensure fairness for workloads of different baselines, we reconstruct decoding trees from real multi-step reasoning and speculative decoding tasks, as shown in Table 6. For multi-step reasoning, we include these four tasks from [4]: (1) Sorting 128 numbers (Sorting in short), (2) Document merging (Document in short), (3) Keyword counting (Keyword in short), and (4) Set intersection (Set in short). The tree decoding process would be forced to branch and prune the tree in certain iterations to get the same shape of the decoding tree as the original decoding tree sources. See workload generation details and analysis in Appendix A.5.

Table 7: Average attention latency (second) of each tree and its influence in end-to-end latency. b𝑏bitalic_b means tree width. t𝑡titalic_t denotes the token tree size (i.e., the number of tree-structured queries). Attention Speedup over the best attention means the speedup of DeFT-Flatten over the best baseline (Tree Attention-Medusa in most of cases) in attention calculation. Speedup over the best wall-clock time means the speedup of DeFT-Flatten over the best baseline (Radix Attention) in end-to-end latency. Attention Speedup over the best wall-clock means the attention speedup of DeFT-Flatten over the best baseline (Radix Attention) in end-to-end latency. \star means out of memory for A100 80GB, while \spadesuit means not supported/implemented. See details of end-to-end latency in Table 11.
Memory Method Few-shot Prompting Multi-Step Reasoning Speculative Decoding
b=20 b=30 b=50 Sorting Document Keyword Set t=32 t=64 t=128 t=256
Unpaged Flash-Decoding 43.49 66.10 110.09 160.67 105.80 12.14 19.96 340.09 692.88 \star \star
Tree Attention-Medusa 3.93 7.51 9.57 38.64 29.10 2.62 3.96 18.05 26.31 41.10 68.28
Paged Radix Attention 5.99 7.30 9.96 39.37 24.69 3.11 5.13 32.60 54.57 109.39 212.29
DeFT-Node 10.51 11.41 \spadesuit 42.96 33.29 6.16 9.58 50.82 \spadesuit \spadesuit \spadesuit
DeFT-Flatten . 3.47 4.07 5.87 28.41 21.45 2.57 3.83 12.68 18.18 29.97 55.58
Attention Speedup over the best attention. 1.13×\times× 1.63×\times× 1.70×\times× 1.36×\times× 1.15×\times× 1.02×\times× 1.03×\times× 1.42×\times× 1.45×1.45\times1.45 × 1.37×1.37\times1.37 × 1.22×1.22\times1.22 ×
Attention Speedup over the best wall-clock 1.73×\times× 1.63×\times× 1.70×\times× 1.39×\times× 1.15×\times× 1.21×\times× 1.34×\times× 2.57×\times× 3.00×3.00\times3.00 × 3.64×3.64\times3.64 × 3.82×3.82\times3.82 ×
Speedup over the best wall-clock 1.24×\times× 1.28×\times× 1.33×\times× 1.10×\times× 1.03×\times× 1.03×\times× 1.05×\times× 1.43×\times× 1.70×1.70\times1.70 × 2.22×2.22\times2.22 × 2.52×2.52\times2.52 ×

4.2 Analysis of Memory Management and Bottleneck

As shown in Table 5, the kernel implementations of different attention algorithms adapt to different memory management. To fairly compare their performance of wall-clock time speedup, we need to analyze the influence of memory management and the bottleneck of the system.

A trade-off between memory storage and memory operation.

For tree-based decoding, we can store the KV cache by each branch of the decoding tree in a sequence, which is quite straightforward but no storage sharing of the prefix’s KV cache. Considering the limited capacity of GPU memory, ignoring the tree structure when sharing KV storage significantly restricts the number of tokens in the decoding tree. Though storing the KV cache according to each node of the decoding tree can greatly improve storage efficiency, many existing attention kernels are designed for sequence-based decoding [6, 15, 7]. To adapt these kernels, the KV caches of different nodes need to be concatenated and materialized into a single sequence tensor, incurring significant data movement costs [20].

Refer to caption
Figure 4: Latency breakdown for speculative decoding with a token tree of 32 queries, whose tree topology is from Medusa [5]. U means unpaged memory management.

The benefits of paged memory for tree-based decoding.

To improve the efficiency of KV cache memory management, paged memory [20, 45] is the current mainstream technology. These KV cache tensors are stored in a non-contiguous, paged layout to provide token-level reuse. Besides higher storage efficiency, we note an additional benefit of paged memory management for tree-based decoding: non-contiguous storage in a memory pool is addressed by pointers, ensuring that we do not need to materialize the tree-structured KV into a single tensor before executing the attention kernel. Instead, we only need to record the memory pool addresses of each token’s KV cache.

Bottlenecks and trade-offs.

We provide support for DeFT and baselines with KV cache in memory management (unpaged or paged) according to their designs. We visualize the latency breakdown for (1) KV cache management, (2) attention, and (3) other operations (including MLP calculation) in Figure 13a. We observe that with unpaged KV cache management in tree-based decoding, the bottleneck (69.5-83.4%percent\%%) is the data movement required to materialize the KV cache. However, when we use paged memory management, attention becomes the new bottleneck (50.5-60.0%percent\%%), especially when the token tree is large.

4.3 End-to-end Behaviors: Latency and IOs.

We evaluate DeFT’s performance on various tree-based decoding tasks by measuring end-to-end latency (Table 11 in Appendix A.6), attention latency (Table 7), and IO (Table 12 in Appendix A.6). This assessment demonstrates DeFT’s optimization of tree attention and its acceleration of wall-clock time.

For few-shot prompting tasks, we used a prompt with 4k tokens and performed 400 decoding iterations, achieving a 1.33×\times× end-to-end speedup thanks to 1.70×\times× faster attention calculation and an approximately 90% reduction in IO.

For speculative decoding tasks, DeFT-Flatten achieved up to a 2.52×\times× wall-clock time speedup due to up to a 3.82×\times× speedup in attention, as the entire token tree (all queries) can share IO of the long prefix.

Refer to caption
Figure 5: Comparison of split strategies DeFT-Node and DeFT-Flatten in sorting task. Speedup ratio refers to the ratio between the per iteration latency of DeFT-Node and DeFT-Flatten. Tree Node Len std represents the standard deviation of the tree node lengths for each iteration.

For multi-step reasoning tasks, although DeFT-Flatten can have up to 1.36×1.36\times1.36 × attention speedup, the end-to-end acceleration is less pronounced for two reasons: (1) the tree width is too small (only 10), making the benefits of reusing KV cache IO less significant; (2) the total number of tokens in the tree is too low, resulting in attention’s end-to-end latency accounting for only about 30% of the total time (compared to approximately 50-80% in speculative decoding). Our experiments in few-shot prompting demonstrate that increasing the tree width (from 10 to 50) can result in significant end-to-end acceleration of 100 iterations from 1.2×1.2\times1.2 × to 1.5×1.5\times1.5 ×, as shown in Appendix A.6).

4.4 Ablation Study

The influence of split strategy in DeFT.

We visualize the per-iteration latency of DeFT-Node and DeFT-Flatten for a tree in the sorting task in Figure 5, as the size and topology of the decoding tree change with each iteration. This comparison highlights the sensitivity of these two split strategies to changes in tree size. We observe a strong positive correlation between the ratio of per-iteration latency of DeFT-Node and DeFT-Flatten (Speedup Ratio) and the dispersion of tree node sizes. This correlation arises because the performance of DeFT-Flatten remains relatively stable, whereas the performance of DeFT-Node is more strongly influenced by the topology of the tree. DeFT-Flatten provides a stable speedup of approximately 1.75×\times× compared to DeFT-Node.

5 Discussion and Limitations

Transitioning to complex tree-structured interactions demands efficient systems. DeFT optimizes memory access in tree-based decoding by wisely splitting and grouping KV cache entries, showing up to 3.82×\times× faster attention calculation. The limitation of DeFT is that the obvious performance gain requires a relatively large token tree (e.g. few-shot prompting with a long prompt) or sufficient queries (e.g., speculative decoding scenario) to share KV cache IOs of prefixes. In future work, we will test DeFT on tasks with larger token trees, such as multi-step reasoning in coding or document analysis, to demonstrate its effectiveness in diverse scenarios.

References

  • [1] Josh Achiam, Steven Adler, Sandhini Agarwal, Lama Ahmad, Ilge Akkaya, Florencia Leoni Aleman, Diogo Almeida, Janko Altenschmidt, Sam Altman, Shyamal Anadkat, et al. Gpt-4 technical report. arXiv preprint arXiv:2303.08774, 2023.
  • [2] Peter Anderson, Basura Fernando, Mark Johnson, and Stephen Gould. Guided open vocabulary image captioning with constrained beam search. In Martha Palmer, Rebecca Hwa, and Sebastian Riedel, editors, Proceedings of the 2017 Conference on Empirical Methods in Natural Language Processing, pages 936–945, Copenhagen, Denmark, September 2017. Association for Computational Linguistics.
  • [3] Ben Athiwaratkun, Sujan Kumar Gonugondla, Sanjay Krishna Gouda, Haifeng Qian, Hantian Ding, Qing Sun, Jun Wang, Jiacheng Guo, Liangfu Chen, Parminder Bhatia, et al. Bifurcated attention for single-context large-batch sampling. arXiv preprint arXiv:2403.08845, 2024.
  • [4] Maciej Besta, Nils Blach, Ales Kubicek, Robert Gerstenberger, Lukas Gianinazzi, Joanna Gajda, Tomasz Lehmann, Michal Podstawski, Hubert Niewiadomski, Piotr Nyczyk, et al. Graph of thoughts: Solving elaborate problems with large language models. arXiv preprint arXiv:2308.09687, 2023.
  • [5] Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D Lee, Deming Chen, and Tri Dao. Medusa: Simple llm inference acceleration framework with multiple decoding heads. arXiv preprint arXiv:2401.10774, 2024.
  • [6] Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35:16344–16359, 2022.
  • [7] Tri Dao, Daniel Haziza, Francisco Massa, and Grigory Sizov. Flash-decoding for long-context inference, 2023. PyTorch Blog.
  • [8] Sumanth Dathathri, Andrea Madotto, Janice Lan, Jane Hung, Eric Frank, Piero Molino, Jason Yosinski, and Rosanne Liu. Plug and play language models: A simple approach to controlled text generation. In International Conference on Learning Representations, 2019.
  • [9] In Gim, Guojun Chen, Seung-seob Lee, Nikhil Sarda, Anurag Khandelwal, and Lin Zhong. Prompt cache: Modular attention reuse for low-latency inference. arXiv preprint arXiv:2311.04934, 2023.
  • [10] Alex Graves. Sequence transduction with recurrent neural networks. arXiv preprint arXiv:1211.3711, 2012.
  • [11] Shibo Hao, Yi Gu, Haodi Ma, Joshua Jiahua Hong, Zhen Wang, Daisy Zhe Wang, and Zhiting Hu. Reasoning with language model is planning with world model. arXiv preprint arXiv:2305.14992, 2023.
  • [12] Dan Hendrycks, Steven Basart, Saurav Kadavath, Mantas Mazeika, Akul Arora, Ethan Guo, Collin Burns, Samir Puranik, Horace He, Dawn Song, et al. Measuring coding challenge competence with apps. arXiv preprint arXiv:2105.09938, 2021.
  • [13] Chris Hokamp and Qun Liu. Lexically constrained decoding for sequence generation using grid beam search. In Regina Barzilay and Min-Yen Kan, editors, Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1535–1546, Vancouver, Canada, July 2017. Association for Computational Linguistics.
  • [14] Ari Holtzman, Jan Buys, Maxwell Forbes, Antoine Bosselut, David Golub, and Yejin Choi. Learning to write with cooperative discriminators. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1638–1649, 2018.
  • [15] Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Hanyu Dong, and Yu Wang. Flashdecoding++: Faster large language model inference on gpus. arXiv preprint arXiv:2311.01282, 2023.
  • [16] Hugging Face. Text Generation Inference. https://github.com/huggingface/text-generation-inference. Accessed: 2024-05.
  • [17] Zhe Jia and Peter Van Sandt. Dissecting the ampere gpu architecture via microbenchmarking. In GPU Technology Conference, 2021.
  • [18] Jordan Juravsky, Bradley Brown, Ryan Ehrlich, Daniel Y Fu, Christopher Ré, and Azalia Mirhoseini. Hydragen: High-throughput llm inference with shared prefixes. arXiv preprint arXiv:2402.05099, 2024.
  • [19] Sehoon Kim, Coleman Hooper, Amir Gholami, Zhen Dong, Xiuyu Li, Sheng Shen, Michael W Mahoney, and Kurt Keutzer. Squeezellm: Dense-and-sparse quantization. arXiv preprint arXiv:2306.07629, 2023.
  • [20] Woosuk Kwon, Zhuohan Li, Siyuan Zhuang, Ying Sheng, Lianmin Zheng, Cody Hao Yu, Joseph E Gonzalez, Hao Zhang, and Ion Stoica. Efficient memory management for large language model serving with pagedattention. arXiv preprint arXiv:2309.06180, 2023.
  • [21] Jiacheng Liu, Andrew Cohen, Ramakanth Pasunuru, Yejin Choi, Hannaneh Hajishirzi, and Asli Celikyilmaz. Making ppo even better: Value-guided monte-carlo tree search decoding. arXiv preprint arXiv:2309.15028, 2023.
  • [22] Mingdao Liu, Aohan Zeng, Bowen Wang, Peng Zhang, Jie Tang, and Yuxiao Dong. Apar: Llms can do auto-parallel auto-regressive decoding. arXiv preprint arXiv:2401.06761, 2024.
  • [23] Ximing Lu, Sean Welleck, Peter West, Liwei Jiang, Jungo Kasai, Daniel Khashabi, Ronan Le Bras, Lianhui Qin, Youngjae Yu, Rowan Zellers, et al. Neurologic a* esque decoding: Constrained text generation with lookahead heuristics. In Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pages 780–799, 2022.
  • [24] Ximing Lu, Peter West, Rowan Zellers, Ronan Le Bras, Chandra Bhagavatula, and Yejin Choi. Neurologic decoding:(un) supervised neural text generation with predicate logic constraints. In Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pages 4288–4299, 2021.
  • [25] Ben Mann, N Ryder, M Subbiah, J Kaplan, P Dhariwal, A Neelakantan, P Shyam, G Sastry, A Askell, S Agarwal, et al. Language models are few-shot learners. arXiv preprint arXiv:2005.14165, 2020.
  • [26] Chen Mark, Tworek Jerry, Jun Heewoo, Yuan Qiming, Pinto Henrique Ponde de Oliveira, Kaplan Jared, Edwards Harrison, Burda Yuri, Joseph Nicholas, Brockman Greg, et al. Carr andrew n. Leike Jan, Achiam Joshua, Misra Vedant, Morikawa Evan, Radford Alec, Knight Matthew, Brundage Miles, Murati Mira, Mayer Katie, Welinder Peter, McGrew Bob, Amodei Dario, McCandlish Sam, Sutskever Ilya, and Zaremba Wojciech, 2021.
  • [27] Xupeng Miao, Gabriele Oliaro, Zhihao Zhang, Xinhao Cheng, Zeyu Wang, Rae Ying Yee Wong, Zhuoming Chen, Daiyaan Arfeen, Reyna Abhyankar, and Zhihao Jia. Specinfer: Accelerating generative llm serving with speculative inference and token tree verification. arXiv preprint arXiv:2305.09781, 2023.
  • [28] Xuefei Ning, Zinan Lin, Zixuan Zhou, Huazhong Yang, and Yu Wang. Skeleton-of-thought: Large language models can do parallel decoding. arXiv preprint arXiv:2307.15337, 2023.
  • [29] NVIDIA. TensorRT-LLM. https://github.com/NVIDIA/TensorRT-LLM. Accessed: 2024-05.
  • [30] Matt Post and David Vilar. Fast lexically constrained decoding with dynamic beam allocation for neural machine translation. In Marilyn Walker, Heng Ji, and Amanda Stent, editors, Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers), pages 1314–1324, New Orleans, Louisiana, June 2018. Association for Computational Linguistics.
  • [31] Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M Smith, et al. Recipes for building an open-domain chatbot. arXiv preprint arXiv:2004.13637, 2020.
  • [32] Noam Shazeer. Fast transformer decoding: One write-head is all you need. arXiv preprint arXiv:1911.02150, 2019.
  • [33] Philippe Tillet, Hsiang-Tsung Kung, and David Cox. Triton: an intermediate language and compiler for tiled neural network computations. In Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages, pages 10–19, 2019.
  • [34] Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, et al. Llama: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023.
  • [35] Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, et al. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
  • [36] Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving, and Irina Higgins. Solving math word problems with process-and outcome-based feedback. arXiv preprint arXiv:2211.14275, 2022.
  • [37] Xuezhi Wang, Jason Wei, Dale Schuurmans, Quoc Le, Ed Chi, Sharan Narang, Aakanksha Chowdhery, and Denny Zhou. Self-consistency improves chain of thought reasoning in language models. arXiv preprint arXiv:2203.11171, 2022.
  • [38] Jason Wei, Xuezhi Wang, Dale Schuurmans, Maarten Bosma, Fei Xia, Ed Chi, Quoc V Le, Denny Zhou, et al. Chain-of-thought prompting elicits reasoning in large language models. Advances in Neural Information Processing Systems, 35:24824–24837, 2022.
  • [39] Sean Welleck, Jiacheng Liu, Ximing Lu, Hannaneh Hajishirzi, and Yejin Choi. Naturalprover: Grounded mathematical proof generation with language models. Advances in Neural Information Processing Systems, 35:4913–4927, 2022.
  • [40] Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Clement Delangue, Anthony Moi, Pierric Cistac, Tim Rault, Rémi Louf, Morgan Funtowicz, et al. Huggingface’s transformers: State-of-the-art natural language processing. arXiv preprint arXiv:1910.03771, 2019.
  • [41] Yuxi Xie, Kenji Kawaguchi, Yiran Zhao, James Xu Zhao, Min-Yen Kan, Junxian He, and Michael Xie. Self-evaluation guided beam search for reasoning. Advances in Neural Information Processing Systems, 36, 2024.
  • [42] Shunyu Yao, Dian Yu, Jeffrey Zhao, Izhak Shafran, Thomas L Griffiths, Yuan Cao, and Karthik Narasimhan. Tree of thoughts: Deliberate problem solving with large language models. arXiv preprint arXiv:2305.10601, 2023.
  • [43] Lu Ye, Ze Tao, Yong Huang, and Yang Li. Chunkattention: Efficient self-attention with prefix-aware kv cache and two-phase partition. arXiv preprint arXiv:2402.15220, 2024.
  • [44] Yao Zhao, Zhitian Xie, Chenyi Zhuang, and Jinjie Gu. Lookahead: An inference acceleration framework for large language model with lossless generation accuracy. arXiv preprint arXiv:2312.12728, 2023.
  • [45] Lianmin Zheng, Liangsheng Yin, Zhiqiang Xie, Jeff Huang, Chuyue Sun, Cody Hao Yu, Shiyi Cao, Christos Kozyrakis, Ion Stoica, Joseph E Gonzalez, et al. Efficiently programming large language models using sglang. arXiv preprint arXiv:2312.07104, 2023.

Appendix A Appendix

A.1 Components of System Support for DeFT

The left part of Figure 6 shows the coordinations of different components for efficient and flexible tree-based decoding. The details of functions for system components of DeFT are as below:

  1. 1.

    Branch Controller: It makes the tree decoding process forced by a user-defined function (e.g. branch to two children every 3333 iterations, as the example shown in the right of Figure 6). Tree-search-based algorithms can be applied here using the decoding tree’s topology information.

  2. 2.

    Sequence Tree Manager: It maintains the topology of the decoding tree based on the tree operations and tokens from the Branch Controller. The tree operations like pruning and branching will be executed by Tree Handler in this component. Branch Result Storage will record token generation results of all branches in the decoding tree, and output when the decoding stops.

  3. 3.

    KV cache Manager: It will maintain KV cache with a tree structure. A map between sequence IDs in the decoding tree and KV cache index is kept, which will be updated based on KV operations555 e.g. when a node is pruned in the decoding tree, its KV space will be evicted using a Remove operation. from the Sequence Tree Manager. We provide both paged [20] and unpaged memory management in this part to fit different attention kernels.

  4. 4.

    Model Interface: pass input metadata to DeFT Attention kernel and MLP module, then return logits and memory pointers of updated KV cache.

Refer to caption
Figure 6: Illustration of DeFT. (Left) System overview. (Right) The data flow using a decoding tree example.

The right part of Figure 6 further showcases the key data flow of the system through a decoding tree example: input metadata will be extracted by three components we mentioned above, then loaded from HBM to shared memory in a group manner after the QKV Preparation Phase discussed in Section 3.3. Then QKV groups will be processed by DeFT Attention Kernel in Attention Calculation Phase of DeFT. See details of techniques in these two phases in Appendix A.4.

A.2 Discussion of Tree-based Decoding

Refer to caption
(a) (Left) Sequence KV with queries in a tree for parallel decoding [27, 5], where a causal mask is applied to record the causal information among queries in a tree of tokens. (Right) Tree KV with parallel queries for shared prefixes in multi-step reasoning.
Refer to caption
(b) Bit Mask in SpecInfer [27] to record the causal information between query tokens in a tree structure. The decoding tree is in the left part of 7a.
Figure 7: Discussion of tree-based decoding with tree queries [27] and tree KV.

Tree-based decoding could have tree-structured KV cache for storage with awareness of shared prefixes [45], or tree-structured queries in parallel/speculative decoding [27, 5], as shown in Figure 7. A general decoding could both do with tree KV and tree queries, which could reduce redundancy (e.g. IO, storage, computation, etc) of shared prefixes, as well as increase the generated tokens per decoding iteration.

The existing inference frameworks [45, 9] focused on tree-based decoding efficiency primarily aim to: (1) reduce memory footprints [45] to enable larger batch sizes for higher throughput; (2) reuse the prompt cache [9] to avoid recomputation of the KV cache for faster time-to-first-token (TTFT). However, their designs do not specifically target reducing the wall-clock time of the entire decoding process. We observe that the tree-structured feature of LLM inference could provide us some advantages to speed up the decoding itself.

Analysis of speedup potential in tree-based decoding.

In tree-based decoding, KV cache and queries can be structured in a tree. Not only can we store KV cache in a tree, but also we can load QKV with awareness of tree topology during attention calculation, to minimize the expensive IO between HBM and on-chip shared memory of GPUs. We explain it in two case studies of complex scenarios with tree-structured interactions: (1) multi-step reasoning [42, 41]; (2) speculative decoding [5, 27].

Refer to caption
Figure 8: Analysis for two case studies of tree-based decoding. (Left) Multi-step reasoning. (Right) Speculative decoding. Blue boxes mean shareable past KV cache in storage and memory access during the tree attention calculation, while yellow boxes means the KV cache of generated context.

Case study 1: multi-step reasoning.

As shown in the left part of Figure 8, we can summarize process of multi-step reasoning [11, 42, 4] to three phases: (1) Thought Generation: generate k candidates for the next thought step based on a generation prompt Pgsubscript𝑃𝑔P_{g}italic_P start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT and previous steps S𝑆Sitalic_S; (2) Thought Evaluation: When presented with a frontier of various thoughts, a LLM as state evaluator measures previous thoughts S𝑆Sitalic_S based on an evaluation prompt Pesubscript𝑃𝑒P_{e}italic_P start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT towards resolving the problem. This assessment acts as a heuristic for the search algorithm, guiding it on which states to pursue further and the sequence in which to explore them; (3) Tree Search-based Expansion: play different search algorithms [23, 21, 41] to explore search space, which influences the future tree topology. In both (1) and (2), we can share IO of KV cache for Pgsubscript𝑃𝑔P_{g}italic_P start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT/Pesubscript𝑃𝑒P_{e}italic_P start_POSTSUBSCRIPT italic_e end_POSTSUBSCRIPT and S𝑆Sitalic_S during tree attention calculation.

Case study 2: speculative decoding.

As shown in the right part of Figure 8, we can summarize process of speculative decoding [5, 27] to tree phases: (1) Token Tree Generation: multiple small draft models [27] or fine-tuned heads [5] generate multiple sequences of tokens based on prompt P𝑃Pitalic_P, then they are merged to a speculated token tree Ttsubscript𝑇𝑡T_{t}italic_T start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, which is very fast (e.g. 1%percent11\%1 % of time overhead in SpecInfer [27]); (2) Token Verification: based on these tree-structured token candidates Ttsubscript𝑇𝑡T_{t}italic_T start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, verify the correctness of its tokens against an LLM’s output, where tree-attention calculation is the bottleneck of the process [27]. In (2), we can share IO of KV cache for P𝑃Pitalic_P and S𝑆Sitalic_S during tree attention calculation.

Why existing tree-attention algorithms are not enough?

The existing tree-attention algorithms are either in-efficient in memory access [5, 27] or not suitable for general tree-based decoding [27] with more than 64 tokens in the token tree.

  • In SpecInfer[27], as shown in Figure 7b, a bit mask is utilized to record the causal information among queries of a token tree. Each token tisubscript𝑡𝑖t_{i}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in queries will have a 64-bit Int as a bit mask, where j-th bit means the causal relationship between query of tisubscript𝑡𝑖t_{i}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and KV cache of tjsubscript𝑡𝑗t_{j}italic_t start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. The advantage of this mask design is that it greatly reduces IO, but it results in the maximum number of tree tokens being only 64, which is not practical for scenarios with tree-structured KV cache. What’s more, it is not IO-aware for KV cache as it will load KV cache of the entire tree for each query.

  • Medusa [5] is suitable for general tree-based decoding, but it is not hardware-efficient due to significant IOs of a dense causal mask and partial results during attention calculation (e.g. Softmax).

A.3 Discussion of Concurrent Works

There are some concurrent works [3, 43, 18] in attention algorithm design for single-context large-batch sampling, where the goal is to generate multiple sequences from a single context(e.g. system prompt or few-shot examples), which is a special case of tree-based decoding with a depth of 1. The design of their algorithms are based on this feature, which means they can not suit well in attention calculation of a tree with more than two levels of prefixes with efficiency.

Insights and techniques in common.

Both concurrent works and DeFT have the insight that memory access is the bottleneck of LLM inference, and decomposing attention across subsequences to reduce the memory access of the prefix KV: (1) calculate attention Apsubscript𝐴𝑝A_{p}italic_A start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT, Assubscript𝐴𝑠A_{s}italic_A start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT over prefix and suffixes, respectively; (2) get finial attention by online softmax merging [6, 7] based on Apsubscript𝐴𝑝A_{p}italic_A start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT and Assubscript𝐴𝑠A_{s}italic_A start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT. Here are the details of the correctness proof:

  • Let’s say we have key tensor KR(lkv,d)𝐾superscript𝑅subscript𝑙𝑘𝑣𝑑K\in R^{(l_{kv},d)}italic_K ∈ italic_R start_POSTSUPERSCRIPT ( italic_l start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT, value tensor VR(lkv,d)𝑉superscript𝑅subscript𝑙𝑘𝑣𝑑V\in R^{(l_{kv},d)}italic_V ∈ italic_R start_POSTSUPERSCRIPT ( italic_l start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT, and query tensor QR(lq,d)𝑄superscript𝑅subscript𝑙𝑞𝑑Q\in R^{(l_{q},d)}italic_Q ∈ italic_R start_POSTSUPERSCRIPT ( italic_l start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT. Consider the general case K and V are partitioned across the sequence (row) dimension into two parts for prefix and suffixes, respectively: K=KpKs𝐾conditionalsubscript𝐾𝑝subscript𝐾𝑠K=K_{p}\parallel K_{s}italic_K = italic_K start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ∥ italic_K start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT, and V=VpVs𝑉conditionalsubscript𝑉𝑝subscript𝑉𝑠V=V_{p}\parallel V_{s}italic_V = italic_V start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ∥ italic_V start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT, with parallel-to\parallel denoting concatenation along the row axis.

  • We calculate the attention Apsubscript𝐴𝑝A_{p}italic_A start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT, Assubscript𝐴𝑠A_{s}italic_A start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT over prefix and suffixes, where

    Ap=Q,Kp,Vp,As=Q,Ks,Vs,formulae-sequencesubscript𝐴𝑝𝑄subscript𝐾𝑝subscript𝑉𝑝subscript𝐴𝑠𝑄subscript𝐾𝑠subscript𝑉𝑠A_{p}=\langle Q,K_{p},V_{p}\rangle,\quad A_{s}=\langle Q,K_{s},V_{s}\rangle,italic_A start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT = ⟨ italic_Q , italic_K start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ⟩ , italic_A start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT = ⟨ italic_Q , italic_K start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ⟩ ,

    and

    q,k,v=Softmax(qkTd)v.𝑞𝑘𝑣Softmax𝑞superscript𝑘𝑇𝑑𝑣\langle q,k,v\rangle=\operatorname{Softmax}\left(\frac{qk^{T}}{\sqrt{d}}\right% )v.⟨ italic_q , italic_k , italic_v ⟩ = roman_Softmax ( divide start_ARG italic_q italic_k start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) italic_v .
  • We calculate LogSumExp (LSE) as a weight of merging Apsubscript𝐴𝑝A_{p}italic_A start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT and Assubscript𝐴𝑠A_{s}italic_A start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT. We define LSE(q,k)=log((exp(qkTd)))𝐿𝑆𝐸𝑞𝑘𝑞superscript𝑘𝑇𝑑LSE(q,k)=\log\left(\sum\left(\exp\left(\frac{qk^{T}}{\sqrt{d}}\right)\right)\right)italic_L italic_S italic_E ( italic_q , italic_k ) = roman_log ( ∑ ( roman_exp ( divide start_ARG italic_q italic_k start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) ) ).

  • We have

    Q,K,V=ApeLSE(Q,Kp)+AseLSE(Q,Ks)eLSE(Q,Kp)+eLSE(Q,Ks).𝑄𝐾𝑉subscript𝐴𝑝superscript𝑒LSE𝑄subscript𝐾𝑝subscript𝐴𝑠superscript𝑒LSE𝑄subscript𝐾𝑠superscript𝑒LSE𝑄subscript𝐾𝑝superscript𝑒LSE𝑄subscript𝐾𝑠\displaystyle\langle Q,K,V\rangle=\frac{A_{p}e^{\operatorname{LSE}(Q,K_{p})}+A% _{s}e^{\operatorname{LSE}(Q,K_{s})}}{e^{\operatorname{LSE}(Q,K_{p})}+e^{% \operatorname{LSE}(Q,K_{s})}}.⟨ italic_Q , italic_K , italic_V ⟩ = divide start_ARG italic_A start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT roman_LSE ( italic_Q , italic_K start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT + italic_A start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT roman_LSE ( italic_Q , italic_K start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT roman_LSE ( italic_Q , italic_K start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT roman_LSE ( italic_Q , italic_K start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT end_ARG . (2)
Table 8: Comparison among DeFT and concurrent works in single-context large-batch sampling scenarios [3, 43, 18]. More \star means more balanced workloads after tree split, which also shows how insensitive the acceleration is to the tree topology.
Method Chunk-Attention [43] Hygragen [18] Bifurcated-Attention [3] DeFT-Node DeFT-Flatten
IO-aware levels 2 (depth<=1) 2 (depth<=1) 2 (depth<=1) all(every depth) all(every depth)
Tree KV split granularity by node first, then by block by tree depth by tree depth by tree node flatten tree, then by block
Load-balanced level \star\star⋆ ⋆ \star \star \star \star\star\star⋆ ⋆ ⋆
Goal metrics throughput throughput latency latency latency

Comparison of differences.

The existing works of single-context large-batch sampling are not hardware-efficient for general tree-based decoding with two reasons, as shown in Table 8:

  • They are designed for decoding trees with only two levels—prefixes at the root and suffixes at depth 1. For decoding trees with multiple levels of prefixes, their algorithm can only reduce the IO of the prompt at the root of the tree. However, in scenarios such as multi-step reasoning [42, 4, 11], the token length of non-root prefixes can also be very long (e.g., thousands of tokens), and their KV cache’s IO is not reused. DeFT can reuse KV IO of all non-leaf prefixes in a general decoding tree, providing greater acceleration potential.

  • They have not addressed the unbalanced workload problem in tree-based decoding. Nodes in the decoding tree can vary significantly, making it crucial to split the tree and group QKV in a way that ensures balanced calculations for each QKV group. Simply dividing based on depth alone is insufficient.

A.4 Discussion of Techniques in Efficient Attention Algorithm Design

Refer to caption
Figure 9: Operations of Tree Attention-Medusa [5]. No Kernel Fusion or Tiling strategy is applied, which introduces significant IO of partial results like 𝐐𝐊superscript𝐐𝐊top\mathbf{Q}\mathbf{K}^{\top}bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, DCM, and Softmax between GPU global memory and on-chip shared memory.
Table 9: Technique list of DeFT. What we propose is in red. The details of the first four techniques are in Section 3.3, while the details of the following techniques are discussed in this chapter.
Technique Goal
KV-guided Grouping with Tree Split High utilization of GPU and minimal KV cache IO between HBM and shared memory.
DeFT-Node Tree Split High utilization of GPU and simple tree attention calculation.
DeFT-Flatten Tree Split High utilization of GPU and balanced attention calculation.
Bit Causal Mask [27] Record causal information of tokens in the decoding tree with little IO cost.
Kernel Fusion  [6, 7] Reduce partial results IO (e.g. 𝐐𝐊Tsuperscript𝐐𝐊𝑇\mathbf{Q}\mathbf{K}^{T}bold_QK start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT, Mask M𝑀Mitalic_M, and Softmax, etc ).
Tiling  [6, 7] Enable attention calculation within limited size of GPU’s shared memory.
Tree-topology Aware Global Reduction To get the correct tree attention of the entire decoding tree.

In this subsection, we summarize and discuss the common techniques in existing designs of efficient attention algorithms and kernels : (1) Kernel Fusion with Tiling strategy [6, 15, 27]; (2) Tree-topology Aware Causal Mask [27, 5]; (3) KV Split with Global Reduction[15]. Then we explain the details of design in DeFT Attention Kernel, where the techniques are in Table 9.

Kernel Fusion is a common technique of IO reduction: if multiple operations are performed on the same input, it is more efficient to load the input once from HBM rather than loading it multiple times for each operation; Similarly, the same principle applies when transferring output from shared memory to HBM. To fuse all the attention operations into one GPU kernel with the limited size of shared memory, we further utilize the commonly employed Tiling strategy [6, 7]: split queries and KV cache within each QKV group to small blocks to prevent materialization of attention matrix in HBM by computing attention within the limited size of shared memory, then incrementally performing the softmax reduction as the formulation in Equation 2 to reconstruct the attention.

Remark A.1 (Importance of tiling and fused kernel during Attention Calculation Phase).

Methods in this phase can be roughly divided into two categories: (1) without tiling and kernel fusion: Tree Attention in Medusa [5], which introduces significant IO operations for partial results (i.e.. 𝐐𝐊superscript𝐐𝐊top\mathbf{Q}\mathbf{K}^{\top}bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT and Softmax), as shown in Figure 9; (2) with tiling and a fused kernel: Flash Decoding [7], Tree Attention in SpecInfer [27] and our DeFT.

Refer to caption
Figure 10: Overview of two stages in DeFT Attention Kernel (DeFT-Node for example). Stage 1–calculate partial attentions. Based on the QKV grouping results after KV-Guided Grouping Strategy with Tree Split as mentioned above, each QKV group (Gisubscript𝐺𝑖G_{i}italic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT) will be allocated to a thread block for Flash Attention [6] calculation with common Kernel Fusion and Tiling strategy. Similar to Flash-Decoding [7], we not only get partial attention (PAi𝑃subscript𝐴𝑖PA_{i}italic_P italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT) but also return “LogSumExp” (LSEi𝐿𝑆subscript𝐸𝑖LSE_{i}italic_L italic_S italic_E start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT) as a weight parameter for the next stage’s reduction. Stage 2–global reduction. Upon receiving PAi𝑃subscript𝐴𝑖PA_{i}italic_P italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and LSEi𝐿𝑆subscript𝐸𝑖LSE_{i}italic_L italic_S italic_E start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for each QKV group Gisubscript𝐺𝑖G_{i}italic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, DeFT now performs a Tree-Topology-Aware Global Reduction (DeFT_reduction𝐷𝑒𝐹𝑇_𝑟𝑒𝑑𝑢𝑐𝑡𝑖𝑜𝑛DeFT\_reductionitalic_D italic_e italic_F italic_T _ italic_r italic_e italic_d italic_u italic_c italic_t italic_i italic_o italic_n). Guided by the tree topology among sequence nodes of KV in the decoding tree, DeFT logically remaps the partial results of attention and LogSumExp to get the correct final attention for each query after reduction. The decoding tree is the same as the one in the left of Figure 3. SMi𝑆subscript𝑀𝑖SM_{i}italic_S italic_M start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT means the streaming multiprocessor i𝑖iitalic_i in GPU.

The Tree-topology Aware Causal Mask (Causal Mask for short) is introduced in speculative decoding works [27, 5] to facilitate the calculation of attention for all queries within a decoding tree using a single GPU kernel. It achieves this by recording the causal relationships among queries and KV cache in the decoding tree. As depicted in Figure 7, while originally designed for tree-based decoding with KV cache for a sequence of tokens and tree-structured queries, the Causal Mask can also be adapted to tree decoding with tree-structured KV cache and parallel queries—a configuration targeted by DeFT to enhance efficiency.

Remark A.2 (The effects of introducing a causal mask).

Causal mask brings two parts of redundancy:

  • Memory Access. Medusa [5] materializes the dense causal mask (DCM) in HBM to record the causal information between nqsubscript𝑛𝑞n_{q}italic_n start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT tokens in queries and nkvsubscript𝑛𝑘𝑣n_{kv}italic_n start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT tokens in the KV cache, thereby introducing a significant IO cost for loading this nq×nkvsubscript𝑛𝑞subscript𝑛𝑘𝑣n_{q}\times n_{kv}italic_n start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_n start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT-sized mask to shared memory. SpecInfer [27] introduces a 64-bit integer as a bit causal mask (BCM) to record the causal information among up to 64 tokens, which incurs minimal IO cost from HBM to shared memory but is not suitable for decoding trees with more than 64 tokens. Details regarding the design of the bit mask in SpecInfer are discussed in Appendix A.2.

  • Computation. In addition to the computational cost of generating the causal mask itself, there is an additional redundancy in computation: many of the matrix multiplication results of 𝐐𝐊superscript𝐐𝐊top\mathbf{Q}\mathbf{K}^{\top}bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT are masked out and never utilized. Both Medusa and SpecInfer have this issue.

DeFT-Node in Appendix A.7 does not require a causal mask and there is no IO and calculation redundancy caused by masking. DeFT-Flatten in Appendix A.8 adopts a bit causal mask insipred by SpecInfer [27]to minimize the IO of the causal mask. Details of the bit mask design is in the left of Figure 3.

Split is introduced to improve GPU utilization in sequence-based decoding [15], which is necessary when the parallelism is limited by a small batch size for long-context scenarios. Flash-Decoding splits long KV and group QKV based on Q first, then these groups will be allocated to different streaming multi-processors (SMs) in the GPU to get partial attention via Flash Attention [6].

Refer to caption
(a) Left: Illustration of DeFT-Node Attention Kernel with two stages. Right: Global reduction kernel called in DeFT stage 2 illustrated in Figure 11b. QKV Groups G0subscript𝐺0G_{0}italic_G start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT,G1subscript𝐺1G_{1}italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and G2subscript𝐺2G_{2}italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT are from DeFT QKV groups in Figure 3.
Refer to caption
(b) Stage 2 of DeFT: Global Reduction. Based on tree topology in Figure 3, we can group LogSumExp and Partial Attention based on Query, then we call the Global reduction kernel in the right of Figure 11a to get the final attention.
Figure 11: Detailed attention operations of DeFT kernel (DeFT-Node for example). Based on the same decoding tree in Figure 3.

To obtain the accurate final attention, partial attentions from QKV groups with identical queries need to be grouped for Global Reduction.

Similarly, DeFT also split the decoding tree to different QKV groups for high utilization of GPUs, which is the KV-Guided Grouping Strategy with Tree Split strategy we propose in subsection 3.3, as illustrated in the bottom right part of Section 3. To obtain the correct tree attention, DeFT also requires a global reduction. However, the global reduction in Flash-Decoding is for sequence-based decoding, which cannot aware the tree-topology for global reduction in tree-based decoding. Therefore, we propose Tree-Topology-Aware Global Reduction, as shown in the Figure 11b.

Based on the techniques mentioned above, we designed the DeFT Attention Kernel with two stages, as shown in Figure 10, to execute the attention operations after the QKV Preparation Phase of DeFT, which we elaborated on in Section 3.3. For more details on the DeFT Attention Kernel, see Figure 11. The attention operations of DeFT-Flatten are omitted because they are very similar to those of DeFT-Node, except for the usage of the bit causal mask for tree attention calculation.

A.5 Discussion of Workloads Generation

Refer to caption
Figure 12: The detailed procedure of reconstructing tree templates for multi-step reasoning. (Left) Reconstructing reasoning trees from practical reasoning records as outlined in [4] involves capturing the following aspects: (1) the structure of trees, characterized by their depth d𝑑ditalic_d and width w𝑤witalic_w; (2) the token length associated with each thought; and (3) the best thought at each depth along with its corresponding score. For the task of document merging, the tree depth is set to d=3𝑑3d=3italic_d = 3, with a width of w=10𝑤10w=10italic_w = 10 at each depth. For sorting 128 numbers, the depth is reduced to d=10𝑑10d=10italic_d = 10, while maintaining the same width of w=10𝑤10w=10italic_w = 10. See details of tree topology for other multi-step reasoning tasks in Table 10. (Right) Utilizing the extracted thought information from Left, we can generate tree templates for decoding, encompassing branch records and prune records. These records are instrumental in guiding the tree decoding process to produce decoding trees that faithfully replicate the structure of the tree-of-thoughts.

The rationality of workload settings.

To validate DeFT’s acceleration across various decoding tree topologies, we compiled decoding trees from real tasks, covering the following three aspects:

  • Few-shot prompting: This involves a two-level tree with a prompt prefix and multiple branches for suffix generation. As a case study, we fixed the prompt length at approximately 4000 tokens and varied the number of branches.

  • Multi-step reasoning [42, 11, 4]: We recorded the tree shapes, prompts, and lengths of all thoughts from real reasoning task interactions [4], using these as guidance for tree decoding to validate DeFT’s acceleration in thought generation of reasoning (the thought evaluation phase follows a similar pattern). See details of generation in Figure 12.

  • Speculative decoding [5, 27]: We used the token tree topology from Medusa [5] and recorded real interaction data with APPS [12] as prompt dataset, including the length of accepted tokens at each step. This served as guidance to simulate the bottleneck of speculative decoding—the attention computation during the token verification phase.

Table 10: Details of generated workloads. For multi-steps reasoning, we include these 4 tasks from  [4]: (1) Sorting 128 numbers (sorting in short); (2) Document merging (document in short); (3) Keyword counting (keyword in short); (4) Set intersection (set in short). d𝑑ditalic_d, w𝑤witalic_w means depth and width of the tree, respectively. t𝑡titalic_t means the token tree size for speculative decoding, where the tree topology is from Medusa [5].
Task Tree Shape Decoding Tree Source Records Contents
Multi-step reasoning sorting: d=10𝑑10d=10italic_d = 10, w=10𝑤10w=10italic_w = 10 ToT-BFS in  [4] Prompt [4],tree shape, thought size, branch records, prune records
document: d=3𝑑3d=3italic_d = 3, w=10𝑤10w=10italic_w = 10
keyword:d=5𝑑5d=5italic_d = 5, w=10𝑤10w=10italic_w = 10
set:d=8𝑑8d=8italic_d = 8, w=10𝑤10w=10italic_w = 10
Few-shot prompting d=1𝑑1d=1italic_d = 1, w=10,20,30𝑤102030w=10,20,30italic_w = 10 , 20 , 30
Speculative decoding t=32,64,128,256𝑡3264128256t=32,64,128,256italic_t = 32 , 64 , 128 , 256 Medusa [5] APPS [12] Prompt, token tree shape, accepted token length per step

The rationality of our experiment paradigm.

Our experimental paradigm involves: first, obtaining decoding trees from real tree-based decoding tasks, and second, replicating these decoding trees exactly within the same framework by enforcing LLM inference, to investigate the impact of attention acceleration on wall clock time performance. This paradigm has two advantages:

  • We can utilize decoding trees from real tasks as a benchmark within a unified system, enabling fair comparison of different attention algorithms in terms of wall-clock time performance. This comparison is possible despite the algorithms being based on distinct systems, such as variations in memory management implementations for their kernels.

  • We consider both the unique characteristics of tasks with diverse tree structures and the broader applicability of general tree-based decoding. See details of generated workloads for other multi-step reasoning tasks in Table 10.

A.6 Additional Results

End-to-end latency and IOs with breakdowns.

The details of end-to-end latency and IO comparsion among DeFT and baselines are in Table 11 and Table 12,respectively. We provide IO breakdowns of multi-step reasoning tasks, where the attention occupies 27.7-37.6% overhead of Radix Attention with a paged memory management. Unpaged memory will introduce about 40-75.6% overhead in end-to-end latency, due to the materialization of QKVs for tree-based decoding with a sequence-based attention kernel [6, 7].

Table 11: Average end-to-end latency (second) of each tree. b𝑏bitalic_b means tree width. t𝑡titalic_t denotes the token tree size (i.e., the number of tree-structured queries). Speedup Upper-bound(no attention) means the wall-clock time speedup we could obtain for the best baseline (Radix Attention) if we remove the attention calculation. \star means out of memory for A100 80GB, while \spadesuit means not supported/implemented.
Memory Method Few-shot Prompting Multi-Step Reasoning Speculative Decoding
b=20 b=30 b=50 Sorting Document Keyword Set t=32 t=64 t=128 t=256
Unpaged Flash-Decoding 78.96 131.19 191.09 429.65 241.20 32.75 51.76 574.50 1128.45 \star \star
Tree Attention-Medusa 52.58 103.90 144.07 380.87 236.86 33.52 50.10 263.40 483.35 924.97 1881.51
Paged Radix Attention 12.37 14.08 16.54 104.79 69.61 11.25 17.03 64.57 86.12 145.88 263.76
DeFT-Node 17.53 21.19 \spadesuit 114.06 81.87 15.20 22.55 84.72 \spadesuit \spadesuit \spadesuit
DeFT-Flatten 9.98 10.99 12.48 94.67 66.95 10.90 16.10 44.94 50.48 65.44 104.65
Speedup of DeFT-Flatten 1.24×\times× 1.28×\times× 1.33×\times× 1.10×\times× 1.03×\times× 1.03×\times× 1.05×\times× 1.43×\times× 1.70×1.70\times1.70 × 2.22×2.22\times2.22 × 2.52×2.52\times2.52 ×
Upper-bound(no attention) 1.71×\times× 2.08×\times× 2.51×\times× 1.96×\times× 1.82×\times× 1.70×\times× 1.76×\times× 2.01×\times× 2.72×\times× 3.99×\times× 5.12×\times×
Table 12: Average end-to-end IO (TB). Data format is Left/Right: (Left) KV Cache IO; (Right) partial results IO, including 𝐐𝐊Tsuperscript𝐐𝐊𝑇\mathbf{Q}\mathbf{K}^{T}bold_QK start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT,𝐐𝐊/scsuperscript𝐐𝐊topsubscript𝑠𝑐\mathbf{Q}\mathbf{K}^{\top}/{s_{c}}bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT / italic_s start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT, Mask M𝑀Mitalic_M, 𝐌+𝐐𝐊/sc𝐌superscript𝐐𝐊topsubscript𝑠𝑐\mathbf{M}+\mathbf{Q}\mathbf{K}^{\top}/{s_{c}}bold_M + bold_QK start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT / italic_s start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT and Softmax. b𝑏bitalic_b means tree width. t𝑡titalic_t denotes the token tree size (i.e., the number of tree-structured queries).\star means out of memory for A100 80GB, while \spadesuit means not supported/implemented.
Method Few-shot Prompting Multi-Step Reasoning Speculative Decoding
b=20 b=30 b=50 Sorting Document Keyword Set t=32 t=64 t=128 t=256
Flash-Decoding 17.62/0.00 26.43/0.00 44.05/0.00 59.96/0.00 39.74/0.00 4.68/0.00 7.01/0.00 128.72/0.00 255.16/0.00 \star \star
Tree Attention-Medusa 1.68/1.05 2.10/1.98 2.94/4.61 12.40/3.69 10.57/3.24 0.58/0.18 1.04/0.27 4.02/4.03 4.15/8.33 4.18/16.77 4.32/34.70
Radix Attention 17.62/0.00 26.43/0.00 44.05/0.00 59.96/0.00 39.74/0.00 4.68/0.00 7.01/0.00 131.45/0.00 256.79/0.00 522.05/0.00 1044.10/0.00
DeFT-Node 1.68/0.00 2.10/0.00 \spadesuit 12.40/0.00 10.57/0.00 0.58/0.00 1.04/0.00 4.05/0.00 \spadesuit \spadesuit \spadesuit
DeFT-Flatten 1.68/0.00 2.10/0.00 2.94/0.00 12.40/0.01 10.57/0.01 0.58/0.00 1.04/0.00 4.10/0.00 4.11/0.00 4.16/0.00 4.35/0.00
IO reduction of DeFT-Flatten(%percent\%%) 90.47/100.00 92.1/100.00 93.33/100.00 79.32/99.73 73.40/99.70 87.61/100.00 85.16/100.00 96.88/100.00 98.40/100.00 99.20/100.00 99.58/100.00
Refer to caption
(a) Latency breakdown for task sorting.
Refer to caption
(b) Latency breakdown for task document.
Refer to caption
(c) Latency breakdown for task set.
Refer to caption
(d) Latency breakdown for task keyword.
Figure 13: Latency breakdown for 4 multi-step reasoning tasks [4].

The influence of width in decoding trees.

We observe that the effectiveness of attention speedup varies with different decoding tree topologies. Considering the simplest tree structure, a prompt with several suffixes—given a prompt that is not very short, one of the most important factors for speedup is the extent to which we can reuse its KV cache IO. This can be measured by the width of the tree. More specifically, it is determined by the number of queries per iteration. Therefore, we fix the prompt length at 4000 and vary the width of the decoding tree in few-shot prompting (which also indicates how many requests share the same prompt). Then, as shown in Figure 14, we evaluate DeFT-Flatten with the best baseline in attention calculation– Tree Attention-Medusa [5] (Medusa-Attn in the figure), as well as the best baseline in wall-clock time– Radix Attention [45], for the per-iteration latency over time.

We have the following observations:

  1. 1.

    When the tree width is 10, the attention overhead of DeFT-Flatten is nearly the same as Tree Attention-Medusa because the IO overhead of the dense causal mask (DCM) is small compared to that of the KV cache, but it is still 2×\times× faster in attention latency than Radix Attention thanks to the KV IO reuse.

  2. 2.

    As the tree width increases, the attention computation overhead of Tree Attention-Medusa grows faster because the size of the DCM is directly related to the tree width. A larger tree width means the IO of the DCM grows rapidly.

  3. 3.

    Since the tree topology consists of a fixed prefix with several suffixes, a larger tree width allows the prompt prefix’s KV cache to be reused more frequently during IO. This leads to a more significant end-to-end speedup—1.24×1.24\times1.24 × with a width of w=20𝑤20w=20italic_w = 20, and 1.33×1.33\times1.33 × with a width of w=50𝑤50w=50italic_w = 50—compared to Radix Attention.

  4. 4.

    As iterations progress, the length of the suffixes gradually approaches the length of the prefix, leading to a decrease in the speedup of DeFT-Flatten compared with Radix Attention.

Refer to caption
(a) Tree width is 10.
Refer to caption
(b) Tree width is 20.
Refer to caption
(c) Tree width is 30..
Refer to caption
(d) Tree width is 50.
Figure 14: Per iteration latency for few-shot prompting tasks with different tree width. e2e means end-to-end result, while Attn means only the attention overhead.

A.7 DeFT-Node Algorithm

Algorithm 1 DeFT-Node Algorithm-Phase 1: QKV Preparation.
  Input: query QR(bq,d)𝑄superscript𝑅subscript𝑏𝑞𝑑Q\in R^{(b_{q},d)}italic_Q ∈ italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT, Key cache list KL=(K0,KN1)𝐾𝐿subscript𝐾0subscript𝐾𝑁1KL=(K_{0},...K_{N-1})italic_K italic_L = ( italic_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … italic_K start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ), Value cache list VL=(V0,VN1)𝑉𝐿subscript𝑉0subscript𝑉𝑁1VL=(V_{0},...V_{N-1})italic_V italic_L = ( italic_V start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … italic_V start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ) for each sequence node in the tree, where N𝑁Nitalic_N is the total number of sequences in a tree, and Tree T𝑇Titalic_T with its topology information.
  for each q𝑞qitalic_q in Q𝑄Qitalic_Q with its global index idx𝑖𝑑𝑥idxitalic_i italic_d italic_x do
     /*Get KV indices of all prefixes’ for a query.*/
     QMapKV[idx]𝑄𝑀𝑎𝑝𝐾𝑉delimited-[]𝑖𝑑𝑥QMapKV[idx]italic_Q italic_M italic_a italic_p italic_K italic_V [ italic_i italic_d italic_x ]=GetPrefixKVIndices(q,KL,VL,T𝑞𝐾𝐿𝑉𝐿𝑇q,KL,VL,Titalic_q , italic_K italic_L , italic_V italic_L , italic_T)
  end for
  for each seq’s KV cache Ki,Visubscript𝐾𝑖subscript𝑉𝑖K_{i},V_{i}italic_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in KL,VL𝐾𝐿𝑉𝐿KL,VLitalic_K italic_L , italic_V italic_L with its KV indice i𝑖iitalic_i do
     /*Group each sequence’s KV with all queries that share it.*/
     Qisubscript𝑄𝑖Q_{i}italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT= GroupQueryToKV(Q,Ki,Vi,T𝑄subscript𝐾𝑖subscript𝑉𝑖𝑇Q,K_{i},V_{i},Titalic_Q , italic_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_T) Rbi,dQabsentsuperscript𝑅subscript𝑏𝑖𝑑𝑄\in R^{b_{i},d}\subset Q∈ italic_R start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d end_POSTSUPERSCRIPT ⊂ italic_Q
     KVMapQ[i]=Qi𝐾𝑉𝑀𝑎𝑝𝑄delimited-[]𝑖subscript𝑄𝑖KVMapQ[i]=Q_{i}italic_K italic_V italic_M italic_a italic_p italic_Q [ italic_i ] = italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
  end for
  Return QMapKV, KVMapQ

DeFT-Node has two phases-Phase 1-QKV Preparation and Phase 2-Attention Calculation.

Phase 2-Attention Calculation of DeFT has two stages.

  1. 1.

    Stage 1: Calculate Partial Attentions. We will apply Flash Attention of all QKV groups obtained after Phase 1-QKV Preparation of DeFT, to get partial attention and LogSumExp.

  2. 2.

    Stage 2: Global Reduction. We will remap partial attention and LogSumExp based on each query, and get final attention based on global reduction similar to Flash-Decoding [7].

Algorithm 2 DeFT-Node Algorithm-Phase 2: Attention Calculation.
  Input: query QR(bq,d)𝑄superscript𝑅subscript𝑏𝑞𝑑Q\in R^{(b_{q},d)}italic_Q ∈ italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT, Key cache list KL=(K0,KN1)𝐾𝐿subscript𝐾0subscript𝐾𝑁1KL=(K_{0},...K_{N-1})italic_K italic_L = ( italic_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … italic_K start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ), Value cache list VL=(V0,VN1)𝑉𝐿subscript𝑉0subscript𝑉𝑁1VL=(V_{0},...V_{N-1})italic_V italic_L = ( italic_V start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … italic_V start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ) for each sequence node in the tree, where N𝑁Nitalic_N is the total number of sequences in a tree, and Tree T𝑇Titalic_T with its topology information. QKV group information QMapKV𝑄𝑀𝑎𝑝𝐾𝑉QMapKVitalic_Q italic_M italic_a italic_p italic_K italic_V, KVMapQ𝐾𝑉𝑀𝑎𝑝𝑄KVMapQitalic_K italic_V italic_M italic_a italic_p italic_Q from QKV Preparation Phase.
  for each q𝑞qitalic_q in Q𝑄Qitalic_Q with its global index idx𝑖𝑑𝑥idxitalic_i italic_d italic_x do
     /*Allocate to store LogSumExp of Q@KT𝑄@superscript𝐾𝑇Q@K^{T}italic_Q @ italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT grouped by query.*/
     LogSumExp[idx]={}𝐿𝑜𝑔𝑆𝑢𝑚𝐸𝑥𝑝delimited-[]𝑖𝑑𝑥LogSumExp[idx]=\{\}italic_L italic_o italic_g italic_S italic_u italic_m italic_E italic_x italic_p [ italic_i italic_d italic_x ] = { }
     /*Allocate to store partial results of SoftMax(Q@KT)V𝑆𝑜𝑓𝑡𝑀𝑎𝑥𝑄@superscript𝐾𝑇𝑉SoftMax(Q@K^{T})Vitalic_S italic_o italic_f italic_t italic_M italic_a italic_x ( italic_Q @ italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) italic_V for each query.*/
     O[idx]={}𝑂delimited-[]𝑖𝑑𝑥O[idx]=\{\}italic_O [ italic_i italic_d italic_x ] = { }
  end for
  /*Allocate space for output after reduction.*/
  FO=(0)bq×dR(bq,d)𝐹𝑂subscript0subscript𝑏𝑞𝑑superscript𝑅subscript𝑏𝑞𝑑FO=(0)_{b_{q}\times d}\in R^{(b_{q},d)}italic_F italic_O = ( 0 ) start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_d end_POSTSUBSCRIPT ∈ italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT
  for each seq’s KV cache Ki,ViR(bkv,d),R(bkv,d)formulae-sequencesubscript𝐾𝑖subscript𝑉𝑖superscript𝑅subscript𝑏𝑘𝑣𝑑superscript𝑅subscript𝑏𝑘𝑣𝑑K_{i},V_{i}\in R^{(b_{kv},d)},R^{(b_{kv},d)}italic_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT , italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT in KL,VL𝐾𝐿𝑉𝐿KL,VLitalic_K italic_L , italic_V italic_L with its KV indice i𝑖iitalic_i do
     # Unroll for loop to SMs
     Qisubscript𝑄𝑖Q_{i}italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT= KVMapQ[i]R(bi,d)𝐾𝑉𝑀𝑎𝑝𝑄delimited-[]𝑖superscript𝑅subscript𝑏𝑖𝑑KVMapQ[i]\in R^{(b_{i},d)}italic_K italic_V italic_M italic_a italic_p italic_Q [ italic_i ] ∈ italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT
     /*Get partial attention oisubscript𝑜𝑖o_{i}italic_o start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for each QKV group, LogSumExp lsei𝑙𝑠subscript𝑒𝑖lse_{i}italic_l italic_s italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT of Q@KT𝑄@superscript𝐾𝑇Q@K^{T}italic_Q @ italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT in row for reduction.*/
     oi,lseisubscript𝑜𝑖𝑙𝑠subscript𝑒𝑖o_{i},lse_{i}italic_o start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_l italic_s italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = FlashAttention(Qi,Ki,Visubscript𝑄𝑖subscript𝐾𝑖subscript𝑉𝑖Q_{i},K_{i},V_{i}italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT)
     R(bi,d),Rbiabsentsuperscript𝑅subscript𝑏𝑖𝑑superscript𝑅subscript𝑏𝑖\in R^{(b_{i},d)},R^{b_{i}}∈ italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT , italic_R start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
     /*Map the partial results back to each query for reduction.*/
     for each query q𝑞qitalic_q in Qisubscript𝑄𝑖Q_{i}italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT with its group index gp_idx𝑔𝑝_𝑖𝑑𝑥gp\_idxitalic_g italic_p _ italic_i italic_d italic_x and global index idx𝑖𝑑𝑥idxitalic_i italic_d italic_x in Q𝑄Qitalic_Q do
        if iQMapKV[idx]𝑖𝑄𝑀𝑎𝑝𝐾𝑉delimited-[]𝑖𝑑𝑥i\in QMapKV[idx]italic_i ∈ italic_Q italic_M italic_a italic_p italic_K italic_V [ italic_i italic_d italic_x ] then
           LogSumExp[idx].append(lsei[gp_idx])formulae-sequence𝐿𝑜𝑔𝑆𝑢𝑚𝐸𝑥𝑝delimited-[]𝑖𝑑𝑥𝑎𝑝𝑝𝑒𝑛𝑑𝑙𝑠subscript𝑒𝑖delimited-[]𝑔𝑝_𝑖𝑑𝑥LogSumExp[idx].append(lse_{i}[gp\_idx])italic_L italic_o italic_g italic_S italic_u italic_m italic_E italic_x italic_p [ italic_i italic_d italic_x ] . italic_a italic_p italic_p italic_e italic_n italic_d ( italic_l italic_s italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ italic_g italic_p _ italic_i italic_d italic_x ] )
        end if
     end for
  end for
  for each q𝑞qitalic_q in Q𝑄Qitalic_Q with its global index idx𝑖𝑑𝑥idxitalic_i italic_d italic_x do
     # Unroll for loop to SMs
     if len(O[idx]𝑂delimited-[]𝑖𝑑𝑥O[idx]italic_O [ italic_i italic_d italic_x ])==len(QMapKV[idx]𝑄𝑀𝑎𝑝𝐾𝑉delimited-[]𝑖𝑑𝑥QMapKV[idx]italic_Q italic_M italic_a italic_p italic_K italic_V [ italic_i italic_d italic_x ]then
        /*Global reduction after collecting all partial results from QKV groups that contains q𝑞qitalic_q.*/
        LSEcat𝐿𝑆subscript𝐸𝑐𝑎𝑡LSE_{cat}italic_L italic_S italic_E start_POSTSUBSCRIPT italic_c italic_a italic_t end_POSTSUBSCRIPT= CatTensor(LogSumExp[idx]𝐿𝑜𝑔𝑆𝑢𝑚𝐸𝑥𝑝delimited-[]𝑖𝑑𝑥LogSumExp[idx]italic_L italic_o italic_g italic_S italic_u italic_m italic_E italic_x italic_p [ italic_i italic_d italic_x ])
        LSEmax𝐿𝑆subscript𝐸𝑚𝑎𝑥LSE_{max}italic_L italic_S italic_E start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT=RowMax(LSEcat𝐿𝑆subscript𝐸𝑐𝑎𝑡LSE_{cat}italic_L italic_S italic_E start_POSTSUBSCRIPT italic_c italic_a italic_t end_POSTSUBSCRIPT)
        Mid_L=0,Mid_O=0(1,d)formulae-sequence𝑀𝑖𝑑_𝐿0𝑀𝑖𝑑_𝑂superscript01𝑑Mid\_L=0,Mid\_O=0^{(1,d)}italic_M italic_i italic_d _ italic_L = 0 , italic_M italic_i italic_d _ italic_O = 0 start_POSTSUPERSCRIPT ( 1 , italic_d ) end_POSTSUPERSCRIPT
        for each lsej𝑙𝑠subscript𝑒𝑗lse_{j}italic_l italic_s italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT in LogSumExp[idx]𝐿𝑜𝑔𝑆𝑢𝑚𝐸𝑥𝑝delimited-[]𝑖𝑑𝑥LogSumExp[idx]italic_L italic_o italic_g italic_S italic_u italic_m italic_E italic_x italic_p [ italic_i italic_d italic_x ] do
           new_exp=elsejLSEmax𝑛𝑒𝑤_𝑒𝑥𝑝superscript𝑒𝑙𝑠subscript𝑒𝑗𝐿𝑆subscript𝐸𝑚𝑎𝑥new\_exp=e^{lse_{j}-LSE_{max}}italic_n italic_e italic_w _ italic_e italic_x italic_p = italic_e start_POSTSUPERSCRIPT italic_l italic_s italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_L italic_S italic_E start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
           Mid_L=Mid_L+new_exp𝑀𝑖𝑑_𝐿𝑀𝑖𝑑_𝐿𝑛𝑒𝑤_𝑒𝑥𝑝Mid\_L=Mid\_L+new\_expitalic_M italic_i italic_d _ italic_L = italic_M italic_i italic_d _ italic_L + italic_n italic_e italic_w _ italic_e italic_x italic_p
        end for
        for each lsej,oj𝑙𝑠subscript𝑒𝑗subscript𝑜𝑗lse_{j},o_{j}italic_l italic_s italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT in LogSumExp[idx],O[idx]𝐿𝑜𝑔𝑆𝑢𝑚𝐸𝑥𝑝delimited-[]𝑖𝑑𝑥𝑂delimited-[]𝑖𝑑𝑥LogSumExp[idx],O[idx]italic_L italic_o italic_g italic_S italic_u italic_m italic_E italic_x italic_p [ italic_i italic_d italic_x ] , italic_O [ italic_i italic_d italic_x ] do
           new_exp=elsejLSEmax𝑛𝑒𝑤_𝑒𝑥𝑝superscript𝑒𝑙𝑠subscript𝑒𝑗𝐿𝑆subscript𝐸𝑚𝑎𝑥new\_exp=e^{lse_{j}-LSE_{max}}italic_n italic_e italic_w _ italic_e italic_x italic_p = italic_e start_POSTSUPERSCRIPT italic_l italic_s italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_L italic_S italic_E start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
           Mid_O=Mid_O+new_exp@oj/Mid_L𝑀𝑖𝑑_𝑂𝑀𝑖𝑑_𝑂𝑛𝑒𝑤_𝑒𝑥𝑝@subscript𝑜𝑗𝑀𝑖𝑑_𝐿Mid\_O=Mid\_O+new\_exp@o_{j}/Mid\_Litalic_M italic_i italic_d _ italic_O = italic_M italic_i italic_d _ italic_O + italic_n italic_e italic_w _ italic_e italic_x italic_p @ italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT / italic_M italic_i italic_d _ italic_L
        end for
        FO[idx]=Mid_O𝐹𝑂delimited-[]𝑖𝑑𝑥𝑀𝑖𝑑_𝑂FO[idx]=Mid\_Oitalic_F italic_O [ italic_i italic_d italic_x ] = italic_M italic_i italic_d _ italic_O
     end if
  end for
  Return FO𝐹𝑂FOitalic_F italic_O

A.8 DeFT-Flatten Algorithm

The algorithm (noted as DeFT-Node) in Appendix A.7 adopts a node-granularity split strategy, which is quite simple. However, when the token lengths of different nodes in a decoding tree are very unbalanced, it might introduce inefficient calculation due to the unbalanced workload in on-chip SMs of GPUs.

Therefore, we can split the decoding tree in a more balanced way– in subtree-granularity. We show the DeFT-Flatten algorithm as follows, which also consists of two stages similar to DeFT-Node.

Algorithm 3 DeFT-Flatten Algorithm-Phase 1: QKV Preparation.
  Input: query QR(bq,d)𝑄superscript𝑅subscript𝑏𝑞𝑑Q\in R^{(b_{q},d)}italic_Q ∈ italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT, Key cache list KL=(K0,KN1)𝐾𝐿subscript𝐾0subscript𝐾𝑁1KL=(K_{0},...K_{N-1})italic_K italic_L = ( italic_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … italic_K start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ), Value cache list VL=(V0,VN1)𝑉𝐿subscript𝑉0subscript𝑉𝑁1VL=(V_{0},...V_{N-1})italic_V italic_L = ( italic_V start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … italic_V start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ) for each sequence node in the tree, where N𝑁Nitalic_N is the total number of sequences in a tree, and Tree T𝑇Titalic_T with its topology information. Subtree size Stsubscript𝑆𝑡S_{t}italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, which means each subtree after tiling contains at most Stsubscript𝑆𝑡S_{t}italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT tokens.
  /*Evenly slice/blockwise the Tree KV cache (with nTsubscript𝑛𝑇n_{T}italic_n start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT tokens in the tree ) to subtrees.*/
  SubInfo, KSub, VSub =Slice( KL, VL, Stsubscript𝑆𝑡S_{t}italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, T𝑇Titalic_T)
  /*Notes: (1) subtree number m=Ceil(nT/St)𝑚𝐶𝑒𝑖𝑙subscript𝑛𝑇subscript𝑆𝑡m=Ceil(n_{T}/S_{t})italic_m = italic_C italic_e italic_i italic_l ( italic_n start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT / italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT );
  (2) subtrees’ KV cache KSub=(Kb0,,Kbm1)𝐾𝑆𝑢𝑏𝐾subscript𝑏0𝐾subscript𝑏𝑚1KSub=(Kb_{0},...,Kb_{m-1})italic_K italic_S italic_u italic_b = ( italic_K italic_b start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … , italic_K italic_b start_POSTSUBSCRIPT italic_m - 1 end_POSTSUBSCRIPT ), VSub=(Vb0,,Vbm1)𝑉𝑆𝑢𝑏𝑉subscript𝑏0𝑉subscript𝑏𝑚1VSub=(Vb_{0},...,Vb_{m-1})italic_V italic_S italic_u italic_b = ( italic_V italic_b start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … , italic_V italic_b start_POSTSUBSCRIPT italic_m - 1 end_POSTSUBSCRIPT );
  (3) subtree information SubInfo=(Sb0,,Sbm1)𝑆𝑢𝑏𝐼𝑛𝑓𝑜𝑆subscript𝑏0𝑆subscript𝑏𝑚1SubInfo=(Sb_{0},...,Sb_{m-1})italic_S italic_u italic_b italic_I italic_n italic_f italic_o = ( italic_S italic_b start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … , italic_S italic_b start_POSTSUBSCRIPT italic_m - 1 end_POSTSUBSCRIPT ), where each subtree i has Sbi=(ofs0,ofsnbi1)𝑆subscript𝑏𝑖𝑜𝑓subscript𝑠0𝑜𝑓subscript𝑠subscript𝑛subscript𝑏𝑖1Sb_{i}=(ofs_{0},...ofs_{n_{b_{i}}-1})italic_S italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ( italic_o italic_f italic_s start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … italic_o italic_f italic_s start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT - 1 end_POSTSUBSCRIPT ) to record the offset of each node in the subtree KV cache, with nbisubscript𝑛subscript𝑏𝑖n_{b_{i}}italic_n start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT as the total number of nodes in subtree i𝑖iitalic_i. */
  for each subtree’s KV cache Kbi,Vbi𝐾subscript𝑏𝑖𝑉subscript𝑏𝑖Kb_{i},Vb_{i}italic_K italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_V italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in KSub,VSub𝐾𝑆𝑢𝑏𝑉𝑆𝑢𝑏KSub,VSubitalic_K italic_S italic_u italic_b , italic_V italic_S italic_u italic_b with its subtree ID i𝑖iitalic_i do
     /*Group each subtree’s KV with all queries that share it.*/
     Qisubscript𝑄𝑖Q_{i}italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT= GroupQueryToKV(Q,Kbi,Vbi,T𝑄𝐾subscript𝑏𝑖𝑉subscript𝑏𝑖𝑇Q,Kb_{i},Vb_{i},Titalic_Q , italic_K italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_V italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_T) Rbi,dQabsentsuperscript𝑅subscript𝑏𝑖𝑑𝑄\in R^{b_{i},d}\subset Q∈ italic_R start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d end_POSTSUPERSCRIPT ⊂ italic_Q
     KVMapQ[i]=Qi𝐾𝑉𝑀𝑎𝑝𝑄delimited-[]𝑖subscript𝑄𝑖KVMapQ[i]=Q_{i}italic_K italic_V italic_M italic_a italic_p italic_Q [ italic_i ] = italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
     for each query q𝑞qitalic_q in Qisubscript𝑄𝑖Q_{i}italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT with a global index idx𝑖𝑑𝑥idxitalic_i italic_d italic_x in Q𝑄Qitalic_Q do
        QMapKV[idx].append(i)formulae-sequence𝑄𝑀𝑎𝑝𝐾𝑉delimited-[]𝑖𝑑𝑥𝑎𝑝𝑝𝑒𝑛𝑑𝑖QMapKV[idx].append(i)italic_Q italic_M italic_a italic_p italic_K italic_V [ italic_i italic_d italic_x ] . italic_a italic_p italic_p italic_e italic_n italic_d ( italic_i )
     end for
     /*Add a causal mask as different nodes in a subtree could be shared by different queries.*/
     CausalMask[i]=GetBitMask(Qi,Kbi,Vbi,T)𝐶𝑎𝑢𝑠𝑎𝑙𝑀𝑎𝑠𝑘delimited-[]𝑖𝐺𝑒𝑡𝐵𝑖𝑡𝑀𝑎𝑠𝑘subscript𝑄𝑖𝐾subscript𝑏𝑖𝑉subscript𝑏𝑖𝑇CausalMask[i]=GetBitMask(Q_{i},Kb_{i},Vb_{i},T)italic_C italic_a italic_u italic_s italic_a italic_l italic_M italic_a italic_s italic_k [ italic_i ] = italic_G italic_e italic_t italic_B italic_i italic_t italic_M italic_a italic_s italic_k ( italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_K italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_V italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_T )=(CM0,CMnbi1)𝐶subscript𝑀0𝐶subscript𝑀subscript𝑛subscript𝑏𝑖1(CM_{0},...CM_{n_{b_{i}}-1})( italic_C italic_M start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , … italic_C italic_M start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT - 1 end_POSTSUBSCRIPT )
     where nbisubscript𝑛subscript𝑏𝑖n_{b_{i}}italic_n start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT is the total number of nodes in the subtree, and CMi𝐶subscript𝑀𝑖CM_{i}italic_C italic_M start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is a 64-bit int bit mask for node i.
     /*E.g, 100.00100.00100....00100 … .00 with 1 in bit 0, means the Qi[0]subscript𝑄𝑖delimited-[]0Q_{i}[0]italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ 0 ] does not share KV cache of node i in the subtree.*/
  end for
  Return QMapKV, KVMapQ, CausalMask,SubInfo
Algorithm 4 DeFT-Flatten Algorithm-Phase 2: Attention Calculation.
  Input: query QR(bq,d)𝑄superscript𝑅subscript𝑏𝑞𝑑Q\in R^{(b_{q},d)}italic_Q ∈ italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT, Key cache list in subtree-granularity KSub=(Kb0𝐾subscript𝑏0Kb_{0}italic_K italic_b start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT,…,Kbm1𝐾subscript𝑏𝑚1Kb_{m-1}italic_K italic_b start_POSTSUBSCRIPT italic_m - 1 end_POSTSUBSCRIPT), Value cache list in subtree VSub = (Vb0𝑉subscript𝑏0Vb_{0}italic_V italic_b start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT,…,Vbm1𝑉subscript𝑏𝑚1Vb_{m-1}italic_V italic_b start_POSTSUBSCRIPT italic_m - 1 end_POSTSUBSCRIPT for m𝑚mitalic_m subtrees after tiling based on Tree T𝑇Titalic_T with its topology information. QKV group information QMapKV𝑄𝑀𝑎𝑝𝐾𝑉QMapKVitalic_Q italic_M italic_a italic_p italic_K italic_V, KVMapQ𝐾𝑉𝑀𝑎𝑝𝑄KVMapQitalic_K italic_V italic_M italic_a italic_p italic_Q, causal mask CausalMask𝐶𝑎𝑢𝑠𝑎𝑙𝑀𝑎𝑠𝑘CausalMaskitalic_C italic_a italic_u italic_s italic_a italic_l italic_M italic_a italic_s italic_k and subtree information SubInfo𝑆𝑢𝑏𝐼𝑛𝑓𝑜SubInfoitalic_S italic_u italic_b italic_I italic_n italic_f italic_o from QKV Preparation Phase.
  for each q𝑞qitalic_q in Q𝑄Qitalic_Q with its global index idx𝑖𝑑𝑥idxitalic_i italic_d italic_x do
     /*Allocate to store LogSumExp of Q@KT𝑄@superscript𝐾𝑇Q@K^{T}italic_Q @ italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT grouped by query.*/
     LogSumExp[idx]={}𝐿𝑜𝑔𝑆𝑢𝑚𝐸𝑥𝑝delimited-[]𝑖𝑑𝑥LogSumExp[idx]=\{\}italic_L italic_o italic_g italic_S italic_u italic_m italic_E italic_x italic_p [ italic_i italic_d italic_x ] = { }
     /*Allocate to store partial results of SoftMax(Q@KT)V𝑆𝑜𝑓𝑡𝑀𝑎𝑥𝑄@superscript𝐾𝑇𝑉SoftMax(Q@K^{T})Vitalic_S italic_o italic_f italic_t italic_M italic_a italic_x ( italic_Q @ italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) italic_V for each query.*/
     O[idx]={}𝑂delimited-[]𝑖𝑑𝑥O[idx]=\{\}italic_O [ italic_i italic_d italic_x ] = { }
  end for
  /*Allocate space for output after reduction.*/
  FO=(0)bq×dR(bq,d)𝐹𝑂subscript0subscript𝑏𝑞𝑑superscript𝑅subscript𝑏𝑞𝑑FO=(0)_{b_{q}\times d}\in R^{(b_{q},d)}italic_F italic_O = ( 0 ) start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT × italic_d end_POSTSUBSCRIPT ∈ italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT
  for each subtree’s KV cache Kbi,VbiR(bkv,d),R(bkv,d)formulae-sequence𝐾subscript𝑏𝑖𝑉subscript𝑏𝑖superscript𝑅subscript𝑏𝑘𝑣𝑑superscript𝑅subscript𝑏𝑘𝑣𝑑Kb_{i},Vb_{i}\in R^{(b_{kv},d)},R^{(b_{kv},d)}italic_K italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_V italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT , italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT in KSub,VSub𝐾𝑆𝑢𝑏𝑉𝑆𝑢𝑏KSub,VSubitalic_K italic_S italic_u italic_b , italic_V italic_S italic_u italic_b with subtree ID i𝑖iitalic_i do
     # Unroll for loop to SMs
     Qisubscript𝑄𝑖Q_{i}italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT= KVMapQ[i]R(bi,d)𝐾𝑉𝑀𝑎𝑝𝑄delimited-[]𝑖superscript𝑅subscript𝑏𝑖𝑑KVMapQ[i]\in R^{(b_{i},d)}italic_K italic_V italic_M italic_a italic_p italic_Q [ italic_i ] ∈ italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT
     /*Reconstruct mask for attention calculation based on CausalMask𝐶𝑎𝑢𝑠𝑎𝑙𝑀𝑎𝑠𝑘CausalMaskitalic_C italic_a italic_u italic_s italic_a italic_l italic_M italic_a italic_s italic_k and SubInfo𝑆𝑢𝑏𝐼𝑛𝑓𝑜SubInfoitalic_S italic_u italic_b italic_I italic_n italic_f italic_o*/
     bitmask=CausalMask[i]Rnbi𝑏𝑖𝑡𝑚𝑎𝑠𝑘𝐶𝑎𝑢𝑠𝑎𝑙𝑀𝑎𝑠𝑘delimited-[]𝑖superscript𝑅subscript𝑛subscript𝑏𝑖bitmask=CausalMask[i]\in R^{n_{b_{i}}}italic_b italic_i italic_t italic_m italic_a italic_s italic_k = italic_C italic_a italic_u italic_s italic_a italic_l italic_M italic_a italic_s italic_k [ italic_i ] ∈ italic_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT,where nbisubscript𝑛subscript𝑏𝑖n_{b_{i}}italic_n start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT is the total number of nodes for subtree i.
     SubOfst=SubInfo[i]Rnbi𝑆𝑢𝑏𝑂𝑓𝑠𝑡𝑆𝑢𝑏𝐼𝑛𝑓𝑜delimited-[]𝑖superscript𝑅subscript𝑛subscript𝑏𝑖SubOfst=SubInfo[i]\in R^{n_{b_{i}}}italic_S italic_u italic_b italic_O italic_f italic_s italic_t = italic_S italic_u italic_b italic_I italic_n italic_f italic_o [ italic_i ] ∈ italic_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
     mask=ReconstructMask(bitmask,SubOfst)R(bi,bkv)𝑚𝑎𝑠𝑘𝑅𝑒𝑐𝑜𝑛𝑠𝑡𝑟𝑢𝑐𝑡𝑀𝑎𝑠𝑘𝑏𝑖𝑡𝑚𝑎𝑠𝑘𝑆𝑢𝑏𝑂𝑓𝑠𝑡superscript𝑅subscript𝑏𝑖subscript𝑏𝑘𝑣mask=ReconstructMask(bitmask,SubOfst)\in R^{(b_{i},b_{kv})}italic_m italic_a italic_s italic_k = italic_R italic_e italic_c italic_o italic_n italic_s italic_t italic_r italic_u italic_c italic_t italic_M italic_a italic_s italic_k ( italic_b italic_i italic_t italic_m italic_a italic_s italic_k , italic_S italic_u italic_b italic_O italic_f italic_s italic_t ) ∈ italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_k italic_v end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT
     /*Get partial attention oisubscript𝑜𝑖o_{i}italic_o start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for each QKV group, LogSumExp lsei𝑙𝑠subscript𝑒𝑖lse_{i}italic_l italic_s italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT of Q@KT𝑄@superscript𝐾𝑇Q@K^{T}italic_Q @ italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT in row for reduction.*/
     oi,lseisubscript𝑜𝑖𝑙𝑠subscript𝑒𝑖o_{i},lse_{i}italic_o start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_l italic_s italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = FlashAttention(Qi,Kbi,Vbi,masksubscript𝑄𝑖𝐾subscript𝑏𝑖𝑉subscript𝑏𝑖𝑚𝑎𝑠𝑘Q_{i},Kb_{i},Vb_{i},maskitalic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_K italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_V italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_m italic_a italic_s italic_k)
     R(bi,d),Rbiabsentsuperscript𝑅subscript𝑏𝑖𝑑superscript𝑅subscript𝑏𝑖\in R^{(b_{i},d)},R^{b_{i}}∈ italic_R start_POSTSUPERSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d ) end_POSTSUPERSCRIPT , italic_R start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
     /*Map the partial results back to each query for reduction.*/
     for each query q𝑞qitalic_q in Qisubscript𝑄𝑖Q_{i}italic_Q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT with its group index gp_idx𝑔𝑝_𝑖𝑑𝑥gp\_idxitalic_g italic_p _ italic_i italic_d italic_x and global index idx𝑖𝑑𝑥idxitalic_i italic_d italic_x in Q𝑄Qitalic_Q do
        if iQMapKV[idx]𝑖𝑄𝑀𝑎𝑝𝐾𝑉delimited-[]𝑖𝑑𝑥i\in QMapKV[idx]italic_i ∈ italic_Q italic_M italic_a italic_p italic_K italic_V [ italic_i italic_d italic_x ] then
           LogSumExp[idx].append(lsei[gp_idx])formulae-sequence𝐿𝑜𝑔𝑆𝑢𝑚𝐸𝑥𝑝delimited-[]𝑖𝑑𝑥𝑎𝑝𝑝𝑒𝑛𝑑𝑙𝑠subscript𝑒𝑖delimited-[]𝑔𝑝_𝑖𝑑𝑥LogSumExp[idx].append(lse_{i}[gp\_idx])italic_L italic_o italic_g italic_S italic_u italic_m italic_E italic_x italic_p [ italic_i italic_d italic_x ] . italic_a italic_p italic_p italic_e italic_n italic_d ( italic_l italic_s italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT [ italic_g italic_p _ italic_i italic_d italic_x ] )
        end if
     end for
  end for
  for each q𝑞qitalic_q in Q𝑄Qitalic_Q with its global index idx𝑖𝑑𝑥idxitalic_i italic_d italic_x do
     # Unroll for loop to SMs
     if len(O[idx]𝑂delimited-[]𝑖𝑑𝑥O[idx]italic_O [ italic_i italic_d italic_x ])==len(QMapKV[idx]𝑄𝑀𝑎𝑝𝐾𝑉delimited-[]𝑖𝑑𝑥QMapKV[idx]italic_Q italic_M italic_a italic_p italic_K italic_V [ italic_i italic_d italic_x ]then
        /*Global reduction after collecting all partial results from QKV groups that contains q𝑞qitalic_q.*/
        LSEcat𝐿𝑆subscript𝐸𝑐𝑎𝑡LSE_{cat}italic_L italic_S italic_E start_POSTSUBSCRIPT italic_c italic_a italic_t end_POSTSUBSCRIPT= CatTensor(LogSumExp[idx]𝐿𝑜𝑔𝑆𝑢𝑚𝐸𝑥𝑝delimited-[]𝑖𝑑𝑥LogSumExp[idx]italic_L italic_o italic_g italic_S italic_u italic_m italic_E italic_x italic_p [ italic_i italic_d italic_x ])
        LSEmax𝐿𝑆subscript𝐸𝑚𝑎𝑥LSE_{max}italic_L italic_S italic_E start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT=RowMax(LSEcat𝐿𝑆subscript𝐸𝑐𝑎𝑡LSE_{cat}italic_L italic_S italic_E start_POSTSUBSCRIPT italic_c italic_a italic_t end_POSTSUBSCRIPT)
        Mid_L=0,Mid_O=0(1,d)formulae-sequence𝑀𝑖𝑑_𝐿0𝑀𝑖𝑑_𝑂superscript01𝑑Mid\_L=0,Mid\_O=0^{(1,d)}italic_M italic_i italic_d _ italic_L = 0 , italic_M italic_i italic_d _ italic_O = 0 start_POSTSUPERSCRIPT ( 1 , italic_d ) end_POSTSUPERSCRIPT
        for each lsej𝑙𝑠subscript𝑒𝑗lse_{j}italic_l italic_s italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT in LogSumExp[idx]𝐿𝑜𝑔𝑆𝑢𝑚𝐸𝑥𝑝delimited-[]𝑖𝑑𝑥LogSumExp[idx]italic_L italic_o italic_g italic_S italic_u italic_m italic_E italic_x italic_p [ italic_i italic_d italic_x ] do
           new_exp=elsejLSEmax𝑛𝑒𝑤_𝑒𝑥𝑝superscript𝑒𝑙𝑠subscript𝑒𝑗𝐿𝑆subscript𝐸𝑚𝑎𝑥new\_exp=e^{lse_{j}-LSE_{max}}italic_n italic_e italic_w _ italic_e italic_x italic_p = italic_e start_POSTSUPERSCRIPT italic_l italic_s italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_L italic_S italic_E start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
           Mid_L=Mid_L+new_exp𝑀𝑖𝑑_𝐿𝑀𝑖𝑑_𝐿𝑛𝑒𝑤_𝑒𝑥𝑝Mid\_L=Mid\_L+new\_expitalic_M italic_i italic_d _ italic_L = italic_M italic_i italic_d _ italic_L + italic_n italic_e italic_w _ italic_e italic_x italic_p
        end for
        for each lsej,oj𝑙𝑠subscript𝑒𝑗subscript𝑜𝑗lse_{j},o_{j}italic_l italic_s italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT in LogSumExp[idx],O[idx]𝐿𝑜𝑔𝑆𝑢𝑚𝐸𝑥𝑝delimited-[]𝑖𝑑𝑥𝑂delimited-[]𝑖𝑑𝑥LogSumExp[idx],O[idx]italic_L italic_o italic_g italic_S italic_u italic_m italic_E italic_x italic_p [ italic_i italic_d italic_x ] , italic_O [ italic_i italic_d italic_x ] do
           new_exp=elsejLSEmax𝑛𝑒𝑤_𝑒𝑥𝑝superscript𝑒𝑙𝑠subscript𝑒𝑗𝐿𝑆subscript𝐸𝑚𝑎𝑥new\_exp=e^{lse_{j}-LSE_{max}}italic_n italic_e italic_w _ italic_e italic_x italic_p = italic_e start_POSTSUPERSCRIPT italic_l italic_s italic_e start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_L italic_S italic_E start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
           Mid_O=Mid_O+new_exp@oj/Mid_L𝑀𝑖𝑑_𝑂𝑀𝑖𝑑_𝑂𝑛𝑒𝑤_𝑒𝑥𝑝@subscript𝑜𝑗𝑀𝑖𝑑_𝐿Mid\_O=Mid\_O+new\_exp@o_{j}/Mid\_Litalic_M italic_i italic_d _ italic_O = italic_M italic_i italic_d _ italic_O + italic_n italic_e italic_w _ italic_e italic_x italic_p @ italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT / italic_M italic_i italic_d _ italic_L
        end for
        FO[idx]=Mid_O𝐹𝑂delimited-[]𝑖𝑑𝑥𝑀𝑖𝑑_𝑂FO[idx]=Mid\_Oitalic_F italic_O [ italic_i italic_d italic_x ] = italic_M italic_i italic_d _ italic_O
     end if
  end for
  Return FO𝐹𝑂FOitalic_F italic_O