Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

FuseMax: Leveraging Extended Einsums to Optimize Attention Accelerator Design

Nandeeka Nayak, Xinrui Wu∗∗, Toluwanimi O. Odemuyiwa∗∗∗, Michael Pellauer, Joel S. Emer†‡, Christopher W. Fletcher
University of California, Berkeley, ∗∗Tsinghua University, ∗∗∗University of California, Davis,
NVIDIA, Massachusetts Institute of Technology
{nandeeka, cwfletcher}@berkeley.edu, xr-wu20@mails.tsinghua.edu.cn,
todemuyiwa@ucdavis.edu, mpellauer@nvidia.com, jsemer@mit.edu
Abstract

Attention for transformers is a critical workload that has recently received significant ‘attention’ as a target for custom acceleration. Yet, while prior work succeeds in reducing attention’s memory-bandwidth requirements, it creates load imbalance between attention operators (resulting in severe compute under-utilization) and requires on-chip memory that scales with sequence length (which is expected to grow over time).

This paper ameliorates these issues, enabling attention with nearly 100% compute utilization, no off-chip memory traffic bottlenecks, and on-chip buffer size requirements that are independent of sequence length. The main conceptual contribution is to use a recently proposed abstraction—the cascade of Einsums—to describe, formalize and taxonomize the space of attention algorithms that appear in the literature. In particular, we show how Einsum cascades can be used to infer non-trivial lower bounds on the number of passes a kernel must take through its input data, which has implications for either required on-chip buffer capacity or memory traffic. We show how this notion can be used to meaningfully divide the space of attention algorithms into several categories and use these categories to inform our design process.

Based on the above characterization, we propose FuseMax—a novel mapping of attention onto a spatial array-style architecture. On attention, in an iso-area comparison, FuseMax achieves an average 6.7×6.7\times6.7 × speedup over the prior state-of-the-art FLAT [27] while using 79%percent7979\%79 % of the energy. Similarly, on the full end-to-end transformer inference, FuseMax achieves an average 5.3×5.3\times5.3 × speedup over FLAT using 83%percent8383\%83 % of the energy.

I Introduction

Over the past few years, transformers [48] have emerged as the model architecture of choice for a wide range of machine learning applications, from natural language processing [18, 29, 45, 46] to computer vision [19, 33] to speech recognition [5, 25]. This rise has been accompanied by a corresponding wave of proposals for accelerating transformers in both software [13, 15, 16] and hardware [27, 57].

Fortunately, many of the layers (projections, fully connected layers, etc.) used by transformers look very similar to prior generations of machine learning models. Its resource-intensive tensor products can be described and evaluated with existing tensor algebra accelerator modeling tools [28, 35, 40], and many of the other layers (e.g., layer normalization) have negligible impact on performance and can be safely ignored.

However, attention [48]—usually described as a matrix multiplication, a softmax, and then another matrix multiplication—does not fit into either of these boxes. For example, the softmax is both memory intensive (featuring low algorithmic reuse) and compute intensive (featuring exponentiation and division). Furthermore, attention’s characteristics preclude many “free lunches” often used to improve efficiency for other DNN models. For example, because all tensors are a function of the model inputs, there is no opportunity to amortize memory access costs with an increased batch size. Additionally, since none of the operands can be computed before the inputs are given, compression/strength reduction techniques (e.g., quantization [55, 22], sparsity [49, 34, 44], etc.) must be applied dynamically, leading to more complicated algorithms and hardware designs.

To illustrate the difficulty in accelerating attention, consider the state-of-the-art accelerator for attention: FLAT [27]. FLAT uses fusion to reduce attention memory bandwidth bottlenecks on a spatial architecture (e.g., a TPU [26]). Specifically, FLAT maps attention’s matrix multiplications to the 2D spatial array and softmax operations to a separate 1D array. While FLAT’s design does make attention compute bound, it becomes compute bottlenecked in the 1D array (the softmax), causing severe under utilization of the 2D array. While one could add additional PEs to the 1D array, this results in commensurate area costs.

Making matters worse, FLAT requires that the entire vector over which the softmax is performed be buffered on chip. This vector is proportional to the sequence length, which is growing rapidly with time (e.g., Google reports 10 million length sequences in research, which would require 100s of MegaBytes to buffer [1]). When the vector/sequence length grows beyond allowable buffer capacity, FLAT is forced to spill, which contributes significantly to attention energy consumption and can even make attention memory-bandwidth bound.

This paper. We address the above challenges by proposing a novel spatial architecture – FuseMax – to accelerate attention, with particular emphasis on removing bottlenecks imposed by the softmax. Our architecture addresses all of the aforementioned issues associated with FLAT. Namely:

  • FuseMax is compute bound, but provides almost 100% utilization of both the 2D and 1D arrays throughout the attention operation, without adding additional PEs to the 1D array.

  • FuseMax’s on-chip memory requirements are invariant to sequence length and require no extra spills to memory regardless of sequence length.

The technical core of the paper is three parts.

First, Section III demonstrates a novel analysis on kernels that uses the recently proposed cascade of Einsums abstraction [35]. In a nutshell, an Einsum defines an iteration space over tensors and what computation is done on and between tensors at each point in the iteration space. A cascade of Einsums is a sequence of dependent Einsums that can be used to describe and specify a larger kernel.

While prior work [35, 38] provides a precise definition for Einsums, a major contribution in our work is to show how this definition can be leveraged to inform accelerator design. Specifically, we recognize that the cascade makes explicit precisely what dependencies there are between Einsums. We show how this can be used to make non-trivial deductions about a kernel’s allowed fusion granularity and algorithmic minimum per-tensor live footprint. The relationship between the live footprint and the buffer capacity, in turn, has implications for the required data movement.

In more detail, this analysis provides insight into the number of passes an algorithm performs, i.e., the number of times a given element of an input must be revisited after visiting every other element of the input. Normally, one strives to choose a dataflow that exploits maximal reuse in a given element (or tile of elements) to avoid having frequently reload it. However, some algorithms preclude this strategy. In this work, we describe how to count the number of passes a cascade requires and present two methods for reducing the number of passes. In general, fewer passes is preferable; although, interestingly, we find that decreasing the number of passes can increase the required compute. Given that an Einsum cascade is mapping/scheduling agnostic, this analysis provides insight given any possible scheduling of the cascade onto hardware.

Next, Section IV applies the cascade of Einsums abstraction to describe and formalize the attention kernel. Using the notion of passes introduced in Section III, we taxonomize the space of numerically stable attention proposals that appear in the literature. For example, in a naïve implementation of attention, one must traverse the entire softmax input to build the softmax denominator and only after that can one revisit and scale each input (softmax numerator) by the denominator. We show how transforming the attention cascade reduces the number of passes required. Because this analysis is performed on the cascade of Einsums, our lower bounds on passes hold for all mapping choices, including application of fusion. For example, despite using fusion, FLAT employs a 3-pass cascade and its reliance on large on-chip buffering is a symptom of trying to avoid three passes-worth of DRAM traffic.

Additionally, we find that expressing attention as a cascade of Einsums reveals that optimizations that were previously conflated can actually be applied separately. We specifically call out one that is used by 1-pass algorithms to eliminate the need for a second pass after the final softmax denominator has been calculated. We recognize that this optimization has the added benefit of decreasing the required divisions, which is not only useful for but can be applied to 2- and 3-pass cascades as well.

Finally, in the last part of the techical core (Section V), we use the insights from Section IV as a starting point to develop a novel mapping for attention that can be lowered to a spatial architecture. We call our architecture FuseMax. FuseMax adopts the 1-pass attention cascade used in FlashAttention-2 [15]. However, despite using the cascade from FlashAttention-2, mapping this cascade to a spatial architecture is non-trivial. In particular, FlashAttention-2 maps the cascade onto a GPU, an architecture that features homogeneous PEs, each with relatively large per-PE storage, and expensive inter-PE communication. Spatial architectures feature opposite characteristics: heterogeneous PEs, each with smaller per-PE storage, and cheap (but restricted) inter-PE communication. Specifically, the networks that connect the PEs within the 2D array allow efficient communication primarily between neighbors. We overcome these differences and demonstrate a novel mapping for the 1-pass cascade that achieves high utilization for entire transformer layers. Our architecture requires only minimal changes to a standard spatial architecture and is performance/energy robust to long sequence lengths (e.g., 1M tokens and beyond).

To summarize, we make the following contributions:

  • We show how cascades of Einsums can be used to inform accelerator design, both in terms of reasoning about compute requirements and per-tensor live footprints. We formalize lower bounds on the number of passes a cascade imposes given any possible mapping of the cascade onto hardware.

  • We use cascades of Einsums, and the observation about pass lower bounds, to provide a taxonomy and precise specification of numerically stable attention algorithms in the literature. Orthogonally, we show how previously-entangled attention optimizations can be applied across attention algorithms.

  • We propose a novel mapping (dataflow) for attention for a spatial architecture—which we call FuseMax—that achieves high utilization for both 2D and 1D array PEs, and has memory traffic requirements that are independent of sequence length.

  • We evaluate FuseMax on BERT [18], TrXL [14], T5 [46], and XLM [29] and demonstrate a 6.7×6.7\times6.7 × speedup on attention with 79%percent7979\%79 % of the energy and a 5.3×5.3\times5.3 × speedup on the full end-to-end inference with 83%percent8383\%83 % of the energy relative to FLAT.

II Background

In this section, we describe the concepts and terminology used in the remainder of the paper.

II-A Tensors

This paper focuses on algebraic computations on tensors, where a tensor is a multidimensional array. A tensor’s rank refers to a specific dimension of the tensor, while the tensor’s shape is the set of valid coordinates for each of the tensor’s ranks. We use the notation N𝑁Nitalic_N-tensor to denote a tensor with N𝑁Nitalic_N ranks, where a 0-tensor is a scalar, a 1-tensor is a vector, a 2-tensor is a matrix, etc.

We adopt the format-agnostic fibertree abstraction of tensors, where a tensor is represented as a tree of fibers, as detailed in prior work [53, 47, 37, 24, 52, 42, 35, 50], using the specific version described in Nayak et al. [35, Section 2.1]. In this abstraction, a fiber consists of the set of coordinates for a given rank with common coordinates for all higher-level ranks. Each coordinate is coupled with a payload. The payload may contain a reference to a fiber in the next lower rank, or to a leaf data value.

II-B Traditional Einsums

An Einsum expression defines a computation on a set of tensor operands using an iteration space that specifies the set of points where the computations are performed [35, 38]. For example, we describe matrix-matrix multiplication (GEMM) computation with the following Einsum:

Zm,n=Ak,m×Bk,nsubscript𝑍𝑚𝑛subscript𝐴𝑘𝑚subscript𝐵𝑘𝑛\displaystyle Z_{m,n}=A_{k,m}\times B_{k,n}italic_Z start_POSTSUBSCRIPT italic_m , italic_n end_POSTSUBSCRIPT = italic_A start_POSTSUBSCRIPT italic_k , italic_m end_POSTSUBSCRIPT × italic_B start_POSTSUBSCRIPT italic_k , italic_n end_POSTSUBSCRIPT (1)

where A𝐴Aitalic_A and B𝐵Bitalic_B are input 2-tensors of shape K×M𝐾𝑀K\times Mitalic_K × italic_M and K×N𝐾𝑁K\times Nitalic_K × italic_N, respectively. Z𝑍Zitalic_Z is a output 2-tensor with shape M×N𝑀𝑁M\times Nitalic_M × italic_N. Throughout this paper, the shape of a rank is also the name of that rank (e.g., rank K𝐾Kitalic_K in A𝐴Aitalic_A has a shape of K𝐾Kitalic_K).

The iteration space of this Einsum is [0,K)×[0,M)×[0,N)0𝐾0𝑀0𝑁[0,K)\times[0,M)\times[0,N)[ 0 , italic_K ) × [ 0 , italic_M ) × [ 0 , italic_N ). Execution of this Einsum must: (1) walk every (k,m,n)𝑘𝑚𝑛(k,m,n)( italic_k , italic_m , italic_n ) point in the iteration space; and, at each point (2) project into the data space of all input tensors, (3) multiply the corresponding data values, and (4) place the result at the corresponding data point in Z𝑍Zitalic_Z. If a value already exists at an (m,n)𝑚𝑛(m,n)( italic_m , italic_n ) point in Z𝑍Zitalic_Z (due to computation at a previous (k,m,n)𝑘𝑚𝑛(k,m,n)( italic_k , italic_m , italic_n ) point), reduce the two values together using addition. Note that the Einsum specifies what to compute; it does not indicate the order in which one walks the iteration space. These aspects are left to mapping [10, 40, 35].

II-C Extended Einsums

Traditional Einsums sufficiently express standard traditional algebra, including those supported in Basic Linear Algebra Subprograms (BLAS) [30, 20] and tensor network contractions [2]. However, they cannot handle more complex computations. The recently proposed Extended General Einsums notation (EDGE) [38], extends Einsums to handle graph algorithm computations. We find this abstraction useful for also expressing complex tensor algebra computations and use its notation throughout the paper. We now briefly summarize the portions of EDGE that we leverage.

II-C1   User-Defined Computations

EDGE separates computations into three “actions”: map (\bigwedge), reduce (\bigvee), and populate (===[38, Section 5]. Map specifies the pair-wise computation between the shared ranks of two tensors, reduce specifies the computation for the reduction step of an Einsum, and default populate (===) places a computed value from the right-hand side (RHS) of the Einsum to its location on the left-hand side (LHS).

Each map and reduce action contains two operations: merge and compute. Compute defines the operation to apply between two data values, and can be any user-defined function. Merge specifies which regions of the iteration space to touch; execution will not need to access the data space corresponding to culled points. Together, merge and compute precisely define the computations in an Einsum. Common merge operations include intersection (\cap), which touches points with non-zero values in both operands; and union (\cup), which touches points where at least one of the operands is non-zero.

The full EDGE specification for GEMM is then:

Zm,n=Ak,mBk,n::k×()k+(),\displaystyle Z_{m,n}=A_{k,m}\cdot B_{k,n}::\bigwedge_{k}\times(\cap)\bigvee_{% k}+(\cup),italic_Z start_POSTSUBSCRIPT italic_m , italic_n end_POSTSUBSCRIPT = italic_A start_POSTSUBSCRIPT italic_k , italic_m end_POSTSUBSCRIPT ⋅ italic_B start_POSTSUBSCRIPT italic_k , italic_n end_POSTSUBSCRIPT : : ⋀ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × ( ∩ ) ⋁ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + ( ∪ ) , (2)

where ksubscript𝑘\bigwedge_{k}⋀ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT specifies a map action between A𝐴Aitalic_A and B𝐵Bitalic_B on the k𝑘kitalic_k rank and the intersection merge operator (\cap) culls k𝑘kitalic_k points where at least one operand is zero. The compute operator (×\times×) multiplies the data values of coordinates surviving intersection. The reduce action (ksubscript𝑘\bigvee_{k}⋁ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT) on the k𝑘kitalic_k rank gathers all non-empty points in the k𝑘kitalic_k rank and reduces them using addition (+++).

In this work, we use three user-defined computations:

  1. 1.

    Maximum (max()\max(\cup)roman_max ( ∪ )) takes the maximum of two values. Suppose we have the following expression: Zm=AmBm::mmax()Z_{m}=A_{m}\cdot B_{m}::\bigwedge_{m}\max(\cup)italic_Z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = italic_A start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⋅ italic_B start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT : : ⋀ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT roman_max ( ∪ ). The union merge operator (\cup) filters out any m𝑚mitalic_m coordinates where both operands contain 00 (and places 0 in the output). The max\maxroman_max compute operator then returns the maximum of the two operands.

  2. 2.

    Divide (÷()absent\div(\leftarrow)÷ ( ← )) divides two data values. Given the following expression, Zm=AmBm::m÷()Z_{m}=A_{m}\cdot B_{m}::\bigwedge_{m}\div(\leftarrow)italic_Z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = italic_A start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⋅ italic_B start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT : : ⋀ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ÷ ( ← ), the merge operator (\leftarrow) only touches m𝑚mitalic_m points where there is a non-zero value in the B𝐵Bitalic_B operand (see [38, Appendix]), and the compute operator divides the data value in A𝐴Aitalic_A with the data value in B𝐵Bitalic_B.

  3. 3.

    Exponentiation: we follow the example in EDGE [38, Section 7.4]. The expression Zm=eAmsubscript𝑍𝑚superscript𝑒subscript𝐴𝑚Z_{m}=e^{A_{m}}italic_Z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = italic_e start_POSTSUPERSCRIPT italic_A start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, where e𝑒eitalic_e is Euler’s number, applies the exponential function to every element in A𝐴Aitalic_A. The exponent can also be an Einsum expression: Zm=eAmBm::m×()Z_{m}=e^{A_{m}\cdot B_{m}}::\bigwedge_{m}\times(\cap)italic_Z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = italic_e start_POSTSUPERSCRIPT italic_A start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⋅ italic_B start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUPERSCRIPT : : ⋀ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT × ( ∩ ).

In addition to map and reduce, EDGE enables the expression of user-defined unary operations on tensors. For example, we can express the application of the non-linear, sigmoid function (σ𝜎\sigmaitalic_σ) on each element of a tensor A𝐴Aitalic_A as Zm=σ(Am)subscript𝑍𝑚𝜎subscript𝐴𝑚Z_{m}=\sigma(A_{m})italic_Z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = italic_σ ( italic_A start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ).

II-C2   Shorthand Notation

Throughout this paper, we take advantage of EDGE’s shorthand notation [38, Section 6] in the following ways:

  • We drop all reduce actions that consist of add and union in the compute and merge operator, respectively (+()\bigvee+(\cup)⋁ + ( ∪ )). Thus, Zm=Ak,m::k+()Z_{m}=A_{k,m}::\bigvee_{k}+(\cup)italic_Z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = italic_A start_POSTSUBSCRIPT italic_k , italic_m end_POSTSUBSCRIPT : : ⋁ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + ( ∪ ) becomes Zm=Ak,msubscript𝑍𝑚subscript𝐴𝑘𝑚Z_{m}=A_{k,m}italic_Z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = italic_A start_POSTSUBSCRIPT italic_k , italic_m end_POSTSUBSCRIPT.

  • We express all map actions using infix notation; that is, Ak,mBk,n::k×()A_{k,m}\cdot B_{k,n}::\bigwedge_{k}\times(\cap)italic_A start_POSTSUBSCRIPT italic_k , italic_m end_POSTSUBSCRIPT ⋅ italic_B start_POSTSUBSCRIPT italic_k , italic_n end_POSTSUBSCRIPT : : ⋀ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × ( ∩ ) becomes Ak,m×Bk,nsubscript𝐴𝑘𝑚subscript𝐵𝑘𝑛A_{k,m}\times B_{k,n}italic_A start_POSTSUBSCRIPT italic_k , italic_m end_POSTSUBSCRIPT × italic_B start_POSTSUBSCRIPT italic_k , italic_n end_POSTSUBSCRIPT.

  • When max\maxroman_max is part of a map action (AmBm::mmax()A_{m}\cdot B_{m}::\bigwedge_{m}\max(\cup)italic_A start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⋅ italic_B start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT : : ⋀ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT roman_max ( ∪ )), we replace it with the following shorthand: max(Am,Bm)subscript𝐴𝑚subscript𝐵𝑚\max(A_{m},B_{m})roman_max ( italic_A start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_B start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT )

  • When ÷\div÷ is part of a map action (AmBm::m÷()A_{m}\cdot B_{m}::\bigwedge_{m}\div(\leftarrow)italic_A start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ⋅ italic_B start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT : : ⋀ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ÷ ( ← )), we replace it with the following: Am/Bmsubscript𝐴𝑚subscript𝐵𝑚A_{m}/B_{m}italic_A start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT / italic_B start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT

II-C3   Filtering Rank Expressions

EDGE also enables expressing Einsums that touch only a subset of the data space of their constituent tensors. For example, we may express prefix-sum of a tensor Aksubscript𝐴𝑘A_{k}italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT with the following Einsum:

Si+1=Ak:kisubscript𝑆𝑖1subscript𝐴:𝑘𝑘𝑖\displaystyle S_{i+1}=A_{k:k\leq i}italic_S start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT = italic_A start_POSTSUBSCRIPT italic_k : italic_k ≤ italic_i end_POSTSUBSCRIPT

For each coordinate i𝑖iitalic_i, Si+1subscript𝑆𝑖1S_{i+1}italic_S start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT is built by reducing together the subset of A𝐴Aitalic_A whose coordinates are iabsent𝑖\leq i≤ italic_i. Note that this definition of prefix-sum computes the entire sum for a given i𝑖iitalic_i without iteratively reusing the previous sum.

II-C4   Expressing Iterative Computations

EDGE expresses recursion and iteration through generative/iterative ranks. We use the term standard ranks to differentiate non-iterative ranks from iterative ranks. We can express the iterative prefix-sum as follows:

Si+1subscript𝑆𝑖1\displaystyle S_{i+1}italic_S start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT =Si+Aiabsentsubscript𝑆𝑖subscript𝐴𝑖\displaystyle=S_{i}+A_{i}= italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (3)
:iK\displaystyle\diamond:i\equiv K⋄ : italic_i ≡ italic_K (4)

Here, S𝑆Sitalic_S is the iterative tensor that changes on each iteration, with the iterative rank, i𝑖iitalic_i, ranging from 00 to K𝐾Kitalic_K. Equation 4 indicates the stopping condition for the iterative expression (when i𝑖iitalic_i is equal to K𝐾Kitalic_K).

II-C5   Cascades of Einsums

TeAAL [35] introduces the concept of cascades of Einsums, which expresses directed acyclic graphs (DAG) of Einsum expressions as a sequence of sub-Einsums. One can view the unrolled iterative expression in Equation 3 as a cascade:

S1subscript𝑆1\displaystyle S_{1}italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT =S0+A0absentsubscript𝑆0subscript𝐴0\displaystyle=S_{0}+A_{0}= italic_S start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_A start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
S2subscript𝑆2\displaystyle S_{2}italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT =S1+A1absentsubscript𝑆1subscript𝐴1\displaystyle=S_{1}+A_{1}= italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT
\displaystyle...
SKsubscript𝑆𝐾\displaystyle S_{K}italic_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT =SK1+AKabsentsubscript𝑆𝐾1subscript𝐴𝐾\displaystyle=S_{K-1}+A_{K}= italic_S start_POSTSUBSCRIPT italic_K - 1 end_POSTSUBSCRIPT + italic_A start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT

Finally, we use the EDGE Initialization label to specify computations that initialize tensors, which occur once. We use the EDGE Extended Einsum(s) label to specify the computation that occurs on each iteration of a cascade of Einsums [38] (see Einsum Cascade 5).

II-D Mapping

An Einsum specifies the computation, while a mapping indicates how computation occurs in space and time on an accelerator [10, 40]. Mapping specifications include aspects such as loop order, partitioning, and work scheduling (sequential vs. parallel operations) [35]. Throughout this paper, some mapping choices like partitioning are expressed directly in the cascade of Einsums (e.g., ranks M1,M0𝑀1𝑀0M1,M0italic_M 1 , italic_M 0 result from partitioning the M𝑀Mitalic_M rank in Einsum Cascade 5).

To understand how mapping interacts with iterative ranks and Einsum cascades, we introduce the concept of an iteration space fibertree, or is-fibertree. The is-fibertree is a special tree where each fiber belongs to a rank in the iteration space of the Einsum.

II-E Tensor Algebra Accelerators

In recent years, the popularity of domain-specific tensor algebra accelerators has increased. A typical accelerator based on a spatial architecture consists of off-chip main memory, an on-chip shared global buffer, various scratchpads, and a 1D and/or 2D processing engine (PE) array where each PE contains compute units [57, 27, 26, 37, 10]. This design minimizes memory transfer latency while maximizing compute utilization [12, 8, 10, 26, 9]. Various tools enable the quick modeling and design space exploration of tensor algebra accelerators, including Timeloop [40] and Accelergy [51], GAMMA [56], and DOSA [23].

III Passes Performed by a Cascade of Einsums

Our first contribution is to demonstrate a novel analysis that can be applied using a cascade of Einsums. The key insight is that cascades of Einsums provide a precise description of the iteration space for each Einsum and the data space for each constituent tensor, enabling us to derive the algorithmic minimum live footprint for each tensor, with implications for the allowed fusion schedules and required buffer capacity/memory traffic. Because this analysis relies only on the cascade of Einsums, it holds for any choice of mapping.

III-A Calculating the Number of Passes

We will apply our analysis to attention in Section IV. To illustrate ideas, we first start with a simple pedagogical example, shown in Cascade 1.

Einsum Cascade 1: An example 2-pass cascade.
{mdframed}
Y𝑌\displaystyle Yitalic_Y =Ak×Bkabsentsubscript𝐴𝑘subscript𝐵𝑘\displaystyle=A_{k}\times B_{k}= italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT (5)
Z𝑍\displaystyle Zitalic_Z =Y×Akabsent𝑌subscript𝐴𝑘\displaystyle=Y\times A_{k}= italic_Y × italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT (6)

Equation 5 performs a dot product between Aksubscript𝐴𝑘A_{k}italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and Bksubscript𝐵𝑘B_{k}italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, and Equation 6 multiplies the first equation’s result Y𝑌Yitalic_Y by Aksubscript𝐴𝑘A_{k}italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT again to produce Z𝑍Zitalic_Z. If we want to minimize data traffic of Aksubscript𝐴𝑘A_{k}italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, we need to choose a dataflow for each Einsum that keeps Aksubscript𝐴𝑘A_{k}italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT stationary and fuses the two Einsums together. In other words, the dataflow must finish using the first element of Aksubscript𝐴𝑘A_{k}italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT before moving onto the next. However, such a dataflow does not exist for this cascade. Any implementation must visit every element of Aksubscript𝐴𝑘A_{k}italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT to compute Y𝑌Yitalic_Y before it can revisit any element of Aksubscript𝐴𝑘A_{k}italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT to compute Z𝑍Zitalic_Z.

We define a pass that a cascade performs over a particular fiber of a particular rank and tensor to be a traversal of every element of that fiber. Each time an element must be revisited after visiting every other element of that fiber, there is an additional pass. For example, Cascade 1 performs two passes over the K𝐾Kitalic_K rank of Aksubscript𝐴𝑘A_{k}italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT.

Since an Einsum’s iteration space can also be represented as a fibertree (i.e., an is-fibertree – see Section II-D), we extend our definition of an iteration space for a cascade of Einsums by considering its iteration space to be the sequence of the is-fibertrees for each Einsum. Now, in a scenario where fibers for a particular rank exist in multiple is-fibertrees; in each, they project to the same tensor; and there is a dependency such that all of the elements of the earlier is-fibertree’s fiber must be read before any element can be read again by the later is-fibertree (for all mappings of the cascade), we refer to that read-read sequence as creating an additional pass. When there is a sequence of N𝑁Nitalic_N such read-read dependencies, we say the cascade is an (N+1)𝑁1(N+1)( italic_N + 1 )-pass cascade. For our example, Cascade 1 requires two passes of the K𝐾Kitalic_K rank.

III-B Implications of the Number of Passes

The number of passes a cascade performs is relevant because it restricts possible fusion schedules. Einsums within a pass can be fused at will, producing and consuming a tile of the intermediate at a time. Einsums in different passes cannot be fused. Revisiting Cascade 1, Equations 5 and 6 cannot be fused on the K𝐾Kitalic_K rank. Any implementation must visit all elements of the K𝐾Kitalic_K fiber of A𝐴Aitalic_A to produce Y𝑌Yitalic_Y before it can visit any of the elements of that fiber to produce Z𝑍Zitalic_Z.

This analysis also provides a non-trivial lower bound on the tensors’ live footprints. For example, the algorithmic minimum live footprint for tensor A𝐴Aitalic_A is K𝐾Kitalic_K. In other words, an architecture must either have enough buffer space to hold an entire K𝐾Kitalic_K fiber of A𝐴Aitalic_A or spill and reload that fiber, incurring memory traffic proportional to the shape of K𝐾Kitalic_K. We note that this analysis is mapping independent. There is no dataflow for this cascade that enables a smaller live footprint.

III-C Reducing the Number of Passes via Reassociation

Given the restrictions that multi-pass cascades place on the allowed dataflows and tensor live footprints, it can be beneficial to manipulate the cascade to reduce the number of passes required. Crucially, these manipulations are functionally equivalent and only change how Z𝑍Zitalic_Z is computed. In this section, we will present two methods for doing so, though we leave a full analysis of the space of pass-reduction approaches to future work.

III-C1   Deferring the Multiplication by Y𝑌Yitalic_Y

First, we recognize that, by the distributive property, Equation 6 can be factored to perform the reduction of Aksubscript𝐴𝑘A_{k}italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT first, before multiplying the result by Y𝑌Yitalic_Y. Doing so, we get the following cascade:

Einsum Cascade 2: A reassociation of Cascade 1 that defers the Y×Y\timesitalic_Y × to compute Z𝑍Zitalic_Z with 1-pass of the K𝐾Kitalic_K rank.
{mdframed}
Y𝑌\displaystyle Yitalic_Y =Ak×Bkabsentsubscript𝐴𝑘subscript𝐵𝑘\displaystyle=A_{k}\times B_{k}= italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_B start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT (7)
X𝑋\displaystyle Xitalic_X =Akabsentsubscript𝐴𝑘\displaystyle=A_{k}= italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT (8)
Z𝑍\displaystyle Zitalic_Z =Y×Xabsent𝑌𝑋\displaystyle=Y\times X= italic_Y × italic_X (9)

Now, because there is no read-after-write dependency between Equations 7 and 8, both Einsums can be included in the same pass. In fact, because Equation 8 reduces away the K𝐾Kitalic_K rank, Cascade 2 is a 1-pass cascade with respect to this rank. This reassociation actually provides a second benefit over Cascade 1: Equation 9 now only requires one multiplication (as opposed to K𝐾Kitalic_K multiplications in Equation 6).

III-C2   Iteratively Constructing Y𝑌Yitalic_Y and Z𝑍Zitalic_Z

Einsum Cascade 3: A reassociation of Cascade 1 that iteratively constructs Y𝑌Yitalic_Y and Z𝑍Zitalic_Z with 1-pass of the K𝐾Kitalic_K rank
{mdframed}

Initialization:

RYi:i=0=0𝑅subscript𝑌:𝑖𝑖00\displaystyle RY_{i:i=0}=0italic_R italic_Y start_POSTSUBSCRIPT italic_i : italic_i = 0 end_POSTSUBSCRIPT = 0 (10)
RZi:i=0=0𝑅subscript𝑍:𝑖𝑖00\displaystyle RZ_{i:i=0}=0italic_R italic_Z start_POSTSUBSCRIPT italic_i : italic_i = 0 end_POSTSUBSCRIPT = 0 (11)

Extended Einsums:

RYi+1𝑅subscript𝑌𝑖1\displaystyle RY_{i+1}italic_R italic_Y start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT =RYi+Ai×Biabsent𝑅subscript𝑌𝑖subscript𝐴𝑖subscript𝐵𝑖\displaystyle=RY_{i}+A_{i}\times B_{i}= italic_R italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT × italic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (12)
RZi+1𝑅subscript𝑍𝑖1\displaystyle RZ_{i+1}italic_R italic_Z start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT =RZi×RYi+1RYi+RYi+1×Aiabsent𝑅subscript𝑍𝑖𝑅subscript𝑌𝑖1𝑅subscript𝑌𝑖𝑅subscript𝑌𝑖1subscript𝐴𝑖\displaystyle=RZ_{i}\times\frac{RY_{i+1}}{RY_{i}}+RY_{i+1}\times A_{i}= italic_R italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT × divide start_ARG italic_R italic_Y start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_R italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG + italic_R italic_Y start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT × italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (13)
Z𝑍\displaystyle Zitalic_Z =RZKabsent𝑅subscript𝑍𝐾\displaystyle=RZ_{K}= italic_R italic_Z start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT (14)
:iK\displaystyle\diamond:i\equiv K⋄ : italic_i ≡ italic_K (15)

Alternatively, we can iteratively construct Y𝑌Yitalic_Y and Z𝑍Zitalic_Z as we perform the pass through Aksubscript𝐴𝑘A_{k}italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. To do so, we will take a similar approach to the prefix-sum (see Sections II-C3-II-C4) and build intermediate Y𝑌Yitalic_Ys and Zs𝑍𝑠Zsitalic_Z italic_s.

RYi+1𝑅subscript𝑌𝑖1\displaystyle RY_{i+1}italic_R italic_Y start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT =Ak:ki×Bk:kiabsentsubscript𝐴:𝑘𝑘𝑖subscript𝐵:𝑘𝑘𝑖\displaystyle=A_{k:k\leq i}\times B_{k:k\leq i}= italic_A start_POSTSUBSCRIPT italic_k : italic_k ≤ italic_i end_POSTSUBSCRIPT × italic_B start_POSTSUBSCRIPT italic_k : italic_k ≤ italic_i end_POSTSUBSCRIPT (16)
RZi+1𝑅subscript𝑍𝑖1\displaystyle RZ_{i+1}italic_R italic_Z start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT =RYi+1×Ak:kiabsent𝑅subscript𝑌𝑖1subscript𝐴:𝑘𝑘𝑖\displaystyle=RY_{i+1}\times A_{k:k\leq i}= italic_R italic_Y start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT × italic_A start_POSTSUBSCRIPT italic_k : italic_k ≤ italic_i end_POSTSUBSCRIPT (17)

Just like with the prefix-sum, this version requires a lot of extra compute, but, because Y=RYK𝑌𝑅subscript𝑌𝐾Y=RY_{K}italic_Y = italic_R italic_Y start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT and therefore Z=RZK𝑍𝑅subscript𝑍𝐾Z=RZ_{K}italic_Z = italic_R italic_Z start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, the final result is the same.

We remove this extra work by making the I𝐼Iitalic_I ranks of RYi+1𝑅subscript𝑌𝑖1RY_{i+1}italic_R italic_Y start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT and RZi+1𝑅subscript𝑍𝑖1RZ_{i+1}italic_R italic_Z start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT iterative. This is shown in Cascade 3. Iterative RYi+1𝑅subscript𝑌𝑖1RY_{i+1}italic_R italic_Y start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT (Equation 12) looks very similar to the iterative prefix-sum. However, computing RZi+1𝑅subscript𝑍𝑖1RZ_{i+1}italic_R italic_Z start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT is a little more complicated. We start by introducing one more intermediate Sisubscript𝑆𝑖S_{i}italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, which is the prefix-sum for Aksubscript𝐴𝑘A_{k}italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT:

Si=Ak:ki1subscript𝑆𝑖subscript𝐴:𝑘𝑘𝑖1\displaystyle S_{i}=A_{k:k\leq i-1}italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_A start_POSTSUBSCRIPT italic_k : italic_k ≤ italic_i - 1 end_POSTSUBSCRIPT (18)

Now, we can combine Equations 17 and 18 to write RZi𝑅subscript𝑍𝑖RZ_{i}italic_R italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in terms of this prefix-sum:

RZi=RYi×Si𝑅subscript𝑍𝑖𝑅subscript𝑌𝑖subscript𝑆𝑖\displaystyle RZ_{i}=RY_{i}\times S_{i}italic_R italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_R italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT × italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (19)

Dividing both sides by RYi𝑅subscript𝑌𝑖RY_{i}italic_R italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, we derive an alternate definition for Sisubscript𝑆𝑖S_{i}italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT:

Si=RZiRYisubscript𝑆𝑖𝑅subscript𝑍𝑖𝑅subscript𝑌𝑖\displaystyle S_{i}=\frac{RZ_{i}}{RY_{i}}italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = divide start_ARG italic_R italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_R italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG

Si+1subscript𝑆𝑖1S_{i+1}italic_S start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT can also be written using this alternative definition:

Si+1=RZiRYi+Aisubscript𝑆𝑖1𝑅subscript𝑍𝑖𝑅subscript𝑌𝑖subscript𝐴𝑖\displaystyle S_{i+1}=\frac{RZ_{i}}{RY_{i}}+A_{i}italic_S start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT = divide start_ARG italic_R italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_R italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG + italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (20)

We can combine Equations 19 and 20 to compute RZi+1𝑅subscript𝑍𝑖1RZ_{i+1}italic_R italic_Z start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT in terms of RZi𝑅subscript𝑍𝑖RZ_{i}italic_R italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (i.e., iteratively):

RZi+1=RYi+1×(RZiRYi+Ai)𝑅subscript𝑍𝑖1𝑅subscript𝑌𝑖1𝑅subscript𝑍𝑖𝑅subscript𝑌𝑖subscript𝐴𝑖\displaystyle RZ_{i+1}=RY_{i+1}\times\left(\frac{RZ_{i}}{RY_{i}}+A_{i}\right)italic_R italic_Z start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT = italic_R italic_Y start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT × ( divide start_ARG italic_R italic_Z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_R italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG + italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )

Distributing RYi+1𝑅subscript𝑌𝑖1RY_{i+1}italic_R italic_Y start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT and performing some reassociation, we get Equation 13.

Cascade 3 is also a 1-pass cascade, performing one pass of the K𝐾Kitalic_K rank of Aksubscript𝐴𝑘A_{k}italic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT (indexed with the variable i𝑖iitalic_i) and iteratively building RYi+1𝑅subscript𝑌𝑖1RY_{i+1}italic_R italic_Y start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT and RZi+1𝑅subscript𝑍𝑖1RZ_{i+1}italic_R italic_Z start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT. Unfortunately, unlike Cascade 2, Cascade 3 does require extra compute over the original Cascade 1. However, memory bandwidth-limited workloads can afford to trade off extra compute for reduced memory traffic, and Cascade 3 may still provide benefit.

IV Taxonomizing Attention as Einsum Cascades

Our second contribution is to apply the cascade of Einsums abstraction and the notion of passes to transformer models to describe, taxonomize, and highlight trade-offs in the space of attention implementations. This section first looks at the transformer model as a whole, identifying attention as an important kernel (Section IV-A). We then give an overview of attention and a “straightforward” (but inefficient) algorithm for softmax by writing them as cascades of Einsums (Sections IV-B-IV-C). Finally, we describe how optimizations to softmax can be described by modifying the cascades and provide a taxonomy of the space using the number of passes required by each cascade (Sections IV-D-IV-E).

IV-A Transformers

Refer to caption
(a) Encoder architecture
Refer to caption
(b) Required compute
Figure 1: Overview of transformer encoder inference.

Transformer models generally follow the architecture defined in [48]. In this work, which addresses the impact of long sequence lengths during self-attention, we focus on the encoder architecture. Figure 1(a) gives an overview. The transformer first projects the input (by multiplying it by weight tensors) to form a query, key, and value. Self-attention is made up of three operations: a matrix multiplication of the query and key, a softmax on the result, and another matrix multiplication, which combines the softmax output with the value. The attention output is then deprojected (again, multiplying by a weight tensor), normalized, passed through a two-layer feed-forward neural network (FFN), and normalized once more.

As the sequence length grows, the relative importance of the different operations changes. Figure 1(b) shows that at shorter sequence lengths, the weight-times-activation “linear” layers are a larger fraction of the total required compute, while at long sequence lengths, the attention dominates. In all cases, the additional non-linearities (e.g., the normalization, the ReLU between the FFN layers, etc.) have negligible impact. In the next section, we focus on describing attention more precisely, and use our analysis to understand prior work on efficient implementations.

IV-B Redefining Attention’s “Matrix Multiplications”

In the original transformer paper [48], the kernel was described with the following equation:

Attention(Q,K,V)=softmax(QKTdk)V𝐴𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛𝑄𝐾𝑉𝑠𝑜𝑓𝑡𝑚𝑎𝑥𝑄superscript𝐾𝑇subscript𝑑𝑘𝑉\displaystyle Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})Vitalic_A italic_t italic_t italic_e italic_n italic_t italic_i italic_o italic_n ( italic_Q , italic_K , italic_V ) = italic_s italic_o italic_f italic_t italic_m italic_a italic_x ( divide start_ARG italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG end_ARG ) italic_V (21)

However, this equation says almost nothing about what the inputs Q𝑄Qitalic_Q, K𝐾Kitalic_K, and V𝑉Vitalic_V look like or what iteration space needs to be traversed. We clarify these points by rewriting Equation 21 as a cascade of Einsums, with the exception of the softmax, whose cascade we will explore in Section IV-C:

QKm,p𝑄subscript𝐾𝑚𝑝\displaystyle QK_{m,p}italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT =1E×Qe,p×Ke,mabsent1𝐸subscript𝑄𝑒𝑝subscript𝐾𝑒𝑚\displaystyle=\frac{1}{\sqrt{E}}\times Q_{e,p}\times K_{e,m}= divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_E end_ARG end_ARG × italic_Q start_POSTSUBSCRIPT italic_e , italic_p end_POSTSUBSCRIPT × italic_K start_POSTSUBSCRIPT italic_e , italic_m end_POSTSUBSCRIPT (22)
Am,psubscript𝐴𝑚𝑝\displaystyle A_{m,p}italic_A start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT =softmax(QKm,p)absent𝑠𝑜𝑓𝑡𝑚𝑎𝑥𝑄subscript𝐾𝑚𝑝\displaystyle=softmax(QK_{m,p})= italic_s italic_o italic_f italic_t italic_m italic_a italic_x ( italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT ) (23)
AVf,p𝐴subscript𝑉𝑓𝑝\displaystyle AV_{f,p}italic_A italic_V start_POSTSUBSCRIPT italic_f , italic_p end_POSTSUBSCRIPT =Am,p×Vf,mabsentsubscript𝐴𝑚𝑝subscript𝑉𝑓𝑚\displaystyle=A_{m,p}\times V_{f,m}= italic_A start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT × italic_V start_POSTSUBSCRIPT italic_f , italic_m end_POSTSUBSCRIPT (24)

Here, Equations 22111In Equation 22, we also substitute E𝐸Eitalic_E for dksubscript𝑑𝑘d_{k}italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT following the notation defined in Section II-B, where the shape of a rank is also its rank name. and 24 look like matrix multiplications. Taking Equation 24 as an example, for each point in the iteration space F×M×P𝐹𝑀𝑃F\times M\times Pitalic_F × italic_M × italic_P, we perform a multiplication using elements from two 2-tensors (Am,psubscript𝐴𝑚𝑝A_{m,p}italic_A start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT and Vf,msubscript𝑉𝑓𝑚V_{f,m}italic_V start_POSTSUBSCRIPT italic_f , italic_m end_POSTSUBSCRIPT) to produce a 2-tensor output (AVf,p𝐴subscript𝑉𝑓𝑝AV_{f,p}italic_A italic_V start_POSTSUBSCRIPT italic_f , italic_p end_POSTSUBSCRIPT), which requires reducing across the inputs’ shared rank M𝑀Mitalic_M.

Equations 22-24 can be modified to refer to the full batched, multi-head self attention [48] by adding B𝐵Bitalic_B and H𝐻Hitalic_H ranks to all tensors. This changes the characteristics of the kernel. Adding the B𝐵Bitalic_B and H𝐻Hitalic_H ranks means that Equations 22 and 24 behave like many independent matrix multiplications instead of one monolithic matrix multiplication. The challenges with attention, described in Section I, follow clearly from this modification. Because all tensors contain a B𝐵Bitalic_B rank, the matrix multiplications are all unique to the specific batch’s inputs. Therefore, none of these tensors can be computed before the inputs are given, and there is no data sharing between the different elements in the batch. To simplify notation, we assume the presence of the B𝐵Bitalic_B and H𝐻Hitalic_H ranks but omit writing them throughout the rest of paper.

IV-C Softmax as a Cascade of Einsums

We now apply the same precise notation to the softmax. A softmax [6] over a 1-tensor is traditionally expressed with the following equation:

Am=eImkeIksubscript𝐴𝑚superscript𝑒subscript𝐼𝑚subscript𝑘superscript𝑒subscript𝐼𝑘\displaystyle A_{m}=\frac{e^{I_{m}}}{\sum_{k}e^{I_{k}}}italic_A start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = divide start_ARG italic_e start_POSTSUPERSCRIPT italic_I start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT italic_I start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG (25)

In the context of attention, this operation becomes two dimensional and can be expressed using the following cascade with input QKm,p𝑄subscript𝐾𝑚𝑝QK_{m,p}italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT:

SNm,p𝑆subscript𝑁𝑚𝑝\displaystyle SN_{m,p}italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT =eQKm,pabsentsuperscript𝑒𝑄subscript𝐾𝑚𝑝\displaystyle=e^{QK_{m,p}}= italic_e start_POSTSUPERSCRIPT italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT (26)
SDp𝑆subscript𝐷𝑝\displaystyle SD_{p}italic_S italic_D start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT =SNm,pabsent𝑆subscript𝑁𝑚𝑝\displaystyle=SN_{m,p}= italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT (27)
Am,psubscript𝐴𝑚𝑝\displaystyle A_{m,p}italic_A start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT =SNm,p/SDpabsent𝑆subscript𝑁𝑚𝑝𝑆subscript𝐷𝑝\displaystyle=SN_{m,p}/SD_{p}= italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT / italic_S italic_D start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT (28)

For each point in the iteration space (m𝑚mitalic_m, p𝑝pitalic_p), we exponentiate QKm,p𝑄subscript𝐾𝑚𝑝QK_{m,p}italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT to generate the softmax numerator (SNm,p𝑆subscript𝑁𝑚𝑝SN_{m,p}italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT in Equation 26), reduce SNm,p𝑆subscript𝑁𝑚𝑝SN_{m,p}italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT with addition to produce the softmax denominator (SDp𝑆subscript𝐷𝑝SD_{p}italic_S italic_D start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT in Equation 27), and finally, divide the numerator and denominator to produce the final result (Am,psubscript𝐴𝑚𝑝A_{m,p}italic_A start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT in Equation 28).

IV-C1   Improving Numerical Stability

Because eQKm,psuperscript𝑒𝑄subscript𝐾𝑚𝑝e^{QK_{m,p}}italic_e start_POSTSUPERSCRIPT italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT can easily become extremely large, the above formulation suffers from overflow. Therefore, practical implementations [3, 41] often prefer the numerically stable variant that replaces Equation 26 with:

GMp𝐺subscript𝑀𝑝\displaystyle GM_{p}italic_G italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT =QKm,p::mmax()\displaystyle=QK_{m,p}::\bigvee_{m}\text{max}(\cup)= italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT : : ⋁ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT max ( ∪ ) (29)
SNm,p𝑆subscript𝑁𝑚𝑝\displaystyle SN_{m,p}italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT =eQKm,pGMpabsentsuperscript𝑒𝑄subscript𝐾𝑚𝑝𝐺subscript𝑀𝑝\displaystyle=e^{QK_{m,p}-GM_{p}}= italic_e start_POSTSUPERSCRIPT italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT - italic_G italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT (30)

and drop the 1E1𝐸\frac{1}{\sqrt{E}}divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_E end_ARG end_ARG term when computing QKm,p𝑄subscript𝐾𝑚𝑝QK_{m,p}italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT222The 1E1𝐸\frac{1}{\sqrt{E}}divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_E end_ARG end_ARG term was introduced to bound the magnitude of SNm,p𝑆subscript𝑁𝑚𝑝SN_{m,p}italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT [48]. Because the numerically stable softmax variant already accomplishes this, the scaling is often omitted [16, 15, 13].. To compute the global maximum333“Global” here refers to over the M𝑀Mitalic_M fiber. GMp𝐺subscript𝑀𝑝GM_{p}italic_G italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT, we reduce QKm,p𝑄subscript𝐾𝑚𝑝QK_{m,p}italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT with the operator max (instead of +++). Notice that subtracting GMp𝐺subscript𝑀𝑝GM_{p}italic_G italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT from QKm,p𝑄subscript𝐾𝑚𝑝QK_{m,p}italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT in the exponent is equivalent to dividing by eGMpsuperscript𝑒𝐺subscript𝑀𝑝e^{GM_{p}}italic_e start_POSTSUPERSCRIPT italic_G italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and because the 1eGMp1superscript𝑒𝐺subscript𝑀𝑝\frac{1}{e^{GM_{p}}}divide start_ARG 1 end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_G italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG term appears in both the numerator (SNm,p𝑆subscript𝑁𝑚𝑝SN_{m,p}italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT via Equation 30) and denominator (SDp𝑆subscript𝐷𝑝SD_{p}italic_S italic_D start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT via Equation 27), the result (Am,psubscript𝐴𝑚𝑝A_{m,p}italic_A start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT) stays the same. This construction improves numerical stability by bounding the values of the softmax numerator SNm,p𝑆subscript𝑁𝑚𝑝SN_{m,p}italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT to the range (0,1]01(0,1]( 0 , 1 ].

IV-D Optimizing Softmax Compute

We now describe an optimization to attention that reduces compute requirements, specifically division. This optimization was used in FlashAttention-2 [15]. We point out that it can be applied more broadly, i.e., to any cascade we discuss in Section IV-E. Equation 28 requires M×P𝑀𝑃M\times Pitalic_M × italic_P divisions. While this is the best we can do for an independent softmax, we note that attention does not use the softmax in isolation [48]. Instead, it subsequently multiplies the result, Am,psubscript𝐴𝑚𝑝A_{m,p}italic_A start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT, and another tensor, Vf,msubscript𝑉𝑓𝑚V_{f,m}italic_V start_POSTSUBSCRIPT italic_f , italic_m end_POSTSUBSCRIPT, per Equation 24, reproduced here:

AVf,p=Am,p×Vf,m𝐴subscript𝑉𝑓𝑝subscript𝐴𝑚𝑝subscript𝑉𝑓𝑚\displaystyle AV_{f,p}=A_{m,p}\times V_{f,m}italic_A italic_V start_POSTSUBSCRIPT italic_f , italic_p end_POSTSUBSCRIPT = italic_A start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT × italic_V start_POSTSUBSCRIPT italic_f , italic_m end_POSTSUBSCRIPT

To optimize the full attention cascade, we can refactor Equations 28 and 24 by, instead, first combining SNm,p𝑆subscript𝑁𝑚𝑝SN_{m,p}italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT and Vf,msubscript𝑉𝑓𝑚V_{f,m}italic_V start_POSTSUBSCRIPT italic_f , italic_m end_POSTSUBSCRIPT (Equation 31) and reducing across the M𝑀Mitalic_M rank and then performing the division (Equation 32), as follows:

SNVf,p𝑆𝑁subscript𝑉𝑓𝑝\displaystyle SNV_{f,p}italic_S italic_N italic_V start_POSTSUBSCRIPT italic_f , italic_p end_POSTSUBSCRIPT =SNm,p×Vf,mabsent𝑆subscript𝑁𝑚𝑝subscript𝑉𝑓𝑚\displaystyle=SN_{m,p}\times V_{f,m}= italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT × italic_V start_POSTSUBSCRIPT italic_f , italic_m end_POSTSUBSCRIPT (31)
AVf,p𝐴subscript𝑉𝑓𝑝\displaystyle AV_{f,p}italic_A italic_V start_POSTSUBSCRIPT italic_f , italic_p end_POSTSUBSCRIPT =SNVf,p/SDpabsent𝑆𝑁subscript𝑉𝑓𝑝𝑆subscript𝐷𝑝\displaystyle=SNV_{f,p}/SD_{p}= italic_S italic_N italic_V start_POSTSUBSCRIPT italic_f , italic_p end_POSTSUBSCRIPT / italic_S italic_D start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT (32)

This reassociation does F×P𝐹𝑃F\times Pitalic_F × italic_P divisions instead of M×P𝑀𝑃M\times Pitalic_M × italic_P divisions. Since M𝑀Mitalic_M is the sequence length and F𝐹Fitalic_F is an embedding dimension (i.e., MFmuch-greater-than𝑀𝐹M\gg Fitalic_M ≫ italic_F), this reassociation reduces the required divisions (by a factor of MF𝑀𝐹\frac{M}{F}divide start_ARG italic_M end_ARG start_ARG italic_F end_ARG).

IV-E Optimizing Softmax Live Footprint and Memory Traffic

3-pass 2-pass 1-pass
PyTorch [41] Tileflow [57] FlashAttention [16]
TensorFlow [3] Choi et al. [13] FlashAttention-2 [15]
FLAT [27]
E.T. [7]
TABLE I: Classifying prior attention algorithms.

We can also apply the analysis described in Section III to the efficient attention literature. We find that existing approaches to attention can be classified as either 3-pass, 2-pass, or 1-pass cascades, where an N𝑁Nitalic_N-pass cascade performs N𝑁Nitalic_N passes of a given M𝑀Mitalic_M fiber. See Table I. Next, we describe the key ideas of each.

IV-E1   3-Pass Attention Cascades

The 3-pass cascade is the straightforward, numerically stable cascade that we already discussed in Section IV-C1, namely Equations 29-30 followed by Equations 27-28, reproduced in Cascade 4 for clarity.

Einsum Cascade 4: The 3-pass attention cascade.
{mdframed}
GMp𝐺subscript𝑀𝑝\displaystyle GM_{p}italic_G italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT =QKm,p::mmax()\displaystyle=QK_{m,p}::\bigvee_{m}\text{max}(\cup)= italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT : : ⋁ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT max ( ∪ ) /* Pass 1 */ (33)
SNm,p𝑆subscript𝑁𝑚𝑝\displaystyle SN_{m,p}italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT =eQKm,pGMpabsentsuperscript𝑒𝑄subscript𝐾𝑚𝑝𝐺subscript𝑀𝑝\displaystyle=e^{QK_{m,p}-GM_{p}}= italic_e start_POSTSUPERSCRIPT italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT - italic_G italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT /* Pass 2 */ (34)
SDp𝑆subscript𝐷𝑝\displaystyle SD_{p}italic_S italic_D start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT =SNm,pabsent𝑆subscript𝑁𝑚𝑝\displaystyle=SN_{m,p}= italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT (35)
Am,psubscript𝐴𝑚𝑝\displaystyle A_{m,p}italic_A start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT =SNm,p/SDpabsent𝑆subscript𝑁𝑚𝑝𝑆subscript𝐷𝑝\displaystyle=SN_{m,p}/SD_{p}= italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT / italic_S italic_D start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT /* Pass 3 */ (36)

In Pass 1, we compute GMp𝐺subscript𝑀𝑝GM_{p}italic_G italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT; in Pass 2, we compute SNm,p𝑆subscript𝑁𝑚𝑝SN_{m,p}italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT and SDp𝑆subscript𝐷𝑝SD_{p}italic_S italic_D start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT; and in Pass 3, we compute Am,psubscript𝐴𝑚𝑝A_{m,p}italic_A start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT. Notice that we must finish an entire M𝑀Mitalic_M fiber of Equation 33 (reading an entire M𝑀Mitalic_M fiber of QKm,p𝑄subscript𝐾𝑚𝑝QK_{m,p}italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT) before GMp𝐺subscript𝑀𝑝GM_{p}italic_G italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT is ready to start Equation 34 (where we must read the same M𝑀Mitalic_M fiber of QKm,p𝑄subscript𝐾𝑚𝑝QK_{m,p}italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT again). Similarly, we must finish an entire M𝑀Mitalic_M fiber of Equation 35 (reading an entire M𝑀Mitalic_M fiber of SNm,p𝑆subscript𝑁𝑚𝑝SN_{m,p}italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT) before SDp𝑆subscript𝐷𝑝SD_{p}italic_S italic_D start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT is ready to start Equation 36 (where we must read the same M𝑀Mitalic_M fiber of SNm,p𝑆subscript𝑁𝑚𝑝SN_{m,p}italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT again). Regardless of the mapping (including fusion), this cascade must perform three passes, since they are a consequence of the dependencies between Einsums.

IV-E2   2-Pass Attention Cascades

We now briefly summarize the 2-pass cascade, deferring details due to space. Rather than computing the global max and then starting the softmax (as in the 3-pass cascade), the 2-pass cascade first partitions the input, computes a per-partition local max and applies it to form a variant of SNm,p𝑆subscript𝑁𝑚𝑝SN_{m,p}italic_S italic_N start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT whose elements are adjusted by the local max and likewise partitioned. Analogously, each partition gets a local denominator (also adjusted by the same local max). While this is occurring, it builds the global max from the local max values. Next, in a second pass, it uses the global max to correct the per-partition numerators and denominators and compute the softmax output.

IV-E3   1-Pass Attention Cascades

While prior work proposes multiple different 1-pass cascades [16, 15] that take advantage of the reassociations presented in Section III-C. However, the main ideas are the same. First, modify the cascade to multiply the softmax numerator-times-V𝑉Vitalic_V and then compute the division (as described in Section IV-D). This reassociation combines the second and third passes of Cascade 4 (see Section III-C1). To ensure numerical stability, we cannot use this strategy to combine the first and second passes, so we instead use the iterative approach (see Section III-C2). Rather than using the per-partition local max to compute the local numerator and denominator, instead keep a running max that represents the max value seen so far. Each time a new running max is computed, adjust previous results (e.g., numerator-times-V𝑉Vitalic_V, denominator, etc.) with this max.

Next we describe FlashAttention-2’s 1-pass cascade (Cascade 5) because we use it to build FuseMax. Note, despite the evidently increased compute relative to the 3-pass cascade, we will carefully design a mapping in Section V to hide these overheads on a spatial architecture.

Einsum Cascade 5: A 1-pass attention cascade. Note that M1𝑀1M1italic_M 1 is used as a standard rank (e.g., to access BQKm1,m0,p𝐵𝑄subscript𝐾𝑚1𝑚0𝑝BQK_{m1,m0,p}italic_B italic_Q italic_K start_POSTSUBSCRIPT italic_m 1 , italic_m 0 , italic_p end_POSTSUBSCRIPT) and as an iterative rank (e.g., to access RMm1,p𝑅subscript𝑀𝑚1𝑝RM_{m1,p}italic_R italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT). Therefore, the stopping condition for all iterative ranks is m1=M1+1𝑚1𝑀11m1=M1+1italic_m 1 = italic_M 1 + 1 (Equation 53).
{mdframed}

Initialization:

BQKm1,m0,p𝐵𝑄subscript𝐾𝑚1𝑚0𝑝\displaystyle BQK_{m1,m0,p}italic_B italic_Q italic_K start_POSTSUBSCRIPT italic_m 1 , italic_m 0 , italic_p end_POSTSUBSCRIPT =QKm1×M0+m0,pabsent𝑄subscript𝐾𝑚1𝑀0𝑚0𝑝\displaystyle=QK_{m1\times M0+m0,p}= italic_Q italic_K start_POSTSUBSCRIPT italic_m 1 × italic_M 0 + italic_m 0 , italic_p end_POSTSUBSCRIPT (37)
BVf,m1,m0𝐵subscript𝑉𝑓𝑚1𝑚0\displaystyle BV_{f,m1,m0}italic_B italic_V start_POSTSUBSCRIPT italic_f , italic_m 1 , italic_m 0 end_POSTSUBSCRIPT =Vf,m1×M0+m0absentsubscript𝑉𝑓𝑚1𝑀0𝑚0\displaystyle=V_{f,m1\times M0+m0}= italic_V start_POSTSUBSCRIPT italic_f , italic_m 1 × italic_M 0 + italic_m 0 end_POSTSUBSCRIPT (38)
RMm1:m1=0,p𝑅subscript𝑀:𝑚1𝑚10𝑝\displaystyle RM_{m1:m1=0,p}italic_R italic_M start_POSTSUBSCRIPT italic_m 1 : italic_m 1 = 0 , italic_p end_POSTSUBSCRIPT =absent\displaystyle=-\infty= - ∞ (39)
RDm1:m1=0,p𝑅subscript𝐷:𝑚1𝑚10𝑝\displaystyle RD_{m1:m1=0,p}italic_R italic_D start_POSTSUBSCRIPT italic_m 1 : italic_m 1 = 0 , italic_p end_POSTSUBSCRIPT =0absent0\displaystyle=0= 0 (40)
RNVm1:m1=0,p𝑅𝑁subscript𝑉:𝑚1𝑚10𝑝\displaystyle RNV_{m1:m1=0,p}italic_R italic_N italic_V start_POSTSUBSCRIPT italic_m 1 : italic_m 1 = 0 , italic_p end_POSTSUBSCRIPT =0absent0\displaystyle=0= 0 (41)

Extended Einsums:

LMm1,p𝐿subscript𝑀𝑚1𝑝\displaystyle LM_{m1,p}italic_L italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT =BQKm1,m0,p::m0max()\displaystyle=BQK_{m1,m0,p}::\bigvee_{m0}\text{max}(\cup)= italic_B italic_Q italic_K start_POSTSUBSCRIPT italic_m 1 , italic_m 0 , italic_p end_POSTSUBSCRIPT : : ⋁ start_POSTSUBSCRIPT italic_m 0 end_POSTSUBSCRIPT max ( ∪ ) (42)
RMm1+1,p𝑅subscript𝑀𝑚11𝑝\displaystyle RM_{m1+1,p}italic_R italic_M start_POSTSUBSCRIPT italic_m 1 + 1 , italic_p end_POSTSUBSCRIPT =max(RMm1,p,LMm1,p)absent𝑚𝑎𝑥𝑅subscript𝑀𝑚1𝑝𝐿subscript𝑀𝑚1𝑝\displaystyle=max(RM_{m1,p},LM_{m1,p})= italic_m italic_a italic_x ( italic_R italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT , italic_L italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT ) (43)
SLNm1,m0,p𝑆𝐿subscript𝑁𝑚1𝑚0𝑝\displaystyle SLN_{m1,m0,p}italic_S italic_L italic_N start_POSTSUBSCRIPT italic_m 1 , italic_m 0 , italic_p end_POSTSUBSCRIPT =eBQKm1,m0,pRMm1+1,pabsentsuperscript𝑒𝐵𝑄subscript𝐾𝑚1𝑚0𝑝𝑅subscript𝑀𝑚11𝑝\displaystyle=e^{BQK_{m1,m0,p}-RM_{m1+1,p}}= italic_e start_POSTSUPERSCRIPT italic_B italic_Q italic_K start_POSTSUBSCRIPT italic_m 1 , italic_m 0 , italic_p end_POSTSUBSCRIPT - italic_R italic_M start_POSTSUBSCRIPT italic_m 1 + 1 , italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT (44)
SLDm1,p𝑆𝐿subscript𝐷𝑚1𝑝\displaystyle SLD_{m1,p}italic_S italic_L italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT =SLNm1,m0,pabsent𝑆𝐿subscript𝑁𝑚1𝑚0𝑝\displaystyle=SLN_{m1,m0,p}= italic_S italic_L italic_N start_POSTSUBSCRIPT italic_m 1 , italic_m 0 , italic_p end_POSTSUBSCRIPT (45)
SLNVf,m1,p𝑆𝐿𝑁subscript𝑉𝑓𝑚1𝑝\displaystyle SLNV_{f,m1,p}italic_S italic_L italic_N italic_V start_POSTSUBSCRIPT italic_f , italic_m 1 , italic_p end_POSTSUBSCRIPT =SLNm1,m0,p×BVf,m1,m0absent𝑆𝐿subscript𝑁𝑚1𝑚0𝑝𝐵subscript𝑉𝑓𝑚1𝑚0\displaystyle=SLN_{m1,m0,p}\times BV_{f,m1,m0}= italic_S italic_L italic_N start_POSTSUBSCRIPT italic_m 1 , italic_m 0 , italic_p end_POSTSUBSCRIPT × italic_B italic_V start_POSTSUBSCRIPT italic_f , italic_m 1 , italic_m 0 end_POSTSUBSCRIPT (46)
PRMm1,p𝑃𝑅subscript𝑀𝑚1𝑝\displaystyle PRM_{m1,p}italic_P italic_R italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT =eRMm1,pRMm1+1,pabsentsuperscript𝑒𝑅subscript𝑀𝑚1𝑝𝑅subscript𝑀𝑚11𝑝\displaystyle=e^{RM_{m1,p}-RM_{m1+1,p}}= italic_e start_POSTSUPERSCRIPT italic_R italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT - italic_R italic_M start_POSTSUBSCRIPT italic_m 1 + 1 , italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT (47)
SPDm1,p𝑆𝑃subscript𝐷𝑚1𝑝\displaystyle SPD_{m1,p}italic_S italic_P italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT =RDm1,p×PRMm1,pabsent𝑅subscript𝐷𝑚1𝑝𝑃𝑅subscript𝑀𝑚1𝑝\displaystyle=RD_{m1,p}\times PRM_{m1,p}= italic_R italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT × italic_P italic_R italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT (48)
RDm1+1,p𝑅subscript𝐷𝑚11𝑝\displaystyle RD_{m1+1,p}italic_R italic_D start_POSTSUBSCRIPT italic_m 1 + 1 , italic_p end_POSTSUBSCRIPT =SLDm1,p+SPDm1,pabsent𝑆𝐿subscript𝐷𝑚1𝑝𝑆𝑃subscript𝐷𝑚1𝑝\displaystyle=SLD_{m1,p}+SPD_{m1,p}= italic_S italic_L italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT + italic_S italic_P italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT (49)
SPNVf,m1,p𝑆𝑃𝑁subscript𝑉𝑓𝑚1𝑝\displaystyle SPNV_{f,m1,p}italic_S italic_P italic_N italic_V start_POSTSUBSCRIPT italic_f , italic_m 1 , italic_p end_POSTSUBSCRIPT =RNVf,m1,p×PRMm1,pabsent𝑅𝑁subscript𝑉𝑓𝑚1𝑝𝑃𝑅subscript𝑀𝑚1𝑝\displaystyle=RNV_{f,m1,p}\times PRM_{m1,p}= italic_R italic_N italic_V start_POSTSUBSCRIPT italic_f , italic_m 1 , italic_p end_POSTSUBSCRIPT × italic_P italic_R italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT (50)
RNVf,m1+1,p𝑅𝑁subscript𝑉𝑓𝑚11𝑝\displaystyle RNV_{f,m1+1,p}italic_R italic_N italic_V start_POSTSUBSCRIPT italic_f , italic_m 1 + 1 , italic_p end_POSTSUBSCRIPT =SLNVf,m1,p+SPNVf,m1,pabsent𝑆𝐿𝑁subscript𝑉𝑓𝑚1𝑝𝑆𝑃𝑁subscript𝑉𝑓𝑚1𝑝\displaystyle=SLNV_{f,m1,p}+SPNV_{f,m1,p}= italic_S italic_L italic_N italic_V start_POSTSUBSCRIPT italic_f , italic_m 1 , italic_p end_POSTSUBSCRIPT + italic_S italic_P italic_N italic_V start_POSTSUBSCRIPT italic_f , italic_m 1 , italic_p end_POSTSUBSCRIPT (51)
AVf,p𝐴subscript𝑉𝑓𝑝\displaystyle AV_{f,p}italic_A italic_V start_POSTSUBSCRIPT italic_f , italic_p end_POSTSUBSCRIPT =RNVf,M1,p/RDM1,pabsent𝑅𝑁subscript𝑉𝑓𝑀1𝑝𝑅subscript𝐷𝑀1𝑝\displaystyle=RNV_{f,M1,p}/RD_{M1,p}= italic_R italic_N italic_V start_POSTSUBSCRIPT italic_f , italic_M 1 , italic_p end_POSTSUBSCRIPT / italic_R italic_D start_POSTSUBSCRIPT italic_M 1 , italic_p end_POSTSUBSCRIPT (52)
:m1M1+1\displaystyle\diamond:m1\equiv M1+1⋄ : italic_m 1 ≡ italic_M 1 + 1 (53)

We will start by expressing the partitioning of both of the inputs QKm,p𝑄subscript𝐾𝑚𝑝QK_{m,p}italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT and Vf,msubscript𝑉𝑓𝑚V_{f,m}italic_V start_POSTSUBSCRIPT italic_f , italic_m end_POSTSUBSCRIPT into M1 chunks of M0 elements each (Equations 37-38). This allows us to perform operations like maximum on individual M0𝑀0M0italic_M 0 fibers, rather than on the whole tensor (Equation 42). The problem is, of course, that the local maximum is not necessarily the same for all M0𝑀0M0italic_M 0 fibers and so will not just cancel nicely like the global maximum.

We resolve this by instead using the running maximum (RMm1,p𝑅subscript𝑀𝑚1𝑝RM_{m1,p}italic_R italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT)—the global maximum of all inputs seen so far—instead of the local maximum. We recognize that M1𝑀1M1italic_M 1 can also serve as an iterative rank, and iteratively build up RMm1,p𝑅subscript𝑀𝑚1𝑝RM_{m1,p}italic_R italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT. After initializing RM0,p𝑅subscript𝑀0𝑝RM_{0,p}italic_R italic_M start_POSTSUBSCRIPT 0 , italic_p end_POSTSUBSCRIPT to -\infty- ∞ (Equation 39), we compute a new running maximum RMm1+1,p𝑅subscript𝑀𝑚11𝑝RM_{m1+1,p}italic_R italic_M start_POSTSUBSCRIPT italic_m 1 + 1 , italic_p end_POSTSUBSCRIPT using the running maximum computed in the previous iteration RMm1,p𝑅subscript𝑀𝑚1𝑝RM_{m1,p}italic_R italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT and the new local maximum LMm1,p𝐿subscript𝑀𝑚1𝑝LM_{m1,p}italic_L italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT (Equation 43).

We can now use the running maximum to compute a local numerator SLNm1,m0,p𝑆𝐿subscript𝑁𝑚1𝑚0𝑝SLN_{m1,m0,p}italic_S italic_L italic_N start_POSTSUBSCRIPT italic_m 1 , italic_m 0 , italic_p end_POSTSUBSCRIPT (Equation 44), a local denominator SLDm1,p𝑆𝐿subscript𝐷𝑚1𝑝SLD_{m1,p}italic_S italic_L italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT (Equation 45), and even the dot product result SLNVf,m1,p𝑆𝐿𝑁subscript𝑉𝑓𝑚1𝑝SLNV_{f,m1,p}italic_S italic_L italic_N italic_V start_POSTSUBSCRIPT italic_f , italic_m 1 , italic_p end_POSTSUBSCRIPT (Equation 46) using the partitioned BVf,m1,m0𝐵subscript𝑉𝑓𝑚1𝑚0BV_{f,m1,m0}italic_B italic_V start_POSTSUBSCRIPT italic_f , italic_m 1 , italic_m 0 end_POSTSUBSCRIPT (Equation 38).

Now consider the softmax denominator. Eventually, we would like to reduce SLDm1,p𝑆𝐿subscript𝐷𝑚1𝑝SLD_{m1,p}italic_S italic_L italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT into a 0-tensor, but because its values may have been computed with different maximums, we cannot simply use addition. Instead, by introducing a new running denominator RDm1,p𝑅subscript𝐷𝑚1𝑝RD_{m1,p}italic_R italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT with iterative rank M1𝑀1M1italic_M 1, we can correct the old denominator RDm1,p𝑅subscript𝐷𝑚1𝑝RD_{m1,p}italic_R italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT to the new running maximum RMm1+1,p𝑅subscript𝑀𝑚11𝑝RM_{m1+1,p}italic_R italic_M start_POSTSUBSCRIPT italic_m 1 + 1 , italic_p end_POSTSUBSCRIPT and then perform the addition. we must initialize the running denominator at the start of the computation to 0 (Equation 40). Then, at each point m1𝑚1m1italic_m 1, the correction factor PRMm1,p𝑃𝑅subscript𝑀𝑚1𝑝PRM_{m1,p}italic_P italic_R italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT allows us to correct the previous running denominator RDm1,p𝑅subscript𝐷𝑚1𝑝RD_{m1,p}italic_R italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT with the new maximum (Equation 48). In other words, RDm1,p𝑅subscript𝐷𝑚1𝑝RD_{m1,p}italic_R italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT is downscaled by eRMm1,psuperscript𝑒𝑅subscript𝑀𝑚1𝑝e^{RM_{m1,p}}italic_e start_POSTSUPERSCRIPT italic_R italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. SPDm1,p𝑆𝑃subscript𝐷𝑚1𝑝SPD_{m1,p}italic_S italic_P italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT “switches” the downscaling factor on RDm1,p𝑅subscript𝐷𝑚1𝑝RD_{m1,p}italic_R italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT to eRMm1+1,psuperscript𝑒𝑅subscript𝑀𝑚11𝑝e^{RM_{m1+1,p}}italic_e start_POSTSUPERSCRIPT italic_R italic_M start_POSTSUBSCRIPT italic_m 1 + 1 , italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT by multiplying RDm1,p𝑅subscript𝐷𝑚1𝑝RD_{m1,p}italic_R italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT by eRMm1,peRMm1+1,psuperscript𝑒𝑅subscript𝑀𝑚1𝑝superscript𝑒𝑅subscript𝑀𝑚11𝑝\frac{e^{RM_{m1,p}}}{e^{RM_{m1+1,p}}}divide start_ARG italic_e start_POSTSUPERSCRIPT italic_R italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_R italic_M start_POSTSUBSCRIPT italic_m 1 + 1 , italic_p end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG (PRMm1,p𝑃𝑅subscript𝑀𝑚1𝑝PRM_{m1,p}italic_P italic_R italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT). Once SLDm1,p𝑆𝐿subscript𝐷𝑚1𝑝SLD_{m1,p}italic_S italic_L italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT and SPDm1,p𝑆𝑃subscript𝐷𝑚1𝑝SPD_{m1,p}italic_S italic_P italic_D start_POSTSUBSCRIPT italic_m 1 , italic_p end_POSTSUBSCRIPT have the same maximum, they can be combined to produce the new running denominator RDm1+1,p𝑅subscript𝐷𝑚11𝑝RD_{m1+1,p}italic_R italic_D start_POSTSUBSCRIPT italic_m 1 + 1 , italic_p end_POSTSUBSCRIPT (Equation 49). We can do the same to compute the running numerator-times-V𝑉Vitalic_V (Equations 4150-51).

Finally, AVf,p𝐴subscript𝑉𝑓𝑝AV_{f,p}italic_A italic_V start_POSTSUBSCRIPT italic_f , italic_p end_POSTSUBSCRIPT can be computed by dividing the final numerator-times-V𝑉Vitalic_V by the final denominator. By construction, at this point, RNVf,M1,p𝑅𝑁subscript𝑉𝑓𝑀1𝑝RNV_{f,M1,p}italic_R italic_N italic_V start_POSTSUBSCRIPT italic_f , italic_M 1 , italic_p end_POSTSUBSCRIPT and RDM1,p𝑅subscript𝐷𝑀1𝑝RD_{M1,p}italic_R italic_D start_POSTSUBSCRIPT italic_M 1 , italic_p end_POSTSUBSCRIPT are both downscaled by the same maximum RMM1,p𝑅subscript𝑀𝑀1𝑝RM_{M1,p}italic_R italic_M start_POSTSUBSCRIPT italic_M 1 , italic_p end_POSTSUBSCRIPT (conveniently, also the global maximum) and can be correctly combined.

V Mapping Attention Onto A Spatial Array

Based on the framework from Section IV, we now describe FuseMax, an efficient mapping of an attention algorithm (specifically the 1-pass cascade in Cascade 5) to a spatial array-style architecture.

The goal when mapping a cascade onto hardware is to fully utilize all available compute units. In our evaluation of prior work (Figure 6 and Section VI-B), we observe that at short sequence lengths, the 2D PE array is under-utilized because it must wait for the 1D PE array to compute the softmax. At longer sequence lengths, both arrays are under-utilized since the workload becomes memory-bandwidth limited.

FuseMax’s mapping addresses these issues to achieve full utilization on both the 1D and 2D PE arrays. First, we decrease the compute performed by the 1D array by (1) applying the division reduction optimization (Section IV-D) and (2) sharing the other operations (sum/max/exp) between the 1D and 2D arrays. Similarly, we ensure that the workload is never memory-bandwidth limited by deeply fusing all Einsums in the cascade to restrict the live footprint to only what can be buffered on-chip. No matter the sequence length, our dataflow is never forced to spill any of its intermediates off-chip.

Architecture. We assume a standard spatial array-style architecture for our mapping. See Figure 2. We set parameters to match the cloud configuration in prior work [27].

Refer to caption
Figure 2: Spatial array architecture assumed for FuseMax.

Figure 3 shows the evolution of the 2D PE array architectre, from a fixed-dataflow multiply-accumulate TPU PE (Figure 3(a)) to a flexible-dataflow multiply-accumulate PE (Figure 3(b)) to a FuseMax PE (Figure 3(c)). Note, although both the 1D and 2D PE arrays in FuseMax perform exponentiation, we implement exponentiation with 6 sequential multiply-accumulate operations [36, 49] and therefore do not require a dedicated exponentiation unit.

Refer to caption
(a) TPU [26] PE
Refer to caption
(b) FLAT [27] PE
Refer to caption
(c) FuseMax PE
Figure 3: 2D PE architecture evolution

Fusion and Partitioning. Prior attention accelerators [27, 57] explore fusing many of attention’s loop nests together. However, because these accelerators all use multi-pass cascades, the algorithmic minimum live footprint of some tensors (e.g., QKm,p𝑄subscript𝐾𝑚𝑝QK_{m,p}italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT) is O(M)𝑂𝑀O(M)italic_O ( italic_M ), meaning that for long sequence lengths, intermediates cannot be buffered on chip.

FuseMax leverages fusion in conjunction with the 1-pass cascade to eliminate the memory traffic of these tensors, regardless of the sequence length. Specifically, we partition on both M𝑀Mitalic_M and P𝑃Pitalic_P (forming M1,M0𝑀1𝑀0M1,M0italic_M 1 , italic_M 0 and P2,P1,P0𝑃2𝑃1𝑃0P2,P1,P0italic_P 2 , italic_P 1 , italic_P 0), and maximally fuse all levels in the attention loopnest as shown in Mapping 1. That is, all Einsums in Cascade 5 are fused except for the last (which is fused to the rest only on P2𝑃2P2italic_P 2).

Mapping 1: The FuseMax mapping as a loopnest. We partition on both M𝑀Mitalic_M and P𝑃Pitalic_P and map the innermost ranks M0𝑀0M0italic_M 0 and P0𝑃0P0italic_P 0 to the spatial array PEs. ComputeRNVTile forms a tile of QKm,p𝑄subscript𝐾𝑚𝑝QK_{m,p}italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT (i.e., BQKm1,:,p2,p1,:𝐵𝑄subscript𝐾𝑚1:𝑝2𝑝1:BQK_{m1,:,p2,p1,:}italic_B italic_Q italic_K start_POSTSUBSCRIPT italic_m 1 , : , italic_p 2 , italic_p 1 , : end_POSTSUBSCRIPT) and then performs Equations 37-51 from Cascade 5. ComputeAVTile performs Equation 52. Note that each equation (Einsum) represents a loopnest: by writing all equations in ComputeRNVTile under a single loopnest, we mean that we are maximally fusing those loopnests. Outer loops over B𝐵Bitalic_B and H𝐻Hitalic_H (if performing batched multihead attention) are not shown.
{mdframed}
for p2 ...:
  for m1 ...:
    for p1 ...:
      parallel_for p0 ...:
        parallel_for m0 ...:
          (RNV[:, m1 + 1, p2, p1, p0],
           RD[m1 + 1, p2, p1, p0]) =
              ComputeRNVTile(
                Q[:, p2, p1, p0],
                K[:, m1, m0], V[:, m1, m0])
  for p1 ...:
    parallel_for p0 ...:
      AV[:, p2, p1, p0] =
        ComputeAVTile(
          RNV[:, m1 + 1, p2, p1, p0],
          RD[m1 + 1, p2, p1, p0])
    

Parallelization and Spatial Reduction. While prior work implementing attention in hardware [27, 57] does utilize the 2D spatial array for the tensor products, it fails to do so for the softmax, choosing instead to use the 1D array. However, because there are far fewer total PEs in the 1D array than the 2D array, the softmax becomes a bottleneck. FuseMax improves utilization of the 2D spatial array by using it for both the tensor products and the exponentiation operator in the softmax. FuseMax parallelizes across the M0𝑀0M0italic_M 0 and P0𝑃0P0italic_P 0 ranks throughout the attention kernel (see Mapping 1). We set M0×P0=# 2DArrayPEs𝑀0𝑃0#2DArrayPEsM0\times P0=\#\;\mathrm{2D\;Array}\;\mathrm{PEs}italic_M 0 × italic_P 0 = # 2 roman_D roman_Array roman_PEs. The large spatial reductions required when parallelizing across the M0𝑀0M0italic_M 0 rank are easily handled by the low-cost inter-PE communication network.

Refer to caption
Figure 4: FuseMax pipelining at a glance. Each tensor name (e.g., SLNV𝑆𝐿𝑁𝑉SLNVitalic_S italic_L italic_N italic_V) corresponds to the Einsum used to compute that tensor (see Cascade 5). a𝑎aitalic_a, b𝑏bitalic_b, c𝑐citalic_c and d𝑑ditalic_d denote tile-relative coordinates where a<b<c<d𝑎𝑏𝑐𝑑a<b<c<ditalic_a < italic_b < italic_c < italic_d. If Epoch i𝑖iitalic_i produces tiles with coordinates a,b,c,d𝑎𝑏𝑐𝑑a,b,c,ditalic_a , italic_b , italic_c , italic_d, Epoch i+1𝑖1i+1italic_i + 1 produces tiles with identifiers a+1,b+1,c+1,d+1𝑎1𝑏1𝑐1𝑑1a+1,b+1,c+1,d+1italic_a + 1 , italic_b + 1 , italic_c + 1 , italic_d + 1. And so on. ‘A|Bconditional𝐴𝐵A|Bitalic_A | italic_B’ denotes ‘computing tile A𝐴Aitalic_A is interleaved with computing tile B𝐵Bitalic_B.’ ‘AB𝐴𝐵A\rightarrow Bitalic_A → italic_B’ denotes ‘computing tile A𝐴Aitalic_A is done before computing tile B𝐵Bitalic_B.’ Computing AVf,p𝐴subscript𝑉𝑓𝑝AV_{f},pitalic_A italic_V start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_p is not shown. The green and blue time periods making up an epoch take almost the same number of cycles.

Pipelining. The dependencies between different Einsums in our cascade necessitate fine-grain pipeline parallelism to achieve high utilization of both the 1D and 2D spatial arrays. Figure 4 shows the waterfall diagram for FuseMax in the steady state. Time is broken into epochs. Each epoch performs the same set of tile-granular operations at specific tile-relative coordinates (given by a,b,c,d𝑎𝑏𝑐𝑑a,b,c,ditalic_a , italic_b , italic_c , italic_d in the figure). Across all epochs, the kernel evaluates all tiles and each Einsum in Cascade 5 is mapped to either the 2D or 1D array for all epochs (as shown in the figure).

A major design consideration when pipelining the mapping is how to overcome the latency of fills and drains to/from the spatial array. Consider a tile of QKm,p𝑄subscript𝐾𝑚𝑝QK_{m,p}italic_Q italic_K start_POSTSUBSCRIPT italic_m , italic_p end_POSTSUBSCRIPT of shape M0×P0𝑀0𝑃0M0\times P0italic_M 0 × italic_P 0. Per Equation 22, the iteration space to evaluate this tile is E×M0×P0𝐸𝑀0𝑃0E\times M0\times P0italic_E × italic_M 0 × italic_P 0 which becomes E𝐸Eitalic_E cycles on the spatial array. For the networks we evaluate, E=64𝐸64E=64italic_E = 64 or 128128128128. Assume E=64𝐸64E=64italic_E = 64. This means, assuming an output stationary dataflow, that while each PE performs 64 MACCs, it takes 256similar-toabsent256\sim 256∼ 256 cycles to both fill and drain the spatial array. Without careful interleaving, this combination of parameters causes low utilization because, for example, the running max RMm1+1,p1,:𝑅subscript𝑀𝑚11𝑝1:RM_{m1+1,p1,:}italic_R italic_M start_POSTSUBSCRIPT italic_m 1 + 1 , italic_p 1 , : end_POSTSUBSCRIPT cannot be computed until a tile of QKm1,:,p1,:𝑄subscript𝐾𝑚1:𝑝1:QK_{m1,:,p1,:}italic_Q italic_K start_POSTSUBSCRIPT italic_m 1 , : , italic_p 1 , : end_POSTSUBSCRIPT is completed and spatially reduced (drained) to form the local max LMm1,p1,:𝐿subscript𝑀𝑚1𝑝1:LM_{m1,p1,:}italic_L italic_M start_POSTSUBSCRIPT italic_m 1 , italic_p 1 , : end_POSTSUBSCRIPT (Equations 42-43).

We address the above issues with two levels of interleaving. First, we interleave the construction of dependent tiles across epochs. This is reminiscent of software pipelining. For example, in Figure 4 the d𝑑ditalic_d-th tile of BQK𝐵𝑄𝐾BQKitalic_B italic_Q italic_K and LM𝐿𝑀LMitalic_L italic_M are completed in Epoch i𝑖iitalic_i (as they correspond to a fill followed by a drain and can be easily pipelined). The RM𝑅𝑀RMitalic_R italic_M (which has to wait for the drain) for tile d𝑑ditalic_d takes place in a later epoch. Instead, Epoch i𝑖iitalic_i computes an earlier tile’s running maximum RM[c]𝑅𝑀delimited-[]𝑐RM[c]italic_R italic_M [ italic_c ].

Refer to caption
Figure 5: Intial pipeline fill (t=0𝑡0t=0italic_t = 0 to t=2𝑡2t=2italic_t = 2) and steady-state (t=3𝑡3t=3italic_t = 3 and t=4𝑡4t=4italic_t = 4) for the intra-epoch interleaving of SLNV|BQKconditional𝑆𝐿𝑁𝑉𝐵𝑄𝐾SLNV|BQKitalic_S italic_L italic_N italic_V | italic_B italic_Q italic_K and SPNV|RNVconditional𝑆𝑃𝑁𝑉𝑅𝑁𝑉SPNV|RNVitalic_S italic_P italic_N italic_V | italic_R italic_N italic_V to maximize 2D and 1D PE utilization, respectively, on a toy 2x2 array. Each color indicates a tensor and each number indicates a point in that tensor (e.g., the point BV0𝐵subscript𝑉0BV_{0}italic_B italic_V start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT moves from the top left PE at t=1𝑡1t=1italic_t = 1 to the top right PE at t=2𝑡2t=2italic_t = 2). To reason about signal timing, we use input (but not output) latches for data in each PE, so moving data appears on output wires. Some stationary tensors (e.g., BQK𝐵𝑄𝐾BQKitalic_B italic_Q italic_K) and Einsums (e.g., SLD𝑆𝐿𝐷SLDitalic_S italic_L italic_D) are omitted for clarity.

Second, we interleave the construction of certain tiles within an epoch at a fine (e.g., cycle-by-cycle) granularity. See the notation ‘A|Bconditional𝐴𝐵A|Bitalic_A | italic_B’ in Figure 4. This is to ensure high utilization of both the 2D and 1D PE arrays at all times. To make this more clear, Figure 5 shows the start up and steady-state interleaving of SLNV𝑆𝐿𝑁𝑉SLNVitalic_S italic_L italic_N italic_V and BQK𝐵𝑄𝐾BQKitalic_B italic_Q italic_K in the 2D array and SPNV𝑆𝑃𝑁𝑉SPNVitalic_S italic_P italic_N italic_V and RNV𝑅𝑁𝑉RNVitalic_R italic_N italic_V in the 1D array. In each cycle, a given PE in the 2D array computes a value for either BQK𝐵𝑄𝐾BQKitalic_B italic_Q italic_K or SLNV𝑆𝐿𝑁𝑉SLNVitalic_S italic_L italic_N italic_V and this alternates cycle by cycle. Each neighbor-neighbor link in the array is active in every cycle—carrying data for one of the two operation types. By interleaving SLNV𝑆𝐿𝑁𝑉SLNVitalic_S italic_L italic_N italic_V with BQK𝐵𝑄𝐾BQKitalic_B italic_Q italic_K, the 1D PEs can concurrently compute SPNV𝑆𝑃𝑁𝑉SPNVitalic_S italic_P italic_N italic_V and RNV𝑅𝑁𝑉RNVitalic_R italic_N italic_V.

Putting everything together, as Section VI-B will show, the above enables high utilization of all 2D and 1D array PEs.

VI Evaluation

Refer to caption
(a) 1D PE array utilization
Refer to caption
(b) 2D PE array utilization
Figure 6: Utilization of the different PE arrays on the unfused baseline, FLAT, and FuseMax.
Refer to caption
Figure 7: Speedup of attention for FLAT and FuseMax over an unfused baseline.
Refer to caption
Figure 8: Energy consumption of attention for FLAT and FuseMax over an unfused baseline.
Refer to caption
Figure 9: Speedup of transformer inference on FLAT and FuseMax over an unfused baseline.
Refer to caption
Figure 10: Energy consumption of transformer inference on FLAT and FuseMax over an unfused baseline.

In this section, we demonstrate how the FuseMax dataflow achieves improvements in both performance and energy relative to the state of the art, for both attention and the end-to-end transformer inference.

VI-A Experimental Set-Up

First, we present the experimental set-up details common to all following subsections.

Workloads. We evaluate all accelerators and configurations using the same transformer models used by FLAT [27]: BERT-Base [18] (BERT), TrXL-wt103 [14] (TrXL), T5-small [46] (T5), and XLM [29]. We omit FlauBERT [31] because it uses the same hyperparameters as TrXL. We also note that though T5 is an encoder-decoder model, we only evaluate the encoder in this work. Following FLAT, we use a batch size B=64𝐵64B=64italic_B = 64 for all evaluations.

Modeling with Timeloop and Accelergy. We perform our evaluation using two tools for tensor algebra accelerator modeling and design space exploration: Timeloop [40] and Accelergy [51]. We use these tools to build models of the accelerator architectures at a 45nm technology node and evaluate each Einsum individually. Results from individual Einsums are combined using heuristics presented in prior work for evaluating full cascades [35]. Together, these tools allow us to evaluate execution time, energy, and area for all our designs. We perform floating-point division using the design in Xia et al. [54], scaled down to a 45nm technology node [51].

Unfused Baseline. We build the unfused baseline by combining the costs of three phases: QK𝑄𝐾QKitalic_Q italic_K (Equation 22), the 3-pass softmax (Cascade 4), and AV𝐴𝑉AVitalic_A italic_V (Equation 24). Because this baseline is unfused, each phase can be scheduled independently, but proceed sequentially and require outputs to be written to memory between phases. We use Timeloop to search for efficient mappings to perform QK𝑄𝐾QKitalic_Q italic_K and AV𝐴𝑉AVitalic_A italic_V. Additionally, we model the softmax for the unfused baseline by allowing the accelerator to load the M𝑀Mitalic_M fibers of the input on-chip one-by-one (spilling if there is not enough space) before performing the compute. We model the memory traffic, compute, and energy required to perform all Einsums required for attention.

FLAT Baseline. Our main baseline is the state-of-the-art attention accelerator FLAT [27]. Though we started with the FLAT authors’ original code, we found and corrected a number of bugs. Through private correspondence with the FLAT authors, we verified the bugs were indeed bugs. We also discovered a couple of larger conceptual errors, which the authors told us to avoid by restricting FLAT to only search through configurations without these issues.

Beyond correcting the FLAT codebase, we created and validated a Timeloop model that reproduces the FLAT authors’ (corrected) code to within <1%absentpercent1<1\%< 1 % error. However, the FLAT codebase does not model the cost to perform the softmax. Specifically, their model ignores the cost of data transfers (between any levels of the memory hierarchy) and uses 230superscript2302^{30}2 start_POSTSUPERSCRIPT 30 end_POSTSUPERSCRIPT 1D PEs. When comparing FuseMax and FLAT in this work, we augment our Timeloop model to model softmax correctly per the 3-pass cascade implicitly assumed by FLAT.

Hardware parameters. Figure 2 shows the selected hardware parameters. We chose the PE array dimension to match FLAT’s cloud accelerator and the global buffer capacity by normalizing the area. Also following FLAT, we use a 940 MHz frequency. We use Accelergy to model the area of both designs and find that FuseMax is 17% smaller.

VI-B Evaluating Attention

We now evaluate FuseMax to demonstrate the benefits it provides on the attention kernel by comparing it to the two baselines.

Utilization. Figure 6(a) shows the utilization of the 1D PE array when performing attention. We see that, because fused dataflows (FLAT / FuseMax) do not have to wait for the whole QK𝑄𝐾QKitalic_Q italic_K Einsum to complete to begin the softmax, they achieve high utilization. While FLAT’s utilization drops for sequence lengths 256Kabsent256K\geq 256\text{K}≥ 256 K—it becomes memory bandwidth limited because it must spill the QK𝑄𝐾QKitalic_Q italic_K and A𝐴Aitalic_A tensors to memory—FuseMax achieves full utilization for all sequence lengths.

Similarly, Figure 6(b) shows the utilization of the 2D PE array. Because of the large amount of compute required for the softmax, both baselines achieve very poor utilization of this array. On the other hand, at long sequence lengths, FuseMax achieves almost 100% utilization. We observe that both baselines do achieve slightly higher utilization on XLM, which can be attributed to the higher intensity caused by a larger embedding dimension (E𝐸Eitalic_E/F𝐹Fitalic_F).

Speedup. Figure 7 shows that FuseMax achieves an average speedup of 10×10\times10 × over the unfused baseline and 6.7×6.7\times6.7 × over FLAT. We note FuseMax achieves lower speedup on XLM only because the baselines are able to achieve higher utilization of the 2D array on this transformer (Figure 6(b)).

Energy. Figure 8 shows that FuseMax uses 77%percent7777\%77 % the energy of the unfused baseline and 79%percent7979\%79 % the energy of FLAT.444FLAT reports larger energy savings over the unfused baseline because it only reports energy associated with DRAM traffic during the tensor products. The energy use of the unfused baseline and FLAT are dominated by the DRAM access energy, the global buffer access energy, and the QK𝑄𝐾QKitalic_Q italic_K and AV𝐴𝑉AVitalic_A italic_V (Equations 22 and 24) compute energy. FuseMax achieves its energy savings by significantly reducing the DRAM access energy.

VI-C Evaluating Transformer Inference

To evaluate the benefits of FuseMax on end-to-end transformer inference, we include the other required linear layers (Section IV-A). We use Timeloop to search for optimal mappings for these linear layers and use the same mappings for all three accelerator configurations. The attention modeling remains the same as Section VI-B.

Speedup. Figure 9 shows the performance improvement achieved by FuseMax. Across the sequence lengths tested, FuseMax achieves an average speedup of 7.6×7.6\times7.6 × over the unfused baseline and 5.3×5.3\times5.3 × over FLAT. As discussed in Section IV-A, as sequence length grows, attention becomes a larger fraction of the total required compute. Therefore, at 1M tokens, FuseMax achieves an average 10×10\times10 × speedup over the unfused baseline and 7.5×7.5\times7.5 × speedup over FLAT.

Energy. Figure 10 shows the energy reduction achieved by FuseMax. Here, we see similar results: as attention becomes a larger fraction of the kernel, the energy reduction increases. FuseMax uses 82%percent8282\%82 % of the unfused baseline and 83%percent8383\%83 % of FLAT’s energy during end-to-end inference.

VII Related Work

Spatial architectures have been applied successfully to a variety of domains in academia [10, 11, 43, 39] and industry [26, 4]. Beyond FLAT [27] (discussed in the main body of the paper), TileFlow [57] is a framework for modeling and searching for efficient fused dataflows (including for attention) on spatial architectures. Though TileFlow does explore a broader space of dataflows than FLAT, even implementing the 2-pass softmax cascade (Section IV-E2), its dataflows remain softmax-compute limited.

Quantization and sparsity have also been successfully applied to reduce the transformer inference compute and live footprint. We view these schemes as complementary to our work. GPTQ [21], AWQ [32], and LLM.int8() [17] quantize model weights to 4 or 8 bits without significant accuracy degradation. Outlier-aware quantization schemes like GOBO [55] and OliVe [22] quantize both weights and activations to a low-bit precision on specific hardware designs. SpAtten [49] prunes entire tokens and heads, while Sanger [34] and DOTA [44] use quantized or low-rank projected Q𝑄Qitalic_Q and K𝐾Kitalic_K tensors to estimate which values of QK𝑄𝐾QKitalic_Q italic_K and A𝐴Aitalic_A can be safely pruned. All of these algorithms are expressible as cascades of Einsums, and therefore, may be combined with FuseMax to improve performance and energy efficiency, though we leave their specification and implementation to future work.

VIII Conclusion

This paper advanced the state of the art in spatial accelerator design for transformer inference. To do so, we expressed attention and its variants as cascades of Einsums. We used these cascades to reason about attention’s characteristics, independent of its mapping/scheduling. Using these principles, we proposed FuseMax—an accelerator that uses deep fusion and fine-grain pipelining to map attention onto a spatial architecture. FuseMax achieves 100%similar-toabsentpercent100\sim 100\%∼ 100 % utilization of both PE arrays, demonstrating 6.7×6.7\times6.7 × speedup over the prior state-of-the-art (FLAT) using 79%percent7979\%79 % of the energy on attention and 5.3×5.3\times5.3 × speedup over FLAT using 83%percent8383\%83 % of the energy on end-to-end inference.

Our work shows that cascades of Einsums provide a powerful abstraction for representing and analyzing domain-specific kernels. Future work may explore their application to other attention variants (e.g., those exploiting quantization and sparsity) or even other domains (e.g., fully homomorphic encryption, scientific computing, relational algebra, etc.). Doing so enables mapping-agnostic analysis and may elucidate previously undiscovered cascades and schedules for these algorithms.

References

  • [1] “Our next-generation model: Gemini 1.5,” https://blog.google/technology/ai/google-gemini-next-generation-model-february-2024/#context-window.
  • [2] “Tensor network contractions,” ser. Lecture Notes in Physics, vol. 964.   Springer Cham, 2020.
  • [3] M. Abadi, A. Agarwal, P. Barham, E. Brevdo, Z. Chen, C. Citro, G. S. Corrado, A. Davis, J. Dean, M. Devin, S. Ghemawat, I. Goodfellow, A. Harp, G. Irving, M. Isard, Y. Jia, R. Jozefowicz, L. Kaiser, M. Kudlur, J. Levenberg, D. Mané, R. Monga, S. Moore, D. Murray, C. Olah, M. Schuster, J. Shlens, B. Steiner, I. Sutskever, K. Talwar, P. Tucker, V. Vanhoucke, V. Vasudevan, F. Viégas, O. Vinyals, P. Warden, M. Wattenberg, M. Wicke, Y. Yu, and X. Zheng, “TensorFlow: Large-scale machine learning on heterogeneous systems,” 2015, software available from tensorflow.org. [Online]. Available: https://www.tensorflow.org/
  • [4] AWS. (2024) Trainium architecture. [Accessed April 16, 2024]. [Online]. Available: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/trainium.html
  • [5] A. Baevski, H. Zhou, A. Mohamed, and M. Auli, “wav2vec 2.0: a framework for self-supervised learning of speech representations,” in Proceedings of the 34th International Conference on Neural Information Processing Systems, ser. NIPS ’20.   Red Hook, NY, USA: Curran Associates Inc., 2020.
  • [6] J. S. Bridle, “Probabilistic interpretation of feedforward classification network outputs, with relationships to statistical pattern recognition,” in NATO Neurocomputing, 1989. [Online]. Available: https://api.semanticscholar.org/CorpusID:59636530
  • [7] S. Chen, S. Huang, S. Pandey, B. Li, G. R. Gao, L. Zheng, C. Ding, and H. Liu, “E.t.: Re-thinking self-attention for transformer models on gpus,” in SC21: International Conference for High Performance Computing, Networking, Storage and Analysis, 2021, pp. 1–14.
  • [8] T. Chen, Z. Du, N. Sun, J. Wang, C. Wu, Y. Chen, and O. Temam, “Diannao: A small-footprint high-throughput accelerator for ubiquitous machine-learning,” ACM Sigplan Notices, vol. 49, no. 4, pp. 269–284, 2014.
  • [9] Y. Chen, Y. Xie, L. Song, F. Chen, and T. Tang, “A survey of accelerator architectures for deep neural networks,” Engineering, vol. 6, no. 3, pp. 264–274, 2020.
  • [10] Y.-H. Chen, J. Emer, and V. Sze, “Eyeriss: A spatial architecture for energy-efficient dataflow for convolutional neural networks,” in ISCA’16.
  • [11] Y.-H. Chen, J. Emer, and V. Sze, “Eyeriss v2: A flexible and high-performance accelerator for emerging deep neural networks,” 2018.
  • [12] Y. Chen, T. Luo, S. Liu, S. Zhang, L. He, J. Wang, L. Li, T. Chen, Z. Xu, N. Sun et al., “Dadiannao: A machine-learning supercomputer,” in MICRO’14.
  • [13] J. Choi, H. Li, B. Kim, S. Hwang, and J. H. Ahn, “Accelerating transformer networks through recomposing softmax layers,” in 2022 IEEE International Symposium on Workload Characterization (IISWC), 2022, pp. 92–103.
  • [14] A. CONNEAU and G. Lample, “Cross-lingual language model pretraining,” in Advances in Neural Information Processing Systems, H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, Eds., vol. 32.   Curran Associates, Inc., 2019. [Online]. Available: https://proceedings.neurips.cc/paper_files/paper/2019/file/c04c19c2c2474dbf5f7ac4372c5b9af1-Paper.pdf
  • [15] T. Dao, “Flashattention-2: Faster attention with better parallelism and work partitioning,” 2023.
  • [16] T. Dao, D. Y. Fu, S. Ermon, A. Rudra, and C. Ré, “Flashattention: Fast and memory-efficient exact attention with io-awareness,” 2022.
  • [17] T. Dettmers, M. Lewis, Y. Belkada, and L. Zettlemoyer, “Llm.int8(): 8-bit matrix multiplication for transformers at scale,” ArXiv, vol. abs/2208.07339, 2022. [Online]. Available: https://api.semanticscholar.org/CorpusID:251564521
  • [18] J. Devlin, M.-W. Chang, K. Lee, and K. Toutanova, “Bert: Pre-training of deep bidirectional transformers for language understanding,” in North American Chapter of the Association for Computational Linguistics, 2019. [Online]. Available: https://api.semanticscholar.org/CorpusID:52967399
  • [19] A. Dosovitskiy, L. Beyer, A. Kolesnikov, D. Weissenborn, X. Zhai, T. Unterthiner, M. Dehghani, M. Minderer, G. Heigold, S. Gelly, J. Uszkoreit, and N. Houlsby, “An image is worth 16x16 words: Transformers for image recognition at scale,” in 9th International Conference on Learning Representations, ICLR 2021, Virtual Event, Austria, May 3-7, 2021.   OpenReview.net, 2021. [Online]. Available: https://openreview.net/forum?id=YicbFdNTTy
  • [20] I. S. Duff, M. A. Heroux, and R. Pozo, “An overview of the sparse basic linear algebra subprograms: The new standard from the BLAS technical forum,” ACM Trans. Math. Softw., vol. 28, no. 2, pp. 239–267, 2002. [Online]. Available: https://doi.org/10.1145/567806.567810
  • [21] E. Frantar, S. Ashkboos, T. Hoefler, and D. Alistarh, “GPTQ: Accurate post-training compression for generative pretrained transformers,” arXiv preprint arXiv:2210.17323, 2022.
  • [22] C. Guo, J. Tang, W. Hu, J. Leng, C. Zhang, F. Yang, Y. Liu, M. Guo, and Y. Zhu, “Olive: Accelerating large language models via hardware-friendly outlier-victim pair quantization,” in Proceedings of the 50th Annual International Symposium on Computer Architecture, ser. ISCA ’23.   ACM, Jun. 2023. [Online]. Available: http://dx.doi.org/10.1145/3579371.3589038
  • [23] C. Hong, Q. Huang, G. Dinh, M. Subedar, and Y. S. Shao, “DOSA: Differentiable model-based one-loop search for DNN accelerators,” in 56th Annual IEEE/ACM International Symposium on Microarchitecture, ser. MICRO ’23.   IEEE, Oct. 2023, pp. 209–224.
  • [24] O. Hsu, M. Strange, R. Sharma, J. Won, K. Olukotun, J. S. Emer, M. A. Horowitz, and F. Kjølstad, “The sparse abstract machine,” in ASPLOS’23, 2023.
  • [25] W.-N. Hsu, B. Bolte, Y.-H. H. Tsai, K. Lakhotia, R. Salakhutdinov, and A. Mohamed, “Hubert: Self-supervised speech representation learning by masked prediction of hidden units,” IEEE/ACM Trans. Audio, Speech and Lang. Proc., vol. 29, p. 3451–3460, oct 2021. [Online]. Available: https://doi.org/10.1109/TASLP.2021.3122291
  • [26] N. P. Jouppi, C. Young, N. Patil, D. Patterson, G. Agrawal, R. Bajwa, S. Bates, S. Bhatia, N. Boden, A. Borchers, R. Boyle, P.-l. Cantin, C. Chao, C. Clark, J. Coriell, M. Daley, M. Dau, J. Dean, B. Gelb, T. V. Ghaemmaghami, R. Gottipati, W. Gulland, R. Hagmann, C. R. Ho, D. Hogberg, J. Hu, R. Hundt, D. Hurt, J. Ibarz, A. Jaffey, A. Jaworski, A. Kaplan, H. Khaitan, D. Killebrew, A. Koch, N. Kumar, S. Lacy, J. Laudon, J. Law, D. Le, C. Leary, Z. Liu, K. Lucke, A. Lundin, G. MacKean, A. Maggiore, M. Mahony, K. Miller, R. Nagarajan, R. Narayanaswami, R. Ni, K. Nix, T. Norrie, M. Omernick, N. Penukonda, A. Phelps, J. Ross, M. Ross, A. Salek, E. Samadiani, C. Severn, G. Sizikov, M. Snelham, J. Souter, D. Steinberg, A. Swing, M. Tan, G. Thorson, B. Tian, H. Toma, E. Tuttle, V. Vasudevan, R. Walter, W. Wang, E. Wilcox, and D. H. Yoon, “In-datacenter performance analysis of a tensor processing unit,” in ISCA ’17.
  • [27] S.-C. Kao, S. Subramanian, G. Agrawal, A. Yazdanbakhsh, and T. Krishna, “Flat: An optimized dataflow for mitigating attention bottlenecks,” ser. ASPLOS 2023.   New York, NY, USA: Association for Computing Machinery, 2023, p. 295–310. [Online]. Available: https://doi.org/10.1145/3575693.3575747
  • [28] H. Kwon, P. Chatarasi, M. Pellauer, A. Parashar, V. Sarkar, and T. Krishna, “Understanding reuse, performance, and hardware cost of DNN dataflow: A data-centric approach,” in Proceedings of the 52nd Annual IEEE/ACM International Symposium on Microarchitecture, MICRO.   ACM, 2019, pp. 754–768.
  • [29] G. Lample and A. Conneau, “Cross-lingual language model pretraining,” ArXiv, vol. abs/1901.07291, 2019. [Online]. Available: https://api.semanticscholar.org/CorpusID:58981712
  • [30] C. L. Lawson, R. J. Hanson, D. R. Kincaid, and F. T. Krogh, “Basic linear algebra subprograms for fortran usage,” ACM Trans. Math. Softw., vol. 5, no. 3, pp. 308–323, 1979. [Online]. Available: https://doi.org/10.1145/355841.355847
  • [31] H. Le, L. Vial, J. Frej, V. Segonne, M. Coavoux, B. Lecouteux, A. Allauzen, B. Crabbé, L. Besacier, and D. Schwab, “Flaubert: Unsupervised language model pre-training for french,” CoRR, vol. abs/1912.05372, 2019. [Online]. Available: http://arxiv.org/abs/1912.05372
  • [32] J. Lin, J. Tang, H. Tang, S. Yang, W.-M. Chen, W.-C. Wang, G. Xiao, X. Dang, C. Gan, and S. Han, “Awq: Activation-aware weight quantization for llm compression and acceleration,” in MLSys, 2024.
  • [33] Z. Liu, Y. Lin, Y. Cao, H. Hu, Y. Wei, Z. Zhang, S. Lin, and B. Guo, “Swin transformer: Hierarchical vision transformer using shifted windows,” in 2021 IEEE/CVF International Conference on Computer Vision (ICCV), 2021, pp. 9992–10 002.
  • [34] L. Lu, Y. Jin, H. Bi, Z. Luo, P. Li, T. Wang, and Y. Liang, “Sanger: A co-design framework for enabling sparse attention using reconfigurable architecture,” MICRO-54: 54th Annual IEEE/ACM International Symposium on Microarchitecture, 2021. [Online]. Available: https://api.semanticscholar.org/CorpusID:239012114
  • [35] N. Nayak, T. O. Odemuyiwa, S. Ugare, C. Fletcher, M. Pellauer, and J. Emer, “Teaal: A declarative framework for modeling sparse tensor accelerators,” in Proceedings of the 56th Annual IEEE/ACM International Symposium on Microarchitecture, ser. MICRO ’23.   New York, NY, USA: Association for Computing Machinery, 2023, p. 1255–1270. [Online]. Available: https://doi.org/10.1145/3613424.3623791
  • [36] P. Nilsson, A. U. R. Shaik, R. Gangarajaiah, and E. Hertz, “Hardware implementation of the exponential function using taylor series,” in 2014 NORCHIP.   IEEE, oct 2014, pp. 1–4. [Online]. Available: https://doi.org/10.1109/NORCHIP.2014.7004740
  • [37] T. O. Odemuyiwa, H. Asghari-Moghaddam, M. Pellauer, K. Hegde, P.-A. Tsai, N. Crago, A. Jaleel, J. D. Owens, E. Solomonik, J. Emer, and C. Fletcher, “Accelerating sparse data orchestration via dynamic reflexive tiling,” in Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, ser. ASPLOS ’23, vol. 3, Mar. 2023, pp. 18–32.
  • [38] T. O. Odemuyiwa, J. S. Emer, and J. D. Owens, “The EDGE language: Extended general einsums for graph algorithms,” CoRR, vol. abs/2404.11591, 2024. [Online]. Available: https://doi.org/10.48550/arXiv.2404.11591
  • [39] A. Parashar, M. Pellauer, M. Adler, B. Ahsan, N. Crago, D. Lustig, V. Pavlov, A. Zhai, M. Gambhir, A. Jaleel, R. Allmon, R. Rayess, S. Maresh, and J. Emer, “Efficient spatial processing element control via triggered instructions,” IEEE Micro, vol. 34, no. 3, pp. 120–137, 2014.
  • [40] A. Parashar, P. Raina, Y. S. Shao, Y.-H. Chen, V. A. Ying, A. Mukkara, R. Venkatesan, B. Khailany, S. W. Keckler, and J. Emer, “Timeloop: A systematic approach to dnn accelerator evaluation,” in 2019 IEEE International Symposium on Performance Analysis of Systems and Software (ISPASS), 2019, pp. 304–315.
  • [41] A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T. Killeen, Z. Lin, N. Gimelshein, L. Antiga, A. Desmaison, A. Köpf, E. Yang, Z. DeVito, M. Raison, A. Tejani, S. Chilamkurthy, B. Steiner, L. Fang, J. Bai, and S. Chintala, PyTorch: an imperative style, high-performance deep learning library.   Red Hook, NY, USA: Curran Associates Inc., 2019.
  • [42] M. Pellauer, J. Clemons, V. Balaji, N. C. Crago, A. Jaleel, D. Lee, M. O’Connor, A. Parashar, S. Treichler, P. Tsai, S. W. Keckler, and J. S. Emer, “Symphony: Orchestrating sparse and dense tensors with hierarchical heterogeneous processing,” ACM Transactions on Computing Systems, vol. 41, pp. 4:1–4:30, 2023. [Online]. Available: https://doi.org/10.1145/3630007
  • [43] R. Prabhakar, Y. Zhang, D. Koeplinger, M. Feldman, T. Zhao, S. Hadjis, A. Pedram, C. Kozyrakis, and K. Olukotun, “Plasticine: A reconfigurable architecture for parallel paterns,” SIGARCH Comput. Archit. News, vol. 45, no. 2, pp. 389–402, Jun. 2017. [Online]. Available: http://doi.acm.org/10.1145/3140659.3080256
  • [44] Z. Qu, L. Liu, F. Tu, Z. Chen, Y. Ding, and Y. Xie, “Dota: detect and omit weak attentions for scalable transformer acceleration,” in Proceedings of the 27th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, ser. ASPLOS ’22.   New York, NY, USA: Association for Computing Machinery, 2022, p. 14–26. [Online]. Available: https://doi.org/10.1145/3503222.3507738
  • [45] A. Radford and K. Narasimhan, “Improving language understanding by generative pre-training,” 2018. [Online]. Available: https://api.semanticscholar.org/CorpusID:49313245
  • [46] C. Raffel, N. Shazeer, A. Roberts, K. Lee, S. Narang, M. Matena, Y. Zhou, W. Li, and P. J. Liu, “Exploring the limits of transfer learning with a unified text-to-text transformer,” vol. 21, no. 1, jan 2020.
  • [47] V. Sze, Y. Chen, T. Yang, and J. S. Emer, Efficient Processing of Deep Neural Networks, ser. Synthesis Lectures on Computer Architecture.   Springer, 2020.
  • [48] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin, “Attention is all you need,” in Advances in Neural Information Processing Systems, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, Eds., vol. 30.   Curran Associates, Inc., 2017.
  • [49] H. Wang, Z. Zhang, and S. Han, “Spatten: Efficient sparse attention architecture with cascade token and head pruning,” in 2021 IEEE International Symposium on High-Performance Computer Architecture (HPCA).   IEEE, Feb. 2021. [Online]. Available: http://dx.doi.org/10.1109/HPCA51647.2021.00018
  • [50] J. Won, C. Hong, C. Mendis, J. Emer, and S. Amarasinghe, “Unified convolution framework: A compiler-based approach to support sparse convolutions,” in MLSys’23, 2023.
  • [51] Y. N. Wu, J. S. Emer, and V. Sze, “Accelergy: An architecture-level energy estimation methodology for accelerator designs,” in ICCAD’19, 2019.
  • [52] Y. N. Wu, P. Tsai, S. Muralidharan, A. Parashar, V. Sze, and J. S. Emer, “HighLight: Efficient and flexible DNN acceleration with hierarchical structured sparsity,” in IEEE/ACM International Symposium on Microarchitecture, ser. MICRO.   ACM, Oct. 2023, pp. 1106–1120. [Online]. Available: https://doi.org/10.1145/3613424.3623786
  • [53] Y. N. Wu, P.-A. Tsai, A. Parashar, V. Sze, and J. S. Emer, “Sparseloop: An analytical approach to sparse tensor accelerator modeling,” in 55th IEEE/ACM International Symposium on Microarchitecture (MICRO).   IEEE, Oct. 2022, pp. 1377–1395. [Online]. Available: https://doi.org/10.1109/MICRO56248.2022.00096
  • [54] J. Xia, W. Fu, M. Liu, and M. Wang, “Low-latency bit-accurate architecture for configurable precision floating-point division,” Applied Sciences, vol. 11, no. 11, 2021. [Online]. Available: https://www.mdpi.com/2076-3417/11/11/4988
  • [55] A. H. Zadeh, I. Edo, O. M. Awad, and A. Moshovos, “Gobo: Quantizing attention-based nlp models for low latency and energy efficient inference,” in 2020 53rd Annual IEEE/ACM International Symposium on Microarchitecture (MICRO).   IEEE, Oct. 2020. [Online]. Available: http://dx.doi.org/10.1109/MICRO50266.2020.00071
  • [56] G. Zhang, N. Attaluri, J. S. Emer, and D. Sanchez, “Gamma: Leveraging gustavson’s algorithm to accelerate sparse matrix multiplication,” in ASPLOS’21.
  • [57] S. Zheng, S. Chen, S. Gao, L. Jia, G. Sun, R. Wang, and Y. Liang, “Tileflow: A framework for modeling fusion dataflow via tree-based analysis,” in Proceedings of the 56th Annual IEEE/ACM International Symposium on Microarchitecture, ser. MICRO ’23.   New York, NY, USA: Association for Computing Machinery, 2023, p. 1271–1288. [Online]. Available: https://doi.org/10.1145/3613424.3623792