Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
License: CC BY 4.0
arXiv:2403.14047v2 [cs.DC] 12 Apr 2024

Accelerating ViT Inference on FPGA through Static and Dynamic Pruning

Dhruv Parikh1, Shouyi Li1, Bingyi Zhang1, Rajgopal Kannan2, Carl Busart2, Viktor Prasanna1 1University of Southern California 2DEVCOM Army Research Office
1{dhruvash, liderric, bingyizh, prasanna}@usc.edu 2{rajgopal.kannan.civ, carl.e.busart.civ}@army.mil
Abstract

Vision Transformers (ViTs) have achieved state-of-the-art accuracy on various computer vision tasks. However, their high computational complexity prevents them from being applied to many real-world applications. Weight and token pruning are two well-known methods for reducing computational complexity. Weight pruning reduces the model size and associated computational demands, while token pruning further reduces the computation based on the input. Combining these two techniques should significantly reduce computation complexity and model size; however, naively integrating them results in irregular computation patterns, leading to significant accuracy drops and difficulties in hardware acceleration.

To address the above challenges, we propose a comprehensive algorithm-hardware codesign for accelerating ViT on FPGA through simultaneous pruning – combining static weight pruning and dynamic token pruning. For algorithm design, we systematically combine a hardware-aware structured block-pruning method for pruning model parameters and a dynamic token pruning method for removing unimportant token vectors. Moreover, we design a novel training algorithm to recover the model’s accuracy. For hardware design, we develop a novel hardware accelerator for executing the pruned model. The proposed hardware design employs multi-level parallelism with a load-balancing strategy to efficiently deal with the irregular computation pattern led by the two pruning approaches. Moreover, we develop an efficient hardware mechanism for executing the on-the-fly token pruning. We apply our codesign approach to the widely used DeiT-Small model. We implement the proposed accelerator on a state-of-the-art FPGA board. The evaluation results show that the proposed algorithm can reduce computation complexity by up to 3.4×3.4\times3.4 × with 3%absentpercent3\approx 3\%≈ 3 % accuracy drop and a model compression ratio of up to 1.6×1.6\times1.6 ×. Compared with state-of-the-art implementation on CPU, GPU, and FPGA, our codesign on FPGA achieves an average latency reduction of 12.8×12.8\times12.8 ×, 3.2×3.2\times3.2 ×, and 0.72.1×0.7-2.1\times0.7 - 2.1 ×, respectively.

Index Terms:
Vision transformer, model pruning, hardware acceleration

I Introduction

Vision Transformers (ViTs) [1] have demonstrated superior performance in comparison to Convolutional Neural Networks (CNNs) in various vision tasks [2, 3, 4, 5, 6, 7, 8]. The global self-attention in ViTs leads to a reduced local and image-specific inductive bias [1]; this results in ViTs requiring larger datasets and larger model sizes [9] to perform better than CNN. The Multi-head Self-Attention (MSA) of ViTs allows them to generalize better than CNNs on larger datasets [10]. However, their computational cost is usually significantly higher than CNNs due to the MSA mechanism, which scales quadratically with the number of input tokens [11, 12]. Their intensive computational requirements emphasize the need for efficient hardware acceleration.

In addressing the computational challenge, pruning has been proven to be effective in reducing the computational cost of CNNs [13, 14, 15, 16]. However, explorations in self-attention-based pruning methods still need to be discovered [17, 18, 19]. Many existing works on efficient ViTs explored block weight pruning and token pruning as two distinct strategies. Weight pruning, introduced in [20, 21, 22, 23, 17, 24, 18, 25, 25, 26], reduces the model size by pruning input parameters statically and selectively, thus feeding the neural network with sparse inputs to reduce computation. Token pruning removes tokens to reduce the computational complexity. The static approaches in [27, 28, 29, 30] drop tokens with a fixed ratio, often ignoring the redundancies between tokens; dynamic token pruning studies in [31, 32, 33] do not fully explore the token redundancies from different attention heads and simply discard non-informative tokens. [34, 35, 36, 37] dynamically reduce the number of tokens in ViT encoders during inference based on the inherent image characteristics. Moreover, only a few of these studies support efficient hardware implementations by the respective pruning algorithm. Both weight pruning and token pruning methods reduce the computational complexity independently, but the interaction between the two remains unexplored. A combined approach could bring further computational benefits. However, such integration poses two main challenges: (1) accuracy drop (algorithm level) and (2) increased computational pattern irregularities (hardware level).

Many ViT acceleration works primarily focus on the CPU and GPU platforms [28, 36, 32, 29, 30]. However, the integration of block weight pruning and token pruning in ViTs effectively reduces the model size, thus making it possible to accommodate the compressed model onto FPGA. Comprehensively, we use FPGA to accelerate our pruned ViT models for these reasons: (1) FPGAs, with customized data path and on-chip memory organization, stand out as better choices than CPU/GPU to maximize the computation efficiency. (2) CPUs and GPUs cannot effectively handle the token shuffling process of our dynamic token pruning. We design a specific FPGA kernel to handle the token shuffling in the middle of model inferences. (3) CPUs and GPUs need complicated processes to address work-load imbalance, whereas on FPGA, we can design customized hardware modules for balancing work-load.

In this paper, we propose an algorithm-hardware codesign for accelerating ViT inference. Different from existing ViT acceleration works, we utilize the combined power of static weight pruning and dynamic token pruning. We propose a simultaneous pruning algorithm to recover the model accuracy caused by two pruning approaches. Combining the two pruning approaches leads to more severe computational irregularity. Therefore, we develop a customized data path and memory organization on FPGA to execute the pruned model efficiently. While existing ViT accelerators on FPGA [35, 37] can handle the irregular patterns after pruning, they target either weight pruning or token pruning, but not both. Therefore, none of the existing FPGA ViT accelerators can support our integrated simultaneous pruning approach. We summarize our main contributions below:

  • We propose an algorithm-hardware codesign for efficient ViT inference based on FPGA. The design combines parameter (static) and token (dynamic) pruning to reduce both the ViT model size and computational complexity.

  • For the algorithm design, we systematically combine static block weight pruning and dynamic token pruning to reduce the computation complexity of ViTs. We propose a novel training algorithm to recover the accuracy drop led by the two pruning algorithms.

  • For the hardware design, we develop a novel hardware accelerator with multi-level parallelism and a load balancing strategy. This can efficiently deal with (1) load imbalance caused by the block pruning and (2) a changing number of tokens caused by token pruning. We also develop an efficient hardware mechanism for executing the on-the-fly token pruning algorithm.

  • We evaluate our codesign on DeiT models and deploy the proposed accelerator on a state-of-the-art FPGA board - Xilinx Alveo U250. The evaluation results show that the proposed algorithm can reduce computation complexity by up to 3.4×3.4\times3.4 × with 3%absentpercent3\approx 3\%≈ 3 % accuracy drop and a model compression ratio of up to 1.6×1.6\times1.6 ×. Compared with state-of-the-art implementation on CPU, GPU, and FPGA, our codesign on FPGA achieves average latency reduction of 12.8×12.8\times12.8 ×, 3.2×3.2\times3.2 ×, 0.72.1×0.7-2.1\times0.7 - 2.1 × respectively.

II Background and Related Work

II-A Vision Transformer

ViT [1] has a stack of transformer encoders. Each encoder has a multi-head self-attention (MSA) and a multi-layer perceptron (MLP). The input image 𝐱H×W×C𝐱superscript𝐻𝑊𝐶\mathbf{x}\in\mathbb{R}^{H\times W\times C}bold_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W × italic_C end_POSTSUPERSCRIPT is first partitioned into N𝑁Nitalic_N patches 𝐱pN×P2Csubscript𝐱𝑝superscript𝑁superscript𝑃2𝐶\mathbf{x}_{p}\in\mathbb{R}^{N\times P^{2}C}bold_x start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_P start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT. Each patch is flattened into a vector of length P2Csuperscript𝑃2𝐶P^{2}Citalic_P start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C. Next, a learnable linear mapping method maps each patch to a token vector of length D𝐷Ditalic_D. A special parameterized token 𝐱CLSsubscript𝐱CLS\mathbf{x}_{\text{CLS}}bold_x start_POSTSUBSCRIPT CLS end_POSTSUBSCRIPT is appended as a token vector. Then, a positional embedding 𝐄POS(N+1)×Dsubscript𝐄POSsuperscript𝑁1𝐷\mathbf{E}_{\text{POS}}\in\mathbb{R}^{(N+1)\times D}bold_E start_POSTSUBSCRIPT POS end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT ( italic_N + 1 ) × italic_D end_POSTSUPERSCRIPT is added to input token matrix to produce 𝐙𝟎(N+1)×Dsubscript𝐙0superscript𝑁1𝐷\mathbf{Z_{0}}\in\mathbb{R}^{(N+1)\times D}bold_Z start_POSTSUBSCRIPT bold_0 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT ( italic_N + 1 ) × italic_D end_POSTSUPERSCRIPT which is the input to the transformer encoder stack. For simplicity, we denote the number of input tokens to the encoder stack as N𝑁Nitalic_N instead of N+1𝑁1N+1italic_N + 1 for the rest of the paper.

MSA. The input to encoder 𝐙l1subscript𝐙𝑙1\mathbf{Z}_{l-1}bold_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT, is layer normalized (LN) [38] and passed through a multi-headed self-attention (MSA) layer [12]:

𝐙l=MSA(LN(𝐙l1))+𝐙l1\mathbf{Z}_{l}{{}^{\prime}}=\mbox{MSA}(\mbox{LN}(\mathbf{Z}_{l-1}))+\mathbf{Z}% _{l-1}bold_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT = MSA ( LN ( bold_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) ) + bold_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT (1)

MSA()MSA\mbox{MSA}(\cdot)MSA ( ⋅ ) is expressed as:

[𝐐,𝐊,𝐕]=𝐙𝐔qkv𝐐𝐊𝐕subscript𝐙𝐔𝑞𝑘𝑣\begin{split}[\mathbf{Q},\mathbf{K},\mathbf{V}]=\mathbf{Z}\mathbf{U}_{qkv}\end% {split}start_ROW start_CELL [ bold_Q , bold_K , bold_V ] = bold_ZU start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT end_CELL end_ROW (2)

where 𝐔qkv=[𝐖q,𝐖k,𝐖v]D×3Dsubscript𝐔𝑞𝑘𝑣subscript𝐖𝑞subscript𝐖𝑘subscript𝐖𝑣superscript𝐷3superscript𝐷\mathbf{U}_{qkv}=[\mathbf{W}_{q},\mathbf{W}_{k},\mathbf{W}_{v}]\in\mathbb{R}^{% D\times 3D^{\prime}}bold_U start_POSTSUBSCRIPT italic_q italic_k italic_v end_POSTSUBSCRIPT = [ bold_W start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × 3 italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT, 𝐙N×D𝐙superscript𝑁𝐷\mathbf{Z}\in\mathbb{R}^{N\times D}bold_Z ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT, 𝐐,𝐊,𝐕N×D𝐐𝐊𝐕superscript𝑁superscript𝐷\mathbf{Q},\mathbf{K},\mathbf{V}\in\mathbb{R}^{N\times D^{\prime}}bold_Q , bold_K , bold_V ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT, corresponding to query, key and value matrices, respectively. D𝐷Ditalic_D is the length of input token and Dsuperscript𝐷D^{\prime}italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is the hidden dimension. Then, the attention score matrix 𝐀𝐀\mathbf{A}bold_A is calculated through:

𝐀=softmax(𝐐𝐊TD)where𝐀N×Nformulae-sequence𝐀softmaxsuperscript𝐐𝐊𝑇superscript𝐷where𝐀superscript𝑁𝑁\mathbf{A}=\mbox{softmax}(\frac{\mathbf{Q}\mathbf{K}^{T}}{\sqrt{D^{\prime}}})% \quad\mbox{where}\quad\mathbf{A}\in\mathbb{R}^{N\times N}bold_A = softmax ( divide start_ARG bold_QK start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG end_ARG ) where bold_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_N end_POSTSUPERSCRIPT (3)
SA(𝐙)=𝐀𝐕whereSA(𝐙)N×Dformulae-sequenceSA𝐙𝐀𝐕whereSA𝐙superscript𝑁superscript𝐷\mbox{SA}(\mathbf{Z})=\mathbf{A}\mathbf{V}\quad\mbox{where}\quad\mbox{SA}(% \mathbf{Z})\in\mathbb{R}^{N\times D^{\prime}}SA ( bold_Z ) = bold_AV where SA ( bold_Z ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT (4)

SA()SA\mbox{SA}(\cdot)SA ( ⋅ ) refers to self-attention with a single head. MSA(.)\mbox{MSA}(.)MSA ( . ) extends this notion of self-attention to several parallel heads, each with its own parameters:

MSA(𝐙)=[SA1(𝐙)SA2(𝐙)SAH(𝐙)]𝐖projMSA𝐙subscriptSA1𝐙subscriptSA2𝐙subscriptSA𝐻𝐙subscript𝐖proj\mbox{MSA}(\mathbf{Z})=[\mbox{SA}_{1}(\mathbf{Z})\quad\mbox{SA}_{2}(\mathbf{Z}% )\quad\cdots\quad\mbox{SA}_{H}(\mathbf{Z})]\mathbf{W}_{\text{proj}}MSA ( bold_Z ) = [ SA start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_Z ) SA start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( bold_Z ) ⋯ SA start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT ( bold_Z ) ] bold_W start_POSTSUBSCRIPT proj end_POSTSUBSCRIPT (5)

where H𝐻Hitalic_H denotes the total number of heads. 𝐖projHD×Dsubscript𝐖projsuperscript𝐻superscript𝐷𝐷\mathbf{W}_{\text{proj}}\in\mathbb{R}^{HD^{\prime}\times D}bold_W start_POSTSUBSCRIPT proj end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT × italic_D end_POSTSUPERSCRIPT projects the concatenated self-attention outputs of the individual heads back to the embedding dimension D𝐷Ditalic_D.

MLP. The output of MSA, 𝐙l\mathbf{Z}_{l}{{}^{\prime}}bold_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT is layer normalized and passed through a multi-layer perceptron (MLP):

𝐙l=MLP(LN(𝐙l))+𝐙l\mathbf{Z}_{l}=\mbox{MLP}(\mbox{LN}(\mathbf{Z}_{l}{{}^{\prime}}))+\mathbf{Z}_{% l}{{}^{\prime}}bold_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = MLP ( LN ( bold_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT ) ) + bold_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT (6)

II-B Related Work

Weight pruning: Structured and hardware-friendly model parameter pruning, used in traditional CNNs [20, 21, 22], becomes popular for ViT. [17] introduces the notion of movement pruning, which prunes parameters by generating a pruning mask based on the learned scores. [23] proposes to prune parameters across all the encoders. Magnitude-based approaches, on the other hand, discard parameters with large magnitudes [24]. [18] partitions a parameter matrix into blocks and prunes the rows and columns in each block by using the l2subscript𝑙2l_{2}italic_l start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT norms. [25] prunes the entire attention heads within the MSA and neurons in each feed-forward linear layer (width pruning). [25] also removes entire encoders after a certain depth (depth pruning). [26] proposes a collaborative approach to optimizing ViT pruning that prunes heads and neurons; the latter neurons are pruned such that they reduce the length of each token. This method considers the collective pruning impact through an expensive approximation of the Hessian of the loss.

Token pruning: Token pruning approaches attempt to identify redundant tokens and drop them to reduce the computational footprint associated with the number of tokens. Both [19] and [39] have been notable for accelerating transformer models by leveraging the inherent sparsity in attention mechanisms. However, they do not use weight and token pruning simultaneously. [28] proposes a static approach to token dropping that ranks the importance of tokens by the attention score of the class token with respect to each token aggregated across heads. In theory, such static approaches do not need additional training since the token-dropping module is not parameterized. In contrast to static token pruning, dynamic token pruning as employed in [36, 35, 34] adds additional model parameters that facilitate the selection of relevant/attentive tokens. [36] and [35] utilize a token selector network inserted at various depths along the original transformer network; such token selector networks are neural networks that output a decision (binary) mask to inform the retention or removal of a token. [34], on the other hand, associate a learnable score to each token and prune tokens with scores lower than a threshold.

III Overview

III-A Problem Definition

Our objective is to accelerate ViT on FPGA through an algorithm-hardware codesign that 1) utilizes a novel combination of model weight pruning and token pruning to reduce computation complexity (algorithm design), and 2) an efficient accelerator that explicitly accounts for the distinct and irregular access patterns of the two pruning techniques for executing the combined pruned model (hardware design).

For the algorithm design, the input is a ViT model denoted as (,hstructure,Θ)subscriptstructureΘ\mathcal{M}(\cdot,h_{\text{structure}},\Theta)caligraphic_M ( ⋅ , italic_h start_POSTSUBSCRIPT structure end_POSTSUBSCRIPT , roman_Θ ) where hstructuresubscriptstructureh_{\text{structure}}italic_h start_POSTSUBSCRIPT structure end_POSTSUBSCRIPT are the model hyper-parameters, including the number of encoders, number of heads, dimensions of linear layers, etc. ΘΘ\Thetaroman_Θ is the trainable parameters containing the weights and biases of the MSA and MLP. Our algorithm design prunes the input model \mathcal{M}caligraphic_M through (1) offline weight pruning that reduces the redundant parameters in ΘΘ\Thetaroman_Θ, and (2) runtime token pruning that trims the number of tokens (in hstructuresubscriptstructureh_{\text{structure}}italic_h start_POSTSUBSCRIPT structure end_POSTSUBSCRIPT) in the intermediate layers according to the importance of the token. After pruning, we obtain the pruned model denoted as (,hstructure,Θ)superscriptsuperscriptsubscriptstructuresuperscriptΘ\mathcal{M}^{{}^{\prime}}(\cdot,h_{\text{structure}}^{{}^{\prime}},\Theta^{{}^% {\prime}})caligraphic_M start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT ( ⋅ , italic_h start_POSTSUBSCRIPT structure end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT , roman_Θ start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT ), where ΘsuperscriptΘ\Theta^{{}^{\prime}}roman_Θ start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT denotes the model parameters after weight pruning and hstructuresuperscriptsubscriptstructureh_{\text{structure}}^{{}^{\prime}}italic_h start_POSTSUBSCRIPT structure end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT denotes the hyperparameters after token pruning. Our algorithm design aims to reduce the number of parameters, reduce the computation complexity, and maintain accuracy.

For the hardware design, the hardware accelerator executes the pruned model (,hstructure,Θ)superscriptsuperscriptsubscriptstructuresuperscriptΘ\mathcal{M}^{{}^{\prime}}(\cdot,h_{\text{structure}}^{{}^{\prime}},\Theta^{{}^% {\prime}})caligraphic_M start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT ( ⋅ , italic_h start_POSTSUBSCRIPT structure end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT , roman_Θ start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT ). For executing the model, the input is an image sample 𝒙𝒙\bm{x}bold_italic_x, and the accelerator executes (𝒙,hstructure,Θ)superscript𝒙superscriptsubscriptstructuresuperscriptΘ\mathcal{M}^{{}^{\prime}}(\bm{x},h_{\text{structure}}^{{}^{\prime}},\Theta^{{}% ^{\prime}})caligraphic_M start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT ( bold_italic_x , italic_h start_POSTSUBSCRIPT structure end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT , roman_Θ start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT ) to obtain the result. The latency is the duration from the time when the accelerator receives the input 𝒙𝒙\bm{x}bold_italic_x to the time when the accelerator obtains the result.

III-B Computational Complexity

The computational complexity for each operation within the MSA and MLP without pruning are summarized in table I. B𝐵Bitalic_B denotes the batch size.

TABLE I: Computational complexity within a single ViT encoder. ()()( ) indicates the number of instances inside a single encoder.

Operation Computational Complexity LayerNorm (×2)(\times 2)( × 2 ) BND𝐵𝑁𝐷BNDitalic_B italic_N italic_D Residual Add (×2)(\times 2)( × 2 ) BND𝐵𝑁𝐷BNDitalic_B italic_N italic_D MSA (×1)(\times 1)( × 1 ) 4BHNDD+2BHN2D4𝐵𝐻𝑁𝐷superscript𝐷2𝐵𝐻superscript𝑁2superscript𝐷4BHNDD^{\prime}+2BHN^{2}D^{\prime}4 italic_B italic_H italic_N italic_D italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 2 italic_B italic_H italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT MLP (×1)(\times 1)( × 1 ) 2BNDDmlp2𝐵𝑁𝐷subscript𝐷mlp2BNDD_{\text{mlp}}2 italic_B italic_N italic_D italic_D start_POSTSUBSCRIPT mlp end_POSTSUBSCRIPT Total Complexity 4BND+4BHNDD4𝐵𝑁𝐷4𝐵𝐻𝑁𝐷superscript𝐷4BND+4BHNDD^{\prime}4 italic_B italic_N italic_D + 4 italic_B italic_H italic_N italic_D italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT +2BHN2D+2BNDDmlp2𝐵𝐻superscript𝑁2superscript𝐷2𝐵𝑁𝐷subscript𝐷mlp+2BHN^{2}D^{\prime}+2BNDD_{\text{mlp}}+ 2 italic_B italic_H italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + 2 italic_B italic_N italic_D italic_D start_POSTSUBSCRIPT mlp end_POSTSUBSCRIPT

III-C Overview of Algorithm-hardware Codesign

The overview of the proposed codesign is shown in Figure 1, which consists of algorithm design and hardware design.

Algorithm design: For algorithm design, we utilize the combination of block-wise static weight pruning (Section IV-A) and input token pruning (Section IV-B). The input is the input ViT model and the pruning hyper-parameters (including the weight pruning ratio for each layer and token pruning ratio for each layer). Users manually specify the pruning hyper-parameters. Then, the proposed simultaneous pruning (training) algorithm (Section IV-C) prunes the input model according to the user-specified pruning hyper-parameters. Then, the pruned model is generated. The model optimizations organize the data blocks in the weight matrices into the required data layout and format (Section V-A) for efficient hardware execution on the proposed accelerator.

Hardware accelerator design: At runtime, when the host process receives an input image, it sends the input image to the accelerator for inference. We employ a multi-level parallelism strategy for the proposed hardware architecture to efficiently handle the irregular computation patterns. We design a Token Dropping Hardware Module for efficient on-the-fly token dropping. See Section (Section V-C) for details.

Discussion on the challenges of hardware acceleration: The combination of two pruning approaches leads to significant challenges for hardware accelerations: (1) Through weight pruning, the weight matrix of MSA has uneven number of data blocks among different columns and different layers can have different number of heads. Moreover, the token pruning leads to varying number of tokens for different layers. These potentially leads to runtime resource underutilization. We address this challenge through multi-level parallelism (Section V-C) with load balance strategy (Section V-D). (2) Due to the block-wise weight pruning, both token matrix and weight matrices are partitioned into data blocks. However, the token dropping algorithms drops unimportant tokens in the intermediate layers. Therefore, the token matrix needs to be reordered and reconstructed on the fly based on their importance score. This involves sorting and data shuffling which cannot be efficiently handled by CPU or GPU. We develop efficiently hardware mechanism in Token Dropping Hardware Module to address the above issue (Section V-C3).

Refer to caption
Figure 1: Overview of the proposed algorithm-hardware codesign

IV Pruning Algorithm

Existing works only utilize either weight pruning or token pruning. In contrast, we systematically combine the two pruning approaches with a novel training algorithm for recovering accuracy. To this end, we first introduce static weight pruning (Section IV-A) and dynamic token pruning (Section IV-B) separately, and then introduce our Simultaneous Pruning algorithm to prune the input model.

IV-A Static Weight Pruning

The weights to be pruned are: weight matrices for 𝐖qsubscript𝐖𝑞\mathbf{W}_{q}bold_W start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT, 𝐖ksubscript𝐖𝑘\mathbf{W}_{k}bold_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, 𝐖vsubscript𝐖𝑣\mathbf{W}_{v}bold_W start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT and 𝐖projsubscript𝐖proj\mathbf{W}_{\text{proj}}bold_W start_POSTSUBSCRIPT proj end_POSTSUBSCRIPT within MSA and the intermediate and output linear layers within MLP. Pruning is performed as follows:

IV-A1 Pruning of MSA

We use 𝐖qsubscript𝐖𝑞\mathbf{W}_{q}bold_W start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT, 𝐖ksubscript𝐖𝑘\mathbf{W}_{k}bold_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, 𝐖vsubscript𝐖𝑣\mathbf{W}_{v}bold_W start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT D×HDabsentsuperscript𝐷𝐻superscript𝐷\in\mathbb{R}^{D\times HD^{\prime}}∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_H italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT to denote the concatenation of weight matrices of all the heads. For example, 𝐖pD×HD where p={q,k,v}.subscript𝐖𝑝superscript𝐷𝐻superscript𝐷 where 𝑝𝑞𝑘𝑣\mathbf{W}_{p}\in\mathbb{R}^{D\times HD^{\prime}}\text{ }\mbox{where}\text{ }p% =\{q,k,v\}.bold_W start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_H italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT roman_where italic_p = { italic_q , italic_k , italic_v } . The projection operation (Equation 5) projects the concatenated SA outputs of embedding dimension HD𝐻superscript𝐷HD^{\prime}italic_H italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT back to dimension D𝐷Ditalic_D via 𝐖projHD×Dsubscript𝐖projsuperscript𝐻superscript𝐷𝐷\mathbf{W}_{\text{proj}}\in\mathbb{R}^{HD^{\prime}\times D}bold_W start_POSTSUBSCRIPT proj end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT × italic_D end_POSTSUPERSCRIPT. To prune a weight matrix 𝐖M1×M2𝐖superscriptsubscript𝑀1subscript𝑀2\mathbf{W}\in\mathbb{R}^{M_{1}\times M_{2}}bold_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT × italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, we define a parameterized score matrix 𝐒m×n𝐒superscript𝑚𝑛\mathbf{S}\in\mathbb{R}^{m\times n}bold_S ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_n end_POSTSUPERSCRIPT such that, (m,n)=(M1b,M2b)𝑚𝑛subscript𝑀1𝑏subscript𝑀2𝑏(m,n)=(\left\lceil\frac{M_{1}}{b}\right\rceil,\left\lceil\frac{M_{2}}{b}\right\rceil)( italic_m , italic_n ) = ( ⌈ divide start_ARG italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_b end_ARG ⌉ , ⌈ divide start_ARG italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG italic_b end_ARG ⌉ ) where (b,b)𝑏𝑏(b,b)( italic_b , italic_b ) is the block size. 𝐒ijsubscript𝐒𝑖𝑗\mathbf{S}_{ij}bold_S start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT denotes the importance score of a parameter block of size (b,b)𝑏𝑏(b,b)( italic_b , italic_b ) in the weight matrix 𝐖𝐖\mathbf{W}bold_W denoted by the slice 𝐖(ib:α,jb:β)\mathbf{W}(ib:\alpha,jb:\beta)bold_W ( italic_i italic_b : italic_α , italic_j italic_b : italic_β ) where (α,β)=(min(ib+b,M1),min(jb+b,M2))𝛼𝛽𝑖𝑏𝑏subscript𝑀1𝑗𝑏𝑏subscript𝑀2(\alpha,\beta)=(\min(ib+b,M_{1}),\min(jb+b,M_{2}))( italic_α , italic_β ) = ( roman_min ( italic_i italic_b + italic_b , italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , roman_min ( italic_j italic_b + italic_b , italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ). 𝐒𝐒\mathbf{S}bold_S is used to construct a mask 𝐌M1×M2𝐌superscriptsubscript𝑀1subscript𝑀2\mathbf{M}\in\mathbb{R}^{M_{1}\times M_{2}}bold_M ∈ blackboard_R start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT × italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT via the top-k𝑘kitalic_k selection:

𝐌ijblock(sij)={1if sijtop-k of 𝐒0otherwisesuperscriptsubscript𝐌𝑖𝑗blocksubscript𝑠𝑖𝑗cases1if subscript𝑠𝑖𝑗top-k of 𝐒0otherwise\mathbf{M}_{ij}^{\text{block}}(s_{ij})=\begin{cases}1&\text{if }s_{ij}\in\text% {top-$k$ of }\mathbf{S}\\ 0&\text{otherwise}\end{cases}bold_M start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT block end_POSTSUPERSCRIPT ( italic_s start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) = { start_ROW start_CELL 1 end_CELL start_CELL if italic_s start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ∈ top- italic_k of bold_S end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL otherwise end_CELL end_ROW (7)

wher 𝐌ijblock(.)\mathbf{M}_{ij}^{\text{block}}(.)bold_M start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT block end_POSTSUPERSCRIPT ( . ) is a block of size (b,b)𝑏𝑏(b,b)( italic_b , italic_b ) in 𝐌𝐌\mathbf{M}bold_M corresponding to sijsubscript𝑠𝑖𝑗s_{ij}italic_s start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT. The masked weight is generated as, 𝐖(𝐌)=𝐖𝐌𝐖𝐌direct-product𝐖𝐌\mathbf{W(\mathbf{M})}=\mathbf{W}\odot\mathbf{M}bold_W ( bold_M ) = bold_W ⊙ bold_M where direct-product\odot is the element-wise Hadamard product. The generated masked weight 𝐖(𝐌)𝐖𝐌\mathbf{W(\mathbf{M})}bold_W ( bold_M ) is used for the forward pass during training. Note that top-k𝑘kitalic_k is the target weight blocks of interest. To compute the gradient of 𝐒𝐒\mathbf{S}bold_S during the backward pass, a straight-through estimator (STE) [40, 41, 42] is used that neglects the gradients of 𝐌𝐌\mathbf{M}bold_M with respect to 𝐒𝐒\mathbf{S}bold_S. Additionally, the pruning of 𝐖psubscript𝐖𝑝\mathbf{W}_{p}bold_W start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT in row dimension and the pruning of 𝐖projsubscript𝐖proj\mathbf{W}_{\text{proj}}bold_W start_POSTSUBSCRIPT proj end_POSTSUBSCRIPT in column dimension follows the same pattern (denoted as alternate pattern), as shown in Figure 2. For example, a head removed from 𝐖psubscript𝐖𝑝\mathbf{W}_{p}bold_W start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT makes the corresponding head in 𝐖projsubscript𝐖proj\mathbf{W}_{\text{proj}}bold_W start_POSTSUBSCRIPT proj end_POSTSUBSCRIPT redundant, and vice-versa.

Refer to caption
Figure 2: Alternate pattern of block pruning for 𝐖psubscript𝐖𝑝\mathbf{W}_{p}bold_W start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT and 𝐖projsubscript𝐖proj\mathbf{W}_{\text{proj}}bold_W start_POSTSUBSCRIPT proj end_POSTSUBSCRIPT parameters.

MLP Pruning. The weight matrices in MLP are 𝐖intD×Dmlpand𝐖outDmlp×Dformulae-sequencesubscript𝐖intsuperscript𝐷subscript𝐷mlpandsubscript𝐖outsuperscriptsubscript𝐷mlp𝐷\mathbf{W}_{\text{int}}\in\mathbb{R}^{D\times D_{\text{mlp}}}\quad\mbox{and}% \quad\mathbf{W}_{\text{out}}\in\mathbb{R}^{D_{\text{mlp}}\times D}bold_W start_POSTSUBSCRIPT int end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_D start_POSTSUBSCRIPT mlp end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and bold_W start_POSTSUBSCRIPT out end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT mlp end_POSTSUBSCRIPT × italic_D end_POSTSUPERSCRIPT. The pruning of 𝐖intsubscript𝐖int\mathbf{W}_{\text{int}}bold_W start_POSTSUBSCRIPT int end_POSTSUBSCRIPT and 𝐖outsubscript𝐖out\mathbf{W}_{\text{out}}bold_W start_POSTSUBSCRIPT out end_POSTSUBSCRIPT follows the approach for pruning MSA. A key difference, however, is in how the score parameters are defined for 𝐖intsubscript𝐖int\mathbf{W}_{\text{int}}bold_W start_POSTSUBSCRIPT int end_POSTSUBSCRIPT and 𝐖outsubscript𝐖out\mathbf{W}_{\text{out}}bold_W start_POSTSUBSCRIPT out end_POSTSUBSCRIPT. Specifically, we define the scores as: 𝐒linearDmlp where linear={int,out}subscript𝐒linearsuperscriptsubscript𝐷mlp where linearintout\mathbf{S}_{\text{linear}}\in\mathbb{R}^{D_{\text{mlp}}}\text{ }\mbox{where}% \text{ }\text{linear}=\{\text{int},\text{out}\}bold_S start_POSTSUBSCRIPT linear end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT mlp end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_where roman_linear = { int , out }. The score vectors are defined to prune entire columns of 𝐖intsubscript𝐖int\mathbf{W}_{\text{int}}bold_W start_POSTSUBSCRIPT int end_POSTSUBSCRIPT and entire rows for 𝐖outsubscript𝐖out\mathbf{W}_{\text{out}}bold_W start_POSTSUBSCRIPT out end_POSTSUBSCRIPT (see Figure 3). Masks are generated column-wise/row-wise through top-k𝑘kitalic_k selection. The natural parameter partitioning along the heads in MSA makes block pruning more effective in terms of removing entire heads. MLP parameters, on the other hand, lack such a partitioning. We thus focus on removing entire columns/rows for MLP parameters. For model training, we add a norm of the sigmoid of scores to the training loss [23]:

min𝐖min𝐖,𝐒+λσ(𝐒) where 𝐀=ijAijsubscript𝐖subscript𝐖𝐒𝜆norm𝜎𝐒 where norm𝐀subscript𝑖𝑗subscript𝐴𝑖𝑗\min_{\mathbf{W}}\mathcal{L}\rightarrow\min_{\mathbf{W},\mathbf{S}}\mathcal{L}% +\lambda||\sigma(\mathbf{S})||\text{ }\mbox{where}\text{ }||\mathbf{A}||=\sum_% {ij}A_{ij}roman_min start_POSTSUBSCRIPT bold_W end_POSTSUBSCRIPT caligraphic_L → roman_min start_POSTSUBSCRIPT bold_W , bold_S end_POSTSUBSCRIPT caligraphic_L + italic_λ | | italic_σ ( bold_S ) | | roman_where | | bold_A | | = ∑ start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT italic_A start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT (8)

where the loss is updated to penalize the presence of a model parameter, thus driving the model to be sparse. The extent of this penalization is controlled by the hyper-parameter λ𝜆\lambdaitalic_λ.

Refer to caption
Figure 3: Alternate column-wise/row-wise pruning for 𝐖intsubscript𝐖int\mathbf{W}_{\text{int}}bold_W start_POSTSUBSCRIPT int end_POSTSUBSCRIPT and 𝐖outsubscript𝐖out\mathbf{W}_{\text{out}}bold_W start_POSTSUBSCRIPT out end_POSTSUBSCRIPT. Note that Dmlpsubscript𝐷𝑚𝑙𝑝D_{mlp}italic_D start_POSTSUBSCRIPT italic_m italic_l italic_p end_POSTSUBSCRIPT is much larger than D𝐷Ditalic_D.

IV-B Dynamic Token Pruning

Dynamic Token Pruning prunes along the token dimension N𝑁Nitalic_N. The redundancy along the token dimension comes from the fact that several patches within an input image are inattentive, contributing insignificantly to the final learned model [35, 34, 28, 36]. Since ViTs can inherently handle inputs with an arbitrary number of tokens (patches), we exploit this independence of the input token dimension from the model parameter dimension(s) by dropping inattentive tokens. Specifically, to classify tokens into attentive and inattentive tokens, we use a non-parametric approach [28]. The attention 𝐀𝐀\mathbf{A}bold_A computed within the MSA (Equation 3) is utilized to perform attentive token identification. In MSA, the attention score 𝐀hsubscript𝐀\mathbf{A}_{h}bold_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT is generated by each head. We aggregate the above score vector across all the heads using 𝒮=1Hh𝐀h where 𝒮N𝒮1𝐻subscriptsubscript𝐀 where 𝒮superscript𝑁\mathbf{\mathcal{S}}=\frac{1}{H}\sum_{h}\mathbf{A}_{h}\text{ }\mbox{where}% \text{ }\mathbf{\mathcal{S}}\in\mathbb{R}^{N}caligraphic_S = divide start_ARG 1 end_ARG start_ARG italic_H end_ARG ∑ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT roman_where caligraphic_S ∈ blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT and represents the importance score of every single token. Based on a keep-rate rtsubscript𝑟𝑡r_{t}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, a total of (N1)rt𝑁1subscript𝑟𝑡\left\lceil(N-1)r_{t}\right\rceil⌈ ( italic_N - 1 ) italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⌉ tokens with the top scores in 𝒮𝒮\mathbf{\mathcal{S}}caligraphic_S are retained. The remaining inattentive tokens are fused into a single token by performing a weighted aggregation of these tokens with respect to their respective scores in 𝒮𝒮\mathbf{\mathcal{S}}caligraphic_S. The above token dropping is performed via a token dropping module (TDM) inserted between the MSA and the MLP modules (Figure 4), with the tokens dropped dynamically during both training and inference.

Refer to caption
Figure 4: TDM inserted between the MSA and MLP block inside an encoder. TDM updates the input to the MLP block, 𝐙l\mathbf{Z}_{l}{{}^{\prime}}bold_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT, as 𝐙lTDM(𝐙l)\mathbf{Z}_{l}{{}^{\prime}}\leftarrow\mbox{TDM}(\mathbf{Z}_{l}{{}^{\prime}})bold_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT ← TDM ( bold_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT ).

IV-C Simultaneous Pruning

To recover accuracy for the pruned model, we utilize the knowledge distillation technique commonly used to transfer knowledge from an already trained larger teacher model to a smaller student model [43]. The class logits associated with the teacher and the student networks are used to compute a distillation loss at a distillation temperature T𝑇Titalic_T:

distill=T2KL(𝐩teacher(T)||𝐩student(T))\mathcal{L}_{\text{distill}}=T^{2}\text{KL}(\mathbf{p}_{\text{teacher}}(T)\;||% \;\mathbf{p_{\text{student}}}(T))caligraphic_L start_POSTSUBSCRIPT distill end_POSTSUBSCRIPT = italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT KL ( bold_p start_POSTSUBSCRIPT teacher end_POSTSUBSCRIPT ( italic_T ) | | bold_p start_POSTSUBSCRIPT student end_POSTSUBSCRIPT ( italic_T ) ) (9)

where KL(.)\text{KL}(.)KL ( . ) stands for the KL divergence loss. 𝐩(T)𝐩𝑇\mathbf{p}(T)bold_p ( italic_T ) refers to the softmax probability vector computed at a temperature of T𝑇Titalic_T from input logits vector 𝐥psubscript𝐥𝑝\mathbf{l}_{p}bold_l start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT as exp(𝐥p/T)iexp(𝐥p(i)/T)expsubscript𝐥𝑝𝑇subscript𝑖𝑒𝑥𝑝subscript𝐥𝑝𝑖𝑇\frac{\text{exp}(\mathbf{l}_{p}/T)}{\sum_{i}exp(\mathbf{l}_{p}(i)/T)}divide start_ARG exp ( bold_l start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT / italic_T ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_e italic_x italic_p ( bold_l start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_i ) / italic_T ) end_ARG. The final loss is obtained as a weighted sum of the generic loss and the distillation loss, with the weights acting as hyper-parameters. The simultaneous training algorithm used to train a sparse model on sparse attentive input tokens is given in Algorithm 1. EncoderTDMs,jsuperscriptsubscriptEncoderTDMsuperscript𝑠𝑗\text{Encoder}_{\text{TDM}}^{\mathcal{M}^{s},j}Encoder start_POSTSUBSCRIPT TDM end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_M start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT , italic_j end_POSTSUPERSCRIPT is an encoder at layer j𝑗jitalic_j with the TDM module included, in a ViT model ssuperscript𝑠\mathcal{M}^{s}caligraphic_M start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT. Similarly, Encoders,jsuperscriptEncodersuperscript𝑠𝑗\text{Encoder}^{\mathcal{M}^{s},j}Encoder start_POSTSUPERSCRIPT caligraphic_M start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT , italic_j end_POSTSUPERSCRIPT is an encoder at layer j𝑗jitalic_j without the TDM.

Algorithm 1 Simultaneous Fine-Pruning
1:Student model s(𝐱;Θ)superscript𝑠𝐱Θ\mathcal{M}^{s}(\mathbf{x};\Theta)caligraphic_M start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ( bold_x ; roman_Θ ); teacher model t(𝐱)superscript𝑡𝐱\mathcal{M}^{t}(\mathbf{x})caligraphic_M start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( bold_x ); model pruning top-k𝑘kitalic_k rate rbsubscript𝑟𝑏r_{b}italic_r start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT; input token pruning keep rate rtsubscript𝑟𝑡r_{t}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT; set {\ellroman_ℓ} of encoders at some depth in the model where TDM used; dataset 𝒟𝒟\mathcal{D}caligraphic_D for fine-pruning
2:Set of weight and score parameters {𝐖,𝐒}𝐖𝐒\{\mathbf{W},\mathbf{S}\}{ bold_W , bold_S } are initialized
3:for i=1epochs𝑖1epochsi=1...\text{epochs}italic_i = 1 … epochs do
4:    for all 𝐱𝐱\mathbf{x}bold_x in 𝒟𝒟\mathcal{D}caligraphic_D do
5:         Compute masks {𝐌𝐌\mathbf{M}bold_M} using scores {𝐒𝐒\mathbf{S}bold_S} via rbsubscript𝑟𝑏r_{b}italic_r start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT
6:         𝐖(𝐌)𝐖𝐌𝐖𝐌direct-product𝐖𝐌\mathbf{W}(\mathbf{M})\leftarrow\mathbf{W}\odot\mathbf{M}bold_W ( bold_M ) ← bold_W ⊙ bold_M for all 𝐖{𝐖}𝐖𝐖\mathbf{W}\in\{\mathbf{W}\}bold_W ∈ { bold_W }
7:         𝐲𝐱𝐲𝐱\mathbf{y}\leftarrow\mathbf{x}bold_y ← bold_x
8:         for encoders in ssuperscript𝑠\mathcal{M}^{s}caligraphic_M start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT at layer j𝑗jitalic_j from 1L1𝐿1...L1 … italic_L do
9:             if j{}𝑗j\in\{\ell\}italic_j ∈ { roman_ℓ } then
10:                 𝐲EncoderTDMs,j(𝐲)𝐲superscriptsubscriptEncoderTDMsuperscript𝑠𝑗𝐲\mathbf{y}\leftarrow\text{Encoder}_{\text{TDM}}^{\mathcal{M}^{s},j}(\mathbf{y})bold_y ← Encoder start_POSTSUBSCRIPT TDM end_POSTSUBSCRIPT start_POSTSUPERSCRIPT caligraphic_M start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT , italic_j end_POSTSUPERSCRIPT ( bold_y )
11:             else
12:                 𝐲Encoders,j(𝐲)𝐲superscriptEncodersuperscript𝑠𝑗𝐲\mathbf{y}\leftarrow\text{Encoder}^{\mathcal{M}^{s},j}(\mathbf{y})bold_y ← Encoder start_POSTSUPERSCRIPT caligraphic_M start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT , italic_j end_POSTSUPERSCRIPT ( bold_y )                       
13:         Compute student logits 𝐳ssubscript𝐳𝑠\mathbf{z}_{s}bold_z start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT from 𝐲𝐲\mathbf{y}bold_y and final classifier of ssuperscript𝑠\mathcal{M}^{s}caligraphic_M start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT
14:         Compute teacher logits 𝐳tsubscript𝐳𝑡\mathbf{z}_{t}bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT using t(𝐱)superscript𝑡𝐱\mathcal{M}^{t}(\mathbf{x})caligraphic_M start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( bold_x ) and final classifier of tsuperscript𝑡\mathcal{M}^{t}caligraphic_M start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT
15:         Compute distillsubscript𝑑𝑖𝑠𝑡𝑖𝑙𝑙\mathcal{L}_{distill}caligraphic_L start_POSTSUBSCRIPT italic_d italic_i italic_s italic_t italic_i italic_l italic_l end_POSTSUBSCRIPT via 𝐳tsubscript𝐳𝑡\mathbf{z}_{t}bold_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, 𝐳ssubscript𝐳𝑠\mathbf{z}_{s}bold_z start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT and Euquation 9
16:         Compute \mathcal{L}caligraphic_L as in Euquation 8
17:         netλdistilldistill+λnormalsubscriptnetsubscript𝜆distillsubscriptdistillsubscript𝜆normal\mathcal{L}_{\text{net}}\leftarrow\lambda_{\text{distill}}\mathcal{L}_{\text{% distill}}+\lambda_{\text{normal}}\mathcal{L}caligraphic_L start_POSTSUBSCRIPT net end_POSTSUBSCRIPT ← italic_λ start_POSTSUBSCRIPT distill end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT distill end_POSTSUBSCRIPT + italic_λ start_POSTSUBSCRIPT normal end_POSTSUBSCRIPT caligraphic_L
18:         Backpropogate netsubscriptnet\mathcal{L}_{\text{net}}caligraphic_L start_POSTSUBSCRIPT net end_POSTSUBSCRIPT and compute gradients
19:         Update {𝐖,𝐒}𝐖𝐒\{\mathbf{W},\mathbf{S}\}{ bold_W , bold_S }     

IV-D Computational Complexity: Pruned Model

We analyze the computational complexity for the proposed pruned model. The complexity of an encoder is described in table II. α𝛼\alphaitalic_α is the average ratio of retained weight blocks to the total weight blocks (retained and pruned) within a column of blocks in parameter matrices 𝐖psubscript𝐖𝑝\mathbf{W}_{p}bold_W start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT (computed after the removal of heads pruned in their entirety). αsuperscript𝛼\alpha^{\prime}italic_α start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is defined similarly, but for matrix 𝐖projsubscript𝐖proj\mathbf{W}_{\text{proj}}bold_W start_POSTSUBSCRIPT proj end_POSTSUBSCRIPT. Hkeptsubscript𝐻keptH_{\text{kept}}italic_H start_POSTSUBSCRIPT kept end_POSTSUBSCRIPT are the number of heads retained within MSA. Nkeptsubscript𝑁keptN_{\text{kept}}italic_N start_POSTSUBSCRIPT kept end_POSTSUBSCRIPT are the total retained tokens after token dropping (Nrtabsent𝑁subscript𝑟𝑡\approx Nr_{t}≈ italic_N italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT). αmlpsuperscript𝛼mlp\alpha^{\text{mlp}}italic_α start_POSTSUPERSCRIPT mlp end_POSTSUPERSCRIPT is the ratio of retained neurons (same for both 𝐖intsubscript𝐖int\mathbf{W}_{\text{int}}bold_W start_POSTSUBSCRIPT int end_POSTSUBSCRIPT and 𝐖outsubscript𝐖out\mathbf{W}_{\text{out}}bold_W start_POSTSUBSCRIPT out end_POSTSUBSCRIPT). Note that αmlp=rbsuperscript𝛼mlpsubscript𝑟𝑏\alpha^{\text{mlp}}=r_{b}italic_α start_POSTSUPERSCRIPT mlp end_POSTSUPERSCRIPT = italic_r start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT.

TABLE II: Computational Complexity of Pruned Model

Operation Computational Complexity LayerNorm 1 (×1absent1\times 1× 1) BND𝐵𝑁𝐷BNDitalic_B italic_N italic_D LayerNorm 2 (×1absent1\times 1× 1) BNkeptD𝐵subscript𝑁kept𝐷BN_{\text{kept}}Ditalic_B italic_N start_POSTSUBSCRIPT kept end_POSTSUBSCRIPT italic_D Residual Add 1 (×1absent1\times 1× 1) BND𝐵𝑁𝐷BNDitalic_B italic_N italic_D Residual Add 2 (×1absent1\times 1× 1) BNkeptD𝐵subscript𝑁kept𝐷BN_{\text{kept}}Ditalic_B italic_N start_POSTSUBSCRIPT kept end_POSTSUBSCRIPT italic_D MSA (×1absent1\times 1× 1) BHkeptNDD(3α+α)+2BHkeptN2D𝐵subscript𝐻kept𝑁superscript𝐷𝐷3𝛼superscript𝛼2𝐵subscript𝐻keptsuperscript𝑁2superscript𝐷BH_{\text{kept}}ND^{\prime}D(3\alpha+\alpha^{\prime})+2BH_{\text{kept}}N^{2}D^% {\prime}italic_B italic_H start_POSTSUBSCRIPT kept end_POSTSUBSCRIPT italic_N italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_D ( 3 italic_α + italic_α start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) + 2 italic_B italic_H start_POSTSUBSCRIPT kept end_POSTSUBSCRIPT italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT TDM (×1absent1\times 1× 1) BN(H+N+D)𝐵𝑁𝐻𝑁𝐷BN(H+N+D)italic_B italic_N ( italic_H + italic_N + italic_D ) MLP (×1absent1\times 1× 1) 2BNkeptDDmlpαmlp2𝐵subscript𝑁kept𝐷subscript𝐷mlpsuperscript𝛼mlp2BN_{\text{kept}}DD_{\text{mlp}}\alpha^{\text{mlp}}2 italic_B italic_N start_POSTSUBSCRIPT kept end_POSTSUBSCRIPT italic_D italic_D start_POSTSUBSCRIPT mlp end_POSTSUBSCRIPT italic_α start_POSTSUPERSCRIPT mlp end_POSTSUPERSCRIPT Total Complexity 2BND+2BNkeptD+BHkeptNDD(3α+α)2𝐵𝑁𝐷2𝐵subscript𝑁kept𝐷𝐵subscript𝐻kept𝑁superscript𝐷𝐷3𝛼superscript𝛼2BND+2BN_{\text{kept}}D+BH_{\text{kept}}ND^{\prime}D(3\alpha+\alpha^{\prime})2 italic_B italic_N italic_D + 2 italic_B italic_N start_POSTSUBSCRIPT kept end_POSTSUBSCRIPT italic_D + italic_B italic_H start_POSTSUBSCRIPT kept end_POSTSUBSCRIPT italic_N italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_D ( 3 italic_α + italic_α start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) +2BHkeptN2D+BN(H+N+D)2𝐵subscript𝐻keptsuperscript𝑁2superscript𝐷𝐵𝑁𝐻𝑁𝐷+2BH_{\text{kept}}N^{2}D^{\prime}+BN(H+N+D)+ 2 italic_B italic_H start_POSTSUBSCRIPT kept end_POSTSUBSCRIPT italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + italic_B italic_N ( italic_H + italic_N + italic_D ) +2BNkeptDDmlpαmlp2𝐵subscript𝑁kept𝐷subscript𝐷mlpsuperscript𝛼mlp+2BN_{\text{kept}}DD_{\text{mlp}}\alpha^{\text{mlp}}+ 2 italic_B italic_N start_POSTSUBSCRIPT kept end_POSTSUBSCRIPT italic_D italic_D start_POSTSUBSCRIPT mlp end_POSTSUBSCRIPT italic_α start_POSTSUPERSCRIPT mlp end_POSTSUPERSCRIPT

Refer to caption
Figure 5: Data layout of dense token matrix and sparse weight matrix.

V Hardware Design

In this Section, we introduce our hardware design to accelerate the pruned model on the FPGA platform. To be specific: in Section V-A, we introduce the data format and layout that store the sparse (and dense) weight matrices and input data; in Section V-B, we introduce the main components in the proposed hardware architecture. In Section V-C, we introduce the workflow for executing the pruned ViT encoder using the proposed hardware design.

Refer to caption
Figure 6: Overview of hardware architecture.
Refer to caption
Figure 7: Task scheduling for executing an ViT encoder on the proposed architecture. LN𝐿𝑁LNitalic_L italic_N denotes LayerNorm

V-A Data Format and Layout

Due to structured block pruning, all weight and feature matrices are partitioned into data blocks of the same size b×b𝑏𝑏b\times bitalic_b × italic_b. All the data in the same blocks are stored contiguously. The dense matrix is stored in block-wise row-major order such that all the data blocks of the same row are stored contiguously in memory space. The weight matrices are stored in column-major order such that all the unpruned data blocks in the same column are stored contiguous in memory space as shown in Figure 5. Note that for the sparse weight matrices, only unpruned data blocks are stored. For each column in a block-wise sparse weight matrix, we include a header at its beginning that encodes row indices of the present blocks and the length of the column block. For simplicity, in the rest of the paper, a row of matrix denotes a row of data blocks, and a column of matrix denotes a column of data blocks.

V-B Hardware Overview

As shown in Figure 6, the architecture design comprises of: (1) Multi-level Parallelism Compute Array (MPCA), (2) Element-wise Module (EM), (3) Token Dropping Hardware Module (TDHM). Besides, there are on-chip buffers, including a Global Feature Buffer (GFB) that stores the feature matrices, (2) Column Buffers (CB) that store the weight matrix, (3) Result Buffers (RB) that store the results of the current layer.

In MPCA, the computation units are organized into multiple levels. An MPCA has phsubscript𝑝hp_{\text{h}}italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT parallel Computing Head Modules (CHMs). Each CHM has a 2-D array of Processing Elements (PEs) of size pt×pcsubscript𝑝tsubscript𝑝cp_{\text{t}}\times p_{\text{c}}italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT × italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT. Each Processing Element has an array of computation units of size ppe×ppesubscript𝑝pesubscript𝑝pep_{\text{pe}}\times p_{\text{pe}}italic_p start_POSTSUBSCRIPT pe end_POSTSUBSCRIPT × italic_p start_POSTSUBSCRIPT pe end_POSTSUBSCRIPT. Essentially, phsubscript𝑝hp_{\text{h}}italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT, ptsubscript𝑝tp_{\text{t}}italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT, pcsubscript𝑝cp_{\text{c}}italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT, ppe2superscriptsubscript𝑝pe2p_{\text{pe}}^{2}italic_p start_POSTSUBSCRIPT pe end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT are the computation parallelism in the head dimension, input token dimension, weight column dimension, and the data parallelism within the data blocks, respectively. The Element-wise Module (EM) performs element-wise GELU and exponentiation. The Token Dropping Hardware Module (TDHM) performs dynamical token dropping (Section IV-B).

The computation units in MPCA are organized to multi-level because (1) it enables massive data parallelism in MSA and MLP, (2) it enables data reuse/sharing within CHM. The PEs in the same column of CHM can share the same weight block, while the PEs in the same rows of CHM can share the same input token block. This data-sharing strategy simplifies the computation complicated by the irregular data access pattern of block-wise weight pruning. (3) By selecting proper pcsubscript𝑝cp_{\text{c}}italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT with a load balancing strategy, we can alleviate the load imbalance caused by block-wise weight pruning.

V-C Workflow

The proposed accelerator executes the input model layer-by-layer, with each layer being executed using the same set of computational resources (e.g., MPCA). Each CHM utilizes Sparse Block-wise Matrix Multiplication (SBMM) to process sparse matrices in blocks. For dense matrices, the computation shifts to Dense Block-wise Matrix Multiplication (DBMM) and Dense Head-wise Block Matrix Multiplication (DHBMM) for a focused computation on matrix blocks associated with each head. Each layer performs as follows:

V-C1 MSA Execution

MSA is divided into four stages (shown in Figure 7): stage (i) computes 𝐐𝐐\mathbf{Q}bold_Q, 𝐊𝐊\mathbf{K}bold_K and 𝐕𝐕\mathbf{V}bold_V through [𝐐,𝐊,𝐕]=𝐙[𝐖q,𝐖k,𝐖v]𝐐𝐊𝐕𝐙subscript𝐖𝑞subscript𝐖𝑘subscript𝐖𝑣[\mathbf{Q},\mathbf{K},\mathbf{V}]=\mathbf{Z}[\mathbf{W}_{q},\mathbf{W}_{k},% \mathbf{W}_{v}][ bold_Q , bold_K , bold_V ] = bold_Z [ bold_W start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ]. 𝐙𝐙\mathbf{Z}bold_Z is the input token matrix and [][\;][ ] is matrix concatenation. Let 𝐐h,𝐊hsubscript𝐐subscript𝐊\mathbf{Q}_{h},\mathbf{K}_{h}bold_Q start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT , bold_K start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT and 𝐕hsubscript𝐕\mathbf{V}_{h}bold_V start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT denote the query, key, and value matrices for a head hhitalic_h, where (0h<H)0𝐻(0\leq h<H)( 0 ≤ italic_h < italic_H ). The algorithm for executing each matrix multiplication (e.g., dense token matrix 𝐙𝐙\mathbf{Z}bold_Z multiply by sparse weight matrices 𝐖q,𝐖k,𝐖vsubscript𝐖𝑞subscript𝐖𝑘subscript𝐖𝑣\mathbf{W}_{q},\mathbf{W}_{k},\mathbf{W}_{v}bold_W start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT) in MSA using MPCA is shown in Algorithm 2 and an example is shown in Figure 8. Each CHM computes its corresponding head [𝐐h,𝐊h,𝐕h]subscript𝐐subscript𝐊subscript𝐕[\mathbf{Q}_{h},\mathbf{K}_{h},\mathbf{V}_{h}][ bold_Q start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT , bold_K start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT , bold_V start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ]. Since we perform block partitioning for each matrix, the computation is executed in a block-wise fashion. Within each CHM, the PEs within a column share the same column of weight, which is stored in the column buffer (CB). PEs of the same row share the same row of 𝐙𝐙\mathbf{Z}bold_Z. We use PE(i,j,k)PE𝑖𝑗𝑘\text{PE}(i,j,k)PE ( italic_i , italic_j , italic_k ) to denote a PE in the kthsuperscript𝑘thk^{\text{th}}italic_k start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT CHM at location (i,j)𝑖𝑗(i,j)( italic_i , italic_j ) within the CHM. A column of PEs: PE(:,j,k)PE:𝑗𝑘\text{PE}(:,j,k)PE ( : , italic_j , italic_k ) share the same column of weight with the corresponding header information (See Figure 5). Thus, each PE in jthsuperscript𝑗thj^{\text{th}}italic_j start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT column utilizes the shared header indices (Figure 5) to fetch the corresponding data block from the token matrix (in the local input buffer) to perform block-wise matrix multiplication. The partial results are accumulated in local result buffers.

The stage (ii) computes the attention scores 𝐀h=softmax(𝐐h𝐊𝐡TD), (0i<H)subscript𝐀softmaxsubscript𝐐superscriptsubscript𝐊𝐡𝑇superscript𝐷 0𝑖𝐻\mathbf{A}_{h}=\mbox{softmax}(\frac{\mathbf{Q}_{h}\mathbf{K_{h}}^{T}}{\sqrt{D^% {\prime}}}),\text{ }(0\leq i<H)bold_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = softmax ( divide start_ARG bold_Q start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG end_ARG ) , ( 0 ≤ italic_i < italic_H ). 𝐐h𝐊hTsubscript𝐐superscriptsubscript𝐊𝑇\mathbf{Q}_{h}\mathbf{K}_{h}^{T}bold_Q start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT is dense block-wise matrix multiplication executed via the MPCA module (See Algorithm 2). 𝐐𝐐\mathbf{Q}bold_Q is buffered in the GFB, and 𝐊Tsuperscript𝐊𝑇\mathbf{K}^{T}bold_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT is buffered in the CB. The output data blocks of 𝐐h𝐊hTsubscript𝐐superscriptsubscript𝐊𝑇\mathbf{Q}_{h}\mathbf{K}_{h}^{T}bold_Q start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT are sent to EM module for element-wise scaling (by 1/D1superscript𝐷1/\sqrt{D^{\prime}}1 / square-root start_ARG italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG) and exponentiation to obtain exp(𝐐h𝐊hTD)expsubscript𝐐superscriptsubscript𝐊𝑇superscript𝐷\text{exp}(\frac{\mathbf{Q}_{h}\mathbf{K}_{h}^{T}}{\sqrt{D^{\prime}}})exp ( divide start_ARG bold_Q start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG end_ARG ). Then, we utilize MPCA to compute the scaling factors for softmax(𝐐h𝐊hTD)softmaxsubscript𝐐superscriptsubscript𝐊𝑇superscript𝐷\mbox{softmax}(\frac{\mathbf{Q}_{h}\mathbf{K}_{h}^{T}}{\sqrt{D^{\prime}}})softmax ( divide start_ARG bold_Q start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG end_ARG ). The rows of matrix exp(𝐐h𝐊hTD)expsubscript𝐐superscriptsubscript𝐊𝑇superscript𝐷\text{exp}(\frac{\mathbf{Q}_{h}\mathbf{K}_{h}^{T}}{\sqrt{D^{\prime}}})exp ( divide start_ARG bold_Q start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG end_ARG ) and their corresponding computed scaling factors, are streamed from MPCA to EM to perform the scaling to obtain the attention scores, 𝐀Hsuperscript𝐀𝐻\mathbf{A}^{H}bold_A start_POSTSUPERSCRIPT italic_H end_POSTSUPERSCRIPT.

The stage (iii) computes the self-attention 𝐀h𝐕hsubscript𝐀subscript𝐕\mathbf{A}_{h}\mathbf{V}_{h}bold_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT bold_V start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT. It is similar to the computation of 𝐐h𝐊hTsubscript𝐐superscriptsubscript𝐊𝑇\mathbf{Q}_{h}\mathbf{K}_{h}^{T}bold_Q start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT. The stage (iv) computes the projection (Equation 5). It is similar to the computation of 𝐐𝐐\mathbf{Q}bold_Q, 𝐊𝐊\mathbf{K}bold_K and 𝐕𝐕\mathbf{V}bold_V described in stage (i) as 𝐖projsubscript𝐖proj\mathbf{W}_{\text{proj}}bold_W start_POSTSUBSCRIPT proj end_POSTSUBSCRIPT is the block-wise sparse matrix due to pruning.

V-C2 MLP Execution

Since the weights of MLP are pruned for entire columns or rows (for 𝐖intsubscript𝐖int\mathbf{W}_{\text{int}}bold_W start_POSTSUBSCRIPT int end_POSTSUBSCRIPT and 𝐖outsubscript𝐖out\mathbf{W}_{\text{out}}bold_W start_POSTSUBSCRIPT out end_POSTSUBSCRIPT respectively), MLP layers can be mapped into dense block-wise matrix-matrix multiplication executed by MPCA (Algorithm 2). This computation is similar to the computation of MSA (computing 𝐐𝐊Tsuperscript𝐐𝐊𝑇\mathbf{Q}\mathbf{K}^{T}bold_QK start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT). GELU activation is computed using the EM module.

Algorithm 2 Executing Sparse Block-wise Matrix Multiplication (SBMM) and Dense Block-wise Matrix Multiplication (DBMM) through multi-level parallelism of MPCA
1:Input matrix 𝐗M1×M2𝐗superscriptsubscript𝑀1subscript𝑀2\mathbf{X}\in\mathbb{R}^{M_{1}\times M_{2}}bold_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT × italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT; weight matrix 𝐖=[𝐖0,𝐖1,,𝐖H1]M2×D𝐖subscript𝐖0subscript𝐖1subscript𝐖𝐻1superscriptsubscript𝑀2𝐷\mathbf{W}=[\mathbf{W}_{0},\mathbf{W}_{1},...,\mathbf{W}_{H-1}]\in\mathbb{R}^{% M_{2}\times D}bold_W = [ bold_W start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_W start_POSTSUBSCRIPT italic_H - 1 end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT × italic_D end_POSTSUPERSCRIPT, where each 𝐖hM2×Dsubscript𝐖superscriptsubscript𝑀2superscript𝐷\mathbf{W}_{h}\in\mathbb{R}^{M_{2}\times D^{\prime}}bold_W start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT × italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT (0h<H0𝐻0\leq h<H0 ≤ italic_h < italic_H); D=HD𝐷𝐻superscript𝐷D=HD^{\prime}italic_D = italic_H italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT where H𝐻Hitalic_H denotes the number of heads and Dsuperscript𝐷D^{\prime}italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is the dimension per head; block size b𝑏bitalic_b
2:Output matrix 𝐘=𝐗𝐖M1×D𝐘𝐗𝐖superscriptsubscript𝑀1𝐷\mathbf{Y}=\mathbf{X}\mathbf{W}\in\mathbb{R}^{M_{1}\times D}bold_Y = bold_XW ∈ blackboard_R start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT × italic_D end_POSTSUPERSCRIPT
3:// 𝐗𝐗\mathbf{X}bold_X and 𝐘𝐘\mathbf{Y}bold_Y are stored in block-wise row-major order and 𝐖𝐖\mathbf{W}bold_W is stored in block-wise column-major order (Figure 5)
4:// 𝐗[i,j]𝐗𝑖𝑗\mathbf{X}[i,j]bold_X [ italic_i , italic_j ] denotes the (i,j)thsuperscript𝑖𝑗th(i,j)^{\text{th}}( italic_i , italic_j ) start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT block of size b×b𝑏𝑏b\times bitalic_b × italic_b in 𝐗𝐗\mathbf{X}bold_X
5:// 𝐖h[i,j]subscript𝐖𝑖𝑗\mathbf{W}_{h}[i,j]bold_W start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT [ italic_i , italic_j ] denotes the (i,j)thsuperscript𝑖𝑗th(i,j)^{\text{th}}( italic_i , italic_j ) start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT block of size b×b𝑏𝑏b\times bitalic_b × italic_b in 𝐖hsubscript𝐖\mathbf{W}_{h}bold_W start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT
6:// 𝐘hsubscript𝐘\mathbf{Y}_{h}bold_Y start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT is the output corresponding to 𝐖hsubscript𝐖\mathbf{W}_{h}bold_W start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT
7:// To compute H𝐻Hitalic_H heads, phsubscript𝑝p_{h}italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT CHMs need Hph𝐻subscript𝑝h\left\lceil\frac{H}{p_{\text{h}}}\right\rceil⌈ divide start_ARG italic_H end_ARG start_ARG italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT end_ARG ⌉ iterations
8:for i=0 to Hph1𝑖0 to 𝐻subscript𝑝h1i=0\;\text{ to }\;\left\lceil\frac{H}{p_{\text{h}}}\right\rceil-1italic_i = 0 to ⌈ divide start_ARG italic_H end_ARG start_ARG italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT end_ARG ⌉ - 1 do
9:    for each CHMjsubscriptCHM𝑗\text{CHM}_{j}CHM start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT with j=0 to ph1𝑗0 to subscript𝑝h1j=0\;\text{ to }\;p_{\text{h}}-1italic_j = 0 to italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT - 1 Parallel do
10:         // CHMjsubscriptCHM𝑗\text{CHM}_{j}CHM start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT computes 𝐘j+iphsubscript𝐘𝑗𝑖subscript𝑝h\mathbf{Y}_{j+ip_{\text{h}}}bold_Y start_POSTSUBSCRIPT italic_j + italic_i italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT end_POSTSUBSCRIPT
11:         // To compute Dbsuperscript𝐷𝑏\left\lceil\frac{D^{\prime}}{b}\right\rceil⌈ divide start_ARG italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG italic_b end_ARG ⌉ column blocks of a 𝐖hsubscript𝐖\mathbf{W}_{h}bold_W start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT, pcsubscript𝑝𝑐p_{c}italic_p start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT columns            of PEs in a CHM need Db/ptsuperscript𝐷𝑏subscript𝑝t\left\lceil\left\lceil\frac{D^{\prime}}{b}\right\rceil/p_{\text{t}}\right\rceil⌈ ⌈ divide start_ARG italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG italic_b end_ARG ⌉ / italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT ⌉ iterations
12:         for k=0 to Db/pc1𝑘0 to superscript𝐷𝑏subscript𝑝c1k=0\;\text{ to }\;\left\lceil\left\lceil\frac{D^{\prime}}{b}\right\rceil/{p_{% \text{c}}}\right\rceil-1italic_k = 0 to ⌈ ⌈ divide start_ARG italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG italic_b end_ARG ⌉ / italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT ⌉ - 1 do
13:             // Load weights into CB
14:             // To compute Mb𝑀𝑏\left\lceil\frac{M}{b}\right\rceil⌈ divide start_ARG italic_M end_ARG start_ARG italic_b end_ARG ⌉ row blocks of 𝐗𝐗\mathbf{X}bold_X, ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT rows of PEs                 in a CHM need Mb/pt𝑀𝑏subscript𝑝t\left\lceil\left\lceil\frac{M}{b}\right\rceil/p_{\text{t}}\right\rceil⌈ ⌈ divide start_ARG italic_M end_ARG start_ARG italic_b end_ARG ⌉ / italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT ⌉ iterations.
15:             for l=0 to Mb/pt1𝑙0 to 𝑀𝑏subscript𝑝t1l=0\;\text{ to }\;\left\lceil\left\lceil\frac{M}{b}\right\rceil/p_{\text{t}}% \right\rceil-1italic_l = 0 to ⌈ ⌈ divide start_ARG italic_M end_ARG start_ARG italic_b end_ARG ⌉ / italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT ⌉ - 1 do
16:                 // Load data (partition of 𝐗𝐗\mathbf{X}bold_X) into GFB
17:                 for each PEj(m,n)subscriptPE𝑗𝑚𝑛\text{PE}_{j}(m,n)PE start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_m , italic_n ) in CHMjsubscriptCHM𝑗\text{CHM}_{j}CHM start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT Parallel do
18:                     // PEj(m,n)subscriptPE𝑗𝑚𝑛\text{PE}_{j}(m,n)PE start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_m , italic_n ) computes output block                             𝐘j+iph[m+lpt,n+kpc]subscript𝐘𝑗𝑖subscript𝑝h𝑚𝑙subscript𝑝t𝑛𝑘subscript𝑝c\mathbf{Y}_{j+ip_{\text{h}}}[m+lp_{\text{t}},n+kp_{\text{c}}]bold_Y start_POSTSUBSCRIPT italic_j + italic_i italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_m + italic_l italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT , italic_n + italic_k italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT ]
19:                     if MPCA mode is SBMM then
20:                         Fetch 𝐗[m+lpt,idx]𝐗𝑚𝑙subscript𝑝tidx\mathbf{X}[m+lp_{\text{t}},\text{idx}]bold_X [ italic_m + italic_l italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT , idx ] from GFB for all idx                                in the header of 𝐖j+iph[:,n+kpc]subscript𝐖𝑗𝑖subscript𝑝h:𝑛𝑘subscript𝑝c\mathbf{W}_{j+ip_{\text{h}}}[:,n+kp_{\text{c}}]bold_W start_POSTSUBSCRIPT italic_j + italic_i italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ : , italic_n + italic_k italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT ]
21:                         Compute 𝐘j+iph[m+lpt,n+kpc]subscript𝐘𝑗𝑖subscript𝑝h𝑚𝑙subscript𝑝t𝑛𝑘subscript𝑝c\mathbf{Y}_{j+ip_{\text{h}}}[m+lp_{\text{t}},n+kp_{\text{c}}]bold_Y start_POSTSUBSCRIPT italic_j + italic_i italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_m + italic_l italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT , italic_n + italic_k italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT ] using                                the fetched input blocks from GFB
22:                     else
23:                         // MPCA mode is DBMM
24:                         Fetch 𝐗[m+lpt,:]𝐗𝑚𝑙subscript𝑝t:\mathbf{X}[m+lp_{\text{t}},:]bold_X [ italic_m + italic_l italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT , : ] from GFB
25:                         Compute 𝐘j+iph[m+lpt,n+kpc]subscript𝐘𝑗𝑖subscript𝑝h𝑚𝑙subscript𝑝t𝑛𝑘subscript𝑝c\mathbf{Y}_{j+ip_{\text{h}}}[m+lp_{\text{t}},n+kp_{\text{c}}]bold_Y start_POSTSUBSCRIPT italic_j + italic_i italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_m + italic_l italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT , italic_n + italic_k italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT ] using                                all the input blocks from GFB                                                                 
Refer to caption
Figure 8: Execution of Sparse Block-wise Matrix Multiplication (SBMM) on MSA. Note that X[i,j]𝑋𝑖𝑗X[i,j]italic_X [ italic_i , italic_j ] denotes a data block at ithsuperscript𝑖thi^{\text{th}}italic_i start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT row and jthsuperscript𝑗thj^{\text{th}}italic_j start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT column. In Each CHM, the PEs of the same column share the same column block of weight matrix. In Each CHM, the PEs of the same row block of token matrix.

V-C3 Dynamic Token Dropping

We design a token dropping hardware module (TDHM) for on-the-fly token dropping and reorganizing the remaining tokens. The token pruning is based on importance scores of tokens 𝒮𝒮\mathbf{\mathcal{S}}caligraphic_S (See Secton IV-B). The attention scores 𝐀hsubscript𝐀\mathbf{A}_{h}bold_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT for all the heads are buffered in the TDHM as soon as they are computed via MSA execution. Then, scores 𝒮=1Hh𝐀h𝒮1𝐻subscriptsubscript𝐀\mathbf{\mathcal{S}}=\frac{1}{H}\sum_{h}\mathbf{A}_{h}caligraphic_S = divide start_ARG 1 end_ARG start_ARG italic_H end_ARG ∑ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT are computed via the EM. After that, a bitonic sorting network sorts the scores 𝒮𝒮\mathbf{\mathcal{S}}caligraphic_S to obtain the indices of top-k𝑘kitalic_k tokens. Original, each token has a row index in the input token matrix 𝐙insubscript𝐙in\mathbf{Z}_{\text{in}}bold_Z start_POSTSUBSCRIPT in end_POSTSUBSCRIPT, denoted as old token index (idoldsubscriptidold\text{id}_{\text{old}}id start_POSTSUBSCRIPT old end_POSTSUBSCRIPT). After sorting by 𝒮𝒮\mathbf{\mathcal{S}}caligraphic_S, each token is assigned new token index (idnewsubscriptidnew\text{id}_{\text{new}}id start_POSTSUBSCRIPT new end_POSTSUBSCRIPT) which is the row index in the output token matrix 𝐙outsubscript𝐙out\mathbf{Z}_{\text{out}}bold_Z start_POSTSUBSCRIPT out end_POSTSUBSCRIPT. Therefore, the sorting network generates (idoldsubscriptidold\text{id}_{\text{old}}id start_POSTSUBSCRIPT old end_POSTSUBSCRIPT, idnewsubscriptidnew\text{id}_{\text{new}}id start_POSTSUBSCRIPT new end_POSTSUBSCRIPT, flag) for each token, where flag indicates if the token will be pruned. To organize output token matrix 𝐙outsubscript𝐙out\mathbf{Z}_{\text{out}}bold_Z start_POSTSUBSCRIPT out end_POSTSUBSCRIPT (stored in Old Token Buffer), an index shuffle network routes each (idold,idnew,(\text{id}_{\text{old}},\text{id}_{\text{new}},( id start_POSTSUBSCRIPT old end_POSTSUBSCRIPT , id start_POSTSUBSCRIPT new end_POSTSUBSCRIPT ,flag)))) to Old Token Buffer for fetching tokens according to idoldsubscriptidold\text{id}_{\text{old}}id start_POSTSUBSCRIPT old end_POSTSUBSCRIPT. Then, the fetched tokens are routed to the New Token Buffer according to idnewsubscriptidnew\text{id}_{\text{new}}id start_POSTSUBSCRIPT new end_POSTSUBSCRIPT, which generates too-k𝑘kitalic_k token matrix and non-top-k𝑘kitalic_k token matrix. The non-top-k𝑘kitalic_k tokens are then fused into a single token and merged with the top-k𝑘kitalic_k tokens to produce the output of TDHM module.

V-D Optimizations

V-D1 Load balancing across columns for SBMM

As the weight matrices are pruned in block-wise fashion, different columns in a weight matrix can have different number of data blocks, which can potentially lead to load imbalance. The multi-level parallelism in MPCA distributes the PEs across several heads (CHMs). PEs within each CHM computes on weight column blocks using multiple iterations that each iteration executes different columns of weight matrices (See Algorithm 2). This naturally reduces the impact of load imbalance due to differing sparsity levels across the columns of weight matrices. As we restrict block-wise pruning to only the weights of MSA (𝐖qsubscript𝐖𝑞\mathbf{W}_{q}bold_W start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT, 𝐖ksubscript𝐖𝑘\mathbf{W}_{k}bold_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, 𝐖vsubscript𝐖𝑣\mathbf{W}_{v}bold_W start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT, 𝐖projsubscript𝐖proj\mathbf{W}_{\text{proj}}bold_W start_POSTSUBSCRIPT proj end_POSTSUBSCRIPT), this further reduces the impact of load imbalance. Note that the gain in performance by block-wise pruning the MSA parameters (leading to the removal of entire heads) is far outweighted by the load imbalance presented by such pruning. Prior methods [44] that balance such block-wise pruning across columns are disadvantaged by the fact that they cannot remove entire heads. Moreover, we perform offline workload assignment among columns of weight marices, prior to inference, such that workloads of columns are evenly distributed across different columns of PEs within a CHM.

V-D2 Dealing with varying number of tokens, retained heads and block sizes

Dynamic token pruning, leads to varying number of tokens for different layers. Moreover, block-wise weight pruning the MSA parameters leads to removal of heads within an encoder. In general, the heads removed or retained in each encoder can vary, which can potentially lead to runtime hardware underutilization. For example, if the number of rows of token matrix Nb<pt𝑁𝑏subscript𝑝𝑡\frac{N}{b}<p_{t}divide start_ARG italic_N end_ARG start_ARG italic_b end_ARG < italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, ptNbsubscript𝑝𝑡𝑁𝑏p_{t}-\frac{N}{b}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG italic_N end_ARG start_ARG italic_b end_ARG rows of PEs in a CHM will be idle. As we utilize multi-level parallelism in MPCA, through selecting proper ptsubscript𝑝tp_{\text{t}}italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT (parallelism in token dimension) and phsubscript𝑝hp_{\text{h}}italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT (parallelism in head dimension), we can alleviate the underutilization. We use Nminbsubscript𝑁min𝑏\frac{N_{\text{min}}}{b}divide start_ARG italic_N start_POSTSUBSCRIPT min end_POSTSUBSCRIPT end_ARG start_ARG italic_b end_ARG to denote the minimum number of row blocks of all the intermediate token matrices and use Hminsubscript𝐻minH_{\text{min}}italic_H start_POSTSUBSCRIPT min end_POSTSUBSCRIPT to denote the minimum number of heads of all the layers. Through setting ptNminbmuch-less-thansubscript𝑝tsubscript𝑁min𝑏p_{\text{t}}\ll\frac{N_{\text{min}}}{b}italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT ≪ divide start_ARG italic_N start_POSTSUBSCRIPT min end_POSTSUBSCRIPT end_ARG start_ARG italic_b end_ARG, the PEs utilization in a CHM will be >Nminpt×bNminpt×babsentsubscript𝑁minsubscript𝑝𝑡𝑏subscript𝑁minsubscript𝑝𝑡𝑏>\frac{\frac{N_{\text{min}}}{p_{t}\times b}}{\lceil\frac{N_{\text{min}}}{p_{t}% \times b}\rceil}> divide start_ARG divide start_ARG italic_N start_POSTSUBSCRIPT min end_POSTSUBSCRIPT end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT × italic_b end_ARG end_ARG start_ARG ⌈ divide start_ARG italic_N start_POSTSUBSCRIPT min end_POSTSUBSCRIPT end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT × italic_b end_ARG ⌉ end_ARG (Suppose 6×pt<Nminb6subscript𝑝tsubscript𝑁min𝑏6\times p_{\text{t}}<\frac{N_{\text{min}}}{b}6 × italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT < divide start_ARG italic_N start_POSTSUBSCRIPT min end_POSTSUBSCRIPT end_ARG start_ARG italic_b end_ARG. The utilization will be >85%absentpercent85>85\%> 85 % ). Similar, we can set phHminmuch-less-thansubscript𝑝hsubscript𝐻minp_{\text{h}}\ll H_{\text{min}}italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT ≪ italic_H start_POSTSUBSCRIPT min end_POSTSUBSCRIPT.

V-E Resource and Performance Models

V-E1 Resource Consumption Model

We perform theoretical analysis for the performance achieved by the codesign and its hardware resource utilization. We denote the total computational resources utilized by the MPCA, EM and TDHM as RMPCA,REMsubscript𝑅MPCAsubscript𝑅EMR_{\text{MPCA}},R_{\text{EM}}italic_R start_POSTSUBSCRIPT MPCA end_POSTSUBSCRIPT , italic_R start_POSTSUBSCRIPT EM end_POSTSUBSCRIPT and RTDHMsubscript𝑅TDHMR_{\text{TDHM}}italic_R start_POSTSUBSCRIPT TDHM end_POSTSUBSCRIPT, respectively. RMPCAsubscript𝑅MPCAR_{\text{MPCA}}italic_R start_POSTSUBSCRIPT MPCA end_POSTSUBSCRIPT is proportional to the total number of computation units: ptphpcppe2subscript𝑝tsubscript𝑝hsubscript𝑝csubscript𝑝superscriptpe2p_{\text{t}}p_{\text{h}}p_{\text{c}}p_{\text{pe}^{2}}italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT pe start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. Compared to RMPCAsubscript𝑅MPCAR_{\text{MPCA}}italic_R start_POSTSUBSCRIPT MPCA end_POSTSUBSCRIPT, the resources used by RTDHMsubscript𝑅TDHMR_{\text{TDHM}}italic_R start_POSTSUBSCRIPT TDHM end_POSTSUBSCRIPT and REMsubscript𝑅EMR_{\text{EM}}italic_R start_POSTSUBSCRIPT EM end_POSTSUBSCRIPT are negligible, and thus ignored for analysis. The total RTotalsubscript𝑅TotalR_{\text{Total}}italic_R start_POSTSUBSCRIPT Total end_POSTSUBSCRIPT (DSPs and LUTs) are RTotal=(c1ptphpcppe2,c2ptphpcppe2)subscript𝑅Totalsubscript𝑐1subscript𝑝tsubscript𝑝hsubscript𝑝csuperscriptsubscript𝑝pe2subscript𝑐2subscript𝑝tsubscript𝑝hsubscript𝑝csuperscriptsubscript𝑝pe2R_{\text{Total}}=(c_{1}p_{\text{t}}p_{\text{h}}p_{\text{c}}p_{\text{pe}}^{2},c% _{2}p_{\text{t}}p_{\text{h}}p_{\text{c}}p_{\text{pe}}^{2})italic_R start_POSTSUBSCRIPT Total end_POSTSUBSCRIPT = ( italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT pe end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT pe end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ), where c1subscript𝑐1c_{1}italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and c2subscript𝑐2c_{2}italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT denote the amount of DSPs and LUTs utilized by a single computation unit. The size of the (global) feature buffer, column buffer and the (global) result buffer, associated with the MPCA, are b2ptγsuperscript𝑏2subscript𝑝𝑡𝛾b^{2}p_{t}\gammaitalic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_γ, b2pcγsuperscript𝑏2subscript𝑝𝑐𝛾b^{2}p_{c}\gammaitalic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT italic_γ and b2ptphpcsuperscript𝑏2subscript𝑝tsubscript𝑝hsubscript𝑝cb^{2}p_{\text{t}}p_{\text{h}}p_{\text{c}}italic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT, respectively. Here, b𝑏bitalic_b is the block size and γ𝛾\gammaitalic_γ is a constant that equals the (maximum) total number of row blocks required to compute a single output block. We match the buffer sizes across compute units to improve the dataflow performance of the accelerator. This gives a total buffer size of 4×max(b2ptphpc,b2ptγ)4superscript𝑏2subscript𝑝tsubscript𝑝hsubscript𝑝csuperscript𝑏2subscript𝑝𝑡𝛾4\times\max({b^{2}p_{\text{t}}p_{\text{h}}p_{\text{c}},b^{2}p_{t}\gamma})4 × roman_max ( italic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT , italic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_γ ) for the EM module and 2×max(b2ptphpc,b2ptγ)2superscript𝑏2subscript𝑝tsubscript𝑝hsubscript𝑝csuperscript𝑏2subscript𝑝𝑡𝛾2\times\max({b^{2}p_{\text{t}}p_{\text{h}}p_{\text{c}},b^{2}p_{t}\gamma})2 × roman_max ( italic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT , italic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_γ ) for the TDHM module. The EM module requires a buffer to store the input, scaling factor, addtion factor and the output. Similarly, TDHM requires an input and an output buffer. The total size of required buffers BTotalsubscript𝐵TotalB_{\text{Total}}italic_B start_POSTSUBSCRIPT Total end_POSTSUBSCRIPT is given as, BTotal=b2ptγ+b2pcγ+b2ptphpc+6×max(b2ptphpc,b2ptγ)subscript𝐵Totalsuperscript𝑏2subscript𝑝𝑡𝛾superscript𝑏2subscript𝑝𝑐𝛾superscript𝑏2subscript𝑝tsubscript𝑝hsubscript𝑝c6superscript𝑏2subscript𝑝tsubscript𝑝hsubscript𝑝csuperscript𝑏2subscript𝑝𝑡𝛾B_{\text{Total}}=b^{2}p_{t}\gamma+b^{2}p_{c}\gamma+b^{2}p_{\text{t}}p_{\text{h% }}p_{\text{c}}+6\times\max({b^{2}p_{\text{t}}p_{\text{h}}p_{\text{c}},b^{2}p_{% t}\gamma})italic_B start_POSTSUBSCRIPT Total end_POSTSUBSCRIPT = italic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_γ + italic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT italic_γ + italic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT + 6 × roman_max ( italic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT , italic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_γ ). RTotalsubscript𝑅TotalR_{\text{Total}}italic_R start_POSTSUBSCRIPT Total end_POSTSUBSCRIPT and BTotalsubscript𝐵TotalB_{\text{Total}}italic_B start_POSTSUBSCRIPT Total end_POSTSUBSCRIPT are the estimation of resource utilization. The main design ptsubscript𝑝tp_{\text{t}}italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT, phsubscript𝑝hp_{\text{h}}italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT and pcsubscript𝑝cp_{\text{c}}italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT, with γ𝛾\gammaitalic_γ are empirically set according the resource of target FPGA platform (See Section VI for details).

V-E2 Performance Model

Based on algorithm 2, the number of cycles to perform either SBMM, DBMM or DHBMM is estimated in table III. Note that DHBMM is DBMM computed head-wise (as in stage (ii) of MSA execution). In table III, the cycles for SBMM/DBMM are the cycles required to multiply a matrix of dimension (M1,M2)subscript𝑀1subscript𝑀2(M_{1},M_{2})( italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) with a matrix of dimension (M2,D)subscript𝑀2𝐷(M_{2},D)( italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_D ). Dsuperscript𝐷D^{\prime}italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is the size per head, b𝑏bitalic_b is the block size and ϕitalic-ϕ\phiitalic_ϕ is the ratio of retained dense blocks to total blocks within a column of the matrix. Note that for DBMM, ϕitalic-ϕ\phiitalic_ϕ is 1111 and for SBMM, ϕitalic-ϕ\phiitalic_ϕ is assumed similar in each column block for simplicity. (M1,M2)subscript𝑀1subscript𝑀2(M_{1},M_{2})( italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) and (M2,D)subscript𝑀2𝐷(M_{2},D)( italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_D ) are the per head left and right matrix sizes, for DHBMM, with H𝐻Hitalic_H being the total number of heads. The cycle estimates in Table III can be used to compute the total cycles for the MSA and the MLP blocks.

TABLE III: Execution cycles for SBMM/DBMM and DHBMM
Cycles
SBMM/DBMM M1bDbptpcDDphM2bbppe2bϕsubscript𝑀1𝑏superscript𝐷𝑏subscript𝑝tsubscript𝑝c𝐷superscript𝐷subscript𝑝hsubscript𝑀2𝑏superscript𝑏subscript𝑝pe2𝑏italic-ϕ\left\lceil\frac{\left\lceil\frac{M_{1}}{b}\right\rceil\left\lceil\frac{D^{% \prime}}{b}\right\rceil}{p_{\text{t}}p_{\text{c}}}\right\rceil\left\lceil\frac% {\left\lceil\frac{D}{D^{\prime}}\right\rceil}{p_{\text{h}}}\right\rceil\left% \lceil\frac{M_{2}}{b}\right\rceil\left\lceil\frac{b}{p_{\text{pe}}}\right% \rceil^{2}b\phi⌈ divide start_ARG ⌈ divide start_ARG italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_b end_ARG ⌉ ⌈ divide start_ARG italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG italic_b end_ARG ⌉ end_ARG start_ARG italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT end_ARG ⌉ ⌈ divide start_ARG ⌈ divide start_ARG italic_D end_ARG start_ARG italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG ⌉ end_ARG start_ARG italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT end_ARG ⌉ ⌈ divide start_ARG italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG italic_b end_ARG ⌉ ⌈ divide start_ARG italic_b end_ARG start_ARG italic_p start_POSTSUBSCRIPT pe end_POSTSUBSCRIPT end_ARG ⌉ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_b italic_ϕ
DHBMM M1bDbptpcHphM2bbppe2bsubscript𝑀1𝑏𝐷𝑏subscript𝑝tsubscript𝑝c𝐻subscript𝑝hsubscript𝑀2𝑏superscript𝑏subscript𝑝pe2𝑏\left\lceil\frac{\left\lceil\frac{M_{1}}{b}\right\rceil\left\lceil\frac{D}{b}% \right\rceil}{{p_{\text{t}}p_{\text{c}}}}\right\rceil\left\lceil\frac{H}{p_{% \text{h}}}\right\rceil\left\lceil\frac{M_{2}}{b}\right\rceil\left\lceil\frac{b% }{p_{\text{pe}}}\right\rceil^{2}b⌈ divide start_ARG ⌈ divide start_ARG italic_M start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_b end_ARG ⌉ ⌈ divide start_ARG italic_D end_ARG start_ARG italic_b end_ARG ⌉ end_ARG start_ARG italic_p start_POSTSUBSCRIPT t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT c end_POSTSUBSCRIPT end_ARG ⌉ ⌈ divide start_ARG italic_H end_ARG start_ARG italic_p start_POSTSUBSCRIPT h end_POSTSUBSCRIPT end_ARG ⌉ ⌈ divide start_ARG italic_M start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG italic_b end_ARG ⌉ ⌈ divide start_ARG italic_b end_ARG start_ARG italic_p start_POSTSUBSCRIPT pe end_POSTSUBSCRIPT end_ARG ⌉ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_b

VI Implementation Details

Evaluated Model: We evaluate our approach on the widely used DeiT-Small [45] model, which has 12 layers, with each layer having six heads. The hidden dimension is D=384𝐷384D=384italic_D = 384, and the (base) model has 22222222M parameters.

Implementation details of weight pruning, token pruning, and simultaneous training: The DeiT-Small model is simultaneously pruned as per algorithm 1. We train several variants of the model by varying the model pruning top-k𝑘kitalic_k rate rbsubscript𝑟𝑏r_{b}italic_r start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT, token pruning keep rate rtsubscript𝑟𝑡r_{t}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and block size b𝑏bitalic_b. Specifically, rbsubscript𝑟𝑏r_{b}italic_r start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT is varied over {0.5,0.7}0.50.7\{0.5,0.7\}{ 0.5 , 0.7 }, rtsubscript𝑟𝑡r_{t}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over {0.5,0.7,0.9}0.50.70.9\{0.5,0.7,0.9\}{ 0.5 , 0.7 , 0.9 } and b𝑏bitalic_b over {16,32}1632\{16,32\}{ 16 , 32 }. A cubic sparsity scheduler, as in [17], is used to schedule rbsubscript𝑟𝑏r_{b}italic_r start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT from a full density of 1111 to its final density (0.50.50.50.5 or 0.70.70.70.7) with a warm-up and a cool-down phase. The token-dropping module, TDM, is inserted in the 3rdsuperscript3rd3^{\text{rd}}3 start_POSTSUPERSCRIPT rd end_POSTSUPERSCRIPT, 7thsuperscript7th7^{\text{th}}7 start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT and 10thsuperscript10th10^{\text{th}}10 start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT encoder layers. All model variants are trained for a total of 30303030 epochs using the AdamW optimizer [46] with a learning rate of 2×1052superscript1052\times 10^{-5}2 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT, weight decay of 0.010.010.010.01 and a batch size of 128128128128 across 4444 GPUs. Note that to train all pruned variants as well as the baseline, we use the pre-trained DeiT-Small model available at [47] with the classification MLP head parameters re-initialized. A ViT base model is used as the teacher for knowledge distillation.

Hardware implementation details: We implement our FPGA design on a state-of-the-art FPGA platform, Xilinx Alveo U250, which consists of four Super Logic Regions (SLRs). We implement the proposed hardware design using Xilinx High-level Synthesis (HLS). For the MPCA module, we empirically determine the hardware hyperparameters to be ph=4subscript𝑝4p_{h}=4italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 4, pt=12subscript𝑝𝑡12p_{t}=12italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 12, pc=2subscript𝑝𝑐2p_{c}=2italic_p start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = 2, ppe=8subscript𝑝𝑝𝑒8p_{pe}=8italic_p start_POSTSUBSCRIPT italic_p italic_e end_POSTSUBSCRIPT = 8 according to the hardware resources of the target FPGA board: (1) We set ph=4subscript𝑝4p_{h}=4italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT = 4 because Alveo u250 has four SLRs, with each SLR placed in one CHM. (2) We set pc=2subscript𝑝𝑐2p_{c}=2italic_p start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = 2 because in a CHM, the PEs of the same row load the same rows of tokens but different data blocks. The BRAM/URAM on FPGA has two independent memory ports, which can support concurrent memory access of 2222 columns of PEs (pc=2subscript𝑝𝑐2p_{c}=2italic_p start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = 2). (3) We set ppe=8subscript𝑝𝑝𝑒8p_{pe}=8italic_p start_POSTSUBSCRIPT italic_p italic_e end_POSTSUBSCRIPT = 8 because the data block size b𝑏bitalic_b is set as 16161616 or 32323232 for block-wise weight pruning. Using ppe=8subscript𝑝𝑝𝑒8p_{pe}=8italic_p start_POSTSUBSCRIPT italic_p italic_e end_POSTSUBSCRIPT = 8 can support these two block sizes without data padding as well as keeping a reasonable value for ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. (4) For setting ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, the PEs of the same column within a CHM shares the same weight blocks. The weight blocks are broadcast into each PE of the same column, which supports any value of ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. We set ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT according to the available resources of the target FPGA board after determining phsubscript𝑝p_{h}italic_p start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT, pcsubscript𝑝𝑐p_{c}italic_p start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT, and ppesubscript𝑝𝑝𝑒p_{pe}italic_p start_POSTSUBSCRIPT italic_p italic_e end_POSTSUBSCRIPT. We use the int16 data format. We utilize the four DDR4 channels of U250, which have 77 GB/s of external memory bandwidth in total. We perform synthesis and place-route for the design using Xilinx Vitis v2022.2. We report the frequency and FPGA resource utilization after place-route. The achieved frequency is 300300300300 MHz, and the resource utilization is shown in Table IV.

TABLE IV: FPGA resource utilization

LUTs DSPs URAMs BRAMs HeatViT [37] 137.6K161.4K 1955-2066 N/A 338-528 Auto-Vit-Acc [48] 120K-193K 13-2066 N/A N/A Our Work 798K 7088 1728 960

VII Experiments and Results

VII-A Baselines, Metrics, Datasets

Baselines: We compare our implementation on FPGA with the state-of-art CPU, GPU, and FPGA accelerators including [37] and [49]. Table V shows the details of these platforms.

TABLE V: Specifications of platforms

CPU GPU HeatViT [37] SPViT [35] Our work Platform AMD EPYC 9654 NVIDIA RTX 6000 Ada Xilinx ZCU102 Xilinx ZCU102 Xilinx Alveo U250 Frequency 2.4 GHz 915 MHz 150 MHz 200 MHz 300 MHz Peak Performance (TFLOPS) 3.69 91.06 0.37 0.54 1.8 On-chip Memory 384 MB 96MB 3.6MB 4MB 36 MB Memory Bandwidth 461 GB/s 960 GB/s 19.2 GB/s 19.2 GB/s 77 GB/s

Datasets: Following prior works [28][37], we use ImageNet dataset in our experiments with approximately 1.2 million images to evaluate our approach.

Performance Metrics: We utilize the following performance metrics: (1) Accuracy: Following prior works, we evaluate the accuracy of our pruned model on ImageNet. (2) Inference latency: Following prior works [37, 49, 35], we measured inference latency via hardware emulation using AMD-Xilinx Vitis, which accurately simulates the behavior of FPGA DDR. The measured latency is end-to-end from the time when the input is given at DDR to the time when the inference result is written back to DDR. (3) Throughput: Throughput denotes the number of images that can be processed for a given time frame. (4) Computation complexity (FLOPS): We measure the computational complexity, which is the number of floating-point operations (FLOPs). (5) Model size: The amount of memory space (MB) to store the model.

TABLE VI: The experimental results for different pruning settings

Notion Block Pruning Token Pruning Head Retained Ratio Model Parameters Training Epochs Accuracy FPGA Latency (ms) FPGA Throughput (images/second) Block Size b𝑏bitalic_b Top-k𝑘kitalic_k Rate rbsubscript𝑟𝑏r_{b}italic_r start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT Token Keep Rate rtsubscript𝑟𝑡r_{t}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT Model Size MACs (Baseline) 16 1 1 1 22M 4.27G 30 79.59% 3.19 313.00 (Baseline) 32 1 1 1 22M 4.27G 30 79.59% 3.55 281.43 (Pruned) 16 0.5 0.5 0.91 14.29M 1.32G 30 66.86% 0.868 1151.55 (Pruned) 16 0.5 0.7 0.91 14.29M 1.79G 30 68.62% 1.169 855.12 (Pruned) 16 0.5 0.9 0.93 14.39M 2.43G 30 70.14% 1.479 676.10 (Pruned) 16 0.7 0.5 0.98 17.63M 1.62G 30 74.12% 1.140 877.054 (Pruned) 16 0.7 0.7 0.98 17.63M 2.20G 30 75.96% 1.553 643.72 (Pruned) 16 0.7 0.9 0.98 17.63M 2.98G 30 76.55% 1.953 511.94 (Pruned) 32 0.5 0.5 0.84 13.80M 1.25G 30 67.25% 1.621 616.79 (Pruned) 32 0.5 0.7 0.83 13.70M 1.70G 30 68.62% 1.796 556.66 (Pruned) 32 0.5 0.9 0.84 13.80M 2.31G 30 70.06% 1.999 500.17 (Pruned) 32 0.7 0.5 0.97 17.53M 1.61G 30 73.45% 2.126 470.33 (Pruned) 32 0.7 0.7 0.94 17.33M 2.16G 30 75.65% 2.353 424.93 (Pruned) 32 0.7 0.9 0.94 17.33M 2.93G 30 76.40% 2.590 386.02

VII-B Evaluation for the Pruning Algorithm

Results in Table VI indicate that for extreme pruning settings (rbsubscript𝑟𝑏r_{b}italic_r start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT, rtsubscript𝑟𝑡r_{t}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT both 0.50.50.50.5), the accuracy drop (12%absentpercent12\approx 12\%≈ 12 %) compared against the baseline is not insignificant. A major reason for this drop is the fact that the training epochs for our experiments were restricted to 30303030 despite the reduction in model and input density. With a lower top-k𝑘kitalic_k rate rbsubscript𝑟𝑏r_{b}italic_r start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT and token keep rate rtsubscript𝑟𝑡r_{t}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, the model requires larger epochs to converge. Compared to the baseline DeiT-Small model, the proposed simultaneous pruning algorithm achieves a compression ratio of up to 1.24×1.24\times1.24 × to 1.60×1.60\times1.60 × and a reduction in the computational cost of up to 1.43×1.43\times1.43 × to 3.42×3.42\times3.42 × with an accuracy drop of as little as 3%absentpercent3\approx 3\%≈ 3 %. Whilst prior works focus on either reducing the model size [44] or on reducing the computational complexity [37, 35], our proposed simultaneous pruning algorithm targets both.

VII-C Evaluation on the FPGA accelerator

Refer to caption
Figure 9: Comparison of latency under various pruning settings when batch size is 1 for all platforms. CPU, GPU, and FPGA execute the same model.
Refer to caption
Figure 10: Comparison of throughput under various pruning settings when batch size is 8 for CPU and GPU and batch size is 1 for FPGA.

VII-C1 Cross platform comparison

We compare the latency and throughput for executing the pruned model with baseline CPU and GPU (Figure 9 and 10). The latency of our accelerator is measured when batch size is 1111, and the throughput is calculated by 1latency1latency\frac{1}{\text{latency}}divide start_ARG 1 end_ARG start_ARG latency end_ARG. For comparing the latency, we set the batch size as 1111 for CPU and GPU because a larger batch size will increase the latency for CPU and GPU. For throughput comparison, we set the batch size as 8888 for CPU and GPU, which can fully exploit their thread-level parallelism. On average, our FPGA accelerator achieves a latency reduction of 12.8×12.8\times12.8 × and 3.2×3.2\times3.2 ×, compared with CPU and GPU, respectively. The lower latency of our FPGA accelerator is due to the followings: (1) our MPCA module with load balance strategy fully exploits the computation parallelism within the pruned model. The FPGA accelerator achieves a higher speedup with higher pruning ratios (smaller rbsubscript𝑟𝑏r_{b}italic_r start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT and rtsubscript𝑟𝑡r_{t}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT). In contrast, the CPU and GPU cannot efficiently handle the computational irregularity caused by weight pruning and dynamic token dropping. (2) CPU and GPU have complex cache hierarchies, leading to higher memory access latency for executing ViT inference, leading to increased latency. On average, our FPGA accelerator achieves 3.6×3.6\times3.6 × and 0.45×0.45\times0.45 × throughput speedup compared with CPU and GPU, respectively. Our FPGA accelerator achieves a lower throughput (0.45×0.45\times0.45 ×) than GPU, because GPU has much higher peak performance (50×50\times50 ×) and eternal memory bandwidth. When the pruning ratios become high (e.g., rb=0.5subscript𝑟𝑏0.5r_{b}=0.5italic_r start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT = 0.5 and rt=0.5subscript𝑟𝑡0.5r_{t}=0.5italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 0.5), our throughput gets closer to GPU, which indicates that our FPGA accelerator has higher efficiency for executing the ViT model with larger pruning ratio.

VII-C2 Comparison with state-of-the-art

We compare the proposed codesign with the state-of-the-art ViT Accelerators [48, 37, 35] on FPGA as shown in Table VII. Prior works use at most one pruning approach. ViTAcc [48] and [37] use int4 or int8 to represent the weights and activations. In contrast, our work is the first algorithm-hardware codesign to combine two pruning approaches. In terms of latency, our accelerator achieves 6.218.5×6.2-18.5\times6.2 - 18.5 × latency reduction compared with the prior accelerator. As different accelerators use different numbers of computation units, which directly influences their peak performance (shown in Table V), we further normalize the latency by their respective peak performance (Normalized Latency=Latency×Peak PerformanceNormalized LatencyLatencyPeak Performance\text{Normalized Latency}=\text{Latency}\times\text{Peak Performance}Normalized Latency = Latency × Peak Performance) to obtain a fair comparison. Our accelerator achieves a normalized speedup of 1.54.5×1.5-4.5\times1.5 - 4.5 × compared with SPViT [35] and achieves a normalized speedup of 0.722.1×0.72-2.1\times0.72 - 2.1 × compared with HeatViT [37]. Our accelerator achieves higher speedup by executing the model with a higher pruning ratio. The achieved speedup is attributed to (1) in addition to token pruning, we further utilize the model pruning to reduce the computational complexity compared with [35, 37], (2) our architecture design using MPCA can efficiently utilize the block-wise data sparsity in the pruned model.

TABLE VII: Comparison with state-of-the-art ViT Accelerators

ViTAcc [48] HeatViT [37] SPViT [35] Our Work Platform Xilinx ZCU102 Xilinx ZCU102 Xilinx ZCU102 Xilinx Alveo U250 Accuracy 77.94% 79.00% 79.34% 66.8%-76.5% Quantization (bits) int4-8 int8 int16 int16 Model Pruning Token Pruning Latency(ms) 26 9.1-17.5 13.23 0.868-2.59

VIII Conclusion and Future Work

In this paper, we proposed an algorithm-hardware codesign that simultaneously utilizes the static weight pruning and dynamic token pruning approaches. It bridges the gap of prior works that utilize only one pruning algorithm, further reducing the computation complexity of ViT. The proposed hardware accelerator can efficiently execute the pruned model through novel hardware architecture design. In the future, we plan to develop a design automation framework that automatically generates optimized implementation for the pruned ViT model given a target FPGA platform.

Acknowledgement

This work is supported by the DEVCOM Army Research Lab (ARL) under grants W911NF2220159, and the National Science Foundation (NSF) under grants CCF-1919289 and SaTC-2104264. Equipment and support by AMD AECG are greatly appreciated. Distribution Statement A: Approved for public release. Distribution is unlimited.

References

  • [1] A. Dosovitskiy, L. Beyer, A. Kolesnikov, D. Weissenborn, X. Zhai, T. Unterthiner, M. Dehghani, M. Minderer, G. Heigold, S. Gelly et al., “An image is worth 16x16 words: Transformers for image recognition at scale,” arXiv preprint arXiv:2010.11929, 2020.
  • [2] K. Han, A. Xiao, E. Wu, J. Guo, C. Xu, and Y. Wang, “Transformer in transformer,” Advances in neural information processing systems, vol. 34, pp. 15 908–15 919, 2021.
  • [3] M. Chen, A. Radford, R. Child, J. Wu, H. Jun, D. Luan, and I. Sutskever, “Generative pretraining from pixels,” in International conference on machine learning.   PMLR, 2020, pp. 1691–1703.
  • [4] X. Chen, S. Xie, and K. He, “An empirical study of training self-supervised vision transformers,” in Proceedings of the IEEE/CVF international conference on computer vision, 2021, pp. 9640–9649.
  • [5] N. Carion, F. Massa, G. Synnaeve, N. Usunier, A. Kirillov, and S. Zagoruyko, “End-to-end object detection with transformers,” in European conference on computer vision.   Springer, 2020, pp. 213–229.
  • [6] H. Wang, Y. Zhu, H. Adam, A. Yuille, and L.-C. Chen, “Max-deeplab: End-to-end panoptic segmentation with mask transformers,” in Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, 2021, pp. 5463–5474.
  • [7] Y. Jiang, S. Chang, and Z. Wang, “Transgan: Two pure transformers can make one strong gan, and that can scale up,” Advances in Neural Information Processing Systems, vol. 34, pp. 14 745–14 758, 2021.
  • [8] A. Ramesh, M. Pavlov, G. Goh, S. Gray, C. Voss, A. Radford, M. Chen, and I. Sutskever, “Zero-shot text-to-image generation,” in International conference on machine learning.   Pmlr, 2021, pp. 8821–8831.
  • [9] S. Khan, M. Naseer, M. Hayat, S. W. Zamir, F. S. Khan, and M. Shah, “Transformers in vision: A survey,” ACM computing surveys (CSUR), vol. 54, no. 10s, pp. 1–41, 2022.
  • [10] N. Park and S. Kim, “How do vision transformers work?” arXiv preprint arXiv:2202.06709, 2022.
  • [11] W. Zhu, “Token propagation controller for efficient vision transformer,” arXiv preprint arXiv:2401.01470, 2024.
  • [12] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, Ł. Kaiser, and I. Polosukhin, “Attention is all you need,” Advances in neural information processing systems, vol. 30, 2017.
  • [13] X. Ma, G. Yuan, X. Shen, T. Chen, X. Chen, X. Chen, N. Liu, M. Qin, S. Liu, Z. Wang et al., “Sanity checks for lottery tickets: Does your winning ticket really win the jackpot?” Advances in Neural Information Processing Systems, vol. 34, pp. 12 749–12 760, 2021.
  • [14] J. Frankle and M. Carbin, “The lottery ticket hypothesis: Finding sparse, trainable neural networks,” arXiv preprint arXiv:1803.03635, 2018.
  • [15] N. Liu, G. Yuan, Z. Che, X. Shen, X. Ma, Q. Jin, J. Ren, J. Tang, S. Liu, and Y. Wang, “Lottery ticket preserves weight correlation: Is it desirable or not?” in International Conference on Machine Learning.   PMLR, 2021, pp. 7011–7020.
  • [16] T. Zhang, X. Ma, Z. Zhan, S. Zhou, M. Qin, F. Sun, Y.-K. Chen, C. Ding, M. Fardad, and Y. Wang, “A unified dnn weight compression framework using reweighted optimization methods,” arXiv preprint arXiv:2004.05531, 2020.
  • [17] V. Sanh, T. Wolf, and A. Rush, “Movement pruning: Adaptive sparsity by fine-tuning,” Advances in neural information processing systems, vol. 33, pp. 20 378–20 389, 2020.
  • [18] B. Li, Z. Kong, T. Zhang, J. Li, Z. Li, H. Liu, and C. Ding, “Efficient transformer-based large scale language representations using hardware-friendly block structured pruning,” arXiv preprint arXiv:2009.08065, 2020.
  • [19] H. Wang, Z. Zhang, and S. Han, “Spatten: Efficient sparse attention architecture with cascade token and head pruning,” in 2021 IEEE International Symposium on High-Performance Computer Architecture (HPCA).   IEEE, 2021, pp. 97–110.
  • [20] S. Anwar, K. Hwang, and W. Sung, “Structured pruning of deep convolutional neural networks,” ACM Journal on Emerging Technologies in Computing Systems (JETC), vol. 13, no. 3, pp. 1–18, 2017.
  • [21] P. Molchanov, S. Tyree, T. Karras, T. Aila, and J. Kautz, “Pruning convolutional neural networks for resource efficient inference,” arXiv preprint arXiv:1611.06440, 2016.
  • [22] Y. He and L. Xiao, “Structured pruning for deep convolutional neural networks: A survey,” IEEE Transactions on Pattern Analysis and Machine Intelligence, 2023.
  • [23] F. Lagunas, E. Charlaix, V. Sanh, and A. M. Rush, “Block pruning for faster transformers,” arXiv preprint arXiv:2109.04838, 2021.
  • [24] S. Han, J. Pool, J. Tran, and W. Dally, “Learning both weights and connections for efficient neural network,” Advances in neural information processing systems, vol. 28, 2015.
  • [25] F. Yu, K. Huang, M. Wang, Y. Cheng, W. Chu, and L. Cui, “Width & depth pruning for vision transformers,” in Proceedings of the AAAI Conference on Artificial Intelligence, vol. 36, no. 3, 2022, pp. 3143–3151.
  • [26] C. Zheng, K. Zhang, Z. Yang, W. Tan, J. Xiao, Y. Ren, S. Pu et al., “Savit: Structure-aware vision transformer pruning via collaborative optimization,” Advances in Neural Information Processing Systems, vol. 35, pp. 9010–9023, 2022.
  • [27] T. Chen, Y. Cheng, Z. Gan, L. Yuan, L. Zhang, and Z. Wang, “Chasing sparsity in vision transformers: An end-to-end exploration,” Advances in Neural Information Processing Systems, vol. 34, pp. 19 974–19 988, 2021.
  • [28] Y. Liang, C. Ge, Z. Tong, Y. Song, J. Wang, and P. Xie, “Not all patches are what you need: Expediting vision transformers via token reorganizations,” arXiv preprint arXiv:2202.07800, 2022.
  • [29] M. Fayyaz, S. A. Koohpayegani, F. R. Jafari, S. Sengupta, H. R. V. Joze, E. Sommerlade, H. Pirsiavash, and J. Gall, “Adaptive token sampling for efficient vision transformers,” in European Conference on Computer Vision.   Springer, 2022, pp. 396–414.
  • [30] Y. Tang, K. Han, Y. Wang, C. Xu, J. Guo, C. Xu, and D. Tao, “Patch slimming for efficient vision transformers,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2022, pp. 12 165–12 174.
  • [31] B. Pan, R. Panda, Y. Jiang, Z. Wang, R. Feris, and A. Oliva, “Ia-red22{}^{2}start_FLOATSUPERSCRIPT 2 end_FLOATSUPERSCRIPT: Interpretability-aware redundancy reduction for vision transformers,” Advances in Neural Information Processing Systems, vol. 34, pp. 24 898–24 911, 2021.
  • [32] Y. Xu, Z. Zhang, M. Zhang, K. Sheng, K. Li, W. Dong, L. Zhang, C. Xu, and X. Sun, “Evo-vit: Slow-fast token evolution for dynamic vision transformer,” in Proceedings of the AAAI Conference on Artificial Intelligence, vol. 36, no. 3, 2022, pp. 2964–2972.
  • [33] S. Yu, T. Chen, J. Shen, H. Yuan, J. Tan, S. Yang, J. Liu, and Z. Wang, “Unified visual transformer compression,” arXiv preprint arXiv:2203.08243, 2022.
  • [34] S. Kim, S. Shen, D. Thorsley, A. Gholami, W. Kwon, J. Hassoun, and K. Keutzer, “Learned token pruning for transformers,” in Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, 2022, pp. 784–794.
  • [35] Z. Kong, P. Dong, X. Ma, X. Meng, W. Niu, M. Sun, X. Shen, G. Yuan, B. Ren, H. Tang et al., “Spvit: Enabling faster vision transformers via latency-aware soft token pruning,” in European conference on computer vision.   Springer, 2022, pp. 620–640.
  • [36] Y. Rao, W. Zhao, B. Liu, J. Lu, J. Zhou, and C.-J. Hsieh, “Dynamicvit: Efficient vision transformers with dynamic token sparsification,” Advances in neural information processing systems, vol. 34, pp. 13 937–13 949, 2021.
  • [37] P. Dong, M. Sun, A. Lu, Y. Xie, K. Liu, Z. Kong, X. Meng, Z. Li, X. Lin, Z. Fang et al., “Heatvit: Hardware-efficient adaptive token pruning for vision transformers,” in 2023 IEEE International Symposium on High-Performance Computer Architecture (HPCA).   IEEE, 2023, pp. 442–455.
  • [38] J. L. Ba, J. R. Kiros, and G. E. Hinton, “Layer normalization,” arXiv preprint arXiv:1607.06450, 2016.
  • [39] L. Lu, Y. Jin, H. Bi, Z. Luo, P. Li, T. Wang, and Y. Liang, “Sanger: A co-design framework for enabling sparse attention using reconfigurable architecture,” in MICRO-54: 54th Annual IEEE/ACM International Symposium on Microarchitecture, 2021, pp. 977–991.
  • [40] Y. Bengio, “Estimating or propagating gradients through stochastic neurons,” arXiv preprint arXiv:1305.2982, 2013.
  • [41] V. Ramanujan, M. Wortsman, A. Kembhavi, A. Farhadi, and M. Rastegari, “What’s hidden in a randomly weighted neural network?” in Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, 2020, pp. 11 893–11 902.
  • [42] A. Mallya, D. Davis, and S. Lazebnik, “Piggyback: Adapting a single network to multiple tasks by learning to mask weights,” in Proceedings of the European conference on computer vision (ECCV), 2018, pp. 67–82.
  • [43] K. T. Chitty-Venkata, S. Mittal, M. Emani, V. Vishwanath, and A. K. Somani, “A survey of techniques for optimizing transformer inference,” Journal of Systems Architecture, p. 102990, 2023.
  • [44] H. Peng, S. Huang, T. Geng, A. Li, W. Jiang, H. Liu, S. Wang, and C. Ding, “Accelerating transformer-based deep learning models on fpgas using column balanced block pruning,” in 2021 22nd International Symposium on Quality Electronic Design (ISQED).   IEEE, 2021, pp. 142–148.
  • [45] H. Touvron, M. Cord, M. Douze, F. Massa, A. Sablayrolles, and H. Jégou, “Training data-efficient image transformers & distillation through attention,” in International conference on machine learning.   PMLR, 2021, pp. 10 347–10 357.
  • [46] I. Loshchilov and F. Hutter, “Decoupled weight decay regularization,” arXiv preprint arXiv:1711.05101, 2017.
  • [47] “HuggingFace Model Hub DeiT-Small Model,” https://huggingface.co/facebook/deit-small-distilled-patch16-224, accessed: 2024-01-15.
  • [48] Z. Lit, M. Sun, A. Lu, H. Ma, G. Yuan, Y. Xie, H. Tang, Y. Li, M. Leeser, Z. Wang et al., “Auto-vit-acc: An fpga-aware automatic acceleration framework for vision transformer with mixed-scheme quantization,” in 2022 32nd International Conference on Field-Programmable Logic and Applications (FPL).   IEEE, 2022, pp. 109–116.
  • [49] T. Wang, L. Gong, C. Wang, Y. Yang, Y. Gao, X. Zhou, and H. Chen, “Via: A novel vision-transformer accelerator based on fpga,” IEEE Transactions on Computer-Aided Design of Integrated Circuits and Systems, vol. 41, no. 11, pp. 4088–4099, 2022.