Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
License: CC BY 4.0
arXiv:2312.17244v2 [cs.LG] 20 Mar 2024

The LLM Surgeon

Tycho F.A. van der Ouderaa11{}^{1}start_FLOATSUPERSCRIPT 1 end_FLOATSUPERSCRIPT  , Markus Nagel22{}^{2}start_FLOATSUPERSCRIPT 2 end_FLOATSUPERSCRIPT, Mart van Baalen22{}^{2}start_FLOATSUPERSCRIPT 2 end_FLOATSUPERSCRIPT,
Yuki M. Asano33{}^{3}start_FLOATSUPERSCRIPT 3 end_FLOATSUPERSCRIPT, Tijmen Blankevoort22{}^{2}start_FLOATSUPERSCRIPT 2 end_FLOATSUPERSCRIPT
11{}^{1}start_FLOATSUPERSCRIPT 1 end_FLOATSUPERSCRIPTImperial College London , 22{}^{2}start_FLOATSUPERSCRIPT 2 end_FLOATSUPERSCRIPTQualcomm AI Research , 33{}^{3}start_FLOATSUPERSCRIPT 3 end_FLOATSUPERSCRIPTQUVA Lab, University of Amsterdam
Work done while doing an internship at Qualcomm AI ResearchQualcomm AI Research is an initiative of Qualcomm Technologies, Inc.
Abstract

State-of-the-art language models are becoming increasingly large in an effort to achieve the highest performance on large corpora of available textual data. However, the sheer size of the Transformer architectures makes it difficult to deploy models within computational, environmental or device-specific constraints. We explore data-driven compression of existing pretrained models as an alternative to training smaller models from scratch. To do so, we scale Kronecker-factored curvature approximations of the target loss landscape to large language models. In doing so, we can compute both the dynamic allocation of structures that can be removed as well as updates of remaining weights that account for the removal. We provide a general framework for unstructured, semi-structured and structured pruning and improve upon weight updates to capture more correlations between weights, while remaining computationally efficient. Experimentally, our method can prune rows and columns from a range of OPT models and Llamav2-7B by 20%-30%, with a negligible loss in performance, and achieve state-of-the-art results in unstructured and semi-structured pruning of large language models.
Code is available at: https://github.com/Qualcomm-AI-research/llm-surgeon.

Structured compression (rows and columns) Unstructured compression (matrix elements)
Refer to caption

Figure 1: LLM Surgeon allows interpolation of model size between existing pretrained models.

1 Introduction

Recent advancements in language modeling (Vaswani et al., 2017) allow fitting large language models (LLMs) with millions or even billions of parameters (such as OPT (Zhang et al., 2022) and Llama 2 (Touvron et al., 2023)) on big text corpora achieving high performance. Unfortunately, the size of these LLMs often makes it hard to deploy them within practical constraints. Cloud-based deployment can get very expensive for larger models, and efficient devices such as phones are frequently limited in the memory size to host a model.

A body of literature extending back to the late 1980s, e.g., Optimal Brain Damage (OBD, LeCun et al. (1989)) and Optimal Brain Surgeon (OBS, Hassibi & Stork (1992)), phrases pruning as a constraint optimization problem to reduce a model’s footprint and runtime requirements. The Hessian required for this approach grows with the square of the number of parameters, and can only be computed in practice for unrealistically small networks. To overcome this issue, Eigendamage (Wang et al., 2019) introduces a Kronecker factorization of a blockwise-diagonal approximation of the Hessian. Recent works, like Optimal Brain Compression (Frantar & Alistarh, 2022), SparseGPT (Frantar & Alistarh, 2023), demonstrate practical post-training pruning of LLMs, but only consider a loss curvature of a pruned layer’s squared output reconstruction error, ignoring gradients that relate local removal costs to the target loss. As a result, their approximation to the target loss landscape is inaccurate, leading to a significant performance degradation for pruned LLMs. Further, these methods do not readily extend to structured pruning.

This work introduces LLM Surgeon, a general framework for unstructured, semi-structured and structured pruning of LLMs. At paper submission, we deemed this the first method to successfully perform structured pruning of LLMs. Concurrent work by Ashkboos et al. (2024) also considers structured pruning of LLMs but ignores gradient information, resulting in lower final performance. The superior performance of LLM Surgeon is achieved by scaling up the block-diagonal Kronecker-factorized approximations to the empirical Fisher from Eigendamage to LLMs. We further expand upon the work by deriving OBS-like weight pruning costs and updates for structured pruning of multiple rows and columns, and provide a general framework that also incorporates semi-structured and unstructured pruning. Instead of treating individual weight updates independently, we strive to consider as many correlations between weights as practically possible and derive joint weight updates for pruning multiple weights (or multiple sets of structured weights) at once. Unlike prior work in LLM pruning, LLM Surgeon prunes in multiple shots, updating weights and curvature estimates between shots. We use global thresholding for unstructured, semi-structured and structured, i.e., instead of pruning layers by a fixed amount, more sensitive layers are pruned less than those that are more robust. Lastly, we propose to mitigate possible first-order gradients not being zero by using optional low-rank first-order updates between shots. A key advantage of LLM Surgeon is that it allows trading off additional compute during compression for better accuracy by increasing the number of correlations and/or shots. Our method gives the first practically usable results for structured pruning of LLMs – they can be pruned by up to 30% with minor performance degradation. Furthermore, we achieve state-of-the-art results in unstructured and semi-structured LLM pruning.

2 Background and related work

Neural network pruning aims to remove parameters from a model while minimizing negative impact on final performance. More formally, we denote the P𝑃Pitalic_P model parameters as vector 𝜽*=vec(𝑾1*,𝑾2*,𝑾L*)Psuperscript𝜽vecsubscriptsuperscript𝑾1subscriptsuperscript𝑾2subscriptsuperscript𝑾𝐿superscript𝑃{\bm{\theta}}^{*}=\text{vec}({\bm{W}}^{*}_{1},{\bm{W}}^{*}_{2},\ldots{\bm{W}}^% {*}_{L})\in\mathbb{R}^{P}bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT = vec ( bold_italic_W start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_W start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … bold_italic_W start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT, by flattening the L𝐿Litalic_L weight matrices of attention and fully-connected blocks, with already fitted 𝜽*argmin𝜽(𝜽)superscript𝜽subscriptargmin𝜽𝜽{\bm{\theta}}^{*}{\approx}\operatorname*{arg\,min}_{\bm{\theta}}\mathcal{L}({% \bm{\theta}})bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ≈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) to data 𝒟𝒟\mathcal{D}caligraphic_D to minimise a negative likelihood loss (𝜽)=logp(𝜽|𝒟)𝜽𝑝conditional𝜽𝒟\mathcal{L}({\bm{\theta}}){=}-\log p({\bm{\theta}}|\mathcal{D})caligraphic_L ( bold_italic_θ ) = - roman_log italic_p ( bold_italic_θ | caligraphic_D ). To compress the model, we are looking for a pruned vector 𝜽^^𝜽\hat{{\bm{\theta}}}over^ start_ARG bold_italic_θ end_ARG:

𝜽^=argmin𝜽(𝜽) s.t. pruning constraints based on 𝜽*^𝜽subscriptargmin𝜽𝜽 s.t. pruning constraints based on superscript𝜽\displaystyle\hat{{\bm{\theta}}}=\operatorname*{arg\,min}\nolimits_{\bm{\theta% }}\mathcal{L}({\bm{\theta}})\text{ s.t. pruning constraints based on }{\bm{% \theta}}^{*}over^ start_ARG bold_italic_θ end_ARG = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( bold_italic_θ ) s.t. pruning constraints based on bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT (1)

where chosen constraints determine the structure of compressed weights 𝜽^^𝜽\hat{{\bm{\theta}}}over^ start_ARG bold_italic_θ end_ARG. In unstructured pruning, a fraction of total weight elements is set to zero. In semi-structured pruning of M:N we have that M weights of every N consecutive weights are zero (Zhou et al., 2021; Hubara et al., 2021). And in structured pruning (Louizos et al., 2017), entire rows and columns are set to zero. Structured pruning leads to the most immediate gains in memory and computing, as it directly reduces the dimensions of matrices that need to be represented explicitly but is regarded as the most difficult to compress. Maintaining high performance is often easier in the other schemes but requires specialised arithmetic exploiting the sparsity structure to benefit at deployment. We consider all pruning types above, with a focus on structured pruning for LLMs.

Typically, eq. 1 can not be solved directly, as the space of possible pruning configurations exceeds what can be evaluated in practice. To illustrate, a search over all possible unstructured pruning masks of a 125 million parameter LLM would require 2P=2125m1037628749superscript2𝑃superscript2125msuperscript10376287492^{P}{=}2^{125\text{m}}{\approx}10^{37628749}2 start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT = 2 start_POSTSUPERSCRIPT 125 m end_POSTSUPERSCRIPT ≈ 10 start_POSTSUPERSCRIPT 37628749 end_POSTSUPERSCRIPT evaluations. The idea, therefore, is to find 𝜽^^𝜽\hat{{\bm{\theta}}}over^ start_ARG bold_italic_θ end_ARG using a surrogate of the loss landscape q𝑞qitalic_q that is easier to work with:

(𝜽)=logp(𝒟𝜽)logq(𝜽)𝜽𝑝conditional𝒟𝜽𝑞𝜽\displaystyle\mathcal{L}({\bm{\theta}})=-\log p(\mathcal{D}\mid{\bm{\theta}})% \approx-\log q({\bm{\theta}})caligraphic_L ( bold_italic_θ ) = - roman_log italic_p ( caligraphic_D ∣ bold_italic_θ ) ≈ - roman_log italic_q ( bold_italic_θ ) (2)

If one chooses a particular Gaussian form for our surrogate q𝑞qitalic_q, then solutions for unstructured, semi-structured, and structured pruning constraints can be derived in closed-form (appendix A).

2.1 Taylor expansion

How do we obtain a good surrogate of the loss q𝑞qitalic_q? One of the easiest approaches is to locally expand the log loss through a second-order Taylor expansion around the pretrained weights 𝜽*superscript𝜽{\bm{\theta}}^{*}bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT, yielding:

logq(𝜽)logp(𝒟|𝜽*)(𝜽𝜽*)T(𝜽*)12(𝜽𝜽*)T𝑯𝜽*(𝜽𝜽*)𝑞𝜽𝑝conditional𝒟superscript𝜽superscript𝜽superscript𝜽𝑇superscript𝜽12superscript𝜽superscript𝜽𝑇subscript𝑯superscript𝜽𝜽superscript𝜽\displaystyle-\log q({\bm{\theta}})\approx-\log p(\mathcal{D}|{\bm{\theta}}^{*% })-({\bm{\theta}}-{\bm{\theta}}^{*})^{T}\nabla\mathcal{L}({\bm{\theta}}^{*})-% \frac{1}{2}({\bm{\theta}}-{\bm{\theta}}^{*})^{T}{\bm{H}}_{{\bm{\theta}}^{*}}({% \bm{\theta}}-{\bm{\theta}}^{*})- roman_log italic_q ( bold_italic_θ ) ≈ - roman_log italic_p ( caligraphic_D | bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) - ( bold_italic_θ - bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∇ caligraphic_L ( bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) - divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( bold_italic_θ - bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_H start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_θ - bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) (3)

where [(𝜽*)]i=𝜽i(𝜽i*)subscriptdelimited-[]superscript𝜽𝑖subscript𝜽𝑖superscriptsubscript𝜽𝑖[\nabla\mathcal{L}({\bm{\theta}}^{*})]_{i}=\frac{\partial}{\partial{\bm{\theta% }}_{i}}\mathcal{L}({\bm{\theta}}_{i}^{*})[ ∇ caligraphic_L ( bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ] start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = divide start_ARG ∂ end_ARG start_ARG ∂ bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG caligraphic_L ( bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) denotes the Jacobian and [𝑯𝜽]ij=2𝜽i𝜽j(𝜽ij)subscriptdelimited-[]subscript𝑯𝜽𝑖𝑗superscript2subscript𝜽𝑖subscript𝜽𝑗subscript𝜽𝑖𝑗[{\bm{H}}_{{\bm{\theta}}}]_{ij}=\frac{\partial^{2}}{\partial{\bm{\theta}}_{i}{% \bm{\theta}}_{j}}\mathcal{L}({\bm{\theta}}_{ij})[ bold_italic_H start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG caligraphic_L ( bold_italic_θ start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) denotes the Hessian. The first-order term vanishes [(𝜽*)]i=𝟎subscriptdelimited-[]superscript𝜽𝑖0[\nabla\mathcal{L}({\bm{\theta}}^{*})]_{i}={\bm{0}}[ ∇ caligraphic_L ( bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ] start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_0 at the optimum. Note that in practice the first order term may not vanish. While we follow this assumption initially, we consider interleaved first-order corrections to mitigate the issue in section 3.6. The quadratic expansion of eq. 3 forms the basis of the optimal brain damage (LeCun et al., 1989) and optimal brain surgeon (Hassibi & Stork, 1992) pruning methods. Note that from a probabilistic perspective, a quadratic approximation of the log likelihood implies a Gaussian approximation of the likelihood, as also observed by (Wang et al., 2019) and illustrated in fig. 2. This is well-known (Bishop & Nasrabadi, 2006), (MacKay, 2003) as the Laplace approximation q(𝜽)=𝒩(𝜽𝜽*+(𝜽*),𝑯𝜽*1q({\bm{\theta}})=\mathcal{N}({\bm{\theta}}\mid{\bm{\theta}}^{*}+\nabla\mathcal% {L}({\bm{\theta}}^{*}),{\bm{H}}_{{\bm{\theta}}^{*}}^{-1}italic_q ( bold_italic_θ ) = caligraphic_N ( bold_italic_θ ∣ bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT + ∇ caligraphic_L ( bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) , bold_italic_H start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT), with pretrained weights are the mean and the local inverse Hessian is the covariance matrix capturing correlations between weights.

Refer to caption
Figure 2: Pruning as equality constrained optimization of quadratic approximation of the loss landscape (left), or equivalently, maximising the likelihood under a Laplace approximation (right).

2.2 Block Fisher Information Matrix

For a network trained with negative log-likehood loss, the Hessian is identical to the Fisher matrix:

𝑯𝜽=𝑭𝜽=n=1N𝔼yp𝜽(y|xn)[𝜽logp𝜽(y|xn)𝜽logp𝜽(y|xn)T]subscript𝑯𝜽subscript𝑭𝜽superscriptsubscript𝑛1𝑁subscript𝔼similar-to𝑦subscript𝑝𝜽conditional𝑦subscript𝑥𝑛delimited-[]subscript𝜽subscript𝑝𝜽conditional𝑦subscript𝑥𝑛subscript𝜽subscript𝑝𝜽superscriptconditional𝑦subscript𝑥𝑛𝑇\displaystyle{\bm{H}}_{{\bm{\theta}}}={\bm{F}}_{{\bm{\theta}}}=\sum\nolimits_{% n=1}^{N}\mathbb{E}_{y\sim p_{{\bm{\theta}}}(y|x_{n})}\left[\nabla_{\bm{\theta}% }\log p_{{\bm{\theta}}}(y|x_{n})\nabla_{\bm{\theta}}\log p_{{\bm{\theta}}}(y|x% _{n})^{T}\right]bold_italic_H start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT = bold_italic_F start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_y ∼ italic_p start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_y | italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_y | italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_y | italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] (4)

which has the benefit of always being positive semi-definite, with the inverse thus forming a proper covariance matrix for q𝑞qitalic_q, and can be approximated with Monte Carlo samples of p𝜽(y|xn)subscript𝑝𝜽conditional𝑦subscript𝑥𝑛p_{{\bm{\theta}}}(y|x_{n})italic_p start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_y | italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ). For most LLMs, this would be treating the softmax output of the network as categorical distribution p𝜽(y|xn)subscript𝑝𝜽conditional𝑦subscript𝑥𝑛p_{{\bm{\theta}}}(y|x_{n})italic_p start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_y | italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ), and sampling from that. In practice, we use the ‘empirical Fisher’ replacing the expectation over y𝑦yitalic_y with target data ynsubscript𝑦𝑛y_{n}italic_y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT (Kunstner et al., 2019). The full (empirical) Fisher 𝑭𝜽P×Psubscript𝑭𝜽superscript𝑃𝑃{\bm{F}}_{{\bm{\theta}}}\in\mathbb{R}^{P\times P}bold_italic_F start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_P × italic_P end_POSTSUPERSCRIPT scales quadratically in the number of parameters P𝑃Pitalic_P. To overcome this, the Fisher is often written in terms of layer-wise blocks 𝑭lk=n=1N𝔼[vec(𝑾llogp𝜽(y|xn))vec(𝑾klogp𝜽(y|xn))T]subscript𝑭𝑙𝑘superscriptsubscript𝑛1𝑁𝔼delimited-[]vecsubscriptsubscript𝑾𝑙subscript𝑝𝜽conditional𝑦subscript𝑥𝑛vecsuperscriptsubscriptsubscript𝑾𝑘subscript𝑝𝜽conditional𝑦subscript𝑥𝑛𝑇{\bm{F}}_{lk}=\sum_{n=1}^{N}\mathbb{E}\left[\text{vec}(\nabla_{{\bm{W}}_{l}}% \log p_{{\bm{\theta}}}(y|x_{n}))\text{vec}(\nabla_{{\bm{W}}_{k}}\log p_{{\bm{% \theta}}}(y|x_{n}))^{T}\right]bold_italic_F start_POSTSUBSCRIPT italic_l italic_k end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E [ vec ( ∇ start_POSTSUBSCRIPT bold_italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_y | italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) vec ( ∇ start_POSTSUBSCRIPT bold_italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_y | italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ], and approximated by only treating layers independently (Martens & Grosse, 2015; Botev et al., 2017):

𝑭𝜽=diag(𝑭11,𝑭22,,𝑭LL),𝑭lsubscript𝑭𝜽diagsubscript𝑭11subscript𝑭22subscript𝑭𝐿𝐿subscript𝑭𝑙\displaystyle{\bm{F}}_{{\bm{\theta}}}=\text{diag}({\bm{F}}_{11},{\bm{F}}_{22},% \ldots,{\bm{F}}_{LL}),\hskip 30.00005pt{\bm{F}}_{l}bold_italic_F start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT = diag ( bold_italic_F start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT , bold_italic_F start_POSTSUBSCRIPT 22 end_POSTSUBSCRIPT , … , bold_italic_F start_POSTSUBSCRIPT italic_L italic_L end_POSTSUBSCRIPT ) , bold_italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT =n=1N𝔼[(𝒈l,n𝒈l,nT)(𝒂l,n𝒂l,nT)RC×RC]absentsuperscriptsubscript𝑛1𝑁𝔼delimited-[]subscripttensor-productsubscript𝒈𝑙𝑛superscriptsubscript𝒈𝑙𝑛𝑇subscript𝒂𝑙𝑛superscriptsubscript𝒂𝑙𝑛𝑇𝑅𝐶𝑅𝐶\displaystyle=\sum\nolimits_{n=1}^{N}\mathbb{E}\Big{[}\underbrace{({\bm{g}}_{l% ,n}{\bm{g}}_{l,n}^{T})\otimes({\bm{a}}_{l,n}{\bm{a}}_{l,n}^{T})}_{RC\times RC}% \Big{]}= ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E [ under⏟ start_ARG ( bold_italic_g start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT bold_italic_g start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) ⊗ ( bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT italic_R italic_C × italic_R italic_C end_POSTSUBSCRIPT ] (5)

where tensor-product\otimes denotes the Kronecker product and vec()vec\text{vec}(\cdot)vec ( ⋅ ) the matrix vectorisation operation. Because we disregard cross-layer interactions we write 𝑭lsubscript𝑭𝑙{\bm{F}}_{l}bold_italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT instead of 𝑭llsubscript𝑭𝑙𝑙{\bm{F}}_{ll}bold_italic_F start_POSTSUBSCRIPT italic_l italic_l end_POSTSUBSCRIPT for Fisher blocks associated with the weight matrix 𝑾lR×Csubscript𝑾𝑙superscript𝑅𝐶{\bm{W}}_{l}{\in}\mathbb{R}^{R\times C}bold_italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R × italic_C end_POSTSUPERSCRIPT producing outputs 𝒚l,n=𝑾l𝒂l,nRsubscript𝒚𝑙𝑛subscript𝑾𝑙subscript𝒂𝑙𝑛superscript𝑅{\bm{y}}_{l,n}={\bm{W}}_{l}{\bm{a}}_{l,n}{\in}\mathbb{R}^{R}bold_italic_y start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT = bold_italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R end_POSTSUPERSCRIPT from inputs 𝒂l,nCsubscript𝒂𝑙𝑛superscript𝐶{\bm{a}}_{l,n}{\in}\mathbb{R}^{C}bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT, for each layer l𝑙litalic_l and datapoint n𝑛nitalic_n. Consequently, we can compute Fisher blocks from input activations 𝒂l,nCsubscript𝒂𝑙𝑛superscript𝐶{\bm{a}}_{l,n}{\in}\mathbb{R}^{C}bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT of forward-passed data xnsubscript𝑥𝑛x_{n}italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and output gradients 𝒈l,n=𝒚l,nRsubscript𝒈𝑙𝑛subscriptsubscript𝒚𝑙𝑛superscript𝑅{\bm{g}}_{l,n}{=}\nabla_{{\bm{y}}_{l,n}}\mathcal{L}{\in}\mathbb{R}^{R}bold_italic_g start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT = ∇ start_POSTSUBSCRIPT bold_italic_y start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L ∈ blackboard_R start_POSTSUPERSCRIPT italic_R end_POSTSUPERSCRIPT from backpropagation.

2.3 Pruning as constrained optimization

Optimal brain surgery relies on removing and adapting weights such that the loss is least negatively affected, thus it behooves us to write the problem as a constrained optimization problem. From the Gaussian approximation discussed in section 2.1 obtained by quadratically expanding the log likelihood loss logp12𝜽T𝑭𝜽𝑝12superscript𝜽𝑇𝑭𝜽{-}\log p{\approx}\frac{1}{2}{\bm{\theta}}^{T}{\bm{F}}{\bm{\theta}}- roman_log italic_p ≈ divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_θ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_F bold_italic_θ, the optimal update Δ𝜽=𝜽^𝜽Δ𝜽^𝜽𝜽\Delta{\bm{\theta}}{=}\hat{{\bm{\theta}}}{-}{\bm{\theta}}roman_Δ bold_italic_θ = over^ start_ARG bold_italic_θ end_ARG - bold_italic_θ (and thus also 𝜽^=𝜽+Δ𝜽^𝜽𝜽Δ𝜽\hat{{\bm{\theta}}}{=}{\bm{\theta}}{+}\Delta{\bm{\theta}}over^ start_ARG bold_italic_θ end_ARG = bold_italic_θ + roman_Δ bold_italic_θ) becomes the following equality constrained quadratic optimization problem (Hassibi & Stork, 1992):

argminΔ𝜽subscriptargminΔ𝜽\displaystyle\operatorname*{arg\,min}_{\Delta{\bm{\theta}}}\text{ }start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT roman_Δ bold_italic_θ end_POSTSUBSCRIPT 12Δ𝜽T𝑭Δ𝜽12Δsuperscript𝜽𝑇𝑭Δ𝜽\displaystyle\frac{1}{2}\Delta{\bm{\theta}}^{T}{\bm{F}}\Delta{\bm{\theta}}divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_Δ bold_italic_θ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_F roman_Δ bold_italic_θ (6)
s.t. 𝒆kTΔ𝜽+𝒆kT𝜽=0,k𝒦formulae-sequencesuperscriptsubscript𝒆𝑘𝑇Δ𝜽superscriptsubscript𝒆𝑘𝑇𝜽0for-all𝑘𝒦\displaystyle{\bm{e}}_{k}^{T}\Delta{\bm{\theta}}+{\bm{e}}_{k}^{T}{\bm{\theta}}% =0,\forall k\in\mathcal{K}bold_italic_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_Δ bold_italic_θ + bold_italic_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_θ = 0 , ∀ italic_k ∈ caligraphic_K

where 𝑭𝑭{\bm{F}}bold_italic_F is positive semi-definite and 𝒦𝒦\mathcal{K}caligraphic_K is the set of K𝐾Kitalic_K indices that are pruned (i.e., set to zero).

General solution

We denote 𝑬K=[𝒆1𝒆2𝒆K]T[0,1]K×Psubscript𝑬𝐾superscriptmatrixsubscript𝒆1subscript𝒆2subscript𝒆𝐾𝑇superscript01𝐾𝑃{\bm{E}}_{K}=\begin{bmatrix}{\bm{e}}_{1}&{\bm{e}}_{2}&\ldots&{\bm{e}}_{K}\end{% bmatrix}^{T}\in[0,1]^{K\times P}bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = [ start_ARG start_ROW start_CELL bold_italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL bold_italic_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL … end_CELL start_CELL bold_italic_e start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_K × italic_P end_POSTSUPERSCRIPT as a matrix of which the row vectors are canonical basis vectors 𝒆kPsubscript𝒆𝑘superscript𝑃{\bm{e}}_{k}\in\mathbb{R}^{P}bold_italic_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT that select the elements to be pruned. One of the most standard approaches to solve eq. 6 is using Langrange multipliers, which results in a general closed-form solution for the expected increase in loss \mathcal{L}caligraphic_L and optimal weight update Δ𝜽Δ𝜽\Delta{\bm{\theta}}roman_Δ bold_italic_θ:

\displaystyle\mathcal{L}caligraphic_L =12(𝑬K𝜽*)T(𝑬K𝑭1𝑬KT)1𝑬K𝜽absent12superscriptsubscript𝑬𝐾superscript𝜽𝑇superscriptsubscript𝑬𝐾superscript𝑭1superscriptsubscript𝑬𝐾𝑇1subscript𝑬𝐾𝜽\displaystyle=\frac{1}{2}({\bm{E}}_{K}{\bm{\theta}}^{*})^{T}\left({\bm{E}}_{K}% {\bm{F}}^{-1}{\bm{E}}_{K}^{T}\right)^{-1}{\bm{E}}_{K}{\bm{\theta}}= divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ (7)
Δ𝜽Δ𝜽\displaystyle\Delta{\bm{\theta}}roman_Δ bold_italic_θ =𝑭1𝑬KT(𝑬K𝑭1𝑬KT)1𝑬K𝜽absentsuperscript𝑭1superscriptsubscript𝑬𝐾𝑇superscriptsubscript𝑬𝐾superscript𝑭1superscriptsubscript𝑬𝐾𝑇1subscript𝑬𝐾𝜽\displaystyle=-{\bm{F}}^{-1}{\bm{E}}_{K}^{T}\left({\bm{E}}_{K}{\bm{F}}^{-1}{% \bm{E}}_{K}^{T}\right)^{-1}{\bm{E}}_{K}{\bm{\theta}}= - bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ (8)

which we use to derive unstructured, semi-structured, structured for modern Fisher approximations (see sections A.2, A.3 and A.4). The same general form of eqs. 7 and 8 appears in prior LLM pruning work Kurtic et al. (2022), but only for much simpler layer-wise pruning and no structured pruning.

3 LLM Surgeon

This section describes the components of our method, LLM Surgeon, summarised in algorithm 1.

Algorithm 1 LLM Surgeon (structured)
initial weights 𝜽0superscript𝜽0{\bm{\theta}}^{0}bold_italic_θ start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT, target size α𝛼\alphaitalic_α, and data 𝒟𝒟\mathcal{D}caligraphic_D
For shot t𝑡titalic_t in [1, 2, …, T𝑇Titalic_T]
    Compute: approximate curvature 𝑮,𝑨𝑮𝑨{\bm{G}},{\bm{A}}bold_italic_G , bold_italic_A from data 𝒟𝒟\mathcal{D}caligraphic_D \triangleright section 3.1
    Compute: costs per row/column r,csubscript𝑟subscript𝑐\mathcal{L}_{r},\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT , caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT from 𝑮,𝑨𝑮𝑨{\bm{G}},{\bm{A}}bold_italic_G , bold_italic_A \triangleright section 3.2
    Compute: threshold τ𝜏\tauitalic_τ using rsubscript𝑟\mathcal{L}_{r}caligraphic_L start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT and csubscript𝑐\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT given target size αtsubscript𝛼𝑡\alpha_{t}italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT \triangleright section 3.3
    Select: rows and columns to remove 𝑬Rsubscript𝑬𝑅{\bm{E}}_{R}bold_italic_E start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT, 𝑬Csubscript𝑬𝐶{\bm{E}}_{C}bold_italic_E start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT based on τ𝜏\tauitalic_τ \triangleright section 3.3
    Compute: weight update Δ𝜽t1Δsuperscript𝜽𝑡1\Delta{\bm{\theta}}^{t-1}roman_Δ bold_italic_θ start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT based on 𝑬R,𝑬Csubscript𝑬𝑅subscript𝑬𝐶{\bm{E}}_{R},{\bm{E}}_{C}bold_italic_E start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT , bold_italic_E start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT and 𝑮,𝑨𝑮𝑨{\bm{G}},{\bm{A}}bold_italic_G , bold_italic_A \triangleright section 3.4
    Update: remaining weights 𝜽t𝜽t1+Δ𝜽t1superscript𝜽𝑡superscript𝜽𝑡1Δsuperscript𝜽𝑡1{\bm{\theta}}^{t}\leftarrow{\bm{\theta}}^{t-1}+\Delta{\bm{\theta}}^{t-1}bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ← bold_italic_θ start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT + roman_Δ bold_italic_θ start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT \triangleright section 3.5
    Optionally: 𝜽tlow-rank update(𝜽t)superscript𝜽𝑡low-rank updatesuperscript𝜽𝑡{\bm{\theta}}^{t}\leftarrow\text{low-rank update}({\bm{\theta}}^{t})bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ← low-rank update ( bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) \triangleright section 3.6
Output: compressed weights 𝜽^=𝜽T^𝜽superscript𝜽𝑇\hat{{\bm{\theta}}}={\bm{\theta}}^{T}over^ start_ARG bold_italic_θ end_ARG = bold_italic_θ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT

3.1 Estimating loss landscape curvature

Accurate pruning relies on approximating the local curvature accurately while overcoming the memory cost associated with storing the true curvature. Specifically, even with the block-wise approximation of eq. 5, 𝑭RC×RC𝑭superscript𝑅𝐶𝑅𝐶{\bm{F}}\in\mathbb{R}^{RC\times RC}bold_italic_F ∈ blackboard_R start_POSTSUPERSCRIPT italic_R italic_C × italic_R italic_C end_POSTSUPERSCRIPT requires summing N𝑁Nitalic_N large RC×RC𝑅𝐶𝑅𝐶RC\times RCitalic_R italic_C × italic_R italic_C matrices, too large to practically fit in memory. Instead, we adapt the KFAC approximation (Martens & Grosse, 2015) that assumes independence of activations and derivatives, approximating an expectation of Kronecker products as a Kronecker product of two expectations 𝔼[𝒈l,n𝒈l,nT𝒂l,n𝒂l,nT]𝔼[𝒈l,n𝒈l,nT]𝔼[𝒂l,n𝒂l,nT]𝔼delimited-[]tensor-productsubscript𝒈𝑙𝑛superscriptsubscript𝒈𝑙𝑛𝑇subscript𝒂𝑙𝑛superscriptsubscript𝒂𝑙𝑛𝑇tensor-product𝔼delimited-[]subscript𝒈𝑙𝑛superscriptsubscript𝒈𝑙𝑛𝑇𝔼delimited-[]subscript𝒂𝑙𝑛superscriptsubscript𝒂𝑙𝑛𝑇\mathbb{E}[{\bm{g}}_{l,n}{\bm{g}}_{l,n}^{T}\otimes{\bm{a}}_{l,n}{\bm{a}}_{l,n}% ^{T}]\approx\mathbb{E}[{\bm{g}}_{l,n}{\bm{g}}_{l,n}^{T}]\otimes\mathbb{E}[{\bm% {a}}_{l,n}{\bm{a}}_{l,n}^{T}]blackboard_E [ bold_italic_g start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT bold_italic_g start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] ≈ blackboard_E [ bold_italic_g start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT bold_italic_g start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] ⊗ blackboard_E [ bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ], allowing layer-wise Fisher blocks to be approximated as 𝑭l𝑭~lsubscript𝑭𝑙subscript~𝑭𝑙{\bm{F}}_{l}\approx\widetilde{{\bm{F}}}_{l}bold_italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ≈ over~ start_ARG bold_italic_F end_ARG start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT, where

𝑭l~=𝑮l𝑨l, with 𝑮l=1Nn=1N𝒈l,n𝒈l,nT and 𝑨l=1Nn=1N𝒂l,n𝒂l,nTformulae-sequence~subscript𝑭𝑙tensor-productsubscript𝑮𝑙subscript𝑨𝑙, with subscript𝑮𝑙1𝑁superscriptsubscript𝑛1𝑁subscript𝒈𝑙𝑛superscriptsubscript𝒈𝑙𝑛𝑇 and subscript𝑨𝑙1𝑁superscriptsubscript𝑛1𝑁subscript𝒂𝑙𝑛superscriptsubscript𝒂𝑙𝑛𝑇\displaystyle\vspace{-2em}\widetilde{{\bm{F}}_{l}}={\bm{G}}_{l}\otimes{\bm{A}}% _{l}\hskip 10.00002pt\text{, with }{\bm{G}}_{l}=\frac{1}{\sqrt{N}}\sum\nolimits_{n=1}^{N}{\bm{g}}_{l,n}{\bm% {g}}_{l,n}^{T}\text{ and }{\bm{A}}_{l}=\frac{1}{\sqrt{N}}\sum\nolimits_{n=1}^{% N}{\bm{a}}_{l,n}{\bm{a}}_{l,n}^{T}\vspace{-1em}over~ start_ARG bold_italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG = bold_italic_G start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ⊗ bold_italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , with bold_italic_G start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_italic_g start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT bold_italic_g start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT and bold_italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT (9)

constructed from activations 𝒂l,nCsubscript𝒂𝑙𝑛superscript𝐶{\bm{a}}_{l,n}\in\mathbb{R}^{C}bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT from forward passes and gradients 𝒈l,nRsubscript𝒈𝑙𝑛superscript𝑅{\bm{g}}_{l,n}\in\mathbb{R}^{R}bold_italic_g start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R end_POSTSUPERSCRIPT from backward passes (Eschenhagen et al., 2024). The approximation originates from optimization literature, but has recently gained popularity for other problems that require curvature approximations (Immer et al., 2022; van der Ouderaa et al., 2023), including structured pruning in Wang et al. (2019).

An additional advantage of approximating Fisher blocks as Kronecker products is that the inverse becomes particularly easy to compute 𝑭~1=𝑮1𝑨1superscript~𝑭1tensor-productsuperscript𝑮1superscript𝑨1\vspace{-0.2em}\widetilde{{\bm{F}}}^{-1}={\bm{G}}^{-1}\otimes{\bm{A}}^{-1}over~ start_ARG bold_italic_F end_ARG start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT, thus only requires inverting the factors. This fact allows us to never explicitly construct large RC×RC𝑅𝐶𝑅𝐶RC{\times}RCitalic_R italic_C × italic_R italic_C matrices in memory that make up 𝑭~~𝑭\widetilde{{\bm{F}}}over~ start_ARG bold_italic_F end_ARG and 𝑭~1superscript~𝑭1\widetilde{{\bm{F}}}^{-1}over~ start_ARG bold_italic_F end_ARG start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT, but rather directly work with the much smaller matrices 𝑮𝑮{\bm{G}}bold_italic_G and 𝑨𝑨{\bm{A}}bold_italic_A.

3.2 Computing costs in final loss

The number of possible combinations in which weights can be removed grows (supra-)exponentially in parameter count, making it infeasible to estimate a separate cost \mathcal{L}caligraphic_L for each such removal. A common strategy, therefore, is to treat weights independently when computing removal costs \mathcal{L}caligraphic_L. We also follow this strategy, but note that this does not necessarily imply that we have to make such same strong independence assumption for the weight updates Δ𝜽Δ𝜽\Delta{\bm{\theta}}roman_Δ bold_italic_θ after selecting weights to be removed. Unlike most prior work, we present correlated weight updates by taking into account off-diagonal elements of the Fisher approximation in section 3.4.

For semi-structured and unstructured we use independent costs for individual weight elements k[1,RC]𝑘1𝑅𝐶k{\in}[1,RC]italic_k ∈ [ 1 , italic_R italic_C ], and for structured use independent costs for all rows r[1,R]𝑟1𝑅r{\in}[1,R]italic_r ∈ [ 1 , italic_R ] and columns c[1,C]𝑐1𝐶c{\in}[1,C]italic_c ∈ [ 1 , italic_C ]. We find that we can derive the appropriate costs from the general cost formula eq. 7 by letting 𝑬=𝒆kRC𝑬subscript𝒆𝑘superscript𝑅𝐶{\bm{E}}{=}{\bm{e}}_{k}\in\mathbb{R}^{RC}bold_italic_E = bold_italic_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R italic_C end_POSTSUPERSCRIPT where the single one-hot element at index k𝑘kitalic_k of canonical basis vector 𝒆ksubscript𝒆𝑘{\bm{e}}_{k}bold_italic_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT selects the weight to remove. For structured pruning, we similarly select rows r𝑟ritalic_r and columns c𝑐citalic_c, by setting 𝑬=𝒆rT𝑰C×RC𝑬tensor-productsuperscriptsubscript𝒆𝑟𝑇𝑰superscript𝐶𝑅𝐶{\bm{E}}{=}{\bm{e}}_{r}^{T}{\otimes}{\bm{I}}{\in}\mathbb{R}^{C\times RC}bold_italic_E = bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_I ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_R italic_C end_POSTSUPERSCRIPT or 𝑬=𝑰𝒆cR×RC𝑬tensor-product𝑰subscript𝒆𝑐superscript𝑅𝑅𝐶{\bm{E}}{=}{\bm{I}}{\otimes}{\bm{e}}_{c}{\in}\mathbb{R}^{R\times RC}bold_italic_E = bold_italic_I ⊗ bold_italic_e start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R × italic_R italic_C end_POSTSUPERSCRIPT with 𝒆rRsubscript𝒆𝑟superscript𝑅{\bm{e}}_{r}{\in}\mathbb{R}^{R}bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R end_POSTSUPERSCRIPT, 𝒆cCsubscript𝒆𝑐superscript𝐶{\bm{e}}_{c}{\in}\mathbb{R}^{C}bold_italic_e start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT. Plugging into eq. 7, we find:

k=12(𝜽k)2[𝑮1𝑨1]kk,r=12𝜽rT𝑨𝜽r[𝑮1]rr,c=12𝜽cT𝑮𝜽c[𝑨1]ccsubscript𝑘12superscriptsubscript𝜽𝑘2subscriptdelimited-[]tensor-productsuperscript𝑮1superscript𝑨1𝑘𝑘subscript𝑟12superscriptsubscript𝜽𝑟𝑇𝑨subscript𝜽𝑟subscriptdelimited-[]superscript𝑮1𝑟𝑟subscript𝑐12superscriptsubscript𝜽𝑐𝑇𝑮subscript𝜽𝑐subscriptdelimited-[]superscript𝑨1𝑐𝑐\displaystyle\begin{split}\mathcal{L}_{k}=\frac{1}{2}\frac{({\bm{\theta}}_{k})% ^{2}}{[{\bm{G}}^{-1}\otimes{\bm{A}}^{-1}]_{kk}}\end{split},\hskip 10.00002pt% \begin{split}\mathcal{L}_{r}=\frac{1}{2}\frac{{\bm{\theta}}_{r}^{T}{\bm{A}}{% \bm{\theta}}_{r}}{[{\bm{G}}^{-1}]_{rr}}\end{split},\hskip 10.00002pt\begin{% split}\mathcal{L}_{c}=\frac{1}{2}\frac{{\bm{\theta}}_{c}^{T}{\bm{G}}{\bm{% \theta}}_{c}}{[{\bm{A}}^{-1}]_{cc}}\end{split}start_ROW start_CELL caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG 2 end_ARG divide start_ARG ( bold_italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG [ bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_k italic_k end_POSTSUBSCRIPT end_ARG end_CELL end_ROW , start_ROW start_CELL caligraphic_L start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG 2 end_ARG divide start_ARG bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_A bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT end_ARG start_ARG [ bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_r italic_r end_POSTSUBSCRIPT end_ARG end_CELL end_ROW , start_ROW start_CELL caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG 2 end_ARG divide start_ARG bold_italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_G bold_italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG start_ARG [ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_c italic_c end_POSTSUBSCRIPT end_ARG end_CELL end_ROW (10)

Full derivations can be found in sections A.2 and A.3. The costs for single elements ksubscript𝑘\mathcal{L}_{k}caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT are equivalent to those found in optimal brain surgeon (Hassibi & Stork, 1992) and rsubscript𝑟\mathcal{L}_{r}caligraphic_L start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT and csubscript𝑐\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT closely resemble structured brain surgeon of (Wang et al., 2019), but in our case derived for matrix rows and columns (see section A.3). Given curvature estimates, costs for either removing all weights or all rows and columns can be computed in parallel. In addition, we derive costs for the more general sum of Kronecker factor approximation 𝑭~𝑮1𝑨1+𝑮2𝑨2~𝑭tensor-productsubscript𝑮1subscript𝑨1tensor-productsubscript𝑮2subscript𝑨2\widetilde{{\bm{F}}}\approx{\bm{G}}_{1}\otimes{\bm{A}}_{1}+{\bm{G}}_{2}\otimes% {\bm{A}}_{2}over~ start_ARG bold_italic_F end_ARG ≈ bold_italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊗ bold_italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + bold_italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊗ bold_italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT in appendix I through an eigendecomposition.

3.3 Dynamic weight allocation with global threshold

Refer to caption
Figure 3: General framework for structured, semi-structured and unstructured compression.

Unlike prior works that compress layer-by-layer (Frantar & Alistarh, 2023), we use a global threshold τ𝜏\tauitalic_τ enabling a dynamic allocation of sparsity levels across layers, pruning most where it hurts the least. Our method can compress a model to a specifically chosen target size α𝛼\alphaitalic_α, defined as the fraction of weights that should remain, i.e. stay non-zero after compression. In all structured, semi-structured, and unstructured pruning (fig. 3), we select as many weights for removal so that the target size α𝛼\alphaitalic_α is reached that inflict the least possible costs \mathcal{L}caligraphic_L, as computed according to section 3.2. For unstructured pruning, this is as simple as sorting the costs for all weights ksubscript𝑘\mathcal{L}_{k}caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT in the network and setting a global threshold τ𝜏\tauitalic_τ such that α𝛼\alphaitalic_α fraction of weights fall within the threshold kτsubscript𝑘𝜏\mathcal{L}_{k}\leq\taucaligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ≤ italic_τ. For M:N semi-structured pruning, we sort the M costs of each N consecutive weights and select the M weights with lowest cost. In case of a multi shot schedule (see section 3.5) we also sum the M lowest costs in each block to find a cost per block, sort costs per block across the entire network, and similar to the unstructured case set a global threshold τ𝜏\tauitalic_τ such that an α𝛼\alphaitalic_α fraction of weights fall within threshold. Lastly for structured pruning, we perform a sorting appropriately weighted by the number of elements that make up a row or column and set the global threshold τ𝜏\tauitalic_τ such that α𝛼\alphaitalic_α fraction of all weights fall within the threshold. Then we remove all rows and columns that fall within the threshold r,cτsubscript𝑟subscript𝑐𝜏\mathcal{L}_{r},\mathcal{L}_{c}\leq\taucaligraphic_L start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT , caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ≤ italic_τ.

3.4 Correlated weight updates

Like most other pruning methods, we prune multiple weights at once (Frantar & Alistarh, 2023; Wang et al., 2019). To arrive at pruning costs and weight updates for pruning multiple weights, it is common to compute costs and updates for individual weights (or sets of weights) independently and add them together to arrive at a joint pruning cost. In LLM Surgeon, we argue that it’s better to consider weight updates jointly instead of independently. After selecting the set of weights for pruning, we can often afford to compute a single correlated weight update associated to the joint removal of multiple weights, instead of naively summing weight updates associated to individual removals. We derive such correlated weight updates below. Note that, for the expected cost computation, we do assume that the row, column or weight costs are independent, as the number of possible combinations of weights to prune grows too large to compute within reasonable time.

Fast unstructured / semi-structured correlated weight updates

Mathematically, we represent pruned weights as 𝑬K=[𝒆1𝒆2𝒆R]TK×RSsubscript𝑬𝐾superscriptmatrixsubscript𝒆1subscript𝒆2subscript𝒆superscript𝑅𝑇superscript𝐾𝑅𝑆{\bm{E}}_{K}{=}\begin{bmatrix}{\bm{e}}_{1}&{\bm{e}}_{2}&{\ldots}&{\bm{e}}_{R^{% \prime}}\end{bmatrix}^{T}{\in}\mathbb{R}^{K\times RS}bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = [ start_ARG start_ROW start_CELL bold_italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL bold_italic_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL … end_CELL start_CELL bold_italic_e start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_K × italic_R italic_S end_POSTSUPERSCRIPT, where 𝒆rRsubscript𝒆𝑟superscriptsuperscript𝑅{\bm{e}}_{r}{\in}\mathbb{R}^{R^{\prime}}bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT are one-hot canonical basis vectors selecting the weights for removal. As each element k𝑘kitalic_k has a unique associated row r𝑟ritalic_r and column c𝑐citalic_c index, we can consequently also use canonical basis vectors for these respective rows 𝑬RK×Rsubscript𝑬𝑅superscript𝐾𝑅{\bm{E}}_{R}{\in}\mathbb{R}^{K\times R}bold_italic_E start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_K × italic_R end_POSTSUPERSCRIPT and columns 𝑬CK×Csubscript𝑬𝐶superscript𝐾𝐶{\bm{E}}_{C}{\in}\mathbb{R}^{K\times C}bold_italic_E start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_K × italic_C end_POSTSUPERSCRIPT (i.e., we have [𝑬R]i[𝑬C]i=[𝑬K]itensor-productsubscriptdelimited-[]subscript𝑬𝑅𝑖subscriptdelimited-[]subscript𝑬𝐶𝑖subscriptdelimited-[]subscript𝑬𝐾𝑖[{\bm{E}}_{R}]_{i}\otimes[{\bm{E}}_{C}]_{i}{=}[{\bm{E}}_{K}]_{i}[ bold_italic_E start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⊗ [ bold_italic_E start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = [ bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is satisfied for all i𝑖iitalic_i).

We derive unstructured weight updates in section A.2, by considering eigendecompositions 𝑮=𝑲1𝑺1𝑲1T𝑮subscript𝑲1subscript𝑺1superscriptsubscript𝑲1𝑇{\bm{G}}={\bm{K}}_{1}{\bm{S}}_{1}{\bm{K}}_{1}^{T}bold_italic_G = bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT, 𝑨=𝑲2𝑺2𝑲2𝑨subscript𝑲2subscript𝑺2subscript𝑲2{\bm{A}}={\bm{K}}_{2}{\bm{S}}_{2}{\bm{K}}_{2}bold_italic_A = bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT of the Fisher approximation 𝑭𝑮𝑨𝑭tensor-product𝑮𝑨{\bm{F}}\approx{\bm{G}}\otimes{\bm{A}}bold_italic_F ≈ bold_italic_G ⊗ bold_italic_A, which from eq. 8 yields:

Δ𝑾=𝑮1(𝑲1( 𝑲1T 𝑾1 𝑲2𝑺K×K)1𝑲2)𝑨1Δ𝑾superscript𝑮1subscript𝑲1superscriptsubscriptsuperscriptsubscript 𝑲1𝑇superscript 𝑾1subscript 𝑲2𝑺𝐾𝐾1subscript𝑲2superscript𝑨1\displaystyle\smash{\Delta{\bm{W}}={\bm{G}}^{-1}\Big{(}{\bm{K}}_{1}\Big{(}% \underbrace{\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.00006% pt${\bm{K}}$\kern-1.00006pt}}}_{1}^{T}\hbox{\vbox{\hrule height=0.5pt\kern 2.1% 5277pt\hbox{\kern-1.00006pt${\bm{W}}$\kern-1.00006pt}}}^{-1}\hbox{\vbox{\hrule h% eight=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{K}}$\kern-1.00006pt}}}_{2% }\oslash{\bm{S}}}_{K\times K\vspace{-8em}}\Big{)}^{-1}{\bm{K}}_{2}\Big{)}{\bm{% A}}^{-1}}roman_Δ bold_italic_W = bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( under⏟ start_ARG roman_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_W start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊘ bold_italic_S end_ARG start_POSTSUBSCRIPT italic_K × italic_K end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT (11)

where \oslash is element-wise division, and for brevity use bar notation  𝑲1=𝑬K𝑲1subscript 𝑲1subscript𝑬𝐾subscript𝑲1\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{K}}$% \kern-1.00006pt}}}_{1}{=}{\bm{E}}_{K}{\bm{K}}_{1}roman_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT,  𝑲2=𝑬K𝑲2subscript 𝑲2subscript𝑬𝐾subscript𝑲2\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{K}}$% \kern-1.00006pt}}}_{2}{=}{\bm{E}}_{K}{\bm{K}}_{2}roman_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT,  𝜽=𝑬K𝜽 𝜽subscript𝑬𝐾𝜽\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{% \theta}}$\kern-1.00006pt}}}{=}{\bm{E}}_{K}{\bm{\theta}}roman_θ = bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ, and 𝑺=diag(𝑺1)diag(𝑺2)TR×C𝑺diagsubscript𝑺1diagsuperscriptsubscript𝑺2𝑇superscript𝑅𝐶{\bm{S}}{=}\text{diag}({\bm{S}}_{1})\text{diag}({\bm{S}}_{2})^{T}{\in}\mathbb{% R}^{R\times C}bold_italic_S = diag ( bold_italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) diag ( bold_italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R × italic_C end_POSTSUPERSCRIPT, and diag()diag\text{diag}(\cdot)diag ( ⋅ ) vectorises matrix diagonals.

Programmatically, we always avoid explicitly representing large matrices 𝑭~~𝑭\widetilde{{\bm{F}}}over~ start_ARG bold_italic_F end_ARG and 𝑭~1superscript~𝑭1\widetilde{{\bm{F}}}^{-1}over~ start_ARG bold_italic_F end_ARG start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT in memory, but rather compute relevant quantities from their factors. Likewise, we never represent sparse matrices 𝑬Ksubscript𝑬𝐾{\bm{E}}_{K}bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, 𝑬Rsubscript𝑬𝑅{\bm{E}}_{R}bold_italic_E start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT or 𝑬Csubscript𝑬𝐶{\bm{E}}_{C}bold_italic_E start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT in memory, but instead work with a lists of indices of the one-hot elements directly. For example, we can cheaply construct  𝑲1=𝑬R𝑲1K×Rsubscript 𝑲1subscript𝑬𝑅subscript𝑲1superscript𝐾𝑅\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{K}}$% \kern-1.00006pt}}}_{1}{=}{\bm{E}}_{R}{\bm{K}}_{1}\in\mathbb{R}^{K\times R}roman_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = bold_italic_E start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_K × italic_R end_POSTSUPERSCRIPT and  𝑲2=𝑬C𝑲2K×Csubscript 𝑲2subscript𝑬𝐶subscript𝑲2superscript𝐾𝐶\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{K}}$% \kern-1.00006pt}}}_{2}{=}{\bm{E}}_{C}{\bm{K}}_{2}\in\mathbb{R}^{K\times C}roman_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = bold_italic_E start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_K × italic_C end_POSTSUPERSCRIPT, by copying row vectors, and the vector  𝜽=𝑬K𝜽=𝑬R𝑾𝑬CTK 𝜽subscript𝑬𝐾𝜽subscript𝑬𝑅𝑾superscriptsubscript𝑬𝐶𝑇superscript𝐾\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{% \theta}}$\kern-1.00006pt}}}{=}{\bm{E}}_{K}{\bm{\theta}}{=}{\bm{E}}_{R}{\bm{W}}% {\bm{E}}_{C}^{T}\in\mathbb{R}^{K}roman_θ = bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ = bold_italic_E start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT bold_italic_W bold_italic_E start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT by indexing all pruned weights.

Maximum number of correlated weights

The main computational bottleneck is the K×K𝐾𝐾K{\times}Kitalic_K × italic_K matrix inverse in eq. 11. To control compression speed, we can split pruned weights into disjoint subsets K=K1K2𝐾subscript𝐾1subscript𝐾2K{=}K_{1}{\cup}K_{2}{\cup}\ldotsitalic_K = italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∪ italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∪ …, such that each subset Kisubscript𝐾𝑖K_{i}italic_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT does not exceed the set maximum number of correlated weights Kimsubscript𝐾𝑖𝑚K_{i}{\leq}mitalic_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≤ italic_m, and sum associated independent updates. Using less correlation by setting a lower m𝑚mitalic_m allows trading compression quality for speed.

Fast structured correlated weight updates

Unlike the general case which requires inverting a K×K𝐾𝐾K\times Kitalic_K × italic_K matrix for K𝐾Kitalic_K correlated weights, we find that weight updates with the Kronecker factored Fisher approximation 𝑭~=𝑮𝑨~𝑭tensor-product𝑮𝑨\tilde{{\bm{F}}}={\bm{G}}\otimes{\bm{A}}over~ start_ARG bold_italic_F end_ARG = bold_italic_G ⊗ bold_italic_A only require inverting a R×Rsuperscript𝑅superscript𝑅R^{\prime}\times R^{\prime}italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT × italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT matrix when removing Rsuperscript𝑅R^{\prime}italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT rows or a C×Csuperscript𝐶superscript𝐶C^{\prime}\times C^{\prime}italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT × italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT matrix when removing Csuperscript𝐶C^{\prime}italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT columns. The updates are much cheaper than we would have expected based on the effective number of weights in those rows and columns, which would imply inverting RC×RCsuperscript𝑅𝐶superscript𝑅𝐶R^{\prime}C\times R^{\prime}Citalic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_C × italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_C or RC×RC𝑅superscript𝐶𝑅superscript𝐶RC^{\prime}\times RC^{\prime}italic_R italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT × italic_R italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT matrices. In practice, this leads to a significant speed-up for structured pruning and weight updates that take into account correlations between rows or columns. When removing Rsuperscript𝑅R^{\prime}italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT rows, r1,r2,rRsubscript𝑟1subscript𝑟2subscript𝑟superscript𝑅r_{1},r_{2},\ldots r_{R^{\prime}}italic_r start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … italic_r start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, or the Csuperscript𝐶C^{\prime}italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT columns c1,c2,,cCsubscript𝑐1subscript𝑐2subscript𝑐superscript𝐶c_{1},c_{2},\ldots,c_{C^{\prime}}italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_c start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, with 1<R<R1superscript𝑅𝑅1{<}R^{\prime}<R1 < italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT < italic_R and 1<C<C1superscript𝐶𝐶1{<}C^{\prime}{<}C1 < italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT < italic_C, we denote one-hot vectors selecting all rows and columns to be removed respectively as 𝑬R=[𝒆1𝒆2𝒆R]TR×Rsubscript𝑬superscript𝑅superscriptmatrixsubscript𝒆1subscript𝒆2subscript𝒆superscript𝑅𝑇superscriptsuperscript𝑅𝑅{\bm{E}}_{R^{\prime}}=\begin{bmatrix}{\bm{e}}_{1}&{\bm{e}}_{2}&\ldots&{\bm{e}}% _{R^{\prime}}\end{bmatrix}^{T}\in\mathbb{R}^{R^{\prime}\times R}bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = [ start_ARG start_ROW start_CELL bold_italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL bold_italic_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL … end_CELL start_CELL bold_italic_e start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT × italic_R end_POSTSUPERSCRIPT and 𝑬C=[𝒆1𝒆2𝒆C]TC×Csubscript𝑬superscript𝐶superscriptmatrixsubscript𝒆1subscript𝒆2subscript𝒆superscript𝐶𝑇superscriptsuperscript𝐶𝐶{\bm{E}}_{C^{\prime}}=\begin{bmatrix}{\bm{e}}_{1}&{\bm{e}}_{2}&\ldots&{\bm{e}}% _{C^{\prime}}\end{bmatrix}^{T}\in\mathbb{R}^{C^{\prime}\times C}bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = [ start_ARG start_ROW start_CELL bold_italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL bold_italic_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL … end_CELL start_CELL bold_italic_e start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT × italic_C end_POSTSUPERSCRIPT. We find weight updates associated to removing the Rsuperscript𝑅R^{\prime}italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT rows by setting 𝑬K=𝑬R𝑰subscript𝑬𝐾tensor-productsubscript𝑬superscript𝑅𝑰{\bm{E}}_{K}={\bm{E}}_{R^{\prime}}\otimes{\bm{I}}bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ⊗ bold_italic_I or 𝑬K=𝑰𝑬Csubscript𝑬𝐾tensor-product𝑰subscript𝑬superscript𝐶{\bm{E}}_{K}={\bm{I}}\otimes{\bm{E}}_{C^{\prime}}bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = bold_italic_I ⊗ bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT:

remove multiple R rows: remove multiple C columns: Δ𝑾= 𝑾(𝑬C𝑨1𝑬CT)1(𝑨1𝑬CT)Δ𝑾=𝑮1𝑬RT(𝑬R𝑮1𝑬RT)1 𝑾remove multiple R rows: remove multiple C columns: Δ𝑾 𝑾superscriptsubscript𝑬superscript𝐶superscript𝑨1superscriptsubscript𝑬superscript𝐶𝑇1superscript𝑨1superscriptsubscript𝑬superscript𝐶𝑇Δ𝑾superscript𝑮1superscriptsubscript𝑬superscript𝑅𝑇superscriptsubscript𝑬superscript𝑅superscript𝑮1superscriptsubscript𝑬superscript𝑅𝑇1 𝑾\displaystyle\begin{split}\text{remove multiple $R^{\prime}$ rows: }&\\ \text{remove multiple $C^{\prime}$ columns: }&\end{split}\begin{split}\Delta{% \bm{W}}&=-\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt$% {\bm{W}}$\kern-1.00006pt}}}({\bm{E}}_{C^{\prime}}{\bm{A}}^{-1}{\bm{E}}_{C^{% \prime}}^{T})^{-1}({\bm{A}}^{-1}{\bm{E}}_{C^{\prime}}^{T})\\ \Delta{\bm{W}}&=-{\bm{G}}^{-1}{\bm{E}}_{R^{\prime}}^{T}({\bm{E}}_{R^{\prime}}{% \bm{G}}^{-1}{\bm{E}}_{R^{\prime}}^{T})^{-1}\hbox{\vbox{\hrule height=0.5pt% \kern 2.15277pt\hbox{\kern-1.00006pt${\bm{W}}$\kern-1.00006pt}}}\end{split}start_ROW start_CELL remove multiple italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT rows: end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL remove multiple italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT columns: end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL roman_Δ bold_italic_W end_CELL start_CELL = - roman_W ( bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL roman_Δ bold_italic_W end_CELL start_CELL = - bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_W end_CELL end_ROW (12)

From here, it is clear that the special case of removing a single row r𝑟ritalic_r or column c𝑐citalic_c under Kronecker approximation involves inverting a 1×1111\times 11 × 1 matrix, and thus only requires scalar division:

remove single row r|Δ𝜽=𝑮1𝒆r𝜽r[𝑮1]rr, or single column c|Δ𝜽=𝜽c𝑨1𝒆c[𝑨1]ccconditionalremove single row rΔ𝜽tensor-productsuperscript𝑮1subscript𝒆𝑟subscript𝜽𝑟subscriptdelimited-[]superscript𝑮1𝑟𝑟conditional, or single column cΔ𝜽tensor-productsubscript𝜽𝑐superscript𝑨1subscript𝒆𝑐subscriptdelimited-[]superscript𝑨1𝑐𝑐\displaystyle\begin{split}\text{remove single row $r$: }{\color[rgb]{1,1,1}% \Big{|}}\Delta{\bm{\theta}}&=-\frac{{\bm{G}}^{-1}{\bm{e}}_{r}\otimes{\bm{% \theta}}_{r}}{[{\bm{G}}^{-1}]_{rr}}\end{split}\begin{split}\text{, or single % column $c$: }{\color[rgb]{1,1,1}\Big{|}}\Delta{\bm{\theta}}&=-\frac{{\bm{% \theta}}_{c}\otimes{\bm{A}}^{-1}{\bm{e}}_{c}}{[{\bm{A}}^{-1}]_{cc}}\end{split}start_ROW start_CELL remove single row italic_r : | roman_Δ bold_italic_θ end_CELL start_CELL = - divide start_ARG bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ⊗ bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT end_ARG start_ARG [ bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_r italic_r end_POSTSUBSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL , or single column italic_c : | roman_Δ bold_italic_θ end_CELL start_CELL = - divide start_ARG bold_italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG start_ARG [ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_c italic_c end_POSTSUBSCRIPT end_ARG end_CELL end_ROW (13)

in accordance to independent structured updates in Wang et al. (2019), for convolutional filters. We have thus extended existing structured weight updates to rows and columns, and derived update rules that also consider correlation between structured groups (in our case the rows and columns).

3.5 Multi shot pruning schedule

To improve the performance-to-sparsity ratio, we propose pruning in multiple shots. We theoretically justify this multi-shot approach by noting that the surrogate loss landscape q𝑞qitalic_q relies on a Taylor expansion (eq. 3) that only holds locally and thus becomes unreliable for larger jumps Δ𝜽Δ𝜽\Delta{\bm{\theta}}roman_Δ bold_italic_θ in parameter space. We mitigate this by pruning in multiple T>1𝑇1T{>}1italic_T > 1 shots, t[1,2,,T]𝑡12𝑇t\in[1,2,\ldots,T]italic_t ∈ [ 1 , 2 , … , italic_T ], each resulting in a smaller weight update Δ𝜽Δ𝜽\Delta{\bm{\theta}}roman_Δ bold_italic_θ after which the curvature of the loss surface can be re-estimated. When pruning to target size α𝛼\alphaitalic_α, ie. removing 1α1𝛼1{-}\alpha1 - italic_α of total weights, we choose a schedule αtsubscript𝛼𝑡\alpha_{t}italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT starting at α0=1subscript𝛼01\alpha_{0}=1italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1 and ends with αT=αsubscript𝛼𝑇𝛼\alpha_{T}{=}\alphaitalic_α start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT = italic_α, such that after T𝑇Titalic_T shots, exactly α𝛼\alphaitalic_α fraction of the total weight remain. Empirically, we find that a linear schedule for αtsubscript𝛼𝑡\alpha_{t}italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, as formulated in section 4, monotonically improves pruning performance with more shots, and that higher sparsity levels typically require more shots (see section F.1). Multi-shot pruning allows one to spend (linearly in T𝑇Titalic_T) more computation to improve the final compression performance.

3.6 Interleaved low-rank first-order corrections

We propose optional interleaved low-rank first-order corrections to further improve compression performance. So far, we assumed parameters are in a local optimum when finding a closed-form solution to the quadratic constraint problem. In practice, however, this assumption likely does not hold since (i) the neural network may not be optimised to the minimum, (ii) a different loss may be used for compression than used for training, or (iii) we prune in multiple shots (section 3.5) inevitably causing weights to diverge from the optimum. To mitigate this, we consider first-order corrections by interleaving pruning shots with low-rank adaptations of weights 𝑾l+𝑼𝑽subscript𝑾𝑙𝑼𝑽{\bm{W}}_{l}{+}{\bm{U}}{\bm{V}}bold_italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT + bold_italic_U bold_italic_V (LoRA, by (Hu et al., 2021)), commonly used in LLM finetuning. We always absorb updates after each shot, so that the next loss estimate q𝑞qitalic_q is closer to the optimum and underlying assumptions are likely to hold more closely. By absorbing LoRA updates between shots, the sum of low-rank updates can have a higher rank than individual updates. That is, we have rank(𝑼1𝑽1+𝑼2𝑽2++𝑼T𝑽T)rank(𝑼t𝑽t)ranksuperscript𝑼1superscript𝑽1superscript𝑼2superscript𝑽2superscript𝑼𝑇superscript𝑽𝑇ranksuperscript𝑼𝑡superscript𝑽𝑡\text{rank}({\bm{U}}^{1}{\bm{V}}^{1}{+}{\bm{U}}^{2}{\bm{V}}^{2}{+}\ldots{+}{% \bm{U}}^{T}{\bm{V}}^{T})\geq\text{rank}({\bm{U}}^{t}{\bm{V}}^{t})rank ( bold_italic_U start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT bold_italic_V start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT + bold_italic_U start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_V start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + … + bold_italic_U start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_V start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) ≥ rank ( bold_italic_U start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT bold_italic_V start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) for the updates 𝑼t𝑽tsuperscript𝑼𝑡superscript𝑽𝑡{\bm{U}}^{t}{\bm{V}}^{t}bold_italic_U start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT bold_italic_V start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT at any shot t𝑡titalic_t, with equality only arising if updates lie exactly in the same subspace which is unlikely to ever occur in practice. This insight could also be used during regular LoRA finetuning and may therefore be useful outside the context of model compression to allow more expressive low-rank model adaptation, at negligible cost.

4 Results

Table 1: Structured compression of large language models on wikitext-2 data.
Test performance (PPL)
Method Target size OPT (125m) OPT (1.3b) OPT (2.7b) OPT (6.7b) Llama-v2 (7b)
Baseline 100% 27.65 14.62 12.47 10.86 5.12
Magnitude 90% 767.2 894.4 1229 3464 36746
𝑰𝑰tensor-product𝑰𝑰{\bm{I}}\otimes{\bm{I}}bold_italic_I ⊗ bold_italic_I 80% 4685 (1278) 2788 16747 347960
70% 17970 (3098) 9255 17312 41373
L-OBD 90% 33.3 20.76 17.69 27.20 14259
diag(𝑰𝑨)diagtensor-product𝑰𝑨\text{diag}({\bm{I}}\otimes{\bm{A}})diag ( bold_italic_I ⊗ bold_italic_A ) 80% 94.14 1392 3236 7570 15630
multi shot 70% 545.6 2147 7233 7628 21386
K-OBD 90% 27.97 14.68 11.96 10.53 5.48
diag(𝑮𝑨)diagtensor-product𝑮𝑨\text{diag}({\bm{G}}\otimes{\bm{A}})diag ( bold_italic_G ⊗ bold_italic_A ) 80% 29.89 15.63 12.47 11.28 9.14
multi shot 70% 36.54 18.29 14.53 13.03 15.43
60% 47.54 24.65 18.09 16.21 28.03
50% 75.95 37.68 26.68 25.54 46.64
LLM Surgeon (ours) 90% 28.29 14.73 12.00 10.82 5.43
𝑮𝑨tensor-product𝑮𝑨{\bm{G}}\otimes{\bm{A}}bold_italic_G ⊗ bold_italic_A 80% 29.37 15.27 12.37 11.22 7.29
within row/col cor. ΔΔ\Deltaroman_Δ 70% 32.46 16.60 13.16 11.83 10.85
60% 39.82 19.40 14.79 12.94 16.67
50% 51.48 23.81 18.01 15.38 25.62
LLM Surgeon (ours) 90% 28.01 14.70 12.02 10.77 5.25
𝑮𝑨tensor-product𝑮𝑨{\bm{G}}\otimes{\bm{A}}bold_italic_G ⊗ bold_italic_A 80% 28.73 15.12 12.27 11.02 6.18
full cor. ΔΔ\Deltaroman_Δ 70% 31.82 16.24 12.92 11.64 7.83
60% 38.47 18.45 14.23 12.58 10.39
50% 49.78 22.95 17.15 14.90 15.38

We compare compression performance of LLM Surgeon on language modeling tasks on OPT (Zhang et al., 2022) and Llama-v2 (Touvron et al., 2023) model families, using data from wikitext-2 dataset (section B.2). For compression, we use 128 sequences with a sequence length of 2048 tokens from the training data set and evaluate test perplexity (PPL) on the standard test split. In our experiments, we use a linear sparsity schedule αt=1t(1αT)subscript𝛼𝑡1𝑡1𝛼𝑇\alpha_{t}{=}1{-}t(\frac{1-\alpha}{T})italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 1 - italic_t ( divide start_ARG 1 - italic_α end_ARG start_ARG italic_T end_ARG ) at each shot s𝑠sitalic_s before reaching the final sparsity α𝛼\alphaitalic_α. We use 40 shots at α=0.5𝛼0.5\alpha{=}0.5italic_α = 0.5 sparsity and report intermediate compression rates, effectively using T=8𝑇8T{=}8italic_T = 8 shots for α=0.9𝛼0.9\alpha{=}0.9italic_α = 0.9, T=16𝑇16T{=}16italic_T = 16 for α=0.8𝛼0.8\alpha{=}0.8italic_α = 0.8, T=24𝑇24T{=}24italic_T = 24 for α=0.7𝛼0.7\alpha{=}0.7italic_α = 0.7, and T=32𝑇32T{=}32italic_T = 32 for α=0.6𝛼0.6\alpha{=}0.6italic_α = 0.6. We compare against magnitude pruning, L-OBD, SparseGPT and K-OBD baselines. The K-OBD and LLM Surgeon use the multi shot procedure of section 3.5 using T=40𝑇40T{=}40italic_T = 40 shots for structured pruning and T=5𝑇5T{=}5italic_T = 5 shots for semistructured and unstructured pruning. Further details are found in appendix B.

4.1 Structured Compression

Structured compression of rows and columns enables direct savings in memory and compute through a straight reduction of matrix dimensions in the model. For LLM surgeon, we consider in section 3.4 weight updates with different levels of correlations: limited to correlations within rows and columns, and correlations both within and between rows and columns. We further compare against magnitude pruning, which only uses weight magnitudes, L-OBD, which only uses activations, and K-OBD, which also uses Kronecker-factored curvature but assumes full independence and thus only prunes without updating remaining weights. We report results in table 1, and observe that more correlations results in better performance, with the largest improvements for the Llama-v2 model family.

While a 50% structured compression is not better than a smaller model of similar size, LLM Surgeon allows us to reduce model size by up to 30% with minimal loss, without training a smaller model from scratch fig. 1. In our structured compression experiments our proposed LLM Surgeon method outperforms all baselines and achieves the best performance for each compression target size.

4.2 Interleaved low-rank updates

Table 2: Structured compression of OPT-125m on wikitext-2 using interleaved LoRA updates
Target without with
Size LoRA LoRA
Pretrained 100% 27.65 23.35
LLM Surgeon 90% 28.01 24.16
(ours) 80% 28.73 25.25
𝑮𝑨tensor-product𝑮𝑨{\bm{G}}\otimes{\bm{A}}bold_italic_G ⊗ bold_italic_A 70% 31.82 28.86
full cor. ΔΔ\Deltaroman_Δ 60% 38.47 31.26
50% 49.78 36.50

Additionally, we assess compression performance in conjunction with the proposed first-order corrections using the interleaved low-rank adaptation described in section 3.6. We find that LoRA improves compression performance in the smallest 125m model, but not in larger models. We hypothesise that larger models are more prone to overfitting on the relatively few batches of wikitext-2 data used to compress the model. Nevertheless, we conclude that interleaved LoRA can be useful in cases, and recommend first using the proposed method without interleaved updates and, if enough data is available for compression, optionally using it if it improves performance.

4.3 Semi-structured Compression

For 2:4 semi-structured pruning, we compare LLM Surgeon with magnitude pruning, which only uses weight magnitudes, single-shot L-OBD, which only uses activations, and single-shot K-OBD, which also uses Kronecker-factored curvature but assumes full independence and thus only prunes without updating remaining weights as well as the recent state-of-the-art SparseGPT (Frantar & Alistarh, 2023). We report test performance after 50 % (2:4) semi-structured compression on wikitext-2 data in table 3. We empirically find that considering more weight correlations results in improved final performance after compression. Our proposed LLM Surgeon is competitive with prior work outperforming all baselines in terms of test set perplexity (PPL).

Table 3: Semi-structured 2:4 compression for large language models on wikitext-2 data.
Target Test performance (PPL)
Method 𝑭𝑭absent{\bm{F}}\approx\ bold_italic_F ≈ size OPT (125m) OPT (1.3b) OPT (2.7b) OPT (6.7b)
Baseline 100% 27.65 14.62 12.47 10.86
Magnitude 𝑰𝑰tensor-product𝑰𝑰{\bm{I}}\otimes{\bm{I}}bold_italic_I ⊗ bold_italic_I 50% 342.04 379.57 1106.01 187.29
L-OBD diag(𝑰𝑨)diagtensor-product𝑰𝑨\text{diag}({\bm{I}}\otimes{\bm{A}})diag ( bold_italic_I ⊗ bold_italic_A ) 50% 87.26 44.92 41.40 27.36
K-OBD diag(𝑮𝑨)diagtensor-product𝑮𝑨\text{diag}({\bm{G}}\otimes{\bm{A}})diag ( bold_italic_G ⊗ bold_italic_A ) 50% 68.74 27.22 20.23 15.55
SparseGPT 𝑰𝑨tensor-product𝑰𝑨{\bm{I}}\otimes{\bm{A}}bold_italic_I ⊗ bold_italic_A 50% 45.51 29.44 14.92 13.01
LLM Surgeon (ours) 𝑮𝑨tensor-product𝑮𝑨{\bm{G}}\otimes{\bm{A}}bold_italic_G ⊗ bold_italic_A 50% 44.64 25.10 14.64 12.10

4.4 Unstructured Compression

For unstructured pruning, we repeat the same experiments as structured pruning case described in section 4.1. In table 4, we report final test performance in terms of perplexity (PPL) on wikitext-2 after compressing LLMs of different sizes of OPT and Llama-v2 family. Overall, we find that methods with more accurate approximations of the curvature landscape and that account for more correlations perform better. The proposed LLM Surgeon outperforms all baselines, reaching the highest test performance across target sizes.

Table 4: Unstructured compression of large language models on wikitext-2 data.
Target Test performance (PPL)
Method size OPT (125m) OPT (1.3b) OPT (2.7b) OPT (6.7b) Llama-v2 (7b)
Baseline 100% 27.65 14.62 12.47 10.86 5.12
Magnitude 90% 27.62 14.69 12.60 10.88 5.18
𝑰𝑰tensor-product𝑰𝑰{\bm{I}}\otimes{\bm{I}}bold_italic_I ⊗ bold_italic_I 80% 28.53 15.68 13.18 11.26 5.37
70% 52.88 140.2 15.22 12.22 6.03
L-OBD 90% 29.70 16.24 14.44 13.43 6.09
diag(𝑰𝑨)diagtensor-product𝑰𝑨\text{diag}({\bm{I}}\otimes{\bm{A}})diag ( bold_italic_I ⊗ bold_italic_A ) 80% 32.18 21.92 23.35 39.85 116.2
single shot 70% 49.08 204.7 274.8 810.4 6549
K-OBD 90% 27.64 14.62 12.09 36.89 5.13
𝑮𝑨tensor-product𝑮𝑨{\bm{G}}\otimes{\bm{A}}bold_italic_G ⊗ bold_italic_A 80% 27.62 14.37 130220 39928 5.19
single shot 70% 27.92 220.1 23097 19506 5.60
60% 29.24 13783 10331 33896 9.20
50% 34.43 7311 10495 91506 118.6
SparseGPT 90% 27.93 14.69 12.00 10.86 5.49
𝑰𝑨tensor-product𝑰𝑨{\bm{I}}\otimes{\bm{A}}bold_italic_I ⊗ bold_italic_A 80% 28.18 15.07 12.05 10.86 5.58
70% 28.93 22.77 12.17 10.89 5.71
60% 30.20 25.07 12.37 10.98 5.94
50% 33.17 26.77 12.88 11.92 6.51
LLM Surgeon (ours) 90% 27.69 14.62 12.01 10.86 5.13
𝑮1𝑨1tensor-productsubscript𝑮1subscript𝑨1{\bm{G}}_{1}\otimes{\bm{A}}_{1}bold_italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊗ bold_italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT 80% 27.83 14.66 12.14 10.87 5.20
full cor. ΔΔ\Deltaroman_Δ 70% 28.35 14.81 12.25 10.82 5.36
multi shot 60% 28.98 14.91 12.28 10.83 5.66
50% 30.30 15.47 12.68 10.97 6.08

4.5 Learned sparsity structure

The proposed method can dynamically allocate sparsity across layers through global thresholds described in section 3.3. In Fig. 4, we compare total allocated sparsity levels per layer depth and per layer type after compressing a pretrained OPT-125m model. We find that the LLM Surgeon prunes relatively more in the first layer and less in middle layers. Further, we observe that a larger portions of weights are removed in fully-connected compared to attention blocks, but deviations are less compared to other methods. Dynamic allocation allows for most pruning where it hurts least.

Refer to caption
Figure 4: Sparsity levels obtained with structured pruning on OPT-125m by layer depth and type.

5 Conclusion

In this work, we have introduced the LLM Surgeon algorithm for unstructured, semi-structured and structured compression of neural networks. The work builds upon classic neural network compression approaches originating from the early 1990’s that aim to find optimal pruning by expanding the curvature of the loss landscape. The method utilises modern Fisher approximations to scale accurate pruning to the realm of large language models (LLMs) with billions of parameters, while remaining practical in both memory and compute. Unlike most prior work on data-based LLM compression, we not only use weight magnitude and activations from forward passes, but also use gradient information from backward passes to relate weight removal costs to the true final objective. We improve upon prior work through more accurate approximations to the loss landscape curvature and considering more weight correlations to update remaining weights. Increasing the number of correlations and using multiple shots allows us trading off additional compute for better accuracy. Lastly, LLM Surgeon gives the first practically usable results for structured pruning of LLMs and achieves state-of-the-art results in unstructured and semi-structured large language model pruning.

References

  • Ashkboos et al. (2024) Saleh Ashkboos, Maximilian L Croci, Marcelo Gennari do Nascimento, Torsten Hoefler, and James Hensman. Slicegpt: Compress large language models by deleting rows and columns. arXiv preprint arXiv:2401.15024, 2024.
  • Bishop & Nasrabadi (2006) Christopher M Bishop and Nasser M Nasrabadi. Pattern recognition and machine learning, volume 4. Springer, 2006.
  • Botev et al. (2017) Aleksandar Botev, Hippolyt Ritter, and David Barber. Practical gauss-newton optimisation for deep learning. In International Conference on Machine Learning, pp. 557–565. PMLR, 2017.
  • Eschenhagen et al. (2024) Runa Eschenhagen, Alexander Immer, Richard Turner, Frank Schneider, and Philipp Hennig. Kronecker-factored approximate curvature for modern neural network architectures. Advances in Neural Information Processing Systems, 36, 2024.
  • Frantar & Alistarh (2022) Elias Frantar and Dan Alistarh. Optimal brain compression: A framework for accurate post-training quantization and pruning. Advances in Neural Information Processing Systems, 35:4475–4488, 2022.
  • Frantar & Alistarh (2023) Elias Frantar and Dan Alistarh. Sparsegpt: Massive language models can be accurately pruned in one-shot. 2023.
  • Golub & Van Loan (2013) Gene H Golub and Charles F Van Loan. Matrix computations. JHU press, 2013.
  • Hassibi & Stork (1992) Babak Hassibi and David Stork. Second order derivatives for network pruning: Optimal brain surgeon. Advances in neural information processing systems, 5, 1992.
  • Hu et al. (2021) Edward J Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, and Weizhu Chen. Lora: Low-rank adaptation of large language models. arXiv preprint arXiv:2106.09685, 2021.
  • Hubara et al. (2021) Itay Hubara, Yury Nahshan, Yair Hanani, Ron Banner, and Daniel Soudry. Accurate post training quantization with small calibration sets. In International Conference on Machine Learning, pp. 4466–4475. PMLR, 2021.
  • Immer et al. (2022) Alexander Immer, Tycho van der Ouderaa, Gunnar Rätsch, Vincent Fortuin, and Mark van der Wilk. Invariance learning in deep neural networks with differentiable laplace approximations. Advances in Neural Information Processing Systems, 35:12449–12463, 2022.
  • Koroko et al. (2022) Abdoulaye Koroko, Ani Anciaux-Sedrakian, Ibtihel Ben Gharbia, Valérie Garès, Mounir Haddou, and Quang Huy Tran. Efficient approximations of the fisher matrix in neural networks using kronecker product singular value decomposition. arXiv preprint arXiv:2201.10285, 2022.
  • Kunstner et al. (2019) Frederik Kunstner, Philipp Hennig, and Lukas Balles. Limitations of the empirical fisher approximation for natural gradient descent. Advances in neural information processing systems, 32, 2019.
  • Kurtic et al. (2022) Eldar Kurtic, Daniel Campos, Tuan Nguyen, Elias Frantar, Mark Kurtz, Benjamin Fineran, Michael Goin, and Dan Alistarh. The optimal bert surgeon: Scalable and accurate second-order pruning for large language models. arXiv preprint arXiv:2203.07259, 2022.
  • LeCun et al. (1989) Yann LeCun, John Denker, and Sara Solla. Optimal brain damage. Advances in neural information processing systems, 2, 1989.
  • Louizos et al. (2017) Christos Louizos, Max Welling, and Diederik P Kingma. Learning sparse neural networks through l_0𝑙_0l\_0italic_l _ 0 regularization. arXiv preprint arXiv:1712.01312, 2017.
  • MacKay (2003) David JC MacKay. Information theory, inference and learning algorithms. Cambridge university press, 2003.
  • Martens & Grosse (2015) James Martens and Roger Grosse. Optimizing neural networks with kronecker-factored approximate curvature. In International conference on machine learning, pp. 2408–2417. PMLR, 2015.
  • Merity et al. (2016) Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture models. arXiv preprint arXiv:1609.07843, 2016.
  • Sun et al. (2023) Mingjie Sun, Zhuang Liu, Anna Bair, and J Zico Kolter. A simple and effective pruning approach for large language models. arXiv preprint arXiv:2306.11695, 2023.
  • Touvron et al. (2023) Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, et al. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
  • van der Ouderaa et al. (2023) Tycho van der Ouderaa, Alexander Immer, and Mark van der Wilk. Learning layer-wise equivariances automatically using gradients. Advances in Neural Information Processing Systems, 36, 2023.
  • Vaswani et al. (2017) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • Wang et al. (2019) Chaoqi Wang, Roger Grosse, Sanja Fidler, and Guodong Zhang. Eigendamage: Structured pruning in the kronecker-factored eigenbasis. In International conference on machine learning, pp. 6566–6575. PMLR, 2019.
  • Wikipedia (2004) Wikipedia. Wikipedia. PediaPress, 2004.
  • Wolf et al. (2019) Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Clement Delangue, Anthony Moi, Pierric Cistac, Tim Rault, Rémi Louf, Morgan Funtowicz, et al. Huggingface’s transformers: State-of-the-art natural language processing. arXiv preprint arXiv:1910.03771, 2019.
  • Zhang et al. (2022) Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen, Christopher Dewan, Mona Diab, Xian Li, Xi Victoria Lin, et al. Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.
  • Zhou et al. (2021) Aojun Zhou, Yukun Ma, Junnan Zhu, Jianbo Liu, Zhijie Zhang, Kun Yuan, Wenxiu Sun, and Hongsheng Li. Learning n: m fine-grained structured sparse neural networks from scratch. arXiv preprint arXiv:2102.04010, 2021.

Appendix A Derivations for pruning

Given that we use a Gaussian approximation of our loss pq=𝒩𝑝𝑞𝒩p{\approx}q=\mathcal{N}italic_p ≈ italic_q = caligraphic_N through a quadratic approximation of our log likelihood logp12(𝜽*)T𝑭𝜽*𝑝12superscriptsuperscript𝜽𝑇𝑭superscript𝜽-\log p{\approx}\frac{1}{2}({\bm{\theta}}^{*})^{T}{\bm{F}}{\bm{\theta}}^{*}- roman_log italic_p ≈ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_F bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT, the most optimal compression becomes the solution to the following constrained optimization problem:

argminΔ𝜽*subscriptargminΔsuperscript𝜽\displaystyle\operatorname*{arg\,min}_{\Delta{\bm{\theta}}^{*}}\text{ }start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT roman_Δ bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT 12Δ(𝜽*)T𝑭Δ𝜽*12Δsuperscriptsuperscript𝜽𝑇𝑭Δsuperscript𝜽\displaystyle\frac{1}{2}\Delta({\bm{\theta}}^{*})^{T}{\bm{F}}\Delta{\bm{\theta% }}^{*}divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_Δ ( bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_F roman_Δ bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT (14)
s.t. 𝒆kTΔ𝜽*+𝒆kT𝜽*=0,kQformulae-sequencesuperscriptsubscript𝒆𝑘𝑇Δsuperscript𝜽superscriptsubscript𝒆𝑘𝑇superscript𝜽0for-all𝑘𝑄\displaystyle{\bm{e}}_{k}^{T}\Delta{\bm{\theta}}^{*}+{\bm{e}}_{k}^{T}{\bm{% \theta}}^{*}=0,\forall k\in Qbold_italic_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_Δ bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT + bold_italic_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT = 0 , ∀ italic_k ∈ italic_Q

where 𝒬𝒬\mathcal{Q}caligraphic_Q is the set of Q𝑄Qitalic_Q indices that are pruned.

A.1 General solution

Following (Kurtic et al., 2022), we denote pruned elements as 𝑬K=[𝒆q1𝒆q2]T[0,1]|Q|×Psubscript𝑬𝐾superscriptmatrixsubscript𝒆subscript𝑞1subscript𝒆subscript𝑞2𝑇superscript01𝑄𝑃{\bm{E}}_{K}=\begin{bmatrix}{\bm{e}}_{q_{1}}&{\bm{e}}_{q_{2}}&\ldots\end{% bmatrix}^{T}\in[0,1]^{|Q|\times P}bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = [ start_ARG start_ROW start_CELL bold_italic_e start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_CELL start_CELL bold_italic_e start_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_CELL start_CELL … end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT | italic_Q | × italic_P end_POSTSUPERSCRIPT and use the fact that solving eq. 6 through use of Langrange multipliers gives the general closed-form solution for cost \mathcal{L}caligraphic_L and weight update Δ𝜽Δ𝜽\Delta{\bm{\theta}}roman_Δ bold_italic_θ:

\displaystyle\mathcal{L}caligraphic_L =12(𝑬K𝜽*)T(𝑬K𝑭1𝑬KT)1𝑬K𝜽*absent12superscriptsubscript𝑬𝐾superscript𝜽𝑇superscriptsubscript𝑬𝐾superscript𝑭1superscriptsubscript𝑬𝐾𝑇1subscript𝑬𝐾superscript𝜽\displaystyle=\frac{1}{2}({\bm{E}}_{K}{\bm{\theta}}^{*})^{T}\left({\bm{E}}_{K}% {\bm{F}}^{-1}{\bm{E}}_{K}^{T}\right)^{-1}{\bm{E}}_{K}{\bm{\theta}}^{*}= divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT (15)
Δ𝜽*Δsuperscript𝜽\displaystyle\Delta{\bm{\theta}}^{*}roman_Δ bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT =𝑭1𝑬KT(𝑬K𝑭1𝑬KT)1𝑬K𝜽*absentsuperscript𝑭1superscriptsubscript𝑬𝐾𝑇superscriptsubscript𝑬𝐾superscript𝑭1superscriptsubscript𝑬𝐾𝑇1subscript𝑬𝐾superscript𝜽\displaystyle={\bm{F}}^{-1}{\bm{E}}_{K}^{T}\left({\bm{E}}_{K}{\bm{F}}^{-1}{\bm% {E}}_{K}^{T}\right)^{-1}{\bm{E}}_{K}{\bm{\theta}}^{*}= bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT (16)

A.2 Removing a single element

Optimal brain surgeon (OBS)

To remove a single element with index q𝑞qitalic_q, we simply set 𝑬K=𝒆kTsubscript𝑬𝐾superscriptsubscript𝒆𝑘𝑇{\bm{E}}_{K}={\bm{e}}_{k}^{T}bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = bold_italic_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT:

=12(𝑬K𝜽*)T(𝑬K𝑭1𝑬KT)1𝑬K𝜽=12𝜽kT1[𝑭1]kk𝜽k=12(𝜽k)2[𝑭1]kkΔ𝜽=𝑭1𝑬KT(𝑬K𝑭1𝑬KT)1𝑬K𝜽=𝑭1𝒆k(𝒆kT𝑭1𝒆k)1𝒆KT𝜽=𝜽k[𝑭1]kk𝑭1𝒆k12superscriptsubscript𝑬𝐾superscript𝜽𝑇superscriptsubscript𝑬𝐾superscript𝑭1superscriptsubscript𝑬𝐾𝑇1subscript𝑬𝐾𝜽12superscriptsubscript𝜽𝑘𝑇1subscriptdelimited-[]superscript𝑭1𝑘𝑘subscript𝜽𝑘12superscriptsubscript𝜽𝑘2subscriptdelimited-[]superscript𝑭1𝑘𝑘Δ𝜽superscript𝑭1superscriptsubscript𝑬𝐾𝑇superscriptsubscript𝑬𝐾superscript𝑭1superscriptsubscript𝑬𝐾𝑇1subscript𝑬𝐾𝜽superscript𝑭1subscript𝒆𝑘superscriptsuperscriptsubscript𝒆𝑘𝑇superscript𝑭1subscript𝒆𝑘1subscriptsuperscript𝒆𝑇𝐾𝜽subscript𝜽𝑘subscriptdelimited-[]superscript𝑭1𝑘𝑘superscript𝑭1subscript𝒆𝑘\displaystyle\begin{split}\mathcal{L}&=\frac{1}{2}({\bm{E}}_{K}{\bm{\theta}}^{% *})^{T}\left({\bm{E}}_{K}{\bm{F}}^{-1}{\bm{E}}_{K}^{T}\right)^{-1}{\bm{E}}_{K}% {\bm{\theta}}\\ &=\frac{1}{2}{\bm{\theta}}_{k}^{T}\frac{1}{[{\bm{F}}^{-1}]_{kk}}{\bm{\theta}}_% {k}\\ &=\frac{1}{2}\frac{({\bm{\theta}}_{k})^{2}}{[{\bm{F}}^{-1}]_{kk}}\\ \end{split}\text{, }\begin{split}\Delta{\bm{\theta}}&=-{\bm{F}}^{-1}{\bm{E}}_{% K}^{T}\left({\bm{E}}_{K}{\bm{F}}^{-1}{\bm{E}}_{K}^{T}\right)^{-1}{\bm{E}}_{K}{% \bm{\theta}}\\ &=-{\bm{F}}^{-1}{\bm{e}}_{k}\left({\bm{e}}_{k}^{T}{\bm{F}}^{-1}{\bm{e}}_{k}% \right)^{-1}{\bm{e}}^{T}_{K}{\bm{\theta}}\\ &=-\frac{{\bm{\theta}}_{k}}{[{\bm{F}}^{-1}]_{kk}}{\bm{F}}^{-1}{\bm{e}}_{k}\\ \end{split}start_ROW start_CELL caligraphic_L end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG [ bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_k italic_k end_POSTSUBSCRIPT end_ARG bold_italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG 2 end_ARG divide start_ARG ( bold_italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG [ bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_k italic_k end_POSTSUBSCRIPT end_ARG end_CELL end_ROW , start_ROW start_CELL roman_Δ bold_italic_θ end_CELL start_CELL = - bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = - bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_e start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = - divide start_ARG bold_italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG [ bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_k italic_k end_POSTSUBSCRIPT end_ARG bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL end_ROW (17)

which exactly correspond to the loss and updates of optimal brain surgeon (Hassibi & Stork, 1992).

Optimal brain damage (OBD)

We may also consider that elements are independent and the Fisher is diagonal. After noting that this implies that diagonal elements of the inverse Fisher are scalar inverses of elements in the Fisher [𝑭1]kk=1[𝑭]kksubscriptdelimited-[]superscript𝑭1𝑘𝑘1subscriptdelimited-[]𝑭𝑘𝑘[{\bm{F}}^{-1}]_{kk}=\frac{1}{[{\bm{F}}]_{kk}}[ bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_k italic_k end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG [ bold_italic_F ] start_POSTSUBSCRIPT italic_k italic_k end_POSTSUBSCRIPT end_ARG, the formula’s simplify to:

=[𝑭]kk(𝜽k)2Δ𝜽=𝜽k𝒆ksubscriptdelimited-[]𝑭𝑘𝑘superscriptsubscript𝜽𝑘2Δ𝜽subscript𝜽𝑘subscript𝒆𝑘\displaystyle\begin{split}\mathcal{L}&=[{\bm{F}}]_{kk}({\bm{\theta}}_{k})^{2}% \end{split}\text{, }\begin{split}\Delta{\bm{\theta}}&=-{\bm{\theta}}_{k}{\bm{e% }}_{k}\end{split}start_ROW start_CELL caligraphic_L end_CELL start_CELL = [ bold_italic_F ] start_POSTSUBSCRIPT italic_k italic_k end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_CELL end_ROW , start_ROW start_CELL roman_Δ bold_italic_θ end_CELL start_CELL = - bold_italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL end_ROW (18)

which exactly corresponds to loss and updates of optimal brain damage (LeCun et al., 1989).

Vectorised

For implementation purposes, it might be convenient to have a vectorised notation 𝜽RCsubscript𝜽superscript𝑅𝐶\mathcal{L}_{\bm{\theta}}\in\mathbb{R}^{RC}caligraphic_L start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R italic_C end_POSTSUPERSCRIPT or 𝑾R×Csubscript𝑾superscript𝑅𝐶\mathcal{L}_{{\bm{W}}}\in\mathbb{R}^{R\times C}caligraphic_L start_POSTSUBSCRIPT bold_italic_W end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R × italic_C end_POSTSUPERSCRIPT to calculate all expected losses in parallel:

For OBD: |For OBS: |𝜽=12𝜽*𝜽*diag(𝑭)𝜽=12𝜽*𝜽*diag(𝑭1),𝑾=12𝑾*𝑾*mat(diag(𝑭))𝑾=12𝑾*𝑾*mat(diag(𝑭1))For OBD: For OBS: subscript𝜽direct-product12superscript𝜽superscript𝜽diag𝑭subscript𝜽direct-product12superscript𝜽superscript𝜽diagsuperscript𝑭1subscript𝑾direct-product12superscript𝑾superscript𝑾matdiag𝑭subscript𝑾direct-product12superscript𝑾superscript𝑾matdiagsuperscript𝑭1\displaystyle\begin{split}\text{For OBD: }{\color[rgb]{1,1,1}\Big{|}}&\\ \text{For OBS: }{\color[rgb]{1,1,1}\Big{|}}&\end{split}\begin{split}\mathcal{L% }_{{\bm{\theta}}}&=\frac{1}{2}{\bm{\theta}}^{*}\odot{\bm{\theta}}^{*}\odot% \text{diag}({\bm{F}})\\ \mathcal{L}_{{\bm{\theta}}}&=\frac{1}{2}{\bm{\theta}}^{*}\odot{\bm{\theta}}^{*% }\oslash\text{diag}({\bm{F}}^{-1})\end{split},\hskip 20.00003pt\begin{split}% \mathcal{L}_{{\bm{W}}}&=\frac{1}{2}{\bm{W}}^{*}\odot{\bm{W}}^{*}\odot\text{mat% }(\text{diag}({\bm{F}}))\\ \mathcal{L}_{{\bm{W}}}&=\frac{1}{2}{\bm{W}}^{*}\odot{\bm{W}}^{*}\oslash\text{% mat}(\text{diag}({\bm{F}}^{-1}))\end{split}start_ROW start_CELL For OBD: | end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL For OBS: | end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL caligraphic_L start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ⊙ bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ⊙ diag ( bold_italic_F ) end_CELL end_ROW start_ROW start_CELL caligraphic_L start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ⊙ bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ⊘ diag ( bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) end_CELL end_ROW , start_ROW start_CELL caligraphic_L start_POSTSUBSCRIPT bold_italic_W end_POSTSUBSCRIPT end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_W start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ⊙ bold_italic_W start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ⊙ mat ( diag ( bold_italic_F ) ) end_CELL end_ROW start_ROW start_CELL caligraphic_L start_POSTSUBSCRIPT bold_italic_W end_POSTSUBSCRIPT end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_W start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ⊙ bold_italic_W start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ⊘ mat ( diag ( bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) ) end_CELL end_ROW (19)

A.3 Removing a single row or column

Structured OBS

If we consider the approximation 𝑭𝑮𝑨𝑭tensor-product𝑮𝑨{\bm{F}}\approx{\bm{G}}\otimes{\bm{A}}bold_italic_F ≈ bold_italic_G ⊗ bold_italic_A with known inverse (𝑮𝑨)1=𝑮1𝑨1superscripttensor-product𝑮𝑨1tensor-productsuperscript𝑮1superscript𝑨1({\bm{G}}\otimes{\bm{A}})^{-1}={\bm{G}}^{-1}\otimes{\bm{A}}^{-1}( bold_italic_G ⊗ bold_italic_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT, then to remove a row at index r[0,R]𝑟0𝑅r\in[0,R]italic_r ∈ [ 0 , italic_R ], we must take into account correlations within elements of that row. That is, we write matrix 𝑬K=(𝒆rT𝑰)subscript𝑬𝐾tensor-productsubscriptsuperscript𝒆𝑇𝑟𝑰{\bm{E}}_{K}=({\bm{e}}^{T}_{r}\otimes{\bm{I}})bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = ( bold_italic_e start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ⊗ bold_italic_I ) containing one-hot row-vectors for all elements in row r𝑟ritalic_r. Plugging into the general solution eq. 7, we find:

\displaystyle\mathcal{L}caligraphic_L =12𝑬K𝜽T(𝑬K𝑭1𝑬KT)1𝑬K𝜽*absent12subscript𝑬𝐾superscript𝜽𝑇superscriptsubscript𝑬𝐾superscript𝑭1superscriptsubscript𝑬𝐾𝑇1subscript𝑬𝐾superscript𝜽\displaystyle=\frac{1}{2}{\bm{E}}_{K}{\bm{\theta}}^{T}\left({\bm{E}}_{K}{\bm{F% }}^{-1}{\bm{E}}_{K}^{T}\right)^{-1}{\bm{E}}_{K}{\bm{\theta}}^{*}= divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT
=12((𝒆rT𝑰)𝜽*)T((𝒆rT𝑰)(𝑮𝑨)1(𝒆rT𝑰)T)1(𝒆rT𝑰)𝜽*absent12superscripttensor-productsuperscriptsubscript𝒆𝑟𝑇𝑰superscript𝜽𝑇superscripttensor-productsuperscriptsubscript𝒆𝑟𝑇𝑰superscripttensor-product𝑮𝑨1superscripttensor-productsuperscriptsubscript𝒆𝑟𝑇𝑰𝑇1tensor-productsuperscriptsubscript𝒆𝑟𝑇𝑰superscript𝜽\displaystyle=\frac{1}{2}(({\bm{e}}_{r}^{T}\otimes{\bm{I}}){\bm{\theta}}^{*})^% {T}\left(({\bm{e}}_{r}^{T}\otimes{\bm{I}})({\bm{G}}\otimes{\bm{A}})^{-1}({\bm{% e}}_{r}^{T}\otimes{\bm{I}})^{T}\right)^{-1}({\bm{e}}_{r}^{T}\otimes{\bm{I}}){% \bm{\theta}}^{*}= divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( ( bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_I ) bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( ( bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_I ) ( bold_italic_G ⊗ bold_italic_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_I ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_I ) bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT
=12𝜽rT(𝒆rT𝑮1𝒆r𝑰𝑨1𝑰)1𝜽rabsent12superscriptsubscript𝜽𝑟𝑇superscripttensor-productsuperscriptsubscript𝒆𝑟𝑇superscript𝑮1subscript𝒆𝑟𝑰superscript𝑨1𝑰1subscript𝜽𝑟\displaystyle=\frac{1}{2}{\bm{\theta}}_{r}^{T}\left({\bm{e}}_{r}^{T}{\bm{G}}^{% -1}{\bm{e}}_{r}\otimes{\bm{I}}{\bm{A}}^{-1}{\bm{I}}\right)^{-1}{\bm{\theta}}_{r}= divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ⊗ bold_italic_I bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT
=12𝜽T(𝒆rT𝑰)([[𝑮1]rr]𝑨1)1(𝒆r𝑰)𝜽rabsent12superscript𝜽𝑇tensor-productsuperscriptsubscript𝒆𝑟𝑇𝑰superscripttensor-productdelimited-[]subscriptdelimited-[]superscript𝑮1𝑟𝑟superscript𝑨11tensor-productsubscript𝒆𝑟𝑰subscript𝜽𝑟\displaystyle=\frac{1}{2}{\bm{\theta}}^{T}({\bm{e}}_{r}^{T}\otimes{\bm{I}})% \left(\left[[{\bm{G}}^{-1}]_{rr}\right]\otimes{\bm{A}}^{-1}\right)^{-1}({\bm{e% }}_{r}\otimes{\bm{I}}){\bm{\theta}}_{r}= divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_θ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_I ) ( [ [ bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_r italic_r end_POSTSUBSCRIPT ] ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ⊗ bold_italic_I ) bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT
=12𝜽rT𝑨𝜽r[𝑮1]rrabsent12superscriptsubscript𝜽𝑟𝑇𝑨subscript𝜽𝑟subscriptdelimited-[]superscript𝑮1𝑟𝑟\displaystyle=\frac{1}{2}\frac{{\bm{\theta}}_{r}^{T}{\bm{A}}{\bm{\theta}}_{r}}% {[{\bm{G}}^{-1}]_{rr}}= divide start_ARG 1 end_ARG start_ARG 2 end_ARG divide start_ARG bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_A bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT end_ARG start_ARG [ bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_r italic_r end_POSTSUBSCRIPT end_ARG (20)

where we write 𝜽r=𝒆rT𝑾*Csubscript𝜽𝑟superscriptsubscript𝒆𝑟𝑇superscript𝑾superscript𝐶{\bm{\theta}}_{r}={\bm{e}}_{r}^{T}{\bm{W}}^{*}{\in}\mathbb{R}^{C}bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT = bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_W start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT for the r𝑟ritalic_r’th row-vector in 𝑾𝑾{\bm{W}}bold_italic_W. Similarly, we obtain the associated weight update:

Δ𝜽Δ𝜽\displaystyle\Delta{\bm{\theta}}roman_Δ bold_italic_θ =𝑭1𝑬KT(𝑬K𝑭1𝑬KT)1𝑬K𝜽*absentsuperscript𝑭1superscriptsubscript𝑬𝐾𝑇superscriptsubscript𝑬𝐾superscript𝑭1superscriptsubscript𝑬𝐾𝑇1subscript𝑬𝐾superscript𝜽\displaystyle=-{\bm{F}}^{-1}{\bm{E}}_{K}^{T}\left({\bm{E}}_{K}{\bm{F}}^{-1}{% \bm{E}}_{K}^{T}\right)^{-1}{\bm{E}}_{K}{\bm{\theta}}^{*}= - bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT
=(𝑮𝑨)1(𝒆rT𝑰)T((𝒆rT𝑰)(𝑮𝑨)1(𝒆rT𝑰)T)1(𝒆rT𝑰)𝜽*absentsuperscripttensor-product𝑮𝑨1superscripttensor-productsuperscriptsubscript𝒆𝑟𝑇𝑰𝑇superscripttensor-productsuperscriptsubscript𝒆𝑟𝑇𝑰superscripttensor-product𝑮𝑨1superscripttensor-productsuperscriptsubscript𝒆𝑟𝑇𝑰𝑇1tensor-productsuperscriptsubscript𝒆𝑟𝑇𝑰superscript𝜽\displaystyle=-\left({\bm{G}}\otimes{\bm{A}}\right)^{-1}({\bm{e}}_{r}^{T}% \otimes{\bm{I}})^{T}\left(({\bm{e}}_{r}^{T}\otimes{\bm{I}})\left({\bm{G}}% \otimes{\bm{A}}\right)^{-1}({\bm{e}}_{r}^{T}\otimes{\bm{I}})^{T}\right)^{-1}({% \bm{e}}_{r}^{T}\otimes{\bm{I}}){\bm{\theta}}^{*}= - ( bold_italic_G ⊗ bold_italic_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_I ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( ( bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_I ) ( bold_italic_G ⊗ bold_italic_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_I ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_I ) bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT
=(𝑮1𝑨1)(𝒆r𝑰)(𝒆rT𝑮1𝒆r𝑨1)1𝜽rabsenttensor-productsuperscript𝑮1superscript𝑨1tensor-productsubscript𝒆𝑟𝑰superscripttensor-productsuperscriptsubscript𝒆𝑟𝑇superscript𝑮1subscript𝒆𝑟superscript𝑨11subscript𝜽𝑟\displaystyle=-\left({\bm{G}}^{-1}\otimes{\bm{A}}^{-1}\right)({\bm{e}}_{r}% \otimes{\bm{I}})\left({\bm{e}}_{r}^{T}{\bm{G}}^{-1}{\bm{e}}_{r}\otimes{\bm{A}}% ^{-1}\right)^{-1}{\bm{\theta}}_{r}= - ( bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) ( bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ⊗ bold_italic_I ) ( bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT
=1[𝑮1]rr(𝑮1𝒆r𝑨1𝑰𝑨1𝑰)𝜽rabsent1subscriptdelimited-[]superscript𝑮1𝑟𝑟tensor-productsuperscript𝑮1subscript𝒆𝑟superscript𝑨1𝑰superscript𝑨1𝑰subscript𝜽𝑟\displaystyle=-\frac{1}{[{\bm{G}}^{-1}]_{rr}}\left({\bm{G}}^{-1}{\bm{e}}_{r}% \otimes{\bm{A}}^{-1}{\bm{I}}{\bm{A}}^{-1}{\bm{I}}\right){\bm{\theta}}_{r}= - divide start_ARG 1 end_ARG start_ARG [ bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_r italic_r end_POSTSUBSCRIPT end_ARG ( bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_I bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_I ) bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT
=𝑮1𝒆r𝜽r[𝑮1]rrabsenttensor-productsuperscript𝑮1subscript𝒆𝑟subscript𝜽𝑟subscriptdelimited-[]superscript𝑮1𝑟𝑟\displaystyle=-\frac{{\bm{G}}^{-1}{\bm{e}}_{r}\otimes{\bm{\theta}}_{r}}{[{\bm{% G}}^{-1}]_{rr}}= - divide start_ARG bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ⊗ bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT end_ARG start_ARG [ bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_r italic_r end_POSTSUBSCRIPT end_ARG (21)

arriving at a similar structured pruning update as derived in (Wang et al., 2019) for convolutional filters. We can equivalently derive expected loss and update for columns, by considering 𝑬K=(𝑰𝒆cT)subscript𝑬𝐾tensor-product𝑰superscriptsubscript𝒆𝑐𝑇{\bm{E}}_{K}=({\bm{I}}\otimes{\bm{e}}_{c}^{T})bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = ( bold_italic_I ⊗ bold_italic_e start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ). If we do so, we find the structured updates for a row r𝑟ritalic_r or column c𝑐citalic_c:

Remove row r|Remove column c|=12𝜽rT𝑨𝜽r[𝑮1]rr=12𝜽cT𝑮𝜽c[𝑨1]ccΔ𝜽=𝑮1𝒆r𝜽r[𝑮1]rrΔ𝜽=𝜽c𝑨1𝒆c[𝑨1]ccRemove row rRemove column c12superscriptsubscript𝜽𝑟𝑇𝑨subscript𝜽𝑟subscriptdelimited-[]superscript𝑮1𝑟𝑟12superscriptsubscript𝜽𝑐𝑇𝑮subscript𝜽𝑐subscriptdelimited-[]superscript𝑨1𝑐𝑐Δ𝜽tensor-productsuperscript𝑮1subscript𝒆𝑟subscript𝜽𝑟subscriptdelimited-[]superscript𝑮1𝑟𝑟Δ𝜽tensor-productsubscript𝜽𝑐superscript𝑨1subscript𝒆𝑐subscriptdelimited-[]superscript𝑨1𝑐𝑐\displaystyle\begin{split}\text{Remove row $r$: }{\color[rgb]{1,1,1}\Big{|}}&% \\ \text{Remove column $c$: }{\color[rgb]{1,1,1}\Big{|}}&\end{split}\begin{split}% \mathcal{L}&=\frac{1}{2}\frac{{\bm{\theta}}_{r}^{T}{\bm{A}}{\bm{\theta}}_{r}}{% [{\bm{G}}^{-1}]_{rr}}\\ \mathcal{L}&=\frac{1}{2}\frac{{\bm{\theta}}_{c}^{T}{\bm{G}}{\bm{\theta}}_{c}}{% [{\bm{A}}^{-1}]_{cc}}\end{split}\hskip 20.00003pt\begin{split}\Delta{\bm{% \theta}}&=-\frac{{\bm{G}}^{-1}{\bm{e}}_{r}\otimes{\bm{\theta}}_{r}}{[{\bm{G}}^% {-1}]_{rr}}\\ \Delta{\bm{\theta}}&=-\frac{{\bm{\theta}}_{c}\otimes{\bm{A}}^{-1}{\bm{e}}_{c}}% {[{\bm{A}}^{-1}]_{cc}}\\ \end{split}start_ROW start_CELL Remove row italic_r : | end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL Remove column italic_c : | end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL caligraphic_L end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG 2 end_ARG divide start_ARG bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_A bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT end_ARG start_ARG [ bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_r italic_r end_POSTSUBSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL caligraphic_L end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG 2 end_ARG divide start_ARG bold_italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_G bold_italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG start_ARG [ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_c italic_c end_POSTSUBSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL roman_Δ bold_italic_θ end_CELL start_CELL = - divide start_ARG bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ⊗ bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT end_ARG start_ARG [ bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_r italic_r end_POSTSUBSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL roman_Δ bold_italic_θ end_CELL start_CELL = - divide start_ARG bold_italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG start_ARG [ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_c italic_c end_POSTSUBSCRIPT end_ARG end_CELL end_ROW (22)
Structured OBD

We may also assume that, when removing a row r𝑟ritalic_r, the individual elements within the row are also independent which would imply [𝑨]ii=1[𝑨1]iisubscriptdelimited-[]𝑨𝑖𝑖1subscriptdelimited-[]superscript𝑨1𝑖𝑖[{\bm{A}}]_{ii}=\frac{1}{[{\bm{A}}^{-1}]_{ii}}[ bold_italic_A ] start_POSTSUBSCRIPT italic_i italic_i end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG [ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_i italic_i end_POSTSUBSCRIPT end_ARG. Similarly, [𝑮]ii=1[𝑮1]iisubscriptdelimited-[]𝑮𝑖𝑖1subscriptdelimited-[]superscript𝑮1𝑖𝑖[{\bm{G}}]_{ii}=\frac{1}{[{\bm{G}}^{-1}]_{ii}}[ bold_italic_G ] start_POSTSUBSCRIPT italic_i italic_i end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG [ bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ] start_POSTSUBSCRIPT italic_i italic_i end_POSTSUBSCRIPT end_ARG when removing a column c𝑐citalic_c. Consequently, we can simplify to:

Remove row r|Remove column c|=12𝑮rr𝜽rT𝑨𝜽r=12𝑨cc𝜽cT𝑮𝜽cΔ𝜽=|𝒆r𝜽rTΔ𝜽=|𝜽c𝒆cT\displaystyle\begin{split}\text{Remove row $r$: }{\color[rgb]{1,1,1}\Big{|}}&% \\ \text{Remove column $c$: }{\color[rgb]{1,1,1}\Big{|}}&\end{split}\begin{split}% \mathcal{L}&=\frac{1}{2}{\bm{G}}_{rr}{\bm{\theta}}_{r}^{T}{\bm{A}}{\bm{\theta}% }_{r}\\ \mathcal{L}&=\frac{1}{2}{\bm{A}}_{cc}{\bm{\theta}}_{c}^{T}{\bm{G}}{\bm{\theta}% }_{c}\end{split}\hskip 20.00003pt\begin{split}\Delta{\bm{\theta}}&={\color[rgb% ]{1,1,1}\Big{|}}-{\bm{e}}_{r}{\bm{\theta}}_{r}^{T}\\ \Delta{\bm{\theta}}&={\color[rgb]{1,1,1}\Big{|}}-{\bm{\theta}}_{c}{\bm{e}}_{c}% ^{T}\\ \end{split}start_ROW start_CELL Remove row italic_r : | end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL Remove column italic_c : | end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL caligraphic_L end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_G start_POSTSUBSCRIPT italic_r italic_r end_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_A bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL caligraphic_L end_CELL start_CELL = divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_A start_POSTSUBSCRIPT italic_c italic_c end_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_G bold_italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL roman_Δ bold_italic_θ end_CELL start_CELL = | - bold_italic_e start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL roman_Δ bold_italic_θ end_CELL start_CELL = | - bold_italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_CELL end_ROW (23)

similar form to structured OBD losses and updates as derived in (Wang et al., 2019) for convolutional filters. The derivations slightly differ in that we start from the general solution eq. 8, circumventing the need to rederive a Langrange multipliers for each possible structure.

A.4 Pruning multiple (correlated) rows and columns

Let us consider the removal of Rsuperscript𝑅R^{\prime}italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT rows r1,r2,rRsubscript𝑟1subscript𝑟2superscriptsubscript𝑟𝑅r_{1},r_{2},\ldots r_{R}^{\prime}italic_r start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … italic_r start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT rows or Csuperscript𝐶C^{\prime}italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT columns with indices c1,c2,,cCsubscript𝑐1subscript𝑐2subscript𝑐superscript𝐶c_{1},c_{2},\ldots,c_{C^{\prime}}italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_c start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, with 1<R<R1superscript𝑅𝑅1{<}R^{\prime}<R1 < italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT < italic_R and 1<C<C1superscript𝐶𝐶1{<}C^{\prime}{<}C1 < italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT < italic_C. We denote matrices containing one-hot vectors selecting all rows and columns to be removed respectively as:

𝑬R=[𝒆1𝒆2𝒆R]TR×R𝑬C=[𝒆1𝒆2𝒆C]TC×Csubscript𝑬superscript𝑅superscriptmatrixsubscript𝒆1subscript𝒆2subscript𝒆superscript𝑅𝑇superscriptsuperscript𝑅𝑅subscript𝑬superscript𝐶superscriptmatrixsubscript𝒆1subscript𝒆2subscript𝒆superscript𝐶𝑇superscriptsuperscript𝐶𝐶\displaystyle\begin{split}{\bm{E}}_{R^{\prime}}&=\begin{bmatrix}{\bm{e}}_{1}&{% \bm{e}}_{2}&\ldots&{\bm{e}}_{R^{\prime}}\end{bmatrix}^{T}\in\mathbb{R}^{R^{% \prime}\times R}\end{split}\begin{split}{\bm{E}}_{C^{\prime}}&=\begin{bmatrix}% {\bm{e}}_{1}&{\bm{e}}_{2}&\ldots&{\bm{e}}_{C^{\prime}}\end{bmatrix}^{T}\in% \mathbb{R}^{C^{\prime}\times C}\end{split}start_ROW start_CELL bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_CELL start_CELL = [ start_ARG start_ROW start_CELL bold_italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL bold_italic_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL … end_CELL start_CELL bold_italic_e start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT × italic_R end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_CELL start_CELL = [ start_ARG start_ROW start_CELL bold_italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL bold_italic_e start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL … end_CELL start_CELL bold_italic_e start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT × italic_C end_POSTSUPERSCRIPT end_CELL end_ROW (24)

Then, the matrix 𝑬Ksubscript𝑬𝐾{\bm{E}}_{K}bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT containing one-hot row vectors selecting all elements to be removed can be written as:

Multiple rows: Multiple columns: 𝑬K=(𝑬R𝑰C)Q×RC , (with Q=RC)𝑬K=(𝑰R𝑬C)Q×RC , (with Q=RC)Multiple rows: Multiple columns: subscript𝑬𝐾tensor-productsubscript𝑬superscript𝑅subscript𝑰𝐶superscript𝑄𝑅𝐶 , (with 𝑄superscript𝑅𝐶)subscript𝑬𝐾tensor-productsubscript𝑰𝑅subscript𝑬superscript𝐶superscript𝑄𝑅𝐶 , (with 𝑄𝑅superscript𝐶)\displaystyle\begin{split}\text{Multiple rows: }&\\ \text{Multiple columns: }&\\ \end{split}\begin{split}{\bm{E}}_{K}=({\bm{E}}_{R^{\prime}}\otimes{\bm{I}}_{C}% )\in\mathbb{R}^{Q\times RC}\text{ , (with }Q=R^{\prime}C\text{)}\\ {\bm{E}}_{K}=({\bm{I}}_{R}\otimes{\bm{E}}_{C^{\prime}})\in\mathbb{R}^{Q\times RC% }\text{ , (with }Q=RC^{\prime}\text{)}\end{split}start_ROW start_CELL Multiple rows: end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL Multiple columns: end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = ( bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ⊗ bold_italic_I start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_Q × italic_R italic_C end_POSTSUPERSCRIPT , (with italic_Q = italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_C ) end_CELL end_ROW start_ROW start_CELL bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = ( bold_italic_I start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT ⊗ bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_Q × italic_R italic_C end_POSTSUPERSCRIPT , (with italic_Q = italic_R italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_CELL end_ROW (25)

To simultaneously remove rows and columns, we can stack the matrices with duplicate row vectors removed:

Multiple rows and columns:𝑬K[𝑬R𝑰C𝑰R𝑬C]Q×RC with duplicate rows removedMultiple rows and columns:subscript𝑬𝐾matrixtensor-productsubscript𝑬superscript𝑅subscript𝑰𝐶tensor-productsubscript𝑰𝑅subscript𝑬superscript𝐶superscript𝑄𝑅𝐶 with duplicate rows removed\displaystyle\begin{split}\text{Multiple rows and columns:}&\end{split}\begin{% split}{\bm{E}}_{K}\begin{bmatrix}{\bm{E}}_{R^{\prime}}\otimes{\bm{I}}_{C}\\ {\bm{I}}_{R}\otimes{\bm{E}}_{C^{\prime}}\end{bmatrix}\in\mathbb{R}^{Q\times RC% }\text{ with duplicate rows removed}\end{split}start_ROW start_CELL Multiple rows and columns: end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT [ start_ARG start_ROW start_CELL bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ⊗ bold_italic_I start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_I start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT ⊗ bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_Q × italic_R italic_C end_POSTSUPERSCRIPT with duplicate rows removed end_CELL end_ROW (26)

The removal of duplicate rows is required due to the few RCsuperscript𝑅superscript𝐶R^{\prime}C^{\prime}italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT overlapping elements between rows and columns, after which the total number of rows thus becomes Q=RC+CRRC𝑄superscript𝑅𝐶superscript𝐶𝑅superscript𝑅superscript𝐶Q=R^{\prime}C+C^{\prime}R-R^{\prime}C^{\prime}italic_Q = italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_C + italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_R - italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT. We used appropriately sized identity matrices 𝑰RR×Rsubscript𝑰𝑅superscript𝑅𝑅{\bm{I}}_{R}\in\mathbb{R}^{R\times R}bold_italic_I start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R × italic_R end_POSTSUPERSCRIPT and 𝑰CC×Csubscript𝑰𝐶superscript𝐶𝐶{\bm{I}}_{C}\in\mathbb{R}^{C\times C}bold_italic_I start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_C end_POSTSUPERSCRIPT. For brevity, we write the vector or matrix of pruned weights  𝜽:=𝑬K𝜽Qassign 𝜽subscript𝑬𝐾𝜽superscript𝑄\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{% \theta}}$\kern-1.00006pt}}}:={\bm{E}}_{K}{\bm{\theta}}\in\mathbb{R}^{Q}roman_θ := bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ ∈ blackboard_R start_POSTSUPERSCRIPT italic_Q end_POSTSUPERSCRIPT.

First, we derive the removal for Rsuperscript𝑅R^{\prime}italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT rows by defining removal matrix as 𝑬K=𝑬R𝑰subscript𝑬𝐾tensor-productsubscript𝑬superscript𝑅𝑰{\bm{E}}_{K}={\bm{E}}_{R^{\prime}}\otimes{\bm{I}}bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ⊗ bold_italic_I and define  𝑾:=𝑬R𝑾R×Cassign 𝑾subscript𝑬superscript𝑅𝑾superscriptsuperscript𝑅𝐶\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{W}}$% \kern-1.00006pt}}}:={\bm{E}}_{R^{\prime}}{\bm{W}}\in\mathbb{R}^{R^{\prime}% \times C}roman_W := bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT × italic_C end_POSTSUPERSCRIPT. The complete weight update for the removal of multiple rows becomes:

Δ𝜽Δ𝜽\displaystyle\Delta{\bm{\theta}}roman_Δ bold_italic_θ =𝑭1𝑬KT(𝑬K𝑭1𝑬KT)1𝑬K𝜽*absentsuperscript𝑭1superscriptsubscript𝑬𝐾𝑇superscriptsubscript𝑬𝐾superscript𝑭1superscriptsubscript𝑬𝐾𝑇1subscript𝑬𝐾superscript𝜽\displaystyle=-{\bm{F}}^{-1}{\bm{E}}_{K}^{T}\left({\bm{E}}_{K}{\bm{F}}^{-1}{% \bm{E}}_{K}^{T}\right)^{-1}{\bm{E}}_{K}{\bm{\theta}}^{*}= - bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT
=(𝑮𝑨)1(𝑬R𝑰)T((𝑬R𝑰)(𝑮𝑨)1(𝑬R𝑰)T)1(𝑬R𝑰)𝜽*absentsuperscripttensor-product𝑮𝑨1superscripttensor-productsubscript𝑬superscript𝑅𝑰𝑇superscripttensor-productsubscript𝑬superscript𝑅𝑰superscripttensor-product𝑮𝑨1superscripttensor-productsubscript𝑬superscript𝑅𝑰𝑇1tensor-productsubscript𝑬superscript𝑅𝑰superscript𝜽\displaystyle=-({\bm{G}}\otimes{\bm{A}})^{-1}({\bm{E}}_{R^{\prime}}\otimes{\bm% {I}})^{T}\left(({\bm{E}}_{R^{\prime}}\otimes{\bm{I}})({\bm{G}}\otimes{\bm{A}})% ^{-1}({\bm{E}}_{R^{\prime}}\otimes{\bm{I}})^{T}\right)^{-1}({\bm{E}}_{R^{% \prime}}\otimes{\bm{I}}){\bm{\theta}}^{*}= - ( bold_italic_G ⊗ bold_italic_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ⊗ bold_italic_I ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( ( bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ⊗ bold_italic_I ) ( bold_italic_G ⊗ bold_italic_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ⊗ bold_italic_I ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ⊗ bold_italic_I ) bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT
=(𝑮1𝑬RT𝑨1)(𝑬R𝑮1𝑬RT𝑨1)1 𝜽*absenttensor-productsuperscript𝑮1superscriptsubscript𝑬superscript𝑅𝑇superscript𝑨1superscripttensor-productsubscript𝑬superscript𝑅superscript𝑮1superscriptsubscript𝑬superscript𝑅𝑇superscript𝑨11 𝜽*\displaystyle=-({\bm{G}}^{-1}{\bm{E}}_{R^{\prime}}^{T}\otimes{\bm{A}}^{-1})% \left({\bm{E}}_{R^{\prime}}{\bm{G}}^{-1}{\bm{E}}_{R^{\prime}}^{T}\otimes{\bm{A% }}^{-1}\right)^{-1}\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1% .00006pt${\bm{\theta}}^{*}$\kern-1.00006pt}}}= - ( bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) ( bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT θ*
=(𝑮1𝑬RT𝑨1)((𝑬R𝑮1𝑬RT)1𝑨) 𝜽*absenttensor-productsuperscript𝑮1superscriptsubscript𝑬superscript𝑅𝑇superscript𝑨1tensor-productsuperscriptsubscript𝑬superscript𝑅superscript𝑮1superscriptsubscript𝑬superscript𝑅𝑇1𝑨 𝜽*\displaystyle=-({\bm{G}}^{-1}{\bm{E}}_{R^{\prime}}^{T}\otimes{\bm{A}}^{-1})% \left(({\bm{E}}_{R^{\prime}}{\bm{G}}^{-1}{\bm{E}}_{R^{\prime}}^{T})^{-1}% \otimes{\bm{A}}\right)\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{% \kern-1.00006pt${\bm{\theta}}^{*}$\kern-1.00006pt}}}= - ( bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) ( ( bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⊗ bold_italic_A ) θ*
Δ𝑾Δ𝑾\displaystyle\Delta{\bm{W}}roman_Δ bold_italic_W =𝑮1𝑬RT((𝑬R𝑮1𝑬RT)1 𝑾𝑨)𝑨1absentsuperscript𝑮1superscriptsubscript𝑬superscript𝑅𝑇superscriptsubscript𝑬superscript𝑅superscript𝑮1superscriptsubscript𝑬superscript𝑅𝑇1 𝑾𝑨superscript𝑨1\displaystyle=-{\bm{G}}^{-1}{\bm{E}}_{R^{\prime}}^{T}\left(({\bm{E}}_{R^{% \prime}}{\bm{G}}^{-1}{\bm{E}}_{R^{\prime}}^{T})^{-1}\hbox{\vbox{\hrule height=% 0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{W}}$\kern-1.00006pt}}}{\bm{A}}% \right){\bm{A}}^{-1}= - bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( ( bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_W bold_italic_A ) bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT
=𝑮1𝑬RT(𝑬R𝑮1𝑬RT)1 𝑾absentsuperscript𝑮1superscriptsubscript𝑬superscript𝑅𝑇superscriptsubscript𝑬superscript𝑅superscript𝑮1superscriptsubscript𝑬superscript𝑅𝑇1 𝑾\displaystyle=-{\bm{G}}^{-1}{\bm{E}}_{R^{\prime}}^{T}({\bm{E}}_{R^{\prime}}{% \bm{G}}^{-1}{\bm{E}}_{R^{\prime}}^{T})^{-1}\hbox{\vbox{\hrule height=0.5pt% \kern 2.15277pt\hbox{\kern-1.00006pt${\bm{W}}$\kern-1.00006pt}}}= - bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_W (27)

Similarly, we derive the removal of Csuperscript𝐶C^{\prime}italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT columns by defining removal matrix as 𝑬K=𝑰𝑬Csubscript𝑬𝐾tensor-product𝑰subscript𝑬superscript𝐶{\bm{E}}_{K}={\bm{I}}\otimes{\bm{E}}_{C^{\prime}}bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = bold_italic_I ⊗ bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT and define  𝑾:=𝑬C𝑾R×Cassign 𝑾subscript𝑬superscript𝐶𝑾superscript𝑅superscript𝐶\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{W}}$% \kern-1.00006pt}}}:={\bm{E}}_{C^{\prime}}{\bm{W}}\in\mathbb{R}^{R\times C^{% \prime}}roman_W := bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_R × italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT. The complete weight update for multiple column removal becomes:

Δ𝜽Δ𝜽\displaystyle\Delta{\bm{\theta}}roman_Δ bold_italic_θ =𝑭1𝑬KT(𝑬K𝑭1𝑬KT)1𝑬K𝜽*absentsuperscript𝑭1superscriptsubscript𝑬𝐾𝑇superscriptsubscript𝑬𝐾superscript𝑭1superscriptsubscript𝑬𝐾𝑇1subscript𝑬𝐾superscript𝜽\displaystyle=-{\bm{F}}^{-1}{\bm{E}}_{K}^{T}\left({\bm{E}}_{K}{\bm{F}}^{-1}{% \bm{E}}_{K}^{T}\right)^{-1}{\bm{E}}_{K}{\bm{\theta}}^{*}= - bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT
=(𝑮𝑨)1(𝑰𝑬C))T((𝑰𝑬C)(𝑮𝑨)1(𝑰𝑬C)T)1(𝑰𝑬C)𝜽*\displaystyle=-({\bm{G}}\otimes{\bm{A}})^{-1}({\bm{I}}\otimes{\bm{E}}_{C^{% \prime}}))^{T}\left(({\bm{I}}\otimes{\bm{E}}_{C^{\prime}})({\bm{G}}\otimes{\bm% {A}})^{-1}({\bm{I}}\otimes{\bm{E}}_{C^{\prime}})^{T}\right)^{-1}({\bm{I}}% \otimes{\bm{E}}_{C^{\prime}}){\bm{\theta}}^{*}= - ( bold_italic_G ⊗ bold_italic_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_I ⊗ bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( ( bold_italic_I ⊗ bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ( bold_italic_G ⊗ bold_italic_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_I ⊗ bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_I ⊗ bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT
=(𝑮𝑨)1(𝑰𝑬C))T((𝑰𝑬C)(𝑮𝑨)1(𝑰𝑬C)T)1(𝑰𝑬C)𝜽*\displaystyle=-({\bm{G}}\otimes{\bm{A}})^{-1}({\bm{I}}\otimes{\bm{E}}_{C^{% \prime}}))^{T}\left(({\bm{I}}\otimes{\bm{E}}_{C^{\prime}})({\bm{G}}\otimes{\bm% {A}})^{-1}({\bm{I}}\otimes{\bm{E}}_{C^{\prime}})^{T}\right)^{-1}({\bm{I}}% \otimes{\bm{E}}_{C^{\prime}}){\bm{\theta}}^{*}= - ( bold_italic_G ⊗ bold_italic_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_I ⊗ bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( ( bold_italic_I ⊗ bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ( bold_italic_G ⊗ bold_italic_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_I ⊗ bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_I ⊗ bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) bold_italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT
=(𝑮1𝑨1𝑬CT)(𝑮𝑬C𝑨1𝑬CT)1 𝜽absenttensor-productsuperscript𝑮1superscript𝑨1superscriptsubscript𝑬superscript𝐶𝑇superscripttensor-product𝑮subscript𝑬superscript𝐶superscript𝑨1superscriptsubscript𝑬superscript𝐶𝑇1 𝜽\displaystyle=-({\bm{G}}^{-1}\otimes{\bm{A}}^{-1}{\bm{E}}_{C^{\prime}}^{T})% \left({\bm{G}}\otimes{\bm{E}}_{C^{\prime}}{\bm{A}}^{-1}{\bm{E}}_{C^{\prime}}^{% T}\right)^{-1}\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.0000% 6pt${\bm{\theta}}$\kern-1.00006pt}}}= - ( bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) ( bold_italic_G ⊗ bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_θ
Δ𝑾Δ𝑾\displaystyle\Delta{\bm{W}}roman_Δ bold_italic_W =𝑮1𝑮 𝑾(𝑬C𝑨1𝑬CT)1(𝑨1𝑬CT)absentsuperscript𝑮1𝑮 𝑾superscriptsubscript𝑬superscript𝐶superscript𝑨1superscriptsubscript𝑬superscript𝐶𝑇1superscript𝑨1superscriptsubscript𝑬superscript𝐶𝑇\displaystyle=-{\bm{G}}^{-1}{\bm{G}}\hbox{\vbox{\hrule height=0.5pt\kern 2.152% 77pt\hbox{\kern-1.00006pt${\bm{W}}$\kern-1.00006pt}}}({\bm{E}}_{C^{\prime}}{% \bm{A}}^{-1}{\bm{E}}_{C^{\prime}}^{T})^{-1}({\bm{A}}^{-1}{\bm{E}}_{C^{\prime}}% ^{T})= - bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_G roman_W ( bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT )
= 𝑾(𝑬C𝑨1𝑬CT)1(𝑨1𝑬CT)absent 𝑾superscriptsubscript𝑬superscript𝐶superscript𝑨1superscriptsubscript𝑬superscript𝐶𝑇1superscript𝑨1superscriptsubscript𝑬superscript𝐶𝑇\displaystyle=-\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.000% 06pt${\bm{W}}$\kern-1.00006pt}}}({\bm{E}}_{C^{\prime}}{\bm{A}}^{-1}{\bm{E}}_{C% ^{\prime}}^{T})^{-1}({\bm{A}}^{-1}{\bm{E}}_{C^{\prime}}^{T})= - roman_W ( bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) (28)

Appendix B Experimental details.

B.1 Models

OPT models

From the OPT model family ((Zhang et al., 2022)), we consider models with the following number of parameters: 125 million (125m), 1.3 billion (1.3b), 2.7 billion (2.7b), 6.7 billion (6.7b) models. We omit 350 million model due to different layer norm. We obtain the standard pre-trained checkpoints using Huggingface (Wolf et al., 2019) and use this as a baseline and initialisation for compression.

Llama-v2 models

From the Llama-v2 model family ((Touvron et al., 2023)), we consider a model with 7 billion (7b) parameters and a model with 13 billion (13b) parameters. We obtain the standard pre-trained checkpoints using Huggingface (Wolf et al., 2019) and use this as a baseline and initialisation for compression.

B.2 Datasets

English / Wikitext-2

The majority of the results are obtained on the Wikitext-2 dataset containing parsed subsets of the English Wikipedia (Merity et al., 2016; Wikipedia, 2004), using the default training and test sets. For fitting, we use 128 batches of 2048 characters and for testing we use the standard test set containing 4358 characters.

French / Wikipedia

For French data experiments, we use a subset of French wikipedia (Wikipedia, 2004). For fitting, we use 128 batches of 2048 characters and for testing we use a randomly selected test set containing 1067888 characters.

German / Wikipedia

For the Italian data experiments, we use a subset of the German wikipedia (Wikipedia, 2004). For fitting, we use 128 batches of 2048 characters and for testing we use a randomly selected test set containing 1112372 characters.

Italian / Wikipedia

For the Italian data experiments, we use a subset of the Italian wikipedia (Wikipedia, 2004). For fitting, we use 128 batches of 2048 characters and for testing we use a randomly selected test set containing 633177 characters.

B.3 Mask equivalence

When comparing the equivalence of obtained pruning masks between two models 𝜽Asubscript𝜽𝐴{\bm{\theta}}_{A}bold_italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT and 𝜽Bsubscript𝜽𝐵{\bm{\theta}}_{B}bold_italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT obtained by two compression methods A𝐴Aitalic_A and B𝐵Bitalic_B. We always consider the case of 50% pruning, and define the mask equivalence as the fraction of same weights that are set two zero in both models:

mask equivalence=i=1P𝟏([𝜽A]i=0 and [𝜽B]i=0)P.mask equivalencesuperscriptsubscript𝑖1𝑃1subscriptdelimited-[]subscript𝜽𝐴𝑖0 and subscriptdelimited-[]subscript𝜽𝐵𝑖0𝑃\displaystyle\text{mask equivalence}=\sum_{i=1}^{P}\frac{\mathbf{1}([{\bm{% \theta}}_{A}]_{i}=0\text{ and }[{\bm{\theta}}_{B}]_{i}=0)}{P}.mask equivalence = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_P end_POSTSUPERSCRIPT divide start_ARG bold_1 ( [ bold_italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 and [ bold_italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 ) end_ARG start_ARG italic_P end_ARG . (29)

where 𝟏1\mathbf{1}bold_1 denotes an indicator function that returns 1 if both weights [𝜽A]isubscriptdelimited-[]subscript𝜽𝐴𝑖[{\bm{\theta}}_{A}]_{i}[ bold_italic_θ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and [𝜽B]isubscriptdelimited-[]subscript𝜽𝐵𝑖[{\bm{\theta}}_{B}]_{i}[ bold_italic_θ start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are zero, and returns 0 otherwise.

B.4 SparseGPT and evaluation of baselines

For the SparseGPT baseline, we used the official code SparseGPT code repository (Frantar & Alistarh, 2023) which allows for training and evaluation on wikitext-2. The obtained results may differ from those reported in the original paper as the C4 dataset was used there.

In this work, models were trained with the same 128 batches of the wikitext-2 training set as available in the SparseGPT codebase and are evaluated on the wikitext-2 test set using the exact same evaluation procedure.

Appendix C Technical details

C.1 Pseudocodes

Algorithm 2 LLM Surgeon (structured)
target size α𝛼\alphaitalic_α
initial weights 𝜽0superscript𝜽0{\bm{\theta}}^{0}bold_italic_θ start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT
For shot t𝑡titalic_t in [1, 2, …, T𝑇Titalic_T]
    Compute: approximate curvature 𝑮1,𝑨1subscript𝑮1subscript𝑨1{\bm{G}}_{1},{\bm{A}}_{1}bold_italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT from data (optionally also 𝑮2,𝑨2){\bm{G}}_{2},{\bm{A}}_{2})bold_italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) \triangleright section 3.1
    Compute: costs per row/column r,csubscript𝑟subscript𝑐\mathcal{L}_{r},\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT , caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT from 𝑮1,𝑨1,(𝑮2,𝑨2)subscript𝑮1subscript𝑨1subscript𝑮2subscript𝑨2{\bm{G}}_{1},{\bm{A}}_{1},({\bm{G}}_{2},{\bm{A}}_{2})bold_italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ( bold_italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) \triangleright section 3.2
    Compute: threshold τ𝜏\tauitalic_τ using rsubscript𝑟\mathcal{L}_{r}caligraphic_L start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT and csubscript𝑐\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT given target size α𝛼\alphaitalic_α \triangleright section 3.3
    Select: rows and columns to remove 𝑬Rsubscript𝑬𝑅{\bm{E}}_{R}bold_italic_E start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT, 𝑬Csubscript𝑬𝐶{\bm{E}}_{C}bold_italic_E start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT based on τ𝜏\tauitalic_τ \triangleright section 3.3
    Compute: weight update Δ𝜽t1Δsuperscript𝜽𝑡1\Delta{\bm{\theta}}^{t-1}roman_Δ bold_italic_θ start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT based on 𝑬R,𝑬Csubscript𝑬𝑅subscript𝑬𝐶{\bm{E}}_{R},{\bm{E}}_{C}bold_italic_E start_POSTSUBSCRIPT italic_R end_POSTSUBSCRIPT , bold_italic_E start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT and 𝑮1,𝑨1,(𝑮2,𝑨2)subscript𝑮1subscript𝑨1subscript𝑮2subscript𝑨2{\bm{G}}_{1},{\bm{A}}_{1},({\bm{G}}_{2},{\bm{A}}_{2})bold_italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ( bold_italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) \triangleright section 3.4
    Update: remaining weights 𝜽t𝜽t1+Δ𝜽t1superscript𝜽𝑡superscript𝜽𝑡1Δsuperscript𝜽𝑡1{\bm{\theta}}^{t}\leftarrow{\bm{\theta}}^{t-1}+\Delta{\bm{\theta}}^{t-1}bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ← bold_italic_θ start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT + roman_Δ bold_italic_θ start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT \triangleright section 3.5
    Optionally: 𝜽tlow-rank update(𝜽t)superscript𝜽𝑡low-rank updatesuperscript𝜽𝑡{\bm{\theta}}^{t}\leftarrow\text{low-rank update}({\bm{\theta}}^{t})bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ← low-rank update ( bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT )
Output: compressed weights 𝜽^=𝜽T^𝜽superscript𝜽𝑇\hat{{\bm{\theta}}}={\bm{\theta}}^{T}over^ start_ARG bold_italic_θ end_ARG = bold_italic_θ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT
Algorithm 3 LLM Surgeon (semi-structured / unstructured)
target size α𝛼\alphaitalic_α
initial weights 𝜽0superscript𝜽0{\bm{\theta}}^{0}bold_italic_θ start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT
For shot t𝑡titalic_t in [1, 2, …, T𝑇Titalic_T]
    Compute: approximate curvature 𝑮1,𝑨1subscript𝑮1subscript𝑨1{\bm{G}}_{1},{\bm{A}}_{1}bold_italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT from data (optionally also 𝑮2,𝑨2){\bm{G}}_{2},{\bm{A}}_{2})bold_italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) \triangleright section 3.1
    Compute: costs per element ksubscript𝑘\mathcal{L}_{k}caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT from 𝑮1,𝑨1,(𝑮2,𝑨2)subscript𝑮1subscript𝑨1subscript𝑮2subscript𝑨2{\bm{G}}_{1},{\bm{A}}_{1},({\bm{G}}_{2},{\bm{A}}_{2})bold_italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ( bold_italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) \triangleright section 3.2
    Compute: threshold τ𝜏\tauitalic_τ from ksubscript𝑘\mathcal{L}_{k}caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and target size αtsubscript𝛼𝑡\alpha_{t}italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (unstructured/semistructured) \triangleright section 3.3
    Select: elements to remove 𝑬Ksubscript𝑬𝐾{\bm{E}}_{K}bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT based on τ𝜏\tauitalic_τ (unstructured/semistructured) \triangleright section 3.3
    Compute: weight update Δ𝜽t1Δsuperscript𝜽𝑡1\Delta{\bm{\theta}}^{t-1}roman_Δ bold_italic_θ start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT based on 𝑬Ksubscript𝑬𝐾{\bm{E}}_{K}bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT and 𝑮1,𝑨1,(𝑮2,𝑨2)subscript𝑮1subscript𝑨1subscript𝑮2subscript𝑨2{\bm{G}}_{1},{\bm{A}}_{1},({\bm{G}}_{2},{\bm{A}}_{2})bold_italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ( bold_italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) \triangleright section 3.4
    Update: remaining weights 𝜽t𝜽t1+Δ𝜽t1superscript𝜽𝑡superscript𝜽𝑡1Δsuperscript𝜽𝑡1{\bm{\theta}}^{t}\leftarrow{\bm{\theta}}^{t-1}+\Delta{\bm{\theta}}^{t-1}bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ← bold_italic_θ start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT + roman_Δ bold_italic_θ start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT \triangleright section 3.5
    Optionally: 𝜽tlow-rank update(𝜽t)superscript𝜽𝑡low-rank updatesuperscript𝜽𝑡{\bm{\theta}}^{t}\leftarrow\text{low-rank update}({\bm{\theta}}^{t})bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ← low-rank update ( bold_italic_θ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT )
Output: compressed weights 𝜽^=𝜽T^𝜽superscript𝜽𝑇\hat{{\bm{\theta}}}={\bm{\theta}}^{T}over^ start_ARG bold_italic_θ end_ARG = bold_italic_θ start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT

C.2 Dampening

In practice, we dampen the 𝑮𝑮{\bm{G}}bold_italic_G and 𝑨𝑨{\bm{A}}bold_italic_A matrices by adding a diagonal term 𝑮+λG𝑰𝑮subscript𝜆𝐺𝑰{\bm{G}}+\lambda_{G}{\bm{I}}bold_italic_G + italic_λ start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT bold_italic_I and 𝑨+λA𝑰𝑨subscript𝜆𝐴𝑰{\bm{A}}+\lambda_{A}{\bm{I}}bold_italic_A + italic_λ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT bold_italic_I. In our experiments, we found that values in the range [0.01, 0.1] multiplied by mean diagonal terms generally works well. We follow (Frantar & Alistarh, 2023) and always use λA=0.01diag(𝑨)subscript𝜆𝐴0.01diag𝑨\lambda_{A}{=}0.01\text{diag}({\bm{A}})italic_λ start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT = 0.01 diag ( bold_italic_A ) to be consistent with prior work and allow for a fair comparison with baselines. Further, we use λG=0.1diag(𝑮)subscript𝜆𝐺0.1diag𝑮\lambda_{G}{=}0.1\text{diag}({\bm{G}})italic_λ start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT = 0.1 diag ( bold_italic_G ) for structured experiments and λG=0.01diag(𝑮)subscript𝜆𝐺0.01diag𝑮\lambda_{G}{=}0.01\text{diag}({\bm{G}})italic_λ start_POSTSUBSCRIPT italic_G end_POSTSUBSCRIPT = 0.01 diag ( bold_italic_G ) in semi-structured and unstructured experiments.

Appendix D Downstream task performance

We also evaluate our method on downstream tasks as perplexity metrics do not necessarily correlate with downstream performance. Further, we also repeat this experiment using the C4 dataset as reference data for compression, as this is used in prior work (Frantar & Alistarh, 2023) and as this can be regarded a more general reference dataset. In tables 5 and 6 we report 0-shot test performance of structured pruning for LLM surgeon and K-OBD baseline.

Table 5: Downstream task performance using Wikitext-2 for pruning.
Structured pruning
(with wikitext-2) Model size wikitext word ppl boolq piqa hallaswag winogrande arc_easy arc_challenge openbookq copa lambada_openai wsc273 AVERAGE wikitext2
Dense baseline 100% 9.24 77.74 79.11 75.99 69.14 74.58 46.25 44.20 86.00 73.92 85.71 71.26
LLM Surgeon (ours) 90% 9.63 76.21 78.56 75.39 67.64 74.12 46.50 43.60 85.00 72.64 84.98 70.46
80% 12.16 72.97 77.09 71.30 66.30 71.36 41.89 41.80 87.00 56.43 80.22 66.66
70% 16.91 61.25 73.56 60.72 61.09 63.09 36.69 38.80 81.00 28.33 76.56 58.11
60% 25.15 44.98 69.26 48.04 54.38 52.31 30.29 36.80 78.00 11.72 68.50 49.43
50% 43.68 39.60 64.36 40.29 52.57 44.91 26.28 30.80 74.00 6.52 61.54 44.09
K-OBD 90% 9.89 76.67 78.02 74.80 68.11 75.17 46.33 44.60 86.00 72.71 82.78 70.52
80% 17.62 74.34 75.24 67.85 64.64 63.80 40.27 41.60 83.00 30.23 82.42 62.34
70% 32.72 65.29 71.82 53.07 56.83 51.05 33.11 37.80 79.00 12.21 70.70 53.09
60% 68.63 60.80 65.67 43.99 53.20 41.79 28.50 34.00 75.00 7.04 60.44 47.04
50% 136.33 61.56 60.66 36.84 53.04 36.11 26.71 33.00 72.00 4.70 61.17 44.58
Table 6: Downstream task performance using C4 for pruning.
Structured pruning
(with C4) Model size wikitext word ppl boolq piqa hallaswag winogrande arc_easy arc_challenge openbookq copa lambada_openai wsc273 AVERAGE wikitext2
Dense baseline 100% 9.24 77.74 79.11 75.99 69.14 74.58 46.25 44.20 86.00 73.92 85.71 71.26
LLM Surgeon (ours) 90% 9.90 77.03 78.45 74.95 68.27 73.19 45.99 44.60 84.00 72.81 82.78 70.21
80% 14.42 75.60 76.82 69.71 63.85 70.29 41.30 42.80 87.00 45.53 82.42 65.53
70% 25.16 66.39 72.85 58.11 56.83 62.16 34.47 38.40 80.00 22.69 69.96 56.19
60% 45.35 62.48 68.93 48.10 55.64 51.56 27.99 35.20 70.00 12.56 61.54 49.40
50% 77.30 62.60 65.02 41.70 54.22 42.55 24.23 31.20 71.00 7.26 60.44 46.02
K-OBD 90% 10.59 75.47 78.18 73.61 66.46 72.52 44.37 43.60 87.00 71.22 82.42 69.48
80% 20.12 73.36 75.14 66.11 62.43 62.84 38.23 41.00 86.00 21.50 78.39 60.50
70% 56.92 63.30 68.44 52.31 55.64 46.72 31.31 34.60 77.00 5.69 68.86 50.39
60% 112.85 62.23 64.47 46.36 52.17 40.53 29.52 32.40 72.00 2.91 63.00 46.56
50% 272.16 62.42 61.70 38.47 50.43 33.29 26.96 31.80 65.00 0.91 59.34 43.03

We find that our method not only performs well in terms of test perplexity but also correlates well with downstream performance, outperforming the baselines on these downstream tasks.

Appendix E Additional experiments on Llama-v2 13B.

To assess performance on larger 13B parameter models, we also report structured compression on the Llama-v2 13B model and evaluate downstream task performance. Test perplexities (lower is better) can be found in table 7 below:

Table 7: Pruning Llama-v2 13B model.
Baseline Pruned model sizes
Dense 100% 90% 80% 70% 60% 50%
K-OBD 4.547 4.908 6.294 10.08 13.06 16.06
LLM Surgeon 4.547 4.692 5.286 6.207 7.245 9.428

as well as evaluated results on downstream benchmarks (higher is better) in table 8 below.

Table 8: Downstream task performance after pruning large Llama-v2 13B model.
Llama-v2 13B Model size wikitext word ppl boolq piqa hallaswag winogrande arc_easy arc_challenge openbookq copa lambada_openai wsc273 AVERAGE wikitext2
Dense baseline 100% 8.23 80.52% 80.52% 79.38% 72.14% 77.53% 49.23% 45.20% 90.00% 76.77% 89.38% 74.07%
LLM Surgeon (ours) 90% 8.57 81.07% 79.87% 79.24% 72.38% 76.30% 49.91% 47.20% 92.00% 75.47% 89.38% 74.28%
80% 10.08 80.86% 79.00% 77.09% 70.56% 75.93% 46.76% 46.80% 90.00% 67.79% 86.45% 72.12%
70% 12.74 74.50% 76.50% 71.52% 68.67% 69.74% 40.27% 45.00% 91.00% 54.40% 83.52% 67.51%
60% 16.00 64.62% 73.01% 65.04% 65.75% 63.80% 37.12% 39.60% 90.00% 44.50% 81.32% 62.48%
50% 23.75 65.66% 68.77% 56.19% 63.22% 56.19% 31.83% 36.60% 85.00% 35.16% 77.29% 57.59%
K-OBD 90% 8.79 81.31% 79.76% 79.12% 72.22% 76.94% 47.95% 47.80% 91.00% 75.26% 88.64% 74.00%
80% 11.79 80.80% 79.16% 76.80% 70.56% 73.74% 46.93% 48.60% 88.00% 58.99% 87.55% 71.11%
70% 20.00 66.76% 74.43% 64.18% 64.96% 56.23% 36.01% 39.00% 88.00% 38.54% 79.49% 60.76%
60% 27.74 55.66% 70.24% 55.52% 60.46% 49.62% 32.68% 35.80% 80.00% 30.06% 73.63% 54.37%
50% 37.38 59.79% 66.54% 48.39% 57.46% 46.59% 30.72% 34.00% 77.00% 24.61% 69.96% 51.50%

We find that LLM Surgeon also outperforms baselines on existing Llama-v2 13B models. We stress that these results are obtained on structured pruning of rows and columns, which are regarded the hardest and most constrained pruning structure. Yet, we can compress Llama 13B by 20% with less than 2% drop in downstream task performance. It also significantly outperforms the baseline for all compression rates, both in terms of test perplexity and downstream task performance.

Appendix F Ablations

F.1 Shots

Table 9: Ablation of shot counts T𝑇Titalic_T for structured LLM Surgeon compressing OPT-1.3b model.
Target size Shots T𝑇Titalic_T wikitext-2 PPL Shots T𝑇Titalic_T wikitext-2 PPL Shots T𝑇Titalic_T wikitext-2 PPL
90% 6 14.70 8 14.70 10 14.72
80% 12 15.14 16 15.12 20 15.08
70% 18 16.21 24 16.24 30 16.23
60% 24 18.53 32 18.45 40 18.49
50% 30 23.32 40 22.95 50 22.68

F.2 Task-specific compression

Table 10: Cross-task performance and mask equivalences of 50% compressed OPT-125m model using structured LLM Surgeon on language subsets.
evaluation dataset mask equivalence (%)
target EN FR DE IT EN FR DE IT
Pretrained 27.66 22.54 24.32 27.66
EN 47.46 172.9 181.1 169.1 1.00 0.74 0.70 0.72
FR 113.4 28.44 35.02 34.90 0.74 1.00 0.87 0.90
DE 142.1 35.15 27.49 38.49 0.70 0.87 1.00 0.87
IT 123.7 31.85 33.78 30.58 0.72 0.90 0.87 1.00

LLM Surgeon uses data to find a compressed model that has the least negative impact on final test performance. In this section, we explore the extent to which the method can use data to compress specifically to the task at hand. We do so by comparing test performance and equivalences between resulting pruning masks for different language modeling languages: English (EN/wikitext-2), French (FR) and Italian (IT) and the German (DE). We consider 50% unstructured compression using LLM Surgeon with correlated weight updates. For each compressed model, we compare performance on all languages and compare the equivalences between resulting pruning masks (details in section B.3), and report results in table 10. Like other methods that use data for compression (Hassibi & Stork, 1992; Frantar & Alistarh, 2023; Wang et al., 2019), we expect to see some correlation between the data used for training and data with good test performance, which is reflected in both test performance and masks. It is important to note that the final performance after compression will depend on the quality of the used dataset for compression. Further, the experiment demonstrates that the method can be used for task-specific compression tailored towards the data used for compression and generalises to high test performance on the associated test data.

Appendix G On fair comparison

All results in this work (including the SparseGPT) were trained on Wikitext-2 for fair comparison. To do so, we used the same dataloader and evaluation script as the official SparseGPT repo and reran all SparseGPT results to be trained on Wikitext-2. In some cases, this resulted in better scores for the SparseGPT baseline compared to the C4-trained results reported in the original SparseGPT paper. Yet, we find that our method using improved curvature estimates still outperformed the baselines in terms of final test performance.

Appendix H Computational performance

We report computational cost in terms of pruning time in table 11 and GPU memory in table 12.

Table 11: Time performance.
Test performance
Runtime Network Time PPL 90% PPL 80% PPL 70% PPL 60% PPL 50%
Unstructured baseline (SparseGPT) Llama-v2 7B <<<5m 5.49 5.58 5.71 5.94 6.51
Unstructured LLM Surgeon (ours) Llama-v2 7B 2d8h16m 5.13 5.20 5.36 5.66 6.08
Structured baseline (K-OBD) Llama-v2 7B 16h58m 5.48 9.14 15.43 28.03 46.64
Structured LLM Surgeon (ours) Llama-v2 7B 17h08m 5.25 6.18 7.83 10.39 15.38
Structured baseline (K-OBD) Llama-v2 13B 1d6h5m 4.908 6.294 10.08 13.06 16.06
Structured LLM Surgeon (ours) Llama-v2 13B 1d9h26m 4.692 5.286 6.207 7.245 9.428

Our method is most efficient for structured pruning, but it must be noted that engineering efforts may further improve speed for unstructured pruning. The focus of the paper is structured pruning, on which we achieve state-of-the-art compression rates. Importantly, compression of LLMs only needs to happen once after which a pruned model can be deployed infinitely many times without further cost. This motivates our method which takes longer to run but reaches better final test performance.

Table 12: Memory performance.
Network SparseGPT (baseline) Unstructured LLM-Surgeon (ours)
Llama-7B <<<5m / 1 GPU (32GB) 2d8h16m / 4xH100 80 GB
K-OBD (baseline) Structured LLM-Surgeon (ours)
Llama-7B 16h58m / 4xH100 80 GB 17h08m / 4xH100 80 GB
Llama-13B 1d6h5m / 8xH100 80 GB 1d9h26m / 8xH100 80 GB

We argue that differences in the performance and the runtime of pruning methods can largely be attributed to underlying assumptions on correlations between weights. Notably, algorithms that consider few correlations, sometimes to the extent of completely disregarding all gradient information, can result in very fast pruning algorithms for unstructured and semi-structured pruning but are often not flexible enough to perform structured pruning of rows and columns. Examples of such lightweight algorithms for LLMs are (Sun et al., 2023) and SparseGPT (Frantar & Alistarh, 2023), as can also be observed from table 11. Our approach makes less strong assumptions on the curvature of the loss and as a result outperforms all baselines on all unstructured, semi-structured and structured pruning. Further, the improved curvature is also eligible for dynamic allocation of weight removal and improved correlated weight updates. In practice, we always recommend using our method for structured pruning. For unstructured and semi-structured pruning, we note an important trade-off between the desired final test accuracy and the available computational budget. Here, our proposed method can achieve the highest final model performance but requires more computational resources and takes longer to run. It should be noted that pruning only needs to happen once after which a model can be deployed infinitely many times this time, which dependent on the available computational resources can also legitimise spending additional pruning time even if this is much higher compared to other algorithms in relative terms. In absolute terms, the use of multiple large GPUs is common practice in the field of large language models and many more GPUs are typically used to train and deploy large language models. Moreover, the curvature approximation is naively amenable to data parallelism in case further speed-ups or larger models are required. We hope this provides context and emphasises the trade-off between performance and compute in practice.

Appendix I Extending curvature estimates

Instead of using a single Kronecker product, we might consider improving the approximation through a sum of multiple Kronecker factors:

𝑭𝑭~=𝑮1𝑨1+𝑮2𝑨2𝑭~𝑭tensor-productsubscript𝑮1subscript𝑨1tensor-productsubscript𝑮2subscript𝑨2\displaystyle{\bm{F}}\approx\widetilde{{\bm{F}}}={\bm{G}}_{1}\otimes{\bm{A}}_{% 1}+{\bm{G}}_{2}\otimes{\bm{A}}_{2}bold_italic_F ≈ over~ start_ARG bold_italic_F end_ARG = bold_italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊗ bold_italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + bold_italic_G start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊗ bold_italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT (30)

This last appendix deals with the question how one may computationally find such approximations and how to utilise them in the neural network pruning framework.

I.1 Nearest Kronecker product or sum of Kronecker products

Instead of assuming independence of activations and derivatives as in section 3.1, following the classic KFAC of (Martens & Grosse, 2015), we might want to find the nearest Kronecker product approximation 𝑭𝑮~𝑨~𝑭tensor-product~𝑮~𝑨{\bm{F}}\approx\widetilde{{\bm{G}}}\otimes\widetilde{{\bm{A}}}bold_italic_F ≈ over~ start_ARG bold_italic_G end_ARG ⊗ over~ start_ARG bold_italic_A end_ARG that is closest to the Fisher in terms of the Frobenius norm:

𝑮~l,𝑨~lsubscript~𝑮𝑙subscript~𝑨𝑙\displaystyle\widetilde{{\bm{G}}}_{l},\widetilde{{\bm{A}}}_{l}over~ start_ARG bold_italic_G end_ARG start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , over~ start_ARG bold_italic_A end_ARG start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT =argmin𝑮l,𝑨l𝑭l𝑮l𝑨lFabsentsubscriptargminsubscript𝑮𝑙subscript𝑨𝑙subscriptnormsubscript𝑭𝑙tensor-productsubscript𝑮𝑙subscript𝑨𝑙𝐹\displaystyle=\operatorname*{arg\,min}_{{\bm{G}}_{l},{\bm{A}}_{l}}||{\bm{F}}_{% l}-{\bm{G}}_{l}\otimes{\bm{A}}_{l}||_{F}= start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_G start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT | | bold_italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT - bold_italic_G start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ⊗ bold_italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT (31)

Finding the nearest sum of Kronecker factors can be rephrased as a classic eigenvalue problem of finding the nearest rank-1 matrix. Golub & Van Loan (2013).

𝑭𝑮~𝑨~F(𝑭)vec(𝑮~)vec(𝑨~)TFsubscriptnorm𝑭tensor-product~𝑮~𝑨𝐹subscriptnorm𝑭vec~𝑮vecsuperscript~𝑨𝑇𝐹\displaystyle||{\bm{F}}-\tilde{{\bm{G}}}\otimes\tilde{{\bm{A}}}||_{F}\hskip 10% .00002pt\equiv\hskip 10.00002pt||\mathcal{R}({\bm{F}})-\text{vec}(\widetilde{{% \bm{G}}})\text{vec}(\widetilde{{\bm{A}}})^{T}||_{F}| | bold_italic_F - over~ start_ARG bold_italic_G end_ARG ⊗ over~ start_ARG bold_italic_A end_ARG | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ≡ | | caligraphic_R ( bold_italic_F ) - vec ( over~ start_ARG bold_italic_G end_ARG ) vec ( over~ start_ARG bold_italic_A end_ARG ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT (32)
Power method and deflation

After considering the reshaping, we can use power iterations to solve for and find the nearest Kronecker factors 𝑮1,𝑨1=solve(𝑭)subscript𝑮1subscript𝑨1solve𝑭{\bm{G}}_{1},{\bm{A}}_{1}=\text{solve}({\bm{F}})bold_italic_G start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = solve ( bold_italic_F ).

Find with power iterations:G~1,A~1=solve(𝑭)=argmin𝑮,𝑨𝑭𝑮𝑨FDeflation:G~r,A~r=solve(𝑭r=1r1(𝑮~r𝑨~r))Find with power iterations:subscript~𝐺1subscript~𝐴1solve𝑭subscriptargmin𝑮𝑨subscriptnorm𝑭tensor-product𝑮𝑨𝐹Deflation:subscript~𝐺𝑟subscript~𝐴𝑟solve𝑭superscriptsubscriptsuperscript𝑟1𝑟1tensor-productsubscript~𝑮superscript𝑟subscript~𝑨superscript𝑟\displaystyle\begin{split}&\text{Find with power iterations:}\\ \widetilde{G}_{1},\widetilde{A}_{1}&=\text{solve}({\bm{F}})=\operatorname*{arg% \,min}_{{\bm{G}},{\bm{A}}}||{\bm{F}}-{\bm{G}}\otimes{\bm{A}}||_{F}\end{split}% \hskip 10.00002pt\begin{split}&\text{Deflation:}\\ \widetilde{G}_{r},\widetilde{A}_{r}&=\text{solve}({\bm{F}}-\sum\nolimits_{r^{% \prime}=1}^{r-1}(\widetilde{{\bm{G}}}_{r^{\prime}}\otimes\widetilde{{\bm{A}}}_% {r^{\prime}}))\end{split}start_ROW start_CELL end_CELL start_CELL Find with power iterations: end_CELL end_ROW start_ROW start_CELL over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over~ start_ARG italic_A end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL = solve ( bold_italic_F ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_G , bold_italic_A end_POSTSUBSCRIPT | | bold_italic_F - bold_italic_G ⊗ bold_italic_A | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL Deflation: end_CELL end_ROW start_ROW start_CELL over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT , over~ start_ARG italic_A end_ARG start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT end_CELL start_CELL = solve ( bold_italic_F - ∑ start_POSTSUBSCRIPT italic_r start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r - 1 end_POSTSUPERSCRIPT ( over~ start_ARG bold_italic_G end_ARG start_POSTSUBSCRIPT italic_r start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ⊗ over~ start_ARG bold_italic_A end_ARG start_POSTSUBSCRIPT italic_r start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ) end_CELL end_ROW

A more extensive description of the power method solve()solve\text{solve}(\cdot)solve ( ⋅ ) can be found in algorithm 4. At the start of the algorithm, we initialise power iterations as vector with one’s 𝟏=[111]1matrix111\mathbf{1}=\begin{bmatrix}1&1&\ldots&1\end{bmatrix}bold_1 = [ start_ARG start_ROW start_CELL 1 end_CELL start_CELL 1 end_CELL start_CELL … end_CELL start_CELL 1 end_CELL end_ROW end_ARG ]. After each shot we can initialise the vector as the final estimate found during the previous shot.

Algorithm 4 Kronecker power method. Finds 𝑮~,𝑨~~𝑮~𝑨\widetilde{{\bm{G}}},\widetilde{{\bm{A}}}over~ start_ARG bold_italic_G end_ARG , over~ start_ARG bold_italic_A end_ARG nearest Kronecker product 𝑭𝑮~𝑨~Fsubscriptnorm𝑭tensor-product~𝑮~𝑨𝐹||{\bm{F}}-\widetilde{{\bm{G}}}\otimes\widetilde{{\bm{A}}}||_{F}| | bold_italic_F - over~ start_ARG bold_italic_G end_ARG ⊗ over~ start_ARG bold_italic_A end_ARG | | start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT.
Initialise 𝒈~0=𝟏,𝒂~0=𝟏formulae-sequencesuperscript~𝒈01superscript~𝒂01\widetilde{{\bm{g}}}^{0}{=}\mathbf{1},\widetilde{{\bm{a}}}^{0}{=}\mathbf{1}over~ start_ARG bold_italic_g end_ARG start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT = bold_1 , over~ start_ARG bold_italic_a end_ARG start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT = bold_1 (or using estimates of previous shot).
Set iterations I𝐼Iitalic_I (or I=1𝐼1I{=}1italic_I = 1 if using estimates from previous shot)
𝑮~,𝑨~~𝑮~𝑨\widetilde{{\bm{G}}},\widetilde{{\bm{A}}}over~ start_ARG bold_italic_G end_ARG , over~ start_ARG bold_italic_A end_ARG
for iteration i𝑖iitalic_i in [1, 2, …, I𝐼Iitalic_Ido
    Compute: 𝒈~i=(𝑭~)𝒂~i1(𝑭~)𝒂~i12superscript~𝒈𝑖~𝑭superscript~𝒂𝑖1subscriptnorm~𝑭superscript~𝒂𝑖12\widetilde{{\bm{g}}}^{i}=\frac{\mathcal{R}(\widetilde{{\bm{F}}})\widetilde{{% \bm{a}}}^{i-1}}{||\mathcal{R}(\widetilde{{\bm{F}}})\widetilde{{\bm{a}}}^{i-1}|% |_{2}}over~ start_ARG bold_italic_g end_ARG start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = divide start_ARG caligraphic_R ( over~ start_ARG bold_italic_F end_ARG ) over~ start_ARG bold_italic_a end_ARG start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT end_ARG start_ARG | | caligraphic_R ( over~ start_ARG bold_italic_F end_ARG ) over~ start_ARG bold_italic_a end_ARG start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG , with (𝑭~)𝒂~i1=1Nn=1N𝒂nT𝑨~i1𝒂nvec(𝒈n𝒈nT)~𝑭superscript~𝒂𝑖11𝑁superscriptsubscript𝑛1𝑁superscriptsubscript𝒂𝑛𝑇superscript~𝑨𝑖1subscript𝒂𝑛vecsubscript𝒈𝑛superscriptsubscript𝒈𝑛𝑇\mathcal{R}(\widetilde{{\bm{F}}})\widetilde{{\bm{a}}}^{i-1}=\frac{1}{N}\sum_{n% =1}^{N}{\bm{a}}_{n}^{T}\widetilde{{\bm{A}}}^{i-1}{\bm{a}}_{n}\text{vec}({\bm{g% }}_{n}{\bm{g}}_{n}^{T})caligraphic_R ( over~ start_ARG bold_italic_F end_ARG ) over~ start_ARG bold_italic_a end_ARG start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_italic_a start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over~ start_ARG bold_italic_A end_ARG start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT bold_italic_a start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT vec ( bold_italic_g start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT bold_italic_g start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT )
    Compute: 𝒂~i=(𝑭~)T𝒈~i(𝑭~)T𝒈~i2superscript~𝒂𝑖superscript~𝑭𝑇superscript~𝒈𝑖subscriptnormsuperscript~𝑭𝑇superscript~𝒈𝑖2\widetilde{{\bm{a}}}^{i}=\frac{\mathcal{R}(\widetilde{{\bm{F}}})^{T}\widetilde% {{\bm{g}}}^{i}}{||\mathcal{R}(\widetilde{{\bm{F}}})^{T}\widetilde{{\bm{g}}}^{i% }||_{2}}over~ start_ARG bold_italic_a end_ARG start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = divide start_ARG caligraphic_R ( over~ start_ARG bold_italic_F end_ARG ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over~ start_ARG bold_italic_g end_ARG start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT end_ARG start_ARG | | caligraphic_R ( over~ start_ARG bold_italic_F end_ARG ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over~ start_ARG bold_italic_g end_ARG start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG , with (𝑭~)T𝒈~i=1Nn=1N𝒈nT𝑮~i𝒈nvec(𝒂n𝒂nT)superscript~𝑭𝑇superscript~𝒈𝑖1𝑁superscriptsubscript𝑛1𝑁superscriptsubscript𝒈𝑛𝑇superscript~𝑮𝑖subscript𝒈𝑛vecsubscript𝒂𝑛superscriptsubscript𝒂𝑛𝑇\mathcal{R}(\widetilde{{\bm{F}}})^{T}\widetilde{{\bm{g}}}^{i}=\frac{1}{N}\sum_% {n=1}^{N}{\bm{g}}_{n}^{T}\widetilde{{\bm{G}}}^{i}{\bm{g}}_{n}\text{vec}({\bm{a% }}_{n}{\bm{a}}_{n}^{T})caligraphic_R ( over~ start_ARG bold_italic_F end_ARG ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over~ start_ARG bold_italic_g end_ARG start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_italic_g start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over~ start_ARG bold_italic_G end_ARG start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT bold_italic_g start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT vec ( bold_italic_a start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT bold_italic_a start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT )
    Compute: σi=𝒂~i2superscript𝜎𝑖subscriptnormsuperscript~𝒂𝑖2\sigma^{i}=||\widetilde{{\bm{a}}}^{i}||_{2}italic_σ start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = | | over~ start_ARG bold_italic_a end_ARG start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
end for
Return: 𝑮~=σimat(𝒈~),𝑨~=σimat(𝒂~)formulae-sequence~𝑮superscript𝜎𝑖mat~𝒈~𝑨superscript𝜎𝑖mat~𝒂\widetilde{{\bm{G}}}=\sqrt{\sigma^{i}}\text{mat}(\widetilde{{\bm{g}}}),% \widetilde{{\bm{A}}}=\sqrt{\sigma^{i}}\text{mat}(\widetilde{{\bm{a}}})over~ start_ARG bold_italic_G end_ARG = square-root start_ARG italic_σ start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT end_ARG mat ( over~ start_ARG bold_italic_g end_ARG ) , over~ start_ARG bold_italic_A end_ARG = square-root start_ARG italic_σ start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT end_ARG mat ( over~ start_ARG bold_italic_a end_ARG ).
Refer to caption
Figure 5: Example illustration of nearest Kronecker factor approximations 𝑭~r=1RK𝑮i𝑨i~𝑭superscriptsubscript𝑟1subscript𝑅𝐾tensor-productsubscript𝑮𝑖subscript𝑨𝑖\widetilde{{\bm{F}}}{\approx}\sum_{r=1}^{R_{K}}{\bm{G}}_{i}\otimes{\bm{A}}_{i}over~ start_ARG bold_italic_F end_ARG ≈ ∑ start_POSTSUBSCRIPT italic_r = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_italic_G start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⊗ bold_italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, compared to classical KFAC with the IAD assumption. Larger RKsubscript𝑅𝐾R_{K}italic_R start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT yields better approximations to the true Fisher 𝑭𝑭{\bm{F}}bold_italic_F for larger RKsubscript𝑅𝐾R_{K}italic_R start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, as measured by the root mean squared error (rmse).

I.2 Extended curvature approximations

For classic KFAC with IAD or RK=1subscript𝑅𝐾1R_{K}{=}1italic_R start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = 1 nearest Kronecker approximations of the form 𝑭~=𝑮𝑨~𝑭tensor-product𝑮𝑨\widetilde{{\bm{F}}}={\bm{G}}\otimes{\bm{A}}over~ start_ARG bold_italic_F end_ARG = bold_italic_G ⊗ bold_italic_A, the inverse simply becomes (𝑮𝑨)1=𝑮1𝑨1superscripttensor-product𝑮𝑨1tensor-productsuperscript𝑮1superscript𝑨1({\bm{G}}\otimes{\bm{A}})^{-1}={\bm{G}}^{-1}\otimes{\bm{A}}^{-1}( bold_italic_G ⊗ bold_italic_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = bold_italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⊗ bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT. Unfortunately, we can not use this famous inverse identity for sum of Kronecker factors, which is why we fall back on eigendecompositions 𝑮=𝑬1𝑺1𝑬1T𝑮subscript𝑬1subscript𝑺1superscriptsubscript𝑬1𝑇{\bm{G}}={\bm{E}}_{1}{\bm{S}}_{1}{\bm{E}}_{1}^{T}bold_italic_G = bold_italic_E start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_E start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT and 𝑨=𝑬2𝑺2𝑬2T𝑨subscript𝑬2subscript𝑺2superscriptsubscript𝑬2𝑇{\bm{A}}={\bm{E}}_{2}{\bm{S}}_{2}{\bm{E}}_{2}^{T}bold_italic_A = bold_italic_E start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_italic_E start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT, allowing us to decompose the Fisher into:

𝑭~=𝑲1𝑺1𝑲1T𝑲2𝑺2𝑲2T=(𝑲1𝑲2)(𝑰𝑰+𝑺1𝑺2)(𝑲1T𝑲2T)~𝑭tensor-productsubscript𝑲1subscript𝑺1superscriptsubscript𝑲1𝑇subscript𝑲2subscript𝑺2superscriptsubscript𝑲2𝑇tensor-productsubscript𝑲1subscript𝑲2tensor-product𝑰𝑰tensor-productsubscript𝑺1subscript𝑺2tensor-productsuperscriptsubscript𝑲1𝑇superscriptsubscript𝑲2𝑇\displaystyle\widetilde{{\bm{F}}}={\bm{K}}_{1}{\bm{S}}_{1}{\bm{K}}_{1}^{T}% \otimes{\bm{K}}_{2}{\bm{S}}_{2}{\bm{K}}_{2}^{T}=({\bm{K}}_{1}\otimes{\bm{K}}_{% 2})({\bm{I}}\otimes{\bm{I}}+{\bm{S}}_{1}\otimes{\bm{S}}_{2})({\bm{K}}_{1}^{T}% \otimes{\bm{K}}_{2}^{T})over~ start_ARG bold_italic_F end_ARG = bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT = ( bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊗ bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ( bold_italic_I ⊗ bold_italic_I + bold_italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊗ bold_italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ( bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) (33)

where specific 𝑲1subscript𝑲1{\bm{K}}_{1}bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 𝑲2subscript𝑲2{\bm{K}}_{2}bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT can be found in App. B of Martens & Grosse (2015), which we closely followed in our derivations. Because 𝑲1subscript𝑲1{\bm{K}}_{1}bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 𝑲2subscript𝑲2{\bm{K}}_{2}bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT are orthogonal and 𝑺1subscript𝑺1{\bm{S}}_{1}bold_italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 𝑺2subscript𝑺2{\bm{S}}_{2}bold_italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT diagonal, the inverse Fisher becomes:

𝑭~1=(𝑲1𝑲2)(𝑰𝑰+𝑺1𝑺2)1(𝑲1T𝑲2T)superscript~𝑭1tensor-productsubscript𝑲1subscript𝑲2superscripttensor-product𝑰𝑰tensor-productsubscript𝑺1subscript𝑺21tensor-productsuperscriptsubscript𝑲1𝑇superscriptsubscript𝑲2𝑇\displaystyle\widetilde{{\bm{F}}}^{-1}=({\bm{K}}_{1}\otimes{\bm{K}}_{2})({\bm{% I}}\otimes{\bm{I}}+{\bm{S}}_{1}\otimes{\bm{S}}_{2})^{-1}({\bm{K}}_{1}^{T}% \otimes{\bm{K}}_{2}^{T})over~ start_ARG bold_italic_F end_ARG start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = ( bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊗ bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ( bold_italic_I ⊗ bold_italic_I + bold_italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊗ bold_italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) (34)

In the context of neural network training, the problem gets slightly harder since we want to incrementally construct estimates 𝑮~isubscript~𝑮𝑖\widetilde{{\bm{G}}}_{i}over~ start_ARG bold_italic_G end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and 𝑨~isubscript~𝑨𝑖\widetilde{{\bm{A}}}_{i}over~ start_ARG bold_italic_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT from individual samples 𝒂l,n,𝒈l,nsubscript𝒂𝑙𝑛subscript𝒈𝑙𝑛{\bm{a}}_{l,n},{\bm{g}}_{l,n}bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT , bold_italic_g start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT that make up 𝑭𝑭{\bm{F}}bold_italic_F, without having to simultaneously store more than a single or batch of input activations 𝒂l,nsubscript𝒂𝑙𝑛{\bm{a}}_{l,n}bold_italic_a start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT or output gradients 𝒈l,nsubscript𝒈𝑙𝑛{\bm{g}}_{l,n}bold_italic_g start_POSTSUBSCRIPT italic_l , italic_n end_POSTSUBSCRIPT in memory. Although this online Kronecker-product principal component analysis problem largely remains an open research problem, we our approach closely follows the recent work by (Koroko et al., 2022) that uses similar approximations in the context of optimisation. A sum of multiple RK>1subscript𝑅𝐾1R_{K}{>}1italic_R start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT > 1 Kronecker factors will yield closer approximations, but also linearly increase memory requirements with higher RKsubscript𝑅𝐾R_{K}italic_R start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT and makes inverting 𝑭1superscript𝑭1{\bm{F}}^{-1}bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT considerably more difficult.

Formulas to compute cost and weight updates.

For sum of Kronecker factors, we find that the constrained optimization solution of for costs ΔΔ\Delta\mathcal{L}roman_Δ caligraphic_L eq. 7 and weight updates Δ𝜽Δ𝜽\Delta{\bm{\theta}}roman_Δ bold_italic_θ eq. 8 become the following inner-product and matrix-vector product:

k=12 𝜽*,𝑼 𝜽*subscript𝑘12 𝜽*𝑼 𝜽*\displaystyle\mathcal{L}_{k}=\frac{1}{2}\langle\hbox{\vbox{\hrule height=0.5pt% \kern 2.15277pt\hbox{\kern-1.00006pt${\bm{\theta}}^{*}$\kern-1.00006pt}}},{\bm% {U}}\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{% \theta}}^{*}$\kern-1.00006pt}}}\ranglecaligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ⟨ θ* , bold_italic_U θ* ⟩ =( 𝜽*)T𝑼( 𝜽*)absentsuperscript 𝜽*𝑇𝑼 𝜽*\displaystyle=(\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.000% 06pt${\bm{\theta}}^{*}$\kern-1.00006pt}}})^{T}{\bm{U}}(\hbox{\vbox{\hrule heig% ht=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{\theta}}^{*}$\kern-1.00006pt% }}})\in\mathbb{R}= ( θ* ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_U ( θ* ) ∈ blackboard_R (35)
Δ𝜽=𝑭~1𝑬KT𝒖Δ𝜽superscript~𝑭1superscriptsubscript𝑬𝐾𝑇𝒖\displaystyle\Delta{\bm{\theta}}=\widetilde{{\bm{F}}}^{-1}{\bm{E}}_{K}^{T}{\bm% {u}}roman_Δ bold_italic_θ = over~ start_ARG bold_italic_F end_ARG start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_u =𝑲1( 𝑲1T𝑼 𝑲2[𝟏𝟏T+𝒔1𝒔2T])𝑲2TRCabsentsubscript𝑲1superscriptsubscript 𝑲1𝑇𝑼subscript 𝑲2delimited-[]superscript11𝑇subscript𝒔1superscriptsubscript𝒔2𝑇superscriptsubscript𝑲2𝑇superscript𝑅𝐶\displaystyle={\bm{K}}_{1}\left(\hbox{\vbox{\hrule height=0.5pt\kern 2.15277pt% \hbox{\kern-1.00006pt${\bm{K}}$\kern-1.00006pt}}}_{1}^{T}{\bm{U}}\hbox{\vbox{% \hrule height=0.5pt\kern 2.15277pt\hbox{\kern-1.00006pt${\bm{K}}$\kern-1.00006% pt}}}_{2}\oslash\left[\mathbf{1}\mathbf{1}^{T}+{\bm{s}}_{1}{\bm{s}}_{2}^{T}% \right]\right){\bm{K}}_{2}^{T}\in\mathbb{R}^{RC}= bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( roman_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_U roman_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊘ [ bold_11 start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT + bold_italic_s start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_s start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] ) bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_R italic_C end_POSTSUPERSCRIPT (36)

with at the heart of it all a matrix 𝑼=[𝑬K𝑭1𝑬KT]1𝑼superscriptdelimited-[]subscript𝑬𝐾superscript𝑭1superscriptsubscript𝑬𝐾𝑇1{\bm{U}}=[{\bm{E}}_{K}{\bm{F}}^{-1}{\bm{E}}_{K}^{T}]^{-1}bold_italic_U = [ bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_F start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT that captures correlations between weights:

𝑼𝑼\displaystyle{\bm{U}}bold_italic_U =[𝑬K(𝑲1𝑲2)(𝑰𝑰+𝑺1𝑺2)1(𝑲1T𝑲2T)𝑬KT]1absentsuperscriptdelimited-[]subscript𝑬𝐾tensor-productsubscript𝑲1subscript𝑲2superscripttensor-product𝑰𝑰tensor-productsubscript𝑺1subscript𝑺21tensor-productsuperscriptsubscript𝑲1𝑇superscriptsubscript𝑲2𝑇superscriptsubscript𝑬𝐾𝑇1\displaystyle=\left[{\bm{E}}_{K}\Big{(}{\bm{K}}_{1}\otimes{\bm{K}}_{2}\Big{)}% \Big{(}{\bm{I}}\otimes{\bm{I}}+{\bm{S}}_{1}\otimes{\bm{S}}_{2}\Big{)}^{-1}\Big% {(}{\bm{K}}_{1}^{T}\otimes{\bm{K}}_{2}^{T}\Big{)}{\bm{E}}_{K}^{T}\right]^{-1}= [ bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊗ bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ( bold_italic_I ⊗ bold_italic_I + bold_italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊗ bold_italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ⊗ bold_italic_K start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) bold_italic_E start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT (37)

where (𝑰𝑰+𝑺1𝑺2)tensor-product𝑰𝑰tensor-productsubscript𝑺1subscript𝑺2({\bm{I}}\otimes{\bm{I}}+{\bm{S}}_{1}\otimes{\bm{S}}_{2})( bold_italic_I ⊗ bold_italic_I + bold_italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊗ bold_italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) is diagonal and the inverse can thus be computed element-wise. The remaining inverse is of size K×K𝐾𝐾K\times Kitalic_K × italic_K, for K𝐾Kitalic_K correlated weights.

Note on sum of Kronecker factors

Experimentally, we did not find a benefit in performance when using a sum of two nearest Kronecker factor approximation, or found it too slow. Therefore, we focus in the main text on LLM Surgeon with fast single Kronecker product KFAC approximation to approximate the loss landsscape curvature. Nevertheless, we choose to include this appendix as we believe could prove useful in other contexts or inspire future work that aim to further improve the quality of curvature approximations.

Appendix J Code