Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: nccmath
  • failed: chemformula
  • failed: collcell
  • failed: pythonhighlight
  • failed: titletoc

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: arXiv.org perpetual non-exclusive license
arXiv:2207.10226v4 [cs.LG] 07 Apr 2024

Improving Privacy-Preserving Vertical Federated Learning by Efficient Communication with ADMM

Chulin Xie University of Illinois Urbana-Champaign
chulinx2@illinois.edu
   Pin-Yu Chen IBM Research
pin-yu.chen@ibm.com
   Qinbin Li {@IEEEauthorhalign} Arash Nourian UC Berkeley
qinbin@berkeley.edu
Amazon Web Services & UC Berkeley
nouriara@amazon.com
   Ce Zhang University of Chicago
cez@uchicago.edu
   Bo Li University of Chicago & UIUC
lbo@illinois.edu
Abstract

Federated learning (FL) enables distributed resource-constrained devices to jointly train shared models while keeping the training data local for privacy purposes. Vertical FL (VFL), which allows each client to collect partial features, has attracted intensive research efforts recently. We identified the main challenges that existing VFL frameworks are facing: the server needs to communicate gradients with the clients for each training step, incurring high communication cost that leads to rapid consumption of privacy budgets. To address these challenges, in this paper, we introduce a VFL framework with multiple heads (VIM), which takes the separate contribution of each client into account, and enables an efficient decomposition of the VFL optimization objective to sub-objectives that can be iteratively tackled by the server and the clients on their own. In particular, we propose an Alternating Direction Method of Multipliers (ADMM)-based method to solve our optimization problem, which allows clients to conduct multiple local updates before communication, and thus reduces the communication cost and leads to better performance under differential privacy (DP). We provide the client-level DP mechanism for our framework to protect user privacy. Moreover, we show that a byproduct of VIM is that the weights of learned heads reflect the importance of local clients. We conduct extensive evaluations and show that on four vertical FL datasets, VIM achieves significantly higher performance and faster convergence compared with the state-of-the-art. We also explicitly evaluate the importance of local clients and show that VIM enables functionalities such as client-level explanation and client denoising. We hope this work will shed light on a new way of effective VFL training and understanding. 111 Our code is available at: https://github.com/AI-secure/VFL-ADMM

I Introduction

Federated learning (FL) has enabled large-scale training with data privacy guarantees on distributed data for different applications yang2019ffd ; brisimi2018federated ; hard2018federated ; yang2018applied ; yang2019federated . In general, FL can be categorized into Horizontal FL (HFL) mcmahan2016communication where data samples are distributed across clients, and Vertical FL (VFL) yang2019federated where features of the samples are partitioned across clients and the labels are usually owned by the server (or the active party in two-party setting hardy2017private ). In particular, VFL allows clients with partial information of the same dataset to jointly train the model, which leads to many real-world applications hu2019fdml ; yang2019federated ; hard2018federated . For instance, a patient may go to different types of healthcare providers, such as dental clinics and pharmacies for different purposes, and therefore it is important for different healthcare providers (i.e., VFL clients/data owners/organizations) to “share" their information about the same patient (i.e., partial features of the same sample) to better model the health condition of the patient. In addition, nowadays multimodal data has been ubiquitous, while usually, each client is only able to collect one or a few data modalities due to resource limitations. Therefore, VFL provides an effective way to allow such clients to train a model leveraging information from different data modalities jointly.

Despite the importance and practicality of VFL, the state-of-the-art (SOTA) VFL frameworks suffer from notable weaknesses: since the clients own the local features and the server holds the whole labels, the server needs to calculate training loss based on the labels and then send gradients to clients for each training step to update their local models vepakomma2018split ; chen2020vafl ; kang2020fedmvt , which incurs high communication cost and leads to potential rapid consumption of the privacy budget.

TABLE I: Comparison between our work and existing VFL studies.
VFL Setup Method Support DNN Support N>2𝑁2N>2italic_N > 2 parties Labels only held by one party Support multiple local updates Privacy guarantee
w/ model splitting VAFL chen2020vafl , VFL-PBM tran2023privacy ×\times×
Split Learning vepakomma2018split ×\times× ×\times×
FedBCD liu2022fedbcd ×\times×
CELU-VFL ×\times× ×\times×
Flex-VFL castiglia2023flexible ×\times× ×\times×
VIMADMM (Ours)
w/o model splitting Fu et al. fu2022usenix , FDML hu2019fdml ×\times× ×\times× ×\times×
AdaVFL zhang2022adaptive , CAFE jin2021catastrophic ×\times× ×\times×
Linear-ADMM hu2019learning ×\times×
VIMADMM-J (Ours)

To solve the above challenges, in this work, we propose an efficient VFL optimization framework with multiple heads (VIM), where each head corresponds to one local client. VIM takes the individual contribution of clients into consideration and facilitates a thorough decomposition of the VFL optimization problem into multiple subproblems that can be iteratively solved by the server and the clients. In particular, we propose an Alternating Direction Method of Multipliers (ADMM) boyd2011distributed -based method that splits the overall VIM optimization objective into smaller sub-objectives, and the clients can conduct multiple local updates w.r.t their local objectives at each communication round with the coordination of ADMM-related variables. This leads to faster model convergence and significantly reduces the communication cost, which is crucial to preserve privacy because the privacy cost of clients increases when the number of communication rounds increases abadi2016deep ; brendan2018learning , due to the continuous transmission of sensitive local information. We consider two typical VFL settings: with model splitting (i.e., clients host partial models) and without model splitting (i.e., clients hold the entire model). Under with model splitting setting, we propose an ADMM-based algorithm VIMADMM under VIM framework. Compared to gradient-based methods, VIMADMM not only reduces communication frequency but also reduces the dimensionality by only exchanging ADMM-related variables. We provide convergence analysis for VIMADMM and prove that it can converge to stationary points with mild assumptions. With modifications of communication strategies and updating rules for servers and clients, we extend VIMADMM to the without model splitting setting and introduce VIMADMM-J. Under both settings, to further protect the privacy of the local features held by clients, we introduce privacy mechanisms that clip and perturb local outputs to satisfy client-level differential privacy (DP) dwork2006our ; dwork2011firm ; dwork2014algorithmic ; mcmahan2016communication and prove the DP guarantees. Moreover, we offer a basic solution to separately protect the privacy of labels owned by server, leveraging the established label-DP mechanism ALIBI malek2021antipodes that perturbs the labels. Finally, we show that a byproduct of VIM is that the weights of learned heads reflect the importance of local clients, which enables functionalities such as client-level explanation, client denoising, and client summarization. Our main contributions are:

  • We propose an efficient and effective VFL optimization framework with multiple heads (VIM). To solve our optimization problem, we propose an ADMM-based method, VIMADMM, which reduces communication costs by allowing multiple local updates at each step.

  • We theoretically analyze the convergence of VIMADMM and prove that it can converge to stationary points.

  • We introduce the client-level DP mechanism for our VIM framework and prove its privacy guarantees.

  • We conduct extensive experiments on four diverse datasets (i.e., MNIST, CIFAR, NUS-WIDE, and ModelNet40), and show that ADMM-based algorithms under VIM converge faster, achieve higher accuracy, and remain higher utility under client-level DP and label DP than four existing VFL frameworks.

  • We evaluate our client-level explanation under VIM based on the weights norm of the heads, and demonstrate the functionalities it enables such as clients denoising and summarization.

II Related Work

Vertical Federated Learning

VFL has been well studied for simple models including trees cheng2021secureboost ; wu2020privacy , kernel models gu2020federated , and linear and logistic regression hardy2017private ; yang2019parallel ; zhang2021secure ; feng2020multi ; hu2019learning ; liu2019communication . For instance, Hardy et al. hardy2017private propose secure logistic regression for two-party VFL with homomorphic encryption rouhani2018deepsecure ; gilad2016cryptonets and multiparty computation ben1988completeness ; bonawitz2017practical . However, a limitation of these methods is the performance constraint associated with the logistic regression. Subsequent research has expanded the scope of VFL to encompass Deep Neural Networks (DNNs), facilitating VFL training with a larger number of clients and on large-scale models and datasets. For DNNs, there are two popular VFL settings: with model splitting vepakomma2018split ; kang2020fedmvt ; chen2020vafl and without model splitting hu2019fdml ; jin2021catastrophic .

In the with model splitting setting, Split Learning vepakomma2018split is the first related paradigm, where each client trains a partial network up to a cut layer, the server concatenates local activations and trains the rest of the network. VAFL chen2020vafl is proposed for asynchronous VFL where the server averages the local embeddings and sends gradients back to clients to update local models. However, such embedding averaging might lose the unique properties of each client. FedMVT kang2020fedmvt focuses on the semi-supervised VFL with multi-view learning. C-VFL castiglia2022compressed proposes embedding compression techniques to improve communication efficiency. However, we note that these methods  vepakomma2018split ; chen2020vafl ; kang2020fedmvt ; castiglia2022compressed still require the communication of gradients (w.r.t embeddings) from server to the client at each training step, leading to high communication frequency and communication cost before convergence. Recent research efforts have sought to reduce VFL communication frequency by allowing clients to make multiple local updates at each round. Particularly, in FedBCD liu2022fedbcd , after obtaining gradients from the server, clients update local models using the same stale gradients for multiple steps. Building upon this, CELU-VFL celuvfl2022 enhances the performance of FedBCD by caching stale gradients from earlier rounds and reusing them to estimate better model gradients at current round. Nonetheless, it is limited to supporting only two clients (party A and B, with B holding the labels) and cannot be directly extended to scenarios with more than two parties, as our study considers (specifically, it lacks a design to aggregate information from more parties). On another note, Flex-VFL castiglia2023flexible allows each party to undergo a different number of local updates constrained by a set timeout for every round. Yet, it assumes that clients possess copies of labels and receive local embeddings from other clients, enabling them to compute local gradients independently for multi-step local updates. In contrast, we propose an ADMM-based framework that enables multiple local updates and assumes that only the server possesses labels, which cannot be shared with other clients due to privacy restriction fu2022usenix .

For VFL without model splitting setting, each client submits local logits to the server, who then averages over the logits and send gradients w.r.t logits back to clients, as detailed in Fu et al. fu2022usenix . Several other approaches assume that the server shares both labels and aggregated logits with the clients, enabling them to locally compute the gradient hu2019fdml ; zhang2022adaptive . FDML hu2019fdml performs one step of local update at each round for asynchronous and distributed SGD. Considering that certain clients might have slower local computation speeds, AdaVFL zhang2022adaptive optimizes the number of local updates for each client at each round to minimize overall time. Meanwhile, CAFE jin2021catastrophic directly applies FedAvg mcmahan2016communication from Horizonta FL to VFL where all clients possess the labels and can exchange the model parameters with others for model aggregation. This deviates from the standard VFL setup where only the server retains the label and local models cannot be shared owing to privacy implications fu2022usenix .

Differentially Private VFL

In existing VFL frameworks, VAFL chen2020vafl provides Gaussian DP guarantee dong2019gaussian and VFL-PBM tran2023privacy quantizes local embeddings into DP integer vectors. However, they do not calculate the exact privacy budget in the evaluation. FDML hu2019fdml evaluate their framework under different levels of empirical noises, yet without offering detailed DP mechanisms or DP guarantee. The ADMM-based linear VFL framework (abbreviated to Linear-ADMMhu2019learning provides (ϵ,δ)italic-ϵ𝛿(\epsilon,\delta)( italic_ϵ , italic_δ )-DP guarantee for linear models by calculating the closed-form sensitivity of each sample and perturbing the linear model parameters, which is not directly applicable to DNNs whose sensitivity is hard to estimate due to the nonconvexity. Instead, we propose to perturb local outputs and provide formal client-level (ϵ,δ)italic-ϵ𝛿(\epsilon,\delta)( italic_ϵ , italic_δ )-DP theoretical guarantee in Section V.

We provide an overall comparison between our work and existing studies in Table I.

III VFL with Multiple Heads (VIM)

In this section, we start with the VFL background in Section III-A, and then discuss VFL with model splitting setting and introduce our framework VIM and ADMM-based method VIMADMM in Section III-B. Finally, we show that our ADMM-based method can be easily extended to VFL without model splitting setting with slight modifications on communication strategies and update rules, yielding VIMADMM-J in Section III-C.

III-A VFL Background

Typically in VFL, there are M𝑀Mitalic_M clients who hold different feature sets of the same training samples and jointly train the machine learning models. We consider the classification task and denote dcsubscript𝑑𝑐d_{c}italic_d start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT as the number of classes. Suppose there is a training dataset D={xj,yj}j=1N𝐷superscriptsubscriptsubscript𝑥𝑗subscript𝑦𝑗𝑗1𝑁D=\{x_{j},y_{j}\}_{j=1}^{N}italic_D = { italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT containing N𝑁Nitalic_N samples, the server owns the labels {yj}j=1Nsuperscriptsubscriptsubscript𝑦𝑗𝑗1𝑁\{y_{j}\}_{j=1}^{N}{ italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT, and each client k𝑘kitalic_k has a local feature set Xk={xjk}j=1Nsubscript𝑋𝑘superscriptsubscriptsuperscriptsubscript𝑥𝑗𝑘𝑗1𝑁X_{k}=\{x_{j}^{k}\}_{j=1}^{N}italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = { italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT, where the vector xjkdksuperscriptsubscript𝑥𝑗𝑘superscriptsuperscript𝑑𝑘x_{j}^{k}\in\mathbb{R}^{d^{k}}italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT denotes the local (partial) features of sample j𝑗jitalic_j. The overall feature xjdsubscript𝑥𝑗superscript𝑑x_{j}\in\mathbb{R}^{d}italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT of sample j𝑗jitalic_j is the concatenation of all local features {xj1,xj2,,xjM}superscriptsubscript𝑥𝑗1superscriptsubscript𝑥𝑗2superscriptsubscript𝑥𝑗𝑀\{x_{j}^{1},x_{j}^{2},\dots,x_{j}^{M}\}{ italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT }, with d=k=1Mdk𝑑superscriptsubscript𝑘1𝑀superscript𝑑𝑘d=\sum_{k=1}^{M}d^{k}italic_d = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT.

Due to the privacy protection requirement of VFL, each client k𝑘kitalic_k does not share raw local feature set Xksubscript𝑋𝑘X_{k}italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT with other clients or the server. Instead, VFL consists of two steps: (1) local processing step: each client learns a local model that maps the local features to local outputs and sends them to the server. (2) server aggregation step: the server aggregates the local outputs from all clients to compute the final prediction for each sample as well as the corresponding losses. Depending on whether or not the server holds a model, there are two popular VFL settings fu2022usenix : VFL with model splitting chen2020vafl ; vepakomma2018split and VFL without model splitting hu2019fdml : (i) In the model splitting setting, each client trains a feature extractor as the local model that outputs local embeddings, and the server owns a model which predicts the final results based on the aggregated embeddings. (ii) In the VFL without model splitting setting, the clients host the entire model that outputs the local logits, and the server simply performs the logits aggregation operation without hosting any model.

In both settings, the local model is updated based on SGD with federated backward propagation fu2022usenix : a) server first computes the gradients w.r.t the local output (either embeddings or logits) from each client separately and sends the gradients back to clients; b) each client calculates the gradients of local output w.r.t the local model parameters and updates the local model using the chain rule.

III-B VFL with Model Splitting

Setup

Let f𝑓fitalic_f parameterized by θksubscript𝜃𝑘\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT be the local model (i.e., feature extractor) of client k𝑘kitalic_k, which outputs a local embedding vector hjk=f(xjk;θk)superscriptsubscript𝑗𝑘𝑓superscriptsubscript𝑥𝑗𝑘subscript𝜃𝑘h_{j}^{k}=f(x_{j}^{k};\theta_{k})italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) dfabsentsuperscriptsubscript𝑑𝑓\in\mathbb{R}^{d_{f}}∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT end_POSTSUPERSCRIPT for each local feature xjksuperscriptsubscript𝑥𝑗𝑘x_{j}^{k}italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. We denote the parameters of the model on the server-side as θ0subscript𝜃0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Overall, the clients and the server aim to collaboratively solve the Empirical Risk Minimization (ERM) objective:

\medmathmin{θk},θ01Nj=1N({hj1,,hjM},yj;θ0)+k=1Mβk(θk)+β(θ0)\medmathsubscript𝜃𝑘subscript𝜃0min1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑗1superscriptsubscript𝑗𝑀subscript𝑦𝑗subscript𝜃0superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝜃𝑘𝛽subscript𝜃0\displaystyle\medmath{\underset{\{\theta_{k}\},\theta_{0}}{\operatorname{min}}% \frac{1}{N}\sum_{j=1}^{N}\ell(\{h_{j}^{1},\ldots,h_{j}^{M}\},y_{j};\theta_{0})% +\sum_{k=1}^{M}\beta_{k}\mathcal{R}(\theta_{k})+\beta\mathcal{R}(\theta_{0})}start_UNDERACCENT { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_UNDERACCENT start_ARG roman_min end_ARG divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_ℓ ( { italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT } , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + italic_β caligraphic_R ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) (1)

where \ellroman_ℓ is a loss function (e.g., cross-entropy loss with softmax), \mathcal{R}caligraphic_R is a regularizer on model parameters, and βksubscript𝛽𝑘\beta_{k}\in\mathbb{R}italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R is the regularization weight for client k𝑘kitalic_k, and β𝛽\betaitalic_β is the weight for server. The local embeddings for each sample j𝑗jitalic_j can be either concatenated together hj=[hj1,,hjM]subscript𝑗superscriptsubscript𝑗1superscriptsubscript𝑗𝑀h_{j}=[h_{j}^{1},\ldots,h_{j}^{M}]italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = [ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ] as in Split Learning vepakomma2018split or averaged hj=k=1Mαkhjksubscript𝑗superscriptsubscript𝑘1𝑀subscript𝛼𝑘superscriptsubscript𝑗𝑘h_{j}=\sum_{k=1}^{M}\alpha_{k}h_{j}^{k}italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT with aggregation weights αksubscript𝛼𝑘\alpha_{k}\in\mathbb{R}italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R as in VAFL chen2020vafl . Then hjsubscript𝑗h_{j}italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT is used as the input for server model θ0subscript𝜃0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT to calculate the loss. For more detailed description of the training algorithm Split Learning under VFL with model splitting, please refer to Algorithm 2 in Section -A1.

However, as outlined in Section III-A, these VFL methods are based on SGD and depend on the server model θ0subscript𝜃0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT to complete the loss and gradient calculation using server labels for updating local models {θk}subscript𝜃𝑘\{\theta_{k}\}{ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }. Consequently, the server needs to send the gradient w.r.t embeddings back to clients at every training step of the local models. Such (1) frequent communication and (2) high dimensionality of gradients (i.e., bdf𝑏subscript𝑑𝑓bd_{f}italic_b italic_d start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT for b𝑏bitalic_b samples) lead to high communication costs.

VIM Formulation

To address these challenges, we propose the VIM framework where the server learns a model with multiple heads corresponding to multiple local clients. It takes the separate contribution of each client into account and facilitates the breakdown of the VFL optimization into several sub-problems to be solved by clients and the server independently via ADMM without communicating gradients, as we will elaborate on later. Specifically, the server’s model θ0subscript𝜃0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT consists of M𝑀Mitalic_M heads W1,W2,,WMsubscript𝑊1subscript𝑊2subscript𝑊𝑀W_{1},W_{2},\dots,W_{M}italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_W start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT where Wkdf×dc,k[M]formulae-sequencesubscript𝑊𝑘superscriptsubscript𝑑𝑓subscript𝑑𝑐𝑘delimited-[]𝑀W_{k}\in\mathbb{R}^{d_{f}\times d_{c}},k\in[M]italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_k ∈ [ italic_M ]. For the sake of simplicity, we consider each Wksubscript𝑊𝑘W_{k}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT to be a linear head here, and our formulation can be easily extended to the non-linear heads by viewing each Wksubscript𝑊𝑘W_{k}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT as a non-linear model (see the end of Section III-B for more details). This is motivated by the recent studies in representation learning, which have shown that learning a linear classifier is sufficient to accurately predicting the labels on top of embedding representations radford2021learning ; khosla2020supervised , given the expressive power of the local feature extractor that captures essential information from raw feature sets. For sample j𝑗jitalic_j, the server’s model outputs y^j=k=1MhjkWksubscript^𝑦𝑗superscriptsubscript𝑘1𝑀superscriptsubscript𝑗𝑘subscript𝑊𝑘\hat{y}_{j}=\sum_{k=1}^{M}h_{j}^{k}W_{k}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT as the prediction, yielding our VIM objective:

min{Wk},{θk}𝚅𝙸𝙼subscript𝑊𝑘subscript𝜃𝑘minsubscript𝚅𝙸𝙼\displaystyle\underset{\{W_{k}\},\{\theta_{k}\}}{\operatorname{min}}\mathcal{L% }_{\mathrm{\texttt{VIM}}}start_UNDERACCENT { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } end_UNDERACCENT start_ARG roman_min end_ARG caligraphic_L start_POSTSUBSCRIPT VIM end_POSTSUBSCRIPT ({Wk},{θk}):=1Nj=1N(k=1Mf(xjk;θk)Wk,yj)assignsubscript𝑊𝑘subscript𝜃𝑘1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘subscript𝜃𝑘subscript𝑊𝑘subscript𝑦𝑗\displaystyle(\{W_{k}\},\{\theta_{k}\}):=\frac{1}{N}\sum_{j=1}^{N}\ell\left(% \sum_{k=1}^{M}f(x_{j}^{k};\theta_{k})W_{k},y_{j}\right)( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } ) := divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_ℓ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT )
+k=1Mβkk(θk)+k=1Mβkk(Wk)superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘subscript𝜃𝑘superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘subscript𝑊𝑘\displaystyle+\sum_{k=1}^{M}\beta_{k}\mathcal{R}_{k}(\theta_{k})+\sum_{k=1}^{M% }\beta_{k}\mathcal{R}_{k}(W_{k})+ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) (2)
VIMADMM

Based on the VIM formulation, we propose an ADMM-based method, reducing the communication frequency by allowing the clients to perform multiple local updates w.r.t their ADMM objectives at each round, and reducing the dimensionality by only exchanging ADMM-related variables (i.e., (2b+df)dc2𝑏subscript𝑑𝑓subscript𝑑𝑐(2b+d_{f})d_{c}( 2 italic_b + italic_d start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT ) italic_d start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT for b𝑏bitalic_b samples where dcdf,bmuch-less-thansubscript𝑑𝑐subscript𝑑𝑓𝑏d_{c}\ll d_{f},bitalic_d start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ≪ italic_d start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT , italic_b for most VFL settings today chen2020vafl ; hu2019fdml ). Specifically, we note that Eq. III-B can be viewed as the sharing problem boyd2011distributed involving each client adjusting its variable to minimize the shared cost term (k=1MhjkWk,yj)superscriptsubscript𝑘1𝑀superscriptsubscript𝑗𝑘subscript𝑊𝑘subscript𝑦𝑗\ell(\sum_{k=1}^{M}h_{j}^{k}W_{k},y_{j})roman_ℓ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) as well as its individual cost (θk)+(Wk)subscript𝜃𝑘subscript𝑊𝑘\mathcal{R}(\theta_{k})+\mathcal{R}(W_{k})caligraphic_R ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + caligraphic_R ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ). Moreover, the multiple heads in VIM enable the application of ADMM via a special decomposition into simpler sub-problems that can be solved in a distributed manner. We begin by rewriting Eq. III-B to an equivalent constrained optimization problem by introducing auxiliary variables z1,z2,,zNdcsubscript𝑧1subscript𝑧2subscript𝑧𝑁superscriptsubscript𝑑𝑐z_{1},z_{2},\dots,z_{N}\in\mathbb{R}^{d_{c}}italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_z start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUPERSCRIPT:

min{Wk},{θk},{zj}1Nj=1N(zj,yj)+k=1Mβkk(θk)+k=1Mβkk(Wk)subscript𝑊𝑘subscript𝜃𝑘subscript𝑧𝑗min1𝑁superscriptsubscript𝑗1𝑁subscript𝑧𝑗subscript𝑦𝑗superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘subscript𝜃𝑘superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘subscript𝑊𝑘\displaystyle\underset{\{W_{k}\},\{\theta_{k}\},\{z_{j}\}}{\operatorname{min}}% \frac{1}{N}\sum_{j=1}^{N}\ell(z_{j},y_{j})+\sum_{k=1}^{M}\beta_{k}\mathcal{R}_% {k}(\theta_{k})+\sum_{k=1}^{M}\beta_{k}\mathcal{R}_{k}(W_{k})start_UNDERACCENT { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } end_UNDERACCENT start_ARG roman_min end_ARG divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT )
 s.t. k=1Mf(xjk;θk)Wkzj=0,j[N].formulae-sequence s.t. superscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘subscript𝜃𝑘subscript𝑊𝑘subscript𝑧𝑗0for-all𝑗delimited-[]𝑁\displaystyle\quad\text{ s.t. }\sum_{k=1}^{M}f(x_{j}^{k};\theta_{k})W_{k}-z_{j% }=0,\forall j\in[N].s.t. ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = 0 , ∀ italic_j ∈ [ italic_N ] . (3)

Notably, each constraint implies a consensus between the server’s output k=1MhjkWksuperscriptsubscript𝑘1𝑀superscriptsubscript𝑗𝑘subscript𝑊𝑘\sum_{k=1}^{M}h_{j}^{k}W_{k}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and the auxiliary variable zjsubscript𝑧𝑗z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT for each sample j𝑗jitalic_j. The augmented Lagrangian, which adds a quadratic term to the Lagrangian of Eq. III-B, is given by:

min{Wk},{θk},{zj},{λj}ADMM({Wk},{θk},{zj},{λj})subscript𝑊𝑘subscript𝜃𝑘subscript𝑧𝑗subscript𝜆𝑗minsubscriptADMMsubscript𝑊𝑘subscript𝜃𝑘subscript𝑧𝑗subscript𝜆𝑗\displaystyle\underset{\{W_{k}\},\{\theta_{k}\},\{z_{j}\},\{\lambda_{j}\}}{% \operatorname{min}}\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}\},\{\theta_{k}\},\{z_{j% }\},\{\lambda_{j}\})start_UNDERACCENT { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } end_UNDERACCENT start_ARG roman_min end_ARG caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } )
:=1Nj=1N(zj,yj)+k=1Mβkk(θk)+k=1Mβkk(Wk)assignabsent1𝑁superscriptsubscript𝑗1𝑁subscript𝑧𝑗subscript𝑦𝑗superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘subscript𝜃𝑘superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘subscript𝑊𝑘\displaystyle\quad\quad:=\frac{1}{N}\sum_{j=1}^{N}\ell(z_{j},y_{j})+\sum_{k=1}% ^{M}\beta_{k}\mathcal{R}_{k}(\theta_{k})+\sum_{k=1}^{M}\beta_{k}\mathcal{R}_{k% }(W_{k}):= divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT )
+1Nj=1Nλj(k=1Mf(xjk;θk)Wkzj)1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝜆𝑗topsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘subscript𝜃𝑘subscript𝑊𝑘subscript𝑧𝑗\displaystyle\quad\quad+\frac{1}{N}\sum_{j=1}^{N}\lambda_{j}^{\top}(\sum_{k=1}% ^{M}f(x_{j}^{k};\theta_{k})W_{k}-z_{j})+ divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT )
+ρ2Nj=1Nk=1Mf(xjk;θk)WkzjF2,𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘subscript𝜃𝑘subscript𝑊𝑘subscript𝑧𝑗𝐹2\displaystyle\quad\quad+\frac{\rho}{2N}\sum_{j=1}^{N}\left\|\sum_{k=1}^{M}f(x_% {j}^{k};\theta_{k})W_{k}-z_{j}\right\|_{F}^{2},+ divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , (4)

where λjdcsubscript𝜆𝑗superscriptsubscript𝑑𝑐\lambda_{j}\in\mathbb{R}^{d_{c}}italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the dual variable for sample j𝑗jitalic_j, and ρ+𝜌superscript\rho\in\mathbb{R}^{+}italic_ρ ∈ blackboard_R start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT is a constant penalty factor. Recall that y^j=k=1Mf(xjk,θk)Wksubscript^𝑦𝑗superscriptsubscript𝑘1𝑀𝑓subscriptsuperscript𝑥𝑘𝑗subscript𝜃𝑘subscript𝑊𝑘\hat{y}_{j}=\sum_{k=1}^{M}f(x^{k}_{j},\theta_{k})W_{k}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is the server output (i.e., prediction) for sample xjsubscript𝑥𝑗x_{j}italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. VIMADMM essentially aims to minimize the loss between zjsubscript𝑧𝑗z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and ground-truth label yjsubscript𝑦𝑗y_{j}italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, as well as the difference between zjsubscript𝑧𝑗z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and y^jsubscript^𝑦𝑗\hat{y}_{j}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT during training. Specifically, as shown in the ADMM loss (Eq. III-B), l(zj,yj)𝑙subscript𝑧𝑗subscript𝑦𝑗l(z_{j},y_{j})italic_l ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) is the loss between zjsubscript𝑧𝑗z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and yjsubscript𝑦𝑗y_{j}italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, while y^jzj=k=1Mf(xjk,θk)Wkzjsubscript^𝑦𝑗subscript𝑧𝑗superscriptsubscript𝑘1𝑀𝑓subscriptsuperscript𝑥𝑘𝑗subscript𝜃𝑘subscript𝑊𝑘subscript𝑧𝑗\hat{y}_{j}-z_{j}=\sum_{k=1}^{M}f(x^{k}_{j},\theta_{k})W_{k}-z_{j}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT appears in the linear constraint and quadratic constraint terms. The auxiliary variables {zj}subscript𝑧𝑗\{z_{j}\}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } and dual variables {λj}subscript𝜆𝑗\{\lambda_{j}\}{ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } are used to facilitate the training of server heads {Wk}subscript𝑊𝑘\{W_{k}\}{ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } and local models {θk}subscript𝜃𝑘\{\theta_{k}\}{ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }.

To solve Eq. III-B, we follow standard ADMM boyd2011distributed and update the primal variables {Wk}subscript𝑊𝑘\{W_{k}\}{ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , {θk}subscript𝜃𝑘\{\theta_{k}\}{ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }, {zj}subscript𝑧𝑗\{z_{j}\}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } and the dual variables {λj}subscript𝜆𝑗\{\lambda_{j}\}{ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } alternatively, which decomposes the problem in Eq. III-B into four sets of sub-problems over {Wk}subscript𝑊𝑘\{W_{k}\}{ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }, {θk}subscript𝜃𝑘\{\theta_{k}\}{ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }, {zj}subscript𝑧𝑗\{z_{j}\}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT }, {λj}subscript𝜆𝑗\{\lambda_{j}\}{ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT }, and the parameters in each sub-problem can be solved in parallel. In practice, we propose the following strategy for the alternative updating in the server and clients: (i) updating {zj}subscript𝑧𝑗\{z_{j}\}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT }, {λj}subscript𝜆𝑗\{\lambda_{j}\}{ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } and {Wk}subscript𝑊𝑘\{W_{k}\}{ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } at server-side, (ii) updating {θk}subscript𝜃𝑘\{\theta_{k}\}{ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } at the client-side in parallel. Moreover, we consider the realistic setting of stochastic ADMM with mini-batches. Concretely, at communication round t𝑡titalic_t, the server samples a set of data indices, B(t)𝐵𝑡B(t)italic_B ( italic_t ), with batch size |B(t)|=b𝐵𝑡𝑏|B(t)|=b| italic_B ( italic_t ) | = italic_b. Then we describe the key steps of VIMADMM as follows:

(1) Communication from client to server. Each client k𝑘kitalic_k sends a batch of embeddings {hjk(t)}jB(t)subscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡𝑗𝐵𝑡\{{h_{j}^{k}}^{(t)}\}_{j\in B(t)}{ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT to the server, where hjk(t)=f(xjk;θk(t))superscriptsuperscriptsubscript𝑗𝑘𝑡𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡{h_{j}^{k}}^{(t)}=f(x_{j}^{k};\theta_{k}^{(t)})italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ,jB(t),\forall j\in B(t), ∀ italic_j ∈ italic_B ( italic_t ).

(2) Sever updates auxiliary variables {zj}subscriptnormal-znormal-j\{z_{j}\}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT }. After receiving the local embeddings from all clients, the server updates the auxiliary variable for each sample jB(t)𝑗𝐵𝑡j\in B(t)italic_j ∈ italic_B ( italic_t ) as:

zj(t)=argminzj(zj,yj)λj(t1)zj+ρ2k=1Mhjk(t)Wk(t)zjF2.superscriptsubscript𝑧𝑗𝑡subscript𝑧𝑗argminsubscript𝑧𝑗subscript𝑦𝑗superscriptsuperscriptsubscript𝜆𝑗𝑡1topsubscript𝑧𝑗𝜌2superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀superscriptsuperscriptsubscript𝑗𝑘𝑡superscriptsubscript𝑊𝑘𝑡subscript𝑧𝑗𝐹2z_{j}^{(t)}=\underset{z_{j}}{\operatorname{argmin}}\quad\ell(z_{j},y_{j})-{% \lambda_{j}^{(t-1)}}^{\top}z_{j}+\frac{\rho}{2}\left\|\sum_{k=1}^{M}{h_{j}^{k}% }^{(t)}W_{k}^{(t)}-z_{j}\right\|_{F}^{2}.italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = start_UNDERACCENT italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_UNDERACCENT start_ARG roman_argmin end_ARG roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) - italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + divide start_ARG italic_ρ end_ARG start_ARG 2 end_ARG ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (5)

Since the optimization problem in Eq. 5 is convex and differentiable with respect to zjsubscript𝑧𝑗z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, we use the L-BFGS-B algorithm zhu1997algorithm to solve the minimization problem.

(3) Sever updates dual variables {λj}subscriptnormal-λnormal-j\{\lambda_{j}\}{ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT }. The server updates dual variable for each sample jB(t)𝑗𝐵𝑡j\in B(t)italic_j ∈ italic_B ( italic_t ):

λj(t)=λj(t1)+ρ(k=1Mhjk(t)Wk(t)zj(t)).subscriptsuperscript𝜆𝑡𝑗subscriptsuperscript𝜆𝑡1𝑗𝜌superscriptsubscript𝑘1𝑀superscriptsuperscriptsubscript𝑗𝑘𝑡superscriptsubscript𝑊𝑘𝑡subscriptsuperscript𝑧𝑡𝑗\lambda^{(t)}_{j}=\lambda^{(t-1)}_{j}+\rho\left(\sum_{k=1}^{M}{h_{j}^{k}}^{(t)% }W_{k}^{(t)}-z^{(t)}_{j}\right).italic_λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = italic_λ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_ρ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) . (6)

(4) Sever updates the heads {Wk}subscriptnormal-Wnormal-k\{W_{k}\}{ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }. Each head Wk,k[M]subscript𝑊𝑘for-all𝑘delimited-[]𝑀W_{k},\forall k\in[M]italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , ∀ italic_k ∈ [ italic_M ] of the server is then updated:

Wk(t+1)=argminWkβkk(Wk)+1bjB(t)λj(t)hjk(t)Wksuperscriptsubscript𝑊𝑘𝑡1subscript𝑊𝑘argminsubscript𝛽𝑘subscript𝑘subscript𝑊𝑘1𝑏subscript𝑗𝐵𝑡superscriptsuperscriptsubscript𝜆𝑗𝑡topsuperscriptsuperscriptsubscript𝑗𝑘𝑡subscript𝑊𝑘\displaystyle W_{k}^{(t+1)}=\underset{W_{k}}{\operatorname{argmin}}\quad\beta_% {k}\mathcal{R}_{k}(W_{k})+\frac{1}{b}\sum_{j\in B(t)}{\lambda_{j}^{(t)}}^{\top% }{h_{j}^{k}}^{(t)}W_{k}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = start_UNDERACCENT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_UNDERACCENT start_ARG roman_argmin end_ARG italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT
+jB(t)ρ2bi[M],ikhji(t)Wi(t)+hjk(t)Wkzj(t)F2.subscript𝑗𝐵𝑡𝜌2𝑏superscriptsubscriptnormsubscriptformulae-sequence𝑖delimited-[]𝑀𝑖𝑘superscriptsuperscriptsubscript𝑗𝑖𝑡superscriptsubscript𝑊𝑖𝑡superscriptsuperscriptsubscript𝑗𝑘𝑡subscript𝑊𝑘superscriptsubscript𝑧𝑗𝑡𝐹2\displaystyle+\sum\limits_{j\in B(t)}\frac{\rho}{2b}\left\|\sum\limits_{i\in[M% ],i\neq k}{h_{j}^{i}}^{(t)}{W_{i}}^{(t)}+{h_{j}^{k}}^{(t)}W_{k}-{z_{j}}^{(t)}% \right\|_{F}^{2}.+ ∑ start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT divide start_ARG italic_ρ end_ARG start_ARG 2 italic_b end_ARG ∥ ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_M ] , italic_i ≠ italic_k end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (7)

For squared 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT regularizer \mathcal{R}caligraphic_R, we can solve Wk(t+1)superscriptsubscript𝑊𝑘𝑡1W_{k}^{(t+1)}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT in an inexact way to save the computation by one step of SGD with the objective of Eq. III-B.

(5) Communication from server to client. After the updates in Eq. III-B, we define a residual variable sjk(t+1)superscriptsuperscriptsubscript𝑠𝑗𝑘𝑡1{s_{j}^{k}}^{(t+1)}italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT for each sample jB(t)𝑗𝐵𝑡j\in B(t)italic_j ∈ italic_B ( italic_t ) of k𝑘kitalic_k-th client, which provides supervision for updating local model:

sjk(t)zj(t)i[M],ikhji(t)Wi(t+1)superscriptsuperscriptsubscript𝑠𝑗𝑘𝑡superscriptsubscript𝑧𝑗𝑡subscriptformulae-sequence𝑖delimited-[]𝑀𝑖𝑘superscriptsuperscriptsubscript𝑗𝑖𝑡superscriptsubscript𝑊𝑖𝑡1{s_{j}^{k}}^{(t)}\triangleq{z_{j}}^{(t)}-\sum_{i\in[M],i\neq k}{h_{j}^{i}}^{(t% )}{W_{i}}^{(t+1)}italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ≜ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_M ] , italic_i ≠ italic_k end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT (8)

The server sends the dual variables {λj(t)}jB(t)subscriptsubscriptsuperscript𝜆𝑡𝑗𝑗𝐵𝑡\{\lambda^{(t)}_{j}\}_{j\in B(t)}{ italic_λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT and the residual variables {sjk(t)}jB(t)subscriptsuperscriptsuperscriptsubscript𝑠𝑗𝑘𝑡𝑗𝐵𝑡\{{s_{j}^{k}}^{(t)}\}_{j\in B(t)}{ italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT of all samples, as well as the corresponding head Wk(t+1)superscriptsubscript𝑊𝑘𝑡1W_{k}^{(t+1)}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT to each client k𝑘kitalic_k.

(6) Client updates local model parameters θksubscriptnormal-θnormal-k\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Finally, every client k𝑘kitalic_k locally updates the model parameters θksubscript𝜃𝑘\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT as follows:

θk(t+1)=superscriptsubscript𝜃𝑘𝑡1absent\displaystyle\theta_{k}^{(t+1)}=italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = argminθkβkk(θk)+1bjB(t)λj(t)f(xjk;θk)Wk(t+1)subscript𝜃𝑘argminsubscript𝛽𝑘subscript𝑘subscript𝜃𝑘1𝑏subscript𝑗𝐵𝑡superscriptsuperscriptsubscript𝜆𝑗𝑡top𝑓superscriptsubscript𝑥𝑗𝑘subscript𝜃𝑘superscriptsubscript𝑊𝑘𝑡1\displaystyle\underset{\theta_{k}}{\operatorname{argmin}}\quad\beta_{k}% \mathcal{R}_{k}(\theta_{k})+\frac{1}{b}\sum_{j\in B(t)}{\lambda_{j}^{(t)}}^{% \top}f(x_{j}^{k};\theta_{k}){W_{k}^{(t+1)}}start_UNDERACCENT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_UNDERACCENT start_ARG roman_argmin end_ARG italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT
+ρ2bjB(t)sjk(t)f(xjk;θk)Wk(t+1)F2.𝜌2𝑏subscript𝑗𝐵𝑡superscriptsubscriptnormsuperscriptsuperscriptsubscript𝑠𝑗𝑘𝑡𝑓superscriptsubscript𝑥𝑗𝑘subscript𝜃𝑘superscriptsubscript𝑊𝑘𝑡1𝐹2\displaystyle+\frac{\rho}{2b}\sum_{j\in B(t)}\left\|{s_{j}^{k}}^{(t)}-f(x_{j}^% {k};\theta_{k}){W_{k}^{(t+1)}}\right\|_{F}^{2}.+ divide start_ARG italic_ρ end_ARG start_ARG 2 italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT ∥ italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (9)

Due to the nonconvexity of the loss function of DNNs, we use τ𝜏\tauitalic_τ local steps of SGD to update the local model at each round with the objective of Eq. III-B. We note that multiple local updates of Eq. III-B enabled by ADMM lead to better local models at each communication round compared to gradient-based methods, thus VIMADMM requires fewer communication rounds to converge as we will show in Section VI-A. These six steps of VIMADMM are summarized in Algorithm 1.

Note that ADMM auxiliary variables {zj}subscript𝑧𝑗\{z_{j}\}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } and dual variables {λj}subscript𝜆𝑗\{\lambda_{j}\}{ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } are only used during the training phase to help update server heads and local models. Therefore, in the test phase, for any sample x𝑥xitalic_x, the server directly uses the trained multiple heads to make prediction y^=k=1MhkWk^𝑦superscriptsubscript𝑘1𝑀superscript𝑘subscript𝑊𝑘\hat{y}=\sum_{k=1}^{M}h^{k}W_{k}over^ start_ARG italic_y end_ARG = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_h start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT.

Algorithm 1 VIMADMM ( with user-level differential privacy)
1:  Input:number of communication rounds T𝑇Titalic_T, number of clients M𝑀Mitalic_M, number of training samples N𝑁Nitalic_N, batch size b𝑏bitalic_b , input features {{xj1}j=1N,{xj2}j=1N,,{xjM}j=1N}superscriptsubscriptsuperscriptsubscript𝑥𝑗1𝑗1𝑁superscriptsubscriptsuperscriptsubscript𝑥𝑗2𝑗1𝑁superscriptsubscriptsuperscriptsubscript𝑥𝑗𝑀𝑗1𝑁\{\{x_{j}^{1}\}_{j=1}^{N},\{x_{j}^{2}\}_{j=1}^{N},\ldots,\{x_{j}^{M}\}_{j=1}^{% N}\}{ { italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT , { italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT , … , { italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT }, the labels {yj}j=1Nsuperscriptsubscriptsubscript𝑦𝑗𝑗1𝑁\{y_{j}\}_{j=1}^{N}{ italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT, local model {θk}k=1Msuperscriptsubscriptsubscript𝜃𝑘𝑘1𝑀\{\theta_{k}\}_{k=1}^{M}{ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT; linear heads {Wk}k=1Msuperscriptsubscriptsubscript𝑊𝑘𝑘1𝑀\{W_{k}\}_{k=1}^{M}{ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT; auxiliary variables {zj}j=1Nsuperscriptsubscriptsubscript𝑧𝑗𝑗1𝑁\{z_{j}\}_{j=1}^{N}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT; dual variables {λj}j=1Nsuperscriptsubscriptsubscript𝜆𝑗𝑗1𝑁\{\lambda_{j}\}_{j=1}^{N}{ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT; noise parameter σ𝜎\sigmaitalic_σ, clipping constant C𝐶Citalic_C
2:  for communication round t[T]𝑡delimited-[]𝑇t\in[T]italic_t ∈ [ italic_T ] do
3:     Server samples a set of data indices B(t)𝐵𝑡B(t)italic_B ( italic_t ) with |B(t)|=b𝐵𝑡𝑏|B(t)|=b| italic_B ( italic_t ) | = italic_b
4:     for client k[M]𝑘delimited-[]𝑀k\in[M]italic_k ∈ [ italic_M ] do
5:        generates a local training batch {xjk}jB(t)subscriptsuperscriptsubscript𝑥𝑗𝑘𝑗𝐵𝑡\{x_{j}^{k}\}_{j\in B(t)}{ italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT
6:        computes local embeddings \medmath{hjk(t)f(xjk;θk)}jB(t)\medmathsubscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡𝑓superscriptsubscript𝑥𝑗𝑘subscript𝜃𝑘𝑗𝐵𝑡\medmath{\{{h_{j}^{k}}^{(t)}\leftarrow f(x_{j}^{k};\theta_{k})\}_{j\in B(t)}}{ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ← italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT
7:        clips and perturbs local embedding matrix
8:        \medmath{hjk(t)}jB(t)𝙲𝚕𝚒𝚙({hjk(t)}jB(t),C)+𝒩(0,σ2C2)\medmathsubscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡𝑗𝐵𝑡𝙲𝚕𝚒𝚙subscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡𝑗𝐵𝑡𝐶𝒩0superscript𝜎2superscript𝐶2\medmath{\{{h_{j}^{k}}^{(t)}\}_{j\in B(t)}\leftarrow\mathtt{Clip}\left(\{{h_{j% }^{k}}^{(t)}\}_{j\in B(t)},C\right)+\mathcal{N}\left(0,\sigma^{2}C^{2}\right)}{ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT ← typewriter_Clip ( { italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT , italic_C ) + caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
9:        sends local embeddings {hjk(t)}jB(t)subscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡𝑗𝐵𝑡\{{h_{j}^{k}}^{(t)}\}_{j\in B(t)}{ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT to the server
10:     Server updates auxiliary variables {zj(t)}jB(t)subscriptsuperscriptsubscript𝑧𝑗𝑡𝑗𝐵𝑡\{z_{j}^{(t)}\}_{j\in B(t)}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT by Eq. 5
11:     Server updates dual variables {λj(t)}jB(t)subscriptsubscriptsuperscript𝜆𝑡𝑗𝑗𝐵𝑡\{\lambda^{(t)}_{j}\}_{j\in B(t)}{ italic_λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT by Eq. 6
12:     Server updates linear heads {Wk(t+1)}k[M]subscriptsuperscriptsubscript𝑊𝑘𝑡1𝑘delimited-[]𝑀\{W_{k}^{(t+1)}\}_{k\in[M]}{ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_k ∈ [ italic_M ] end_POSTSUBSCRIPT by Eq. III-B
13:     Server computes residual variables {sjk(t)}jB(t),k[M]subscriptsuperscriptsuperscriptsubscript𝑠𝑗𝑘𝑡formulae-sequence𝑗𝐵𝑡𝑘delimited-[]𝑀\{{s_{j}^{k}}^{(t)}\}_{j\in B(t),k\in[M]}{ italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) , italic_k ∈ [ italic_M ] end_POSTSUBSCRIPT by Eq. 8
14:     Server sends {λj(t)}jB(t)subscriptsubscriptsuperscript𝜆𝑡𝑗𝑗𝐵𝑡\{\lambda^{(t)}_{j}\}_{j\in B(t)}{ italic_λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT , {sjk(t)}jB(t)subscriptsuperscriptsuperscriptsubscript𝑠𝑗𝑘𝑡𝑗𝐵𝑡\{{s_{j}^{k}}^{(t)}\}_{j\in B(t)}{ italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT and corresponding Wk(t+1)superscriptsubscript𝑊𝑘𝑡1W_{k}^{(t+1)}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT to each client k,k[M]𝑘for-all𝑘delimited-[]𝑀k,\forall k\in[M]italic_k , ∀ italic_k ∈ [ italic_M ]
15:     for client k[M]𝑘delimited-[]𝑀k\in[M]italic_k ∈ [ italic_M ] do
16:        updates local model θk(t+1)superscriptsubscript𝜃𝑘𝑡1\theta_{k}^{(t+1)}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT for τ𝜏\tauitalic_τ steps by Eq. III-B via SGD
Extending VIMADMM to multiple non-linear heads

The server can learn non-linear transformation from the collected embeddings to uxiliary variables {zj}subscript𝑧𝑗\{z_{j}\}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } by employing multiple non-linear heads. To achieve this, we rewrite all f(xjk,θk)Wk𝑓subscriptsuperscript𝑥𝑘𝑗subscript𝜃𝑘subscript𝑊𝑘f(x^{k}_{j},\theta_{k})W_{k}italic_f ( italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT as a more generalized form g(f(xjk,θk),Wk)𝑔𝑓subscriptsuperscript𝑥𝑘𝑗subscript𝜃𝑘subscript𝑊𝑘g(f(x^{k}_{j},\theta_{k}),W_{k})italic_g ( italic_f ( italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) from Eq. III-B to Eq. III-B. Here, g𝑔gitalic_g can be a non-linear function parameterized by Wksubscript𝑊𝑘W_{k}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Consequently, the prediction for each sample j𝑗jitalic_j becomes y^j=k=1Mg(f(xjk,θk),Wk)subscript^𝑦𝑗superscriptsubscript𝑘1𝑀𝑔𝑓subscriptsuperscript𝑥𝑘𝑗subscript𝜃𝑘subscript𝑊𝑘\hat{y}_{j}=\sum_{k=1}^{M}g(f(x^{k}_{j},\theta_{k}),W_{k})over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_g ( italic_f ( italic_x start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ). In this context, VIMADMM still aims to minimize the loss between zjsubscript𝑧𝑗z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and ground-truth label yjsubscript𝑦𝑗y_{j}italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, as well as the difference between zjsubscript𝑧𝑗z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and y^jsubscript^𝑦𝑗\hat{y}_{j}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT during training in Eq. III-B.

III-C VFL without Model Splitting

Setup

Recall the VFL without model splitting setting described in § III-A. Let g𝑔gitalic_g parameterized by ψksubscript𝜓𝑘{\psi_{k}}italic_ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT be the local model (i.e., whole model) of client k𝑘kitalic_k, which outputs local logits ojk=g(xjk;ψk)dcsuperscriptsubscript𝑜𝑗𝑘𝑔superscriptsubscript𝑥𝑗𝑘subscript𝜓𝑘superscriptsubscript𝑑𝑐o_{j}^{k}=g(x_{j}^{k};{\psi_{k}})\in\mathbb{R}^{d_{c}}italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = italic_g ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUPERSCRIPT for each local feature xjksuperscriptsubscript𝑥𝑗𝑘x_{j}^{k}italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. The clients and the server aim to jointly solve the problem

min{ψk}k=1M1Nj=1N({oj1,,ojM},yj)+βkk=1Mk(ψk),k[M]subscriptsuperscriptsubscriptsubscript𝜓𝑘𝑘1𝑀1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑜𝑗1superscriptsubscript𝑜𝑗𝑀subscript𝑦𝑗subscript𝛽𝑘superscriptsubscript𝑘1𝑀subscript𝑘subscript𝜓𝑘for-all𝑘delimited-[]𝑀\min_{\{{\psi_{k}}\}_{k=1}^{M}}\frac{1}{N}\sum_{j=1}^{N}\ell(\{o_{j}^{1},% \ldots,o_{j}^{M}\},y_{j})+\beta_{k}\sum_{k=1}^{M}\mathcal{R}_{k}({\psi_{k}}),% \forall k\in[M]roman_min start_POSTSUBSCRIPT { italic_ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_ℓ ( { italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT } , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_ψ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , ∀ italic_k ∈ [ italic_M ] (10)
VIMADMM-J

In exisiting VFL frameworks, the server averages the local logits as final prediction k=iMojksuperscriptsubscript𝑘𝑖𝑀superscriptsubscript𝑜𝑗𝑘\sum_{k=i}^{M}o_{j}^{k}∑ start_POSTSUBSCRIPT italic_k = italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT, but these methods also suffers from the high communication cost by sending the gradients w.r.t. local logits to each client at each training step of the local model fu2022usenix . To solve this problem with our VIM framework, we adapt VIMADMM to the without model splitting setting and propose VIMADMM-J, where each feature extractor θksubscript𝜃𝑘\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and each head Wksubscript𝑊𝑘W_{k}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT are held by the corresponding client k𝑘kitalic_k, and are always updated locally. The corresponding Algorithm 3 and detailed description are in Appendix -A.

IV Convergence Analysis for VIMADMM

In this section, we provide the convergence guarantee for VIMADMM, which is non-trivial due to the complexity of the alternative optimization between four sets of parameters {Wk},{θk},{zj},{λj}subscript𝑊𝑘subscript𝜃𝑘subscript𝑧𝑗subscript𝜆𝑗\{W_{k}\},\{\theta_{k}\},\{z_{j}\},\{\lambda_{j}\}{ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT }. To convey the salient ideas of convergence analysis, we consider full batch, i.e., B(t)=[N]𝐵𝑡delimited-[]𝑁B(t)=[N]italic_B ( italic_t ) = [ italic_N ] and use the exact minimization solutions during training (Eq. 56III-B) following hong2016convergence .

We present our main results below and defer formal proofs to Section -B due to space constraints.

Theorem 1.

Assume that 𝚅𝙸𝙼subscript𝚅𝙸𝙼\mathcal{L}_{\mathrm{\texttt{VIM}}}caligraphic_L start_POSTSUBSCRIPT VIM end_POSTSUBSCRIPT is bounded from below, that is e¯:=min{θk},{Wk}𝚅𝙸𝙼({θk},{Wk})>assignnormal-¯𝑒subscriptsubscript𝜃𝑘subscript𝑊𝑘subscript𝚅𝙸𝙼subscript𝜃𝑘subscript𝑊𝑘\underline{e}:=\min_{\{\theta_{k}\},\{W_{k}\}}\mathcal{L}_{\mathrm{\texttt{VIM% }}}(\{\theta_{k}\},\{W_{k}\})>-\inftyunder¯ start_ARG italic_e end_ARG := roman_min start_POSTSUBSCRIPT { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT VIM end_POSTSUBSCRIPT ( { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } ) > - ∞. Assume that (z;)normal-ℓ𝑧normal-⋅\ell(z;\cdot)roman_ℓ ( italic_z ; ⋅ ) is L𝐿Litalic_L-Lipschitz smooth w.r.t z𝑧zitalic_z and ADMMsubscriptnormal-ADMM\mathcal{L}_{\mathrm{ADMM}}caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT loss is strongly convex w.r.t {zj},{Wk},{θk}subscript𝑧𝑗subscript𝑊𝑘subscript𝜃𝑘\{z_{j}\},\{W_{k}\},\{\theta_{k}\}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } , { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } with constant μz,μW,μθsubscript𝜇𝑧subscript𝜇𝑊subscript𝜇𝜃\mu_{z},\mu_{W},\mu_{\theta}italic_μ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT respectively. Assume that the norm of Wksubscript𝑊𝑘W_{k}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is bounded WkσWnormsubscript𝑊𝑘subscript𝜎𝑊\|W_{k}\|\leq\sigma_{W}∥ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ ≤ italic_σ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT, the local model f(;θ)𝑓normal-⋅𝜃f(\cdot;\theta)italic_f ( ⋅ ; italic_θ ) has bounded gradient f(;θ)Lθnormnormal-∇𝑓normal-⋅𝜃subscript𝐿𝜃\|\nabla f(\cdot;\theta)\|\leq L_{\theta}∥ ∇ italic_f ( ⋅ ; italic_θ ) ∥ ≤ italic_L start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and bounded output norm f(;θ)σθnorm𝑓normal-⋅𝜃subscript𝜎𝜃\|f(\cdot;\theta)\|\leq\sigma_{\theta}∥ italic_f ( ⋅ ; italic_θ ) ∥ ≤ italic_σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT. If Algorithm 1 is run, and there exists a ρ𝜌\rhoitalic_ρ satisfying max{L,2L2μz}<ρ<min{μθLθ2σW2,μWσθ2}𝐿2superscript𝐿2subscript𝜇𝑧𝜌subscript𝜇𝜃superscriptsubscript𝐿𝜃2superscriptsubscript𝜎𝑊2subscript𝜇𝑊superscriptsubscript𝜎𝜃2\max\{L,\frac{2L^{2}}{\mu_{z}}\}<\rho<\min\{\frac{\mu_{\theta}}{L_{\theta}^{2}% \sigma_{W}^{2}},\frac{\mu_{W}}{\sigma_{\theta}^{2}}\}roman_max { italic_L , divide start_ARG 2 italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_μ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_ARG } < italic_ρ < roman_min { divide start_ARG italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG , divide start_ARG italic_μ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG }, then we have the following:
(A) ADMMsubscriptnormal-ADMM\mathcal{L}_{\mathrm{ADMM}}caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT loss is monotonically decreasing and lower-bounded:

ADMM({Wk(t+1)},{θk(t+1)},{zj(t+1)},{λj(t+1)})subscriptADMMsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝜆𝑗𝑡1\displaystyle\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{\theta_{k}^{(t+1)% }\},\{z_{j}^{(t+1)}\},\{\lambda_{j}^{(t+1)}\})caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } )
<ADMM({Wk(t)},{θk(t)},{zj(t)},{λj(t)})absentsubscriptADMMsuperscriptsubscript𝑊𝑘𝑡superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡\displaystyle<\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t)}\},\{\theta_{k}^{(t)}\}% ,\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\})< caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } ) (11)
limtADMM({Wk(t)},{θk(t)},{zj(t)},{λj(t)})e¯subscript𝑡subscriptADMMsuperscriptsubscript𝑊𝑘𝑡superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡¯𝑒\displaystyle\lim_{t\rightarrow\infty}\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t)% }\},\{\theta_{k}^{(t)}\},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\})\geq\underline{e}roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } ) ≥ under¯ start_ARG italic_e end_ARG (12)

(B) Let ({Wk*},{θk*},{zj*},{λj*})superscriptsubscriptnormal-Wnormal-ksuperscriptsubscriptnormal-θnormal-ksuperscriptsubscriptnormal-znormal-jsuperscriptsubscriptnormal-λnormal-j(\{W_{k}^{*}\},\{\theta_{k}^{*}\},\{z_{j}^{*}\},\{\lambda_{j}^{*}\})( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT } ) denote any limit points of the sequence ({Wk(t+1)},{θk(t+1)},{zj(t+1)},{λj(t+1)})superscriptsubscriptnormal-Wnormal-knormal-t1superscriptsubscriptnormal-θnormal-knormal-t1superscriptsubscriptnormal-znormal-jnormal-t1superscriptsubscriptnormal-λnormal-jnormal-t1(\{W_{k}^{(t+1)}\},\{\theta_{k}^{(t+1)}\},\{z_{j}^{(t+1)}\},\{\lambda_{j}^{(t+% 1)}\})( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } ) generated by Algorithm 1, then it is stationary:

zj*argminzj(zj;yj)+λj*(k=1Mf(xjk;θk*)Wk*zj) andsuperscriptsubscript𝑧𝑗subscriptsubscript𝑧𝑗subscript𝑧𝑗subscript𝑦𝑗superscriptsuperscriptsubscript𝜆𝑗topsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘superscriptsubscript𝑊𝑘subscript𝑧𝑗 and\displaystyle z_{j}^{*}\in\arg\min_{z_{j}}\ell\left(z_{j};y_{j}\right)+{% \lambda_{j}^{*}}^{\top}\left(\sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{*})W_{k}^{*% }-z_{j}\right)\text{ and }italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ∈ roman_arg roman_min start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ; italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) and
k=1Mf(xjk;θk*)Wk*=zj*,j[N], andformulae-sequencesuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘superscriptsubscript𝑊𝑘superscriptsubscript𝑧𝑗for-all𝑗delimited-[]𝑁 and\displaystyle\sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{*})W_{k}^{*}=z_{j}^{*},% \forall j\in[N],\text{ and }∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT = italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , ∀ italic_j ∈ [ italic_N ] , and
βkk(Wk*)+1Nj=1Nλj*f(xjk;θk*)=0 andsubscript𝛽𝑘subscript𝑘superscriptsubscript𝑊𝑘1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝜆𝑗absenttop𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘0 and\displaystyle\beta_{k}\nabla\mathcal{R}_{k}(W_{k}^{*})+\frac{1}{N}\sum_{j=1}^{% N}\lambda_{j}^{*\top}f(x_{j}^{k};\theta_{k}^{*})=0\text{ and }italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∇ caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * ⊤ end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) = 0 and
βkk(θk*)+1Nj=1Nλj*f(xjk;θk*)Wk*=0,k[M].formulae-sequencesubscript𝛽𝑘subscript𝑘superscriptsubscript𝜃𝑘1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝜆𝑗absenttop𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘superscriptsubscript𝑊𝑘0𝑘delimited-[]𝑀\displaystyle\beta_{k}\nabla\mathcal{R}_{k}(\theta_{k}^{*})+\frac{1}{N}\sum_{j% =1}^{N}\lambda_{j}^{*\top}\nabla f(x_{j}^{k};\theta_{k}^{*})W_{k}^{*}=0,k\in[M].italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∇ caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT = 0 , italic_k ∈ [ italic_M ] . (13)
Proof Sketch.

We obtain Theorem 1 by breaking down the changes of loss ADMMsubscriptADMM\mathcal{L}_{\mathrm{ADMM}}caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT at each round t𝑡titalic_t into the alternatively updates of four components: {λj(t+1)}superscriptsubscript𝜆𝑗𝑡1\{\lambda_{j}^{(t+1)}\}{ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT }, {zj(t+1)}superscriptsubscript𝑧𝑗𝑡1\{z_{j}^{(t+1)}\}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT }, {Wk(t+1)}superscriptsubscript𝑊𝑘𝑡1\{W_{k}^{(t+1)}\}{ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT }, and {θk(t+1)}superscriptsubscript𝜃𝑘𝑡1\{\theta_{k}^{(t+1)}\}{ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT }, respectively. Through our assumptions and the optimality of the minimizers, we demonstrate that the combined loss decreases at each round. Next, to derive Eq. 12, we leverage the Lipschitz continuity of \ellroman_ℓ, the condition ρL𝜌𝐿\rho\geq Litalic_ρ ≥ italic_L, the lower bound of 𝚅𝙸𝙼subscript𝚅𝙸𝙼\mathcal{L}_{\mathrm{\texttt{VIM}}}caligraphic_L start_POSTSUBSCRIPT VIM end_POSTSUBSCRIPT, and the fact that the quadratic loss term in ADMMsubscriptADMM\mathcal{L}_{\mathrm{ADMM}}caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT is non-negative. Finally, by letting t𝑡t\rightarrow\inftyitalic_t → ∞ and examining the optimality conditions of the minimizers, we drive Theorem 1. ∎

Remark. Theorem 1 (A) shows that VIMADMM converges, measured by the monotonically decreasing and convergent loss, and (B) establishes that any limit point is a stationary solution to the problem III-B. Note that we make several assumptions in Theorem 1 to derive the above guarantees, as often made in ADMM analysis hong2016convergence for alternative optimization of multiple sets of variables. Specifically, we follow Hong et al. hong2016convergence to assume convexity, Lipschitz smoothness, and the bounded loss for convergence analysis of VIMADMM. Furthermore, we acknowledge that analyzing the local model can be challenging, given the complexity of DNNs, so we introduce an additional assumption that bounds the norm of the gradient and the output of local models, which could be practical when the model training exhibits stability. Similarly, we assume a bounded norm for the server model. By incorporating these assumptions, we aim to offer a more comprehensive understanding of the convergence behavior of VIMADMM.

V client-level Differentially Private VIM

While the raw features and local models are kept locally without sharing in VFL, sharing the model outputs such as local embeddings or predictions during the training process might also leak sensitive client information mahendran2015understanding ; papernot2018sok . Therefore, we aim to further protect the privacy of the local feature set Xksubscript𝑋𝑘X_{k}italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT of each client k𝑘kitalic_k against potential adversaries such as honest-but-curious server and clients, and external attackers.

Threat Model

We consider different types of potential adversaries based on their capabilities: (1) Honest-but-curious server and clients: they follow the VFL protocol correctly but might try to infer private client information from information exchanged between the clients and server tran2023privacy . (2) External attackers: they are not directly involved in the VFL process but may observe the predicted results from the server and the communicated information during training, trying to extract private client information. Regarding attack scenarios, these attackers may conduct membership inference attacks shokri2017membership to determine whether the data of a specific VFL client was included during training. Our goal is to protect the local data of each client against potential attackers so that the attacker cannot make significant inferences about any single client’s data. Next, we provide privacy-preserving mechanisms to satisfy client-level differential privacy (DP) guarantees.

Client-level DP

We begin with the (ϵ,δ)italic-ϵ𝛿(\epsilon,\delta)( italic_ϵ , italic_δ )-DP definition, which guarantees that the change in a randomized algorithm’s output distribution caused by an input difference is bounded.

Definition 1 ((ϵ,δ)italic-ϵ𝛿(\epsilon,\delta)( italic_ϵ , italic_δ )-DP dwork2014algorithmic ).

A randomized algorithm :𝒳nΘ:maps-tosuperscript𝒳𝑛Θ\mathcal{M}:\mathcal{X}^{n}\mapsto\Thetacaligraphic_M : caligraphic_X start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ↦ roman_Θ is (ϵ,δ)italic-ϵ𝛿(\epsilon,\delta)( italic_ϵ , italic_δ )-DP if for every pair of neighboring datasets X,X𝒳n𝑋superscript𝑋superscript𝒳𝑛X,X^{\prime}\in\mathcal{X}^{n}italic_X , italic_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_X start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT (i.e., differing only by one sample), and every possible (measurable) output set EΘ𝐸ΘE\subseteq\Thetaitalic_E ⊆ roman_Θ the following inequality holds: Pr[(X)E]eϵPr[(X)E]+δPr𝑋𝐸superscript𝑒italic-ϵPrsuperscript𝑋𝐸𝛿\operatorname{Pr}[\mathcal{M}(X)\in E]\leq e^{\epsilon}\operatorname{Pr}\left[% \mathcal{M}\left(X^{\prime}\right)\in E\right]+\deltaroman_Pr [ caligraphic_M ( italic_X ) ∈ italic_E ] ≤ italic_e start_POSTSUPERSCRIPT italic_ϵ end_POSTSUPERSCRIPT roman_Pr [ caligraphic_M ( italic_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∈ italic_E ] + italic_δ.

Next, we introduce client-level (ϵ,δ)italic-ϵ𝛿(\epsilon,\delta)( italic_ϵ , italic_δ )-DP mcmahan2018learning , which guarantees that the algorithm’s output would not be changed much by differing one client.

Definition 2 (Client-level (ϵ,δ)italic-ϵ𝛿(\epsilon,\delta)( italic_ϵ , italic_δ )-DP mcmahan2018learning ).

Let X𝑋Xitalic_X and Xsuperscript𝑋X^{\prime}italic_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT be adjacent datasets if they differ by all samples associated with a single client222We consider the “zero-out” notion for the neighboring dataset, following ponomareva2023dp : datasets are adjacent if any one client’s local data is replaced with the special “zero” data (exactly zero for numeric data). . The mechanism \mathcal{M}caligraphic_M satisfies client-level (ϵ,δ)italic-ϵ𝛿(\epsilon,\delta)( italic_ϵ , italic_δ )-DP if it meets Definition 1 with X𝑋Xitalic_X and Xsuperscript𝑋X^{\prime}italic_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT as adjacent datasets.

Remark. (1) Client-level DP protects the privacy of all local samples of each client mcmahan2018learning . The neighboring datasets in client-level DP are defined between client-adjacent datasets, denoted by X={X1,,Xk,,XM}𝑋subscript𝑋1normal-…subscript𝑋𝑘normal-…subscript𝑋𝑀X=\{X_{1},\ldots,X_{k},\ldots,X_{M}\}italic_X = { italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT } and X={X1,,Xk,,XM}superscript𝑋normal-′subscript𝑋1normal-…superscriptsubscript𝑋𝑘normal-′normal-…subscript𝑋𝑀X^{\prime}=\{X_{1},\ldots,X_{k}^{\prime},\ldots,X_{M}\}italic_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = { italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT } for some client k𝑘kitalic_k. The algorithm’s output should not change significantly if a single client’s entire dataset is changed. (2) User-level DP is another prevalent privacy notion in FL literature, and its definition depends on how “user” is interpreted. If a “user” denotes a client/data owner in FL, then user-level DP aligns with client-level DP geyer2017differentially ; mcmahan2018learning ; agarwal2018cpsgd . Additionally, a “user” in VFL might refer to an entity contributing different samples with partial features, where M𝑀Mitalic_M VFL clients hold disjoint partial features {xj1,xj2,,xjM}superscriptsubscript𝑥𝑗1superscriptsubscript𝑥𝑗2normal-…superscriptsubscript𝑥𝑗𝑀\{x_{j}^{1},x_{j}^{2},\dots,x_{j}^{M}\}{ italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT } about the same user j𝑗jitalic_j cohendifferentially ; ranbaduge2022differentially . For example, different healthcare providers (VFL clients such as dental clinics and pharmacies) can hold different features about the same patient (user). In such cases, neighboring datasets are defined as those differing by all local samples associated with one user across all VFL client datasets. In this work, we focus on client-level DP due to its widespread adoption in FL mcmahan2018learning .

Since the only shared information from clients is their local outputs, denoted as 𝒜ksubscript𝒜𝑘\mathcal{A}_{k}caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT for k𝑘kitalic_k-th client, we leverage the following DP mechanisms to perturb the local outputs of each client k𝑘kitalic_k at every round t𝑡titalic_t: (1) clip the whole local output matrix (either embeddings 𝒜k(t)={hjk(t)}jB(t)superscriptsubscript𝒜𝑘𝑡subscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡𝑗𝐵𝑡{\mathcal{A}_{k}}^{(t)}=\{{h_{j}^{k}}^{(t)}\}_{j\in B(t)}caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = { italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT or logits 𝒜k(t)={ojk(t)}jB(t)superscriptsubscript𝒜𝑘𝑡subscriptsuperscriptsuperscriptsubscript𝑜𝑗𝑘𝑡𝑗𝐵𝑡{\mathcal{A}_{k}}^{(t)}=\{{o_{j}^{k}}^{(t)}\}_{j\in B(t)}caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = { italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT ) with threshold C𝐶Citalic_C such that the 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-sensitivity for each client is upper bounded by C𝐶Citalic_C. That is, 𝙲𝚕𝚒𝚙(𝒜k,C)=𝒜kmin(1,C𝒜kF)𝙲𝚕𝚒𝚙subscript𝒜𝑘𝐶subscript𝒜𝑘1𝐶subscriptnormsubscript𝒜𝑘𝐹\mathtt{Clip}\left(\mathcal{A}_{k},C\right)=\mathcal{A}_{k}\cdot\min\left(1,% \frac{C}{\|\mathcal{A}_{k}\|_{F}}\right)typewriter_Clip ( caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_C ) = caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⋅ roman_min ( 1 , divide start_ARG italic_C end_ARG start_ARG ∥ caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT end_ARG ) where F\|\cdot\|_{F}∥ ⋅ ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT is the Frobenius norm333The Frobenius norm for a m×n𝑚𝑛m\times nitalic_m × italic_n matrix A𝐴Aitalic_A is AF=i=1mj=1n|aij|2subscriptnorm𝐴𝐹superscriptsubscript𝑖1𝑚superscriptsubscript𝑗1𝑛superscriptsubscript𝑎𝑖𝑗2\|A\|_{F}=\sqrt{\sum_{i=1}^{m}\sum_{j=1}^{n}\left|a_{ij}\right|^{2}}∥ italic_A ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT = square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT | italic_a start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG. (2) Then we add scalar Gaussian noise independently to each cell of the matrix. The noise is sampled from 𝒩(0,σ2C2)𝒩0superscript𝜎2superscript𝐶2\mathcal{N}(0,\sigma^{2}C^{2})caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ), which is proportional to C𝐶Citalic_C and can randomize the local output matrix of each client: 𝒜k𝙲𝚕𝚒𝚙(𝒜k,C)+𝒩(0,σ2C2)subscript𝒜𝑘𝙲𝚕𝚒𝚙subscript𝒜𝑘𝐶𝒩0superscript𝜎2superscript𝐶2\mathcal{A}_{k}\leftarrow\mathtt{Clip}\left(\mathcal{A}_{k},C\right)+\mathcal{% N}\left(0,\sigma^{2}C^{2}\right)caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ← typewriter_Clip ( caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_C ) + caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ). Based on the above modification to Algorithm 1 and 3, we now provide their privacy guarantee in Theorem 2.

Theorem 2.

Given a total of M𝑀Mitalic_M clients, T𝑇Titalic_T communication rounds, clipping threshold C𝐶Citalic_C and noise level σ𝜎\sigmaitalic_σ, DP versions of Algorithm 13 satisfy client-level (Tα2σ2+logα1αlogδ+logαα1,δ)𝑇𝛼2superscript𝜎2𝛼1𝛼𝛿𝛼𝛼1𝛿(\frac{T\alpha}{2\sigma^{2}}+\log\frac{\alpha-1}{\alpha}-\frac{\log\delta+\log% \alpha}{\alpha-1},\delta)( divide start_ARG italic_T italic_α end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + roman_log divide start_ARG italic_α - 1 end_ARG start_ARG italic_α end_ARG - divide start_ARG roman_log italic_δ + roman_log italic_α end_ARG start_ARG italic_α - 1 end_ARG , italic_δ )-DP for any α>1𝛼1\alpha>1italic_α > 1 and 0<δ<10𝛿10<\delta<10 < italic_δ < 1.

Proof Sketch. We derive the privacy guarantee using Rényi Differential Privacy (RDP) mironov2017renyi as a bridge. We first leverage the RDP guarantee for the Gaussian mechanism mironov2017renyi to analyze the privacy cost for one communication round under local output perturbation. Then we use RDP Composition property mironov2017renyi to accumulate the privacy costs over T𝑇Titalic_T communication rounds. Finally, we convert client-level RDP guarantee into client-level DP guarantee balle2020hypothesis . Detailed proofs are deferred to Section -C.

Remark. Since DP mechanisms (i.e., clipping and noise addition), are applied to each client’s outputs (i.e., embedding or logits matrix) locally, these local outputs satisfy client-level local DP, protecting against privacy attacks from other clients, server or external attackers. That is, by observing the local outputs matrix of one client, other parties cannot determine the presence of that client’s actual training data. The concatenated output matrix from all clients satisfies the same client-level DP guarantee based on DP parallel composition mcsherry2009privacy , due to non-overlapping nature of local data among clients.

Note that the aforementioned DP mechanisms do not protect the privacy of labels held by server. Therefore, we separately use state-of-the-art label DP mechanism malek2021antipodes to protects server’s label privacy via label perturbing, and conduct empirical evaluations of our method under label DP in Section VI-B2.

VI Experiments

We conduct extensive experiments on four VFL datasets. We show that our proposed framework VIM achieves significantly faster convergence and higher accuracy than SOTA (Section VI-A), maintains higher utility under client-level DP and label DP (Section VI-B), and enables client-level explainability (Section VI-C).

VI-1 Data and Models

We consider classification tasks on four datasets: MNIST lecun-mnisthandwrittendigit-2010 , CIFAR cifar , multi-modality dataset NUS-WIDE with image and textual features chua2009nus , and multi-view dataset ModelNet40 su2018deeper .

  • MNIST lecun-mnisthandwrittendigit-2010 contains images with handwritten digits. We create the VFL scenario by splitting the input features evenly by rows for 14 clients. We use a fully connected model of two linear layers with ReLU activations as the local model.

  • CIFAR cifar contains colour images. We split each image into patches for 9 clients. We use a standard CNN architecture from the PyTorch library 444https://github.com/pytorch/opacus as the local model.

  • NUS-WIDE chua2009nus is a multi-modality dataset with 634 low-level image features and 1000 textual tag features. We distribute image features to 2 clients (300 dim and 334 dim), and text features to 2 clients (500 dim and 500 dim). We use a fully connected model of two linear layers with ReLU activations as the local model.

  • ModelNet40 su2018deeper is a multi-view image dataset, containing the shaded images from 12 views for the same objects. We use 4 views and distribute them to 4 clients respectively. We use ResNet-18 he2016deep as the local model.

We split each dataset into the train, validation, and test sets. See Table II for more details about the number of samples and the number of classes for each dataset.

TABLE II: Dataset description.
Dataset ##\## features ##\## classes ##\## clients ##\## samples
dcsubscript𝑑𝑐d_{c}italic_d start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT M𝑀Mitalic_M train validation test
MNIST 28 ×\times× 28 10 14 54000 6000 10000
CIFAR 32 ×\times× 32 ×\times× 3 10 9 45000 5000 10000
NUS-WIDE 1634 5 4 54000 6000 10000
ModelNet40 224 ×\times× 224 ×\times× 3 ×12absent12\times 12× 12 40 4 8877 966 2468
MNIST CIFAR NUS-WIDE ModelNet40

W/ Model Split

Refer to caption Refer to caption Refer to caption Refer to caption

W/o Model Split

Refer to caption Refer to caption Refer to caption Refer to caption
Figure 1: Test accuracy of VFL methods under with model (first row) and without splitting (second row) settings on four datasets. Our methods (VIMADMM and VIMADMM-J) outperforms baselines due to multiple local updates enabled by ADMM (τ>1𝜏1\tau>1italic_τ > 1). Compared with FedBCD under different number of local steps τ𝜏\tauitalic_τ, VIMADMM also achieves faster convergence and higher accuracy, which shows that the strategic utilization of ADMM-related variables in VIMADMM is more effective than the stale partial gradient in FedBCD for local updates.

To prevent over-fitting (due to the potential over-parameterization with the large number of model parameters from all clients and server as a global model), we adopt standard stopping criteria, i.e., stop training when the model converges or the validation accuracy starts to drop more than 2%percent22\%2 %. More details about setups and hyperparameters are in Section -D.

VI-2 Baselines

We (1) compare VIMADMM with VAFL chen2020vafl Split Learning vepakomma2018split , and FedBCD liu2022fedbcd under w/ model splitting setting; (2) compare VIMADMM-J with FDML hu2019fdml under w/o model splitting setting. Particularly, in VAFL, the server aggregates local embeddings using their linear combination with learnable aggregation weights, and subsequently use these aggregated embeddings as input for the server model. Both Split Learning and FedBCD utilize concatenated local embeddings as server model input. Notably, in VAFL and Split Learning, the clients only perform one step of local update based the partial gradients from the server. Conversely, FedBCD employs the same (stale) partial gradients for τ𝜏\tauitalic_τ local updates. In FDML, the server averages local logits, and sends aggregated logits back to clients at eatch communication round. The clients, who owns the copies of labels, can calculate the local gradient and execute one step of local update. Our empirical findings suggest that our ADMM-based methods outperform the aforementioned methods, due to the multiple local updates that utilize ADMM-related variables.

For fair comparisons, we use the same local models for all methods. Under w/ model splitting setting, owing to the strong feature extraction power of local DNN models, we utilize the linear model as server model by default. Additionally, we evaluate all methods with the non-linear server model, as detailed in Section VI-A2.

We further compare the utility of various VFL methods under differential privacy. Existing VFL frameworks (see Table I) focus on sample-level DP chen2020vafl ; hu2019fdml ; hu2019learning ; tran2023privacy ; cohendifferentially ; ranbaduge2022differentially , where neighboring datasets are defined as those differing by a single sample in a client’s local dataset. In particular, VAFL chen2020vafl adds random noise to the output of each local embedding convolutional layer; VFL-PBM tran2023privacy quantizes local embeddings into differentially private integer vectors; FDML hu2019fdml and Linear-ADMM hu2019learning add noise to local outputs. However, these methods lack exact privacy budget evaluations, providing only empirical utility under different levels of noise. Additionally, ranbaduge2022differentially perturbs local model weights to satisfy DP. However, it requires bounding the sensitivity of each layer’s weights in the local model. To enable a fair comparison of VFL methods under DP guarantees, we evaluate all methods through our proposed DP mechanisms with perturbed local outputs to satisfy client-level (ϵ,δ)italic-ϵ𝛿(\epsilon,\delta)( italic_ϵ , italic_δ )-DP guarantee. Notably, a mechanism satisfying (ϵ,δ)italic-ϵ𝛿(\epsilon,\delta)( italic_ϵ , italic_δ ) client-level DP also satisfies (ϵ,δ)italic-ϵ𝛿(\epsilon,\delta)( italic_ϵ , italic_δ ) sample-level DP based on their definitions. Since client-level DP offers stronger privacy protection, it has gained widespread adoption in FL geyer2017differentially ; mcmahan2018learning ; agarwal2018cpsgd ; bhowmick2018protection ; xie2023unraveling . Furthermore, we evaluate all VFL methods using label DP mechanism ALIBI malek2021antipodes to separately satisfy the ϵitalic-ϵ\epsilonitalic_ϵ-label DP guarantee. We report the averaged results of three times of experiments with different random seeds.

VI-A Evaluation on Vanilla VFL

In this section, we evaluate the ADMM-based methods and baselines in terms of convergence rate, model performance and communication costs. Also, we show the generality of VIMADMM under non-linear server heads, and study VIMADMM performance under different ADMM penalty factor ρ𝜌\rhoitalic_ρ.

VI-A1 Convergence rates and model performance

Figure 1 shows the convergence rates of all methods, where two VIM algorithms consistently outperform baselines. Concretely, (1) our ADMM-based methods converge faster and achieve higher accuracy than gradient-based baselines, especially on CIFAR. This is because the multiple local updates enabled by ADMM lead to higher-quality local models at each round, thereby speeding up the convergence. (2) VIMADMM outperforms FedBCD under various local steps. This superiority can be attributed to the use of ADMM-related variables for local updates τ𝜏\tauitalic_τ in VIMADMM, which is more effective than stale partial gradients in FedBCD. (3) When # of local steps τ𝜏\tauitalic_τ is larger, ADMM-based methods converge faster as the local models can be trained better with more local updates at each round.

MNIST CIFAR NUS-WIDE ModelNet40
Refer to caption Refer to caption Refer to caption Refer to caption
Figure 2: Performance comparison when the server has the non-linear MLP model. ADMM-based method still outperforms other baselines under general architectures with the non-linear server model.

Moreover, we empirically compare VIMADMM with Linear-ADMM hu2019learning . While both VFL methods are rooted in ADMM, we propose new VFL optimization objective and algorithm with multiple heads that enable the ADMM decomposition for practical DNN training under model splitting. Results in Table III show that VIMADMM consistently outperforms Linear-ADMM on MNIST and NUS-WIDE. Compared to DNNs enabled by VIMADMM, the limitations of logistic regression in Linear-ADMM would be more evident when applied to more complex datasets like CIFAR and ModelNet40.

TABLE III: Performance comparison between VIMADMM and Linear-ADMM hu2019learning . VIMADMM achieves higher accuracy.
MNIST NUS-WIDE
VIMADMM 97.13 88.51
Linear-ADMM 91.65 84.63
TABLE IV: Communication costs (in megabytes) comparison. VIMADMM requires lower communication costs per round than baselines under w/ model splitting setting. ADMM-based methods require lower communication costs to achieve the same target accuracy performance. VFL setup Method Comm. costs per round Comm. costs to reach target accuracy performance Each client Server to Total MNIST CIFAR NUS-WIDE ModelNet40 to server each client (96.0%absentpercent96.0\geq 96.0\%≥ 96.0 %) (65.0%absentpercent65.0\geq 65.0\%≥ 65.0 %) (85.0%absentpercent85.0\geq 85.0\%≥ 85.0 %) (89.0%absentpercent89.0\geq 89.0\%≥ 89.0 %) w/ model splitting VAFL 0.23 0.23 0.46 4520.12 5381.40 397.37 134.96 Split Learning 0.23 0.23 0.46 1738.51 4082.44 198.69 84.35 FedBCD 0.23 0.23 0.46 4867.82 2597.92 397.37 118.09 VIMADMM 0.23 0.08 0.31 233.36 124.54 66.67 11.32 w/o model splitting FDML 0.039 0.039 0.078 405.13 617.76 33.07 89.13 VIMADMM-J 0.039 0.078 0.117 86.81 46.33 24.8 8.42

VI-A2 Non-linear server heads

To demonstrate the generality and applicability of VIMADMM, we evaluate VIMADMM when the server model is non-linear. Specifically, the head consists of multiple fully-connected layers accompanied by Dropout layers with 0.25 dropout rate and ReLu activation functions. For a fair comparison, we also use MLP server model architecture for other baseline methods. We use 3 layered MLP for NUS-WIDE and 2 layered MLP for other datasets. The evaluation results in Figure 2 show that our method still outperforms other baselines under general architectures with the non-linear server model.

VI-A3 Communication costs

Here we report the memory of parameters communicated between clients and the server to evaluate communication cost in Table IV. We use batch size 1024 and local embedding size 60 for all datasets. The overall embedding size scales with the number of clients. From Table IV, we observer that (1) for each round, all methods under w/ model splitting setting have the same number of parameters sent from each client to the server (i.e., 0.23 MB for a batch of embeddings), and VIMADMM has a smaller number of parameters sent from server to each client (i.e., 0.08 MB in total for a batch of dual variables, residual variables as well as one corresponding linear head) than VAFL, Split Learning and FedBCD (i.e., 0.23 MB for a batch of gradients w.r.t. embeddings). (2) With smaller # of communicated parameters at each round and faster convergence (i.e., smaller # of communicated rounds to achieve a target accuracy), VIMADMM requires significantly lower communication costs than baselines. For example, to achieve 65.0% accuracy on CIFAR, VAFL needs 5381.4 MB while VIMADMM only requires 124.54 MB, which is about 43x lower costs. Here we use τ=20,30,20,5𝜏2030205\tau=20,30,20,5italic_τ = 20 , 30 , 20 , 5 for the four datasets respectively. (3) The results under w/o model splitting setting demonstrates that VIMADMM-J incurs lower communication costs than FDML to achieve the same accuracy, due to faster convergence with multiple local updates. (4) We note that the communication cost under w/o model splitting setting is generally lower than w/ model splitting setting, which is because the local logits have a lower dimension than local embeddings, i.e., dc<dfsubscript𝑑𝑐subscript𝑑𝑓d_{c}<d_{f}italic_d start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT < italic_d start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT.

VI-A4 Effect of penalty factor ρ𝜌\rhoitalic_ρ

In ADMM-based methods, we introduce one hyper-parameter – penalty factor ρ𝜌\rhoitalic_ρ. Here we study the test accuracy of VIMADMM with different penalty factor ρ𝜌\rhoitalic_ρ. The results in Figure 5 of Section -D show that VIMADMM is not sensitive to ρ𝜌\rhoitalic_ρ on four datasets, and we suggest that the practitioners choose the optimal ρ𝜌\rhoitalic_ρ from 0.5 to 2, which does not influence the test accuracy significantly.

VI-A5 Evaluation on long-tail datasets

Long-tail datasets are characterized by a significant imbalance, where minority classes have far fewer samples than majority ones. This horizontal imbalance is distinct from the challenges addressed by VFL, where the same sample (whether it belongs to a majority or minority class) is vertically split across multiple clients. We compared the VIMADMM model, which consists of M𝑀Mitalic_M local models followed by a server model, with a reference model in a centralized setting. This reference model has the same model size as one local model coupled with a server model. The results in Table V demonstrate that VIMADMM is still effective on challenging long-tail training datasets, yielding results comparable to those of the reference model in a centralized setting. We defer more discussion and detailed experimental setups to Section -D.

TABLE V: Accuracy and fairness (measured by Standard Deviation of class-wise accuracy) on balanced data and long-tail data.
balanced MNIST long-tail MNIST balanced CIFAR long-tail CIFAR
VIMADMM 97.13 (0.76) 95.69 (1.58) 75.25 (9.17) 62.81 (15.27)
Reference model in centralized setting 98.19 (0.45) 95.02 (2.70) 77.61 (9.20) 66.11 (15.29)
TABLE VI: Utility of VFL methods under user-level DP. ADMM-based methods maintain higher utility.
VFL setup Method MNIST CIFAR NUS-WIDE ModelNet40
ϵ=italic-ϵ\epsilon=\inftyitalic_ϵ = ∞ ϵ=8italic-ϵ8\epsilon=8italic_ϵ = 8 ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1 ϵ=italic-ϵ\epsilon=\inftyitalic_ϵ = ∞ ϵ=8italic-ϵ8\epsilon=8italic_ϵ = 8 ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1 ϵ=italic-ϵ\epsilon=\inftyitalic_ϵ = ∞ ϵ=8italic-ϵ8\epsilon=8italic_ϵ = 8 ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1 ϵ=italic-ϵ\epsilon=\inftyitalic_ϵ = ∞ ϵ=8italic-ϵ8\epsilon=8italic_ϵ = 8 ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1
w/ model splitting VAFL 96.86 22.29 11.31 66.39 16.82 14.91 87.81 38.27 38.19 90.07 4.66 4.29
Split Learning 96.92 56.53 16.77 68.32 21.09 15.8 88.25 38.29 33.05 89.98 18.19 6.28
FedBCD 96.59 66.07 65.05 71.2 70.67 55.42 87.59 42.95 41.02 89.87 88.3 87.02
VIMADMM 97.13 92.35 92.09 75.25 73.83 61.65 88.51 83.77 83.51 91.32 91.29 91.18
w/o model splitting FDML 97.06 92.02 85.01 66.8 41.07 35.25 87.67 79.58 67.38 89.86 54.7 43.4
VIMADMM-J 97.37 92.71 92.33 74.48 72.36 58.64 88.46 84.94 84.88 91.13 90.13 89.37
TABLE VII: Utility of VFL methods under label-level DP. ADMM-based methods maintain higher utility.
VFL setup Method MNIST CIFAR NUS-WIDE ModelNet40
ϵ=italic-ϵ\epsilon=\inftyitalic_ϵ = ∞ ϵ=2.8italic-ϵ2.8\epsilon=2.8italic_ϵ = 2.8 ϵ=1.4italic-ϵ1.4\epsilon=1.4italic_ϵ = 1.4 ϵ=italic-ϵ\epsilon=\inftyitalic_ϵ = ∞ ϵ=2.8italic-ϵ2.8\epsilon=2.8italic_ϵ = 2.8 ϵ=1.4italic-ϵ1.4\epsilon=1.4italic_ϵ = 1.4 ϵ=italic-ϵ\epsilon=\inftyitalic_ϵ = ∞ ϵ=2.8italic-ϵ2.8\epsilon=2.8italic_ϵ = 2.8 ϵ=1.4italic-ϵ1.4\epsilon=1.4italic_ϵ = 1.4 ϵ=italic-ϵ\epsilon=\inftyitalic_ϵ = ∞ ϵ=2.8italic-ϵ2.8\epsilon=2.8italic_ϵ = 2.8 ϵ=1.4italic-ϵ1.4\epsilon=1.4italic_ϵ = 1.4
w/ model splitting VAFL 96.86 94.27 51.68 66.39 54.6 38.44 87.81 85.77 60.41 90.07 45.26 2.59
Split Learning 96.92 94.93 91.75 68.32 57.12 49.71 88.25 85.86 82.3 89.98 65.68 33.79
FedBCD 96.59 94.47 87.95 71.2 61.05 46.14 87.59 85.62 64.01 89.87 65.92 43.15
VIMADMM 97.13 95.48 92.8 75.25 65.07 52.97 88.51 86.62 82.43 91.32 76.70 46.39
w/o model splitting FDML 97.06 94.97 91.87 66.8 58.78 49.83 87.67 85.79 82.37 89.86 64.99 29.74
VIMADMM-J 97.37 95.80 93.25 74.48 64.04 53.49 88.46 86.74 82.71 91.13 77.15 45.22

VI-A6 Fairness implication

A common fairness definition is to enforce accuracy parity between protected groups zafar2017fairness . Here we study the fairness implications of VIMADMM on achieving accuracy parity, at both the class and client levels: (1) when considering class-level accuracy parity, a fair model should exhibit equalized accuracy for each class tarzanagh2023fairness ; xu2021robust , indicating that the model’s accuracy is statistically independent of the ground truth label. We use the Standard Deviation of class-wise accuracy xu2021robust to evaluate fairness, where a lower value indicates higher fairness. The results in Table V show that VIMADMM performs comparably or even better in fairness than the reference model in a centralized setting, across MNIST and CIFAR10 datasets with both balanced and long-tail distributions. (2) Furthermore, client-level accuracy parity is a prevalent criterion for fairness in FL  li2021ditto ; Li2020Fair , measuring the degree of uniformity in performance across clients. Notably, in VFL, all clients share the same prediction for each sample, where each of them contributes partial features. Consequently, all clients inherently achieve the same accuracy, fulfilling client-level accuracy parity by the nature of VFL.

VI-B Evaluation on Differentially Private VFL

We evaluate the utility of ADMM-based methods and baselines under client-level DP and label DP, which protect the privacy of local features and server labels, respectively.

VI-B1 Utility under client-level DP (privacy of client data)

We report the utility under ϵ=8italic-ϵ8\epsilon=8italic_ϵ = 8 and ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1 client-level DP. To ensure fair comparison, we perform a grid search for the combination of hyperparameters, including noise scale σ𝜎\sigmaitalic_σ, clipping threshold C𝐶Citalic_C, and learning rate η𝜂\etaitalic_η, for all methods (details are deferred to Section -D). Table VI shows that (1) the accuracy of ADMM-based methods under DP is on par with the non-private accuracy (ϵ=italic-ϵ\epsilon=\inftyitalic_ϵ = ∞) on MNIST, NUS-WIDE and ModelNet40. Nevertheless, there is a discernible decrease of 13.6% for VIMADMM on CIFAR when ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1, which underscores the inherent privacy-utility trade-off for algorithms with formal DP privacy guarantees abadi2016deep . (2) Our ADMM-based methods reach significantly higher utility than gradient-based methods, especially under small ϵitalic-ϵ\epsilonitalic_ϵ. We attribute this to the fact that ADMM-based methods converge in fewer rounds than gradient-based methods at each round, which is also evident in the non-DP setting as shown in Figure 1. This rapid convergence is critical for DP, since the privacy budget ϵitalic-ϵ\epsilonitalic_ϵ is consumed quickly as communication rounds increase. The fast convergence and high utility of VIMADMM under DP compared to other baselines can be interpreted through two lenses. First, multiple local updates lead to a more effectively trained local model at each round. As a consequence, both FedBCD and VIMADMM demonstrate a markedly better DP-utility tradeoff compared to VAFL and Split Learning, as illustrated in Table VI. Furthermore, we explicitly investigate the influence of τ𝜏\tauitalic_τ on the utility of VIMADMM under ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1 in Table VIII. The results show that opting for a τ>1𝜏1\tau>1italic_τ > 1 yields substantially enhanced accuracy than when τ=1𝜏1\tau=1italic_τ = 1 (e.g., 14.68% improvement on CIFAR). Second, update mechanism of ADMM empowers clients to independently update their local models w.r.t the ADMM sub-objective (Eq. III-B). It is worth noting that during this local forward/backward computation based on Eq. III-B, clients do not add noise locally, since local models always remain in their possession without sharing. Clients only need to perturb local embeddings that are sent to the server (i.e., output perturbation). Consequently, even though the server leverages these perturbed embeddings to derive ADMM-related variables, the clients will re-calculate clean embeddings during forward pass of Eq. III-B based on the received ADMM-related variables for local model updates. This updating mechanism potentially facilitate convergence under DP. In contrast, gradients-based methods solely rely on the partial gradients, which are derived from perturbed embeddings, for local update, leading to compromised utility.

MNIST CIFAR NUS-WIDE ModelNet40

input features

Refer to caption Refer to caption Refer to caption Refer to caption

clean

Refer to caption Refer to caption Refer to caption Refer to caption

noisy test client

Refer to caption Refer to caption Refer to caption Refer to caption

denoising

Refer to caption Refer to caption Refer to caption Refer to caption
Figure 3: Client-level explainability of VIM. Row 1 visualizes the input features. Row 2 shows the weights norm of linear heads. Row 3 shows the test accuracy when each client’s test input features are perturbed (red line denotes the clean test accuracy). Row 4 shows the weights norm of linear heads under only one noisy client.
TABLE VIII: A larger number of local steps τ𝜏\tauitalic_τ leads to better utility of VIMADMM under client-level DP ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1.
MNIST CIFAR NUS-WIDE
τ𝜏\tauitalic_τ ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1 τ𝜏\tauitalic_τ ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1 τ𝜏\tauitalic_τ ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1
1 90.63 1 48.19 1 79.38
5 90.84 10 61.65 3 82.58
20 92.09 30 62.87 10 83.51

Methods w/o model splitting (FDML, VIMADMM-J) generally performs better than methods w/ model splitting. This is mainly because the logits have a smaller dimension than the embeddings, and the total amount of noise added to the logits output is smaller than the embedding output; thus VFL w/o model splitting methods retain higher utility under DP.

Additionally, the utility under client-level DP VFL is not directly comparable to sample-level DP in centralized ML abadi2016deep or client-level DP in standard (horizontal) FL mcmahan2018learning due to the unique properties of VFL. For instances, (1) the dimension of DP-perturbed information in VFL can be smaller (e.g., a batch of local embeddings or local logits) than the existing centralized learning or FL (e.g., gradients or model updates of a large model), which could lead to the higher utility under DP noise. (2) The private local training set of VFL for each user has a smaller raw feature dimension (i.e., 1/M1𝑀1/M1 / italic_M if features are divided evenly among M𝑀Mitalic_M clients) than the entire dataset (or local dataset) in the central setting (or horizontal FL) and it does not contain the labels, which leads to a different dataset notion in DP definition. In our work, we follow existing privacy notions in VFL to protect each user’s local training set chen2020vafl ; hu2019fdml with proposed client-level DP mechanisms.

VI-B2 Utility under label DP (privacy of server labels)

To protect the privacy of the labels in the server with formal privacy guarantee, we utilize the existing state-of-the-art label DP mechanism ALIBI malek2021antipodes , which is originally proposed in centralized learning. We evaluate all methods under label DP with a privacy budget ϵ=2.8italic-ϵ2.8\epsilon=2.8italic_ϵ = 2.8 and ϵ=1.4italic-ϵ1.4\epsilon=1.4italic_ϵ = 1.4, which are obtained by adding Laplacian noise with noise parameter λLap=1subscript𝜆Lap1\lambda_{\mathrm{Lap}}=1italic_λ start_POSTSUBSCRIPT roman_Lap end_POSTSUBSCRIPT = 1 and λLap=2subscript𝜆Lap2\lambda_{\mathrm{Lap}}=2italic_λ start_POSTSUBSCRIPT roman_Lap end_POSTSUBSCRIPT = 2, respectively, on the labels once before VFL training, and we use randomized labels for training based on ALIBI. In particular, ALIBI post-processes the model predictions through Bayesian inference to improve the model utility under noisy labels malek2021antipodes . The results on Table VII show that ADMM-based methods retain higher utility than gradient-based methods under the label DP. This could be due to two potential reasons: (1) the additional variables introduced by ADMM (i.e., auxiliary variables {zj}subscript𝑧𝑗\{z_{j}\}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } and dual variables {λj}subscript𝜆𝑗\{\lambda_{j}\}{ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT }) are dynamically adjusted during training, which might contribute to a more robust optimization ding2019differentially for VFL models (i.e., {Wk},{θk}subscript𝑊𝑘subscript𝜃𝑘\{W_{k}\},\{\theta_{k}\}{ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }) against label noises, and (2) multiple updates in each round could result in improved local models. As shown in Table IX, more local steps τ𝜏\tauitalic_τ can significantly enhance the utility of VIMADMM under label-level DP ϵ=1.4italic-ϵ1.4\epsilon=1.4italic_ϵ = 1.4 (similar-to\sim10% and similar-to\sim13% improvement for CIFAR and NUS-WIDE, respectively).

TABLE IX: A larger number of local steps τ𝜏\tauitalic_τ leads to better utility of VIMADMM under label-level DP ϵ=1.4italic-ϵ1.4\epsilon=1.4italic_ϵ = 1.4.
MNIST CIFAR NUS-WIDE
τ𝜏\tauitalic_τ ϵ=1.4italic-ϵ1.4\epsilon=1.4italic_ϵ = 1.4 τ𝜏\tauitalic_τ ϵ=1.4italic-ϵ1.4\epsilon=1.4italic_ϵ = 1.4 τ𝜏\tauitalic_τ ϵ=1.4italic-ϵ1.4\epsilon=1.4italic_ϵ = 1.4
1 92.28 1 46.08 1 69.41
5 92.51 10 52.97 3 81.43
20 92.8 30 56.13 10 82.43

VI-C Client-level Explainability of VIM












MNIST CIFAR NUS-WIDE ModelNet40

important

Refer to caption Refer to caption Refer to caption Refer to caption

unimportant

Refer to caption Refer to caption Refer to caption Refer to caption



Figure 4: T-SNE visualizations of local embeddings from important client and unimportant client for VIMADMM.

In this section, we first visualize the local embeddings of clients, which are diverse, stemming from the distinct input features of clients. This also justifies the multi-head design of VIM that can reweight the embeddings based on their importance. Then, we show that the weights norm of learned linear heads can indeed reflect the importance of local clients, which enables functionalities such as test-time noise validation, client denoising, and summarization.

VI-C1 T-SNE of Local Embeddings

In row 1 of Figure 3, we show the raw feautures of different clients on four datasets. The quality of features can vary among clients. For instance, in MNIST, since the digit always occupies the center, clients hold black background pixels might not provide useful information for the classification task, and thus are less important. The T-SNE tsne2008 visualizations in Figure 4 reveal that important clients learn better local embeddings than unimportant clients on MNIST, CIFAR and NUS-WIDE. Specifically, in NUS-WIDE, client #3 produces linear separable local embeddings (left), which are better than client #4’s embeddings (right) that overlap different classes. For ModelNet40, since clients with multi-view data are of similar importance, their local embeddings exhibit similarities and demonstrate linear separability. A scrutiny of these local embeddings confirms that the unique characteristic of input features in each client lead to varied local embeddings. Consequently, we employ multiple heads as the server model, allowing us to account for the diverse feature quality across clients and aptly reweight the local embeddings.

VI-C2 Client Importance

Given a trained VIMADMM model, we plot the weights norm of each client’s corresponding linear heads in Figure 3 row 2. Combining it with row 1, we find that the client with important local features indeed results in high weights555Here the weights of clients refer to the weights of the client’s corresponding linear head owned by the server.. For example, clients #6, #7, #8 in MNIST holding middle rows of images that contain the center of digits, have high weights, while clients #1, #14 holding the black background pixels have low weights. A similar phenomenon is observed on CIFAR for client #5 (center) and client #1 (corner). On CIFAR, clients #8, #9 also have high weights, which is because the objects on CIFAR also appear on the right bottom corner. On ModelNet40, clients have complementary views of the same objects, so their features have similar importance, leading to similar weights norms. Based on our observation, we conclude that the weights of linear heads can reflect the importance of local clients. We use this principle to infer that, for NUS-WIDE, the first 500 dim. of textual features have higher importance than other multimodality features, resulting in the high weights norm of client 3.

VI-C3 Client Importance Validation via Noisy Test Client

Given a trained VIMADMM model, we add Gaussian noise to the test local features to verify the client-level importance indicated by the linear heads. For each time, we only perturb the features of one client and keep other clients’ features unchanged. The results in Figure 3 row 3 show that perturbing the client with high weights affects more for the test accuracy, which verifies that clients with higher weights are more important.

VI-C4 Client Denoising

We study the denoising ability of VIM under training-time noisy clients. We construct one noisy client (i.e., client #7, #5, #2, #3 for MNIST, CIFAR, NUS-WIDE, ModelNet40 respectively) by adding Gaussian noise to its local features and re-train the VIMADMM model. The obtained weights norm in Figure 3 row 4 shows that VIMADMM can automatically detect the noisy client and lower its weights (compared to the clean one in row 2). Table XIII in Appendix -D shows that VIMADMM outperforms baselines with faster convergence and higher accuracy under noisy clients.

VI-C5 Client Summarization

Regarding client summarization, (1) we first rank the importance of clients according to their weights norm (Figure 3 row 2), then we select u%percent𝑢u\%italic_u % proportion of the most “important" clients to re-train the VIMADMM model. We find that its performance is close to the one trained by all clients. Table X shows that the test accuracy-drop of training with 50% of the most important clients is less than 1% on MNIST and NUS-WIDE, and less than 4% on CIFAR; the accuracy-drop of training with 20% of the most important clients is less than 10% on all datasets. (2) We select u%percent𝑢u\%italic_u % proportion of the least important clients to re-train the model, and we find that its performance is significantly lower than the one trained with important clients, which indicates the effectiveness of VIM for client selection. (3) For the multi-view dataset ModelNet40, we find that the test accuracy of models trained with 12, 8, and 4 clients are similar, i.e., 91.04%, 90.69%, and 90.64%, suggesting that a few views can already provide sufficient training information and the agents with multiview data are of similar importance which is also reflected by our linear head weights.

TABLE X: Functionality of client summarization enabled by VIMADMM.
Client ratio Type MNIST CIFAR NUS-WIDE
100%percent100100\%100 % all 97.13 75.25 88.51
50%percent5050\%50 % important 96.58 70.28 87.29
unimportant 78.11 62.67 75.80
20%percent2020\%20 % important 88.72 66.06 80.28
unimportant 29.11 54.99 59.34

VII Conclusions

We propose a VFL framework with multiple linear heads (VIM) and an ADMM-based method (VIMADMM) for efficient communication. We provide the convergence guarantee for VIMADMM. We also introduce user-level differential privacy mechanism for VIM and prove the privacy guarantee. Extensive experiments verify the superior performance of our algorithms under vanilla VFL and DP VFL and show that VIM enables client-level explainability.

Acknowledgement

The authors thank Yunhui Long, Linyi Li, Yangjun Ruan, Weixin Chen, and the anonymous reviewers for their valuable feedback and suggestions.

This work is partially supported by the National Science Foundation under grant No. 1910100, No. 2046726, No. 2229876, DARPA GARD, the National Aeronautics and Space Administration (NASA) under grant No. 80NSSC20M0229, Alfred P. Sloan Fellowship, the Amazon research award, and the eBay research grant.

References

  • (1) Martin Abadi, Andy Chu, Ian Goodfellow, H Brendan McMahan, Ilya Mironov, Kunal Talwar, and Li Zhang. Deep learning with differential privacy. In Proceedings of the 2016 ACM SIGSAC conference on computer and communications security, pages 308–318, 2016.
  • (2) Naman Agarwal, Ananda Theertha Suresh, Felix Yu, Sanjiv Kumar, and H Brendan McMahan. cpsgd: communication-efficient and differentially-private distributed sgd. In Proceedings of the 32nd International Conference on Neural Information Processing Systems, pages 7575–7586, 2018.
  • (3) Borja Balle, Gilles Barthe, Marco Gaboardi, Justin Hsu, and Tetsuya Sato. Hypothesis testing interpretations and renyi differential privacy. In International Conference on Artificial Intelligence and Statistics, pages 2496–2506. PMLR, 2020.
  • (4) Michael Ben-Or, Shafi Goldwasser, and Avi Wigderson. Completeness theorems for non-cryptographic fault-tolerant distributed computation. In Proceedings of the twentieth annual ACM symposium on Theory of computing, pages 1–10, 1988.
  • (5) Abhishek Bhowmick, John Duchi, Julien Freudiger, Gaurav Kapoor, and Ryan Rogers. Protection against reconstruction and its applications in private federated learning. arXiv preprint arXiv:1812.00984, 2018.
  • (6) Keith Bonawitz, Vladimir Ivanov, Ben Kreuter, Antonio Marcedone, H Brendan McMahan, Sarvar Patel, Daniel Ramage, Aaron Segal, and Karn Seth. Practical secure aggregation for privacy-preserving machine learning. In proceedings of the 2017 ACM SIGSAC Conference on Computer and Communications Security, pages 1175–1191, 2017.
  • (7) Stephen Boyd, Neal Parikh, and Eric Chu. Distributed optimization and statistical learning via the alternating direction method of multipliers. Now Publishers Inc, 2011.
  • (8) Theodora S Brisimi, Ruidi Chen, Theofanie Mela, Alex Olshevsky, Ioannis Ch Paschalidis, and Wei Shi. Federated learning of predictive models from federated electronic health records. International journal of medical informatics, 112:59–67, 2018.
  • (9) Adam Byerly, Tatiana Kalganova, and Ian Dear. No routing needed between capsules. Neurocomputing, 463:545–553, 2021.
  • (10) Timothy Castiglia, Shiqiang Wang, and Stacy Patterson. Flexible vertical federated learning with heterogeneous parties. IEEE Transactions on Neural Networks and Learning Systems, 2023.
  • (11) Timothy J Castiglia, Anirban Das, Shiqiang Wang, and Stacy Patterson. Compressed-vfl: Communication-efficient learning with vertically partitioned data. In International Conference on Machine Learning, pages 2738–2766. PMLR, 2022.
  • (12) Tianyi Chen, Xiao Jin, Yuejiao Sun, and Wotao Yin. Vafl: a method of vertical asynchronous federated learning. arXiv preprint arXiv:2007.06081, 2020.
  • (13) Kewei Cheng, Tao Fan, Yilun Jin, Yang Liu, Tianjian Chen, Dimitrios Papadopoulos, and Qiang Yang. Secureboost: A lossless federated learning framework. IEEE Intelligent Systems, 36(6):87–98, 2021.
  • (14) Tat-Seng Chua, Jinhui Tang, Richang Hong, Haojie Li, Zhiping Luo, and Yantao Zheng. Nus-wide: a real-world web image database from national university of singapore. In Proceedings of the ACM international conference on image and video retrieval, pages 1–9, 2009.
  • (15) Vincent Cohen-Addad, Praneeth Kacham, Vahab Mirrokni, and Peilin Zhong. Differentially private vertical federated learning primitives.
  • (16) Jiahao Ding, Xinyue Zhang, Mingsong Chen, Kaiping Xue, Chi Zhang, and Miao Pan. Differentially private robust admm for distributed machine learning. In 2019 IEEE International Conference on Big Data (Big Data), pages 1302–1311. IEEE, 2019.
  • (17) Jinshuo Dong, Aaron Roth, and Weijie J Su. Gaussian differential privacy. arXiv preprint arXiv:1905.02383, 2019.
  • (18) Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2021.
  • (19) Cynthia Dwork. A firm foundation for private data analysis. Communications of the ACM, 54(1):86–95, 2011.
  • (20) Cynthia Dwork, Krishnaram Kenthapadi, Frank McSherry, Ilya Mironov, and Moni Naor. Our data, ourselves: Privacy via distributed noise generation. In Advances in Cryptology – EUROCRYPT, 2006.
  • (21) Cynthia Dwork and Aaron Roth. The algorithmic foundations of differential privacy. Foundations and Trends in Theoretical Computer Science, 9(3-4):211–407, 2014.
  • (22) Anis Elgabli, Jihong Park, Sabbir Ahmed, and Mehdi Bennis. L-fgadmm: Layer-wise federated group admm for communication efficient decentralized deep learning. In 2020 IEEE Wireless Communications and Networking Conference (WCNC), pages 1–6. IEEE, 2020.
  • (23) Anis Elgabli, Jihong Park, Amrit S Bedi, Mehdi Bennis, and Vaneet Aggarwal. Gadmm: Fast and communication efficient framework for distributed machine learning. J. Mach. Learn. Res., 21(76):1–39, 2020.
  • (24) Siwei Feng and Han Yu. Multi-participant multi-class vertical federated learning. arXiv preprint arXiv:2001.11154, 2020.
  • (25) Chong Fu, Xuhong Zhang, Shouling Ji, Jinyin Chen, Jingzheng Wu, Shanqing Guo, Jun Zhou, Alex X Liu, and Ting Wang. Label inference attacks against vertical federated learning. In 31st USENIX Security Symposium (USENIX Security 22), Boston, MA, August 2022. USENIX Association.
  • (26) Fangcheng Fu, Xupeng Miao, Jiawei Jiang, Huanran Xue, and Bin Cui. Towards communication-efficient vertical federated learning training via cache-enabled local updates. Proc. VLDB Endow., 15(10):2111–2120, jun 2022.
  • (27) Robin C Geyer, Tassilo Klein, and Moin Nabi. Differentially private federated learning: A client level perspective. arXiv preprint arXiv:1712.07557, 2017.
  • (28) Ran Gilad-Bachrach, Nathan Dowlin, Kim Laine, Kristin Lauter, Michael Naehrig, and John Wernsing. Cryptonets: Applying neural networks to encrypted data with high throughput and accuracy. In International Conference on Machine Learning, pages 201–210. PMLR, 2016.
  • (29) Prashant Gohel, Priyanka Singh, and Manoranjan Mohanty. Explainable ai: current status and future directions. arXiv preprint arXiv:2107.07045, 2021.
  • (30) Bin Gu, Zhiyuan Dang, Xiang Li, and Heng Huang. Federated doubly stochastic kernel learning for vertically partitioned data. In Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pages 2483–2493, 2020.
  • (31) Andrew Hard, Kanishka Rao, Rajiv Mathews, Françoise Beaufays, Sean Augenstein, Hubert Eichner, Chloé Kiddon, and Daniel Ramage. Federated learning for mobile keyboard prediction. arXiv preprint arXiv:1811.03604, 2018.
  • (32) Stephen Hardy, Wilko Henecka, Hamish Ivey-Law, Richard Nock, Giorgio Patrini, Guillaume Smith, and Brian Thorne. Private federated learning on vertically partitioned data via entity resolution and additively homomorphic encryption. arXiv preprint arXiv:1711.10677, 2017.
  • (33) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
  • (34) Mingyi Hong, Zhi-Quan Luo, and Meisam Razaviyayn. Convergence analysis of alternating direction method of multipliers for a family of nonconvex problems. SIAM Journal on Optimization, 26(1):337–364, 2016.
  • (35) Yaochen Hu, Peng Liu, Linglong Kong, and Di Niu. Learning privately over distributed features: An admm sharing approach. arXiv preprint arXiv:1907.07735, 2019.
  • (36) Yaochen Hu, Di Niu, Jianming Yang, and Shengping Zhou. Fdml: A collaborative machine learning framework for distributed features. In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pages 2232–2240, 2019.
  • (37) Yuzheng Hu, Fan Wu, Qinbin Li, Yunhui Long, Gonzalo Munilla Garrido, Chang Ge, Bolin Ding, David Forsyth, Bo Li, and Dawn Song. Sok: Privacy-preserving data synthesis. S&P, 2024.
  • (38) Zonghao Huang, Rui Hu, Yuanxiong Guo, Eric Chan-Tin, and Yanmin Gong. Dp-admm: Admm-based distributed learning with differential privacy. IEEE Transactions on Information Forensics and Security, 15:1002–1012, 2019.
  • (39) Xiao Jin, Pin-Yu Chen, Chia-Yi Hsu, Chia-Mu Yu, and Tianyi Chen. Catastrophic data leakage in vertical federated learning. Advances in Neural Information Processing Systems, 34, 2021.
  • (40) Yan Kang, Yang Liu, and Tianjian Chen. Fedmvt: Semi-supervised vertical federated learning with multiview training. arXiv preprint arXiv:2008.10838, 2020.
  • (41) Prannay Khosla, Piotr Teterwak, Chen Wang, Aaron Sarna, Yonglong Tian, Phillip Isola, Aaron Maschinot, Ce Liu, and Dilip Krishnan. Supervised contrastive learning. Advances in Neural Information Processing Systems, 33:18661–18673, 2020.
  • (42) Alex Krizhevsky. Learning multiple layers of features from tiny images. Technical report, 2009.
  • (43) Yann LeCun and Corinna Cortes. MNIST handwritten digit database. 2010.
  • (44) Tian Li, Shengyuan Hu, Ahmad Beirami, and Virginia Smith. Ditto: Fair and robust federated learning through personalization. In International Conference on Machine Learning, pages 6357–6368. PMLR, 2021.
  • (45) Tian Li, Maziar Sanjabi, Ahmad Beirami, and Virginia Smith. Fair resource allocation in federated learning. In International Conference on Learning Representations, 2020.
  • (46) Yang Liu, Yan Kang, Xinwei Zhang, Liping Li, Yong Cheng, Tianjian Chen, Mingyi Hong, and Qiang Yang. A communication efficient collaborative learning framework for distributed features. arXiv preprint arXiv:1912.11187, 2019.
  • (47) Yang Liu, Zhihao Yi, and Tianjian Chen. Backdoor attacks and defenses in feature-partitioned collaborative learning. arXiv preprint arXiv:2007.03608, 2020.
  • (48) Yang Liu, Xinwei Zhang, Yan Kang, Liping Li, Tianjian Chen, Mingyi Hong, and Qiang Yang. Fedbcd: A communication-efficient collaborative learning framework for distributed features. IEEE Transactions on Signal Processing, 70:4277–4290, 2022.
  • (49) Scott M Lundberg and Su-In Lee. A unified approach to interpreting model predictions. Advances in neural information processing systems, 30, 2017.
  • (50) Aravindh Mahendran and Andrea Vedaldi. Understanding deep image representations by inverting them. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 5188–5196, 2015.
  • (51) Mani Malek Esmaeili, Ilya Mironov, Karthik Prasad, Igor Shilov, and Florian Tramer. Antipodes of label differential privacy: Pate and alibi. Advances in Neural Information Processing Systems, 34:6934–6945, 2021.
  • (52) Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas. Communication-Efficient Learning of Deep Networks from Decentralized Data. In Proceedings of the 20th International Conference on Artificial Intelligence and Statistics, volume 54 of Proceedings of Machine Learning Research, pages 1273–1282. PMLR, 20–22 Apr 2017.
  • (53) H. Brendan McMahan, Daniel Ramage, Kunal Talwar, and Li Zhang. Learning differentially private recurrent language models. In International Conference on Learning Representations, 2018.
  • (54) H Brendan McMahan, Daniel Ramage, Kunal Talwar, and Li Zhang. Learning differentially private recurrent language models. In International Conference on Learning Representations, 2018.
  • (55) Frank D McSherry. Privacy integrated queries: an extensible platform for privacy-preserving data analysis. In Proceedings of the 2009 ACM SIGMOD International Conference on Management of data, pages 19–30, 2009.
  • (56) Ilya Mironov. Rényi differential privacy. In 2017 IEEE 30th computer security foundations symposium (CSF), pages 263–275. IEEE, 2017.
  • (57) Nicolas Papernot, Patrick McDaniel, Arunesh Sinha, and Michael P Wellman. Sok: Security and privacy in machine learning. In 2018 IEEE European Symposium on Security and Privacy (EuroS&P), pages 399–414. IEEE, 2018.
  • (58) Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmaison, Andreas Kopf, Edward Yang, Zachary DeVito, Martin Raison, Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala. Pytorch: An imperative style, high-performance deep learning library. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems 32, pages 8024–8035. Curran Associates, Inc., 2019.
  • (59) Natalia Ponomareva, Hussein Hazimeh, Alex Kurakin, Zheng Xu, Carson Denison, H Brendan McMahan, Sergei Vassilvitskii, Steve Chien, and Abhradeep Guha Thakurta. How to dp-fy ml: A practical guide to machine learning with differential privacy. Journal of Artificial Intelligence Research, 77:1113–1201, 2023.
  • (60) Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, et al. Learning transferable visual models from natural language supervision. In International Conference on Machine Learning, pages 8748–8763. PMLR, 2021.
  • (61) Alec Radford, Luke Metz, and Soumith Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434, 2015.
  • (62) Thilina Ranbaduge and Ming Ding. Differentially private vertical federated learning. arXiv preprint arXiv:2211.06782, 2022.
  • (63) Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. " why should i trust you?" explaining the predictions of any classifier. In Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining, pages 1135–1144, 2016.
  • (64) Bita Darvish Rouhani, M Sadegh Riazi, and Farinaz Koushanfar. DeepSecure: Scalable provably-secure deep learning. In Proceedings of the 55th Annual Design Automation Conference, pages 1–6, 2018.
  • (65) Reza Shokri, Marco Stronati, Congzheng Song, and Vitaly Shmatikov. Membership inference attacks against machine learning models. In 2017 IEEE symposium on security and privacy (SP), pages 3–18. IEEE, 2017.
  • (66) Jong-Chyi Su, Matheus Gadelha, Rui Wang, and Subhransu Maji. A deeper look at 3d shape classifiers. In Second Workshop on 3D Reconstruction Meets Semantics, ECCV, 2018.
  • (67) Davoud Ataee Tarzanagh, Bojian Hou, Boning Tong, Qi Long, and Li Shen. Fairness-aware class imbalanced learning on multiple subgroups. In Uncertainty in Artificial Intelligence, pages 2123–2133. PMLR, 2023.
  • (68) Linh Tran, Timothy Castiglia, Stacy Patterson, and Ana Milanova. Privacy tradeoffs in vertical federated learning. In Federated Learning Systems (FLSys) Workshop @ MLSys 2023, 2023.
  • (69) Laurens van der Maaten and Geoffrey Hinton. Visualizing data using t-sne. Journal of Machine Learning Research, 9(86):2579–2605, 2008.
  • (70) Praneeth Vepakomma, Otkrist Gupta, Tristan Swedish, and Ramesh Raskar. Split learning for health: Distributed deep learning without sharing raw patient data. arXiv preprint arXiv:1812.00564, 2018.
  • (71) Rahul Vigneswaran, Marc T Law, Vineeth N Balasubramanian, and Makarand Tapaswi. Feature generation for long-tail classification. In Proceedings of the twelfth Indian conference on computer vision, graphics and image processing, pages 1–9, 2021.
  • (72) Boxin Wang, Fan Wu, Yunhui Long, Luka Rimanic, Ce Zhang, and Bo Li. Datalens: Scalable privacy preserving training via gradient compression and aggregation. In Proceedings of the 2021 ACM SIGSAC Conference on Computer and Communications Security, pages 2146–2168, 2021.
  • (73) Wenju Wang, Yu Cai, and Tao Wang. Multi-view dual attention network for 3d object recognition. Neural Computing and Applications, 34(4):3201–3212, 2022.
  • (74) Yuncheng Wu, Shaofeng Cai, Xiaokui Xiao, Gang Chen, and Beng Chin Ooi. Privacy preserving vertical federated learning for tree-based models. Proceedings of the VLDB Endowment, 13(12):2090–2103, 2020.
  • (75) Chulin Xie, Yunhui Long, Pin-Yu Chen, Qinbin Li, Sanmi Koyejo, and Bo Li. Unraveling the connections between privacy and certified robustness in federated learning against poisoning attacks. In Proceedings of the 2023 ACM SIGSAC Conference on Computer and Communications Security, pages 1511–1525, 2023.
  • (76) Han Xu, Xiaorui Liu, Yaxin Li, Anil Jain, and Jiliang Tang. To be robust or to be fair: Towards fairness in adversarial training. In International conference on machine learning, pages 11492–11501. PMLR, 2021.
  • (77) Qiang Yang, Yang Liu, Tianjian Chen, and Yongxin Tong. Federated machine learning: Concept and applications. ACM Transactions on Intelligent Systems and Technology (TIST), 10(2):12, 2019.
  • (78) Shengwen Yang, Bing Ren, Xuhui Zhou, and Liping Liu. Parallel distributed logistic regression for vertical federated learning without third-party coordinator. arXiv preprint arXiv:1911.09824, 2019.
  • (79) Timothy Yang, Galen Andrew, Hubert Eichner, Haicheng Sun, Wei Li, Nicholas Kong, Daniel Ramage, and Françoise Beaufays. Applied federated learning: Improving google keyboard query suggestions. arXiv preprint arXiv:1812.02903, 2018.
  • (80) Wensi Yang, Yuhang Zhang, Kejiang Ye, Li Li, and Cheng-Zhong Xu. Ffd: a federated learning based method for credit card fraud detection. In International Conference on Big Data, pages 18–32. Springer, 2019.
  • (81) Sheng Yue, Ju Ren, Jiang Xin, Sen Lin, and Junshan Zhang. Inexact-admm based federated meta-learning for fast and continual edge learning. In Proceedings of the Twenty-second International Symposium on Theory, Algorithmic Foundations, and Protocol Design for Mobile Networks and Mobile Computing, pages 91–100, 2021.
  • (82) Muhammad Bilal Zafar, Isabel Valera, Manuel Gomez Rodriguez, and Krishna P Gummadi. Fairness beyond disparate treatment & disparate impact: Learning classification without disparate mistreatment. In Proceedings of the 26th international conference on world wide web, pages 1171–1180, 2017.
  • (83) Jie Zhang, Song Guo, Zhihao Qu, Deze Zeng, Haozhao Wang, Qifeng Liu, and Albert Y Zomaya. Adaptive vertical federated learning on unbalanced features. IEEE Transactions on Parallel and Distributed Systems, 33(12):4006–4018, 2022.
  • (84) Qingsong Zhang, Bin Gu, Cheng Deng, and Heng Huang. Secure bilevel asynchronous vertical federated learning with backward updating. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, pages 10896–10904, 2021.
  • (85) Ciyou Zhu, Richard H Byrd, Peihuang Lu, and Jorge Nocedal. Algorithm 778: L-bfgs-b: Fortran subroutines for large-scale bound-constrained optimization. ACM Transactions on mathematical software (TOMS), 23(4):550–560, 1997.

The Appendix is organized as follows:

  • Appendix -A provides algorithm details for Split Learning [70](Algorithm 2) and VIMADMM-J (Algorithm 3);

  • Appendix -B provides the proofs for convergence guarantees in Theorem 1;

  • Appendix -C provides the proofs for privacy guarantee in Theorem 2;

  • Appendix -D provides more details on experimental setups and the additional experimental results;

  • Appendix -E provides additional discussion on ADMM and VFL.

-A Algorithm Details

-A1 Split Learning [70]

At each communication round t𝑡titalic_t, the server samples a set of data indices, B(t)𝐵𝑡B(t)italic_B ( italic_t ), with batch size |B(t)|=b𝐵𝑡𝑏|B(t)|=b| italic_B ( italic_t ) | = italic_b. Then we describe the key steps Split Learning (Algorithm 2) as follows:

(1) Communication from client to server. Each client k𝑘kitalic_k sends a batch of embeddings {hjk(t)}jB(t)subscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡𝑗𝐵𝑡\{{h_{j}^{k}}^{(t)}\}_{j\in B(t)}{ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT to the server, where hjk(t)=f(xjk;θk(t)),jB(t)formulae-sequencesuperscriptsuperscriptsubscript𝑗𝑘𝑡𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡for-all𝑗𝐵𝑡{h_{j}^{k}}^{(t)}=f(x_{j}^{k};\theta_{k}^{(t)}),\forall j\in B(t)italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , ∀ italic_j ∈ italic_B ( italic_t ).

(2) Sever updates server model θ0subscriptnormal-θ0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. According to VFL objective in Eq. 1, the server model is updated as:

θ0(t+1)θ0(t)ηθ0(t)VFL(θ0(t)),k[M]formulae-sequencesuperscriptsubscript𝜃0𝑡1superscriptsubscript𝜃0𝑡𝜂subscriptsuperscriptsubscript𝜃0𝑡subscriptVFLsuperscriptsubscript𝜃0𝑡for-all𝑘delimited-[]𝑀\theta_{0}^{(t+1)}\leftarrow\theta_{0}^{(t)}-\eta\nabla_{\theta_{0}^{(t)}}% \mathcal{L}_{\mathrm{VFL}}(\theta_{0}^{(t)}),\forall k\in[M]italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ← italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_VFL end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , ∀ italic_k ∈ [ italic_M ] (14)

where η𝜂\etaitalic_η is the server learning rate, and

θ0(t)VFL(θ0(t))=θ0(t)(1Nj=1N([hj1(t),,hjM(t)],yj;θ0(t))+β(θ0(t))).subscriptsuperscriptsubscript𝜃0𝑡subscriptVFLsuperscriptsubscript𝜃0𝑡subscriptsuperscriptsubscript𝜃0𝑡1𝑁superscriptsubscript𝑗1𝑁superscriptsuperscriptsubscript𝑗1𝑡superscriptsuperscriptsubscript𝑗𝑀𝑡subscript𝑦𝑗superscriptsubscript𝜃0𝑡𝛽superscriptsubscript𝜃0𝑡\nabla_{\theta_{0}^{(t)}}\mathcal{L}_{\mathrm{VFL}}(\theta_{0}^{(t)})=\nabla_{% \theta_{0}^{(t)}}\left(\frac{1}{N}\sum_{j=1}^{N}\ell([{h_{j}^{1}}^{(t)},\ldots% ,{h_{j}^{M}}^{(t)}],y_{j};\theta_{0}^{(t)})+\beta\mathcal{R}(\theta_{0}^{(t)})% \right).∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_VFL end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) = ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_ℓ ( [ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , … , italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ] , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) + italic_β caligraphic_R ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) . (15)

Here [hj1(t),,hjM(t)]superscriptsuperscriptsubscript𝑗1𝑡superscriptsuperscriptsubscript𝑗𝑀𝑡[{h_{j}^{1}}^{(t)},\ldots,{h_{j}^{M}}^{(t)}][ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , … , italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ] denotes the concatenated local embeddings.

(3) Communication from server to client. Server computes gradients w.r.t each local embedding hjk(t)VFL(θ0(t+1))subscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡subscriptVFLsuperscriptsubscript𝜃0𝑡1\nabla_{{h_{j}^{k}}^{(t)}}\mathcal{L}_{\mathrm{VFL}}(\theta_{0}^{(t+1)})∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_VFL end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) by the VFL objective in Eq. 1, where

hjk(t)VFL(θ0(t+1))=hjk(t)([hj1(t),,hjM(t)],yj;θ0(t+1)),jB(t),k[M]formulae-sequencesubscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡subscriptVFLsuperscriptsubscript𝜃0𝑡1subscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡superscriptsuperscriptsubscript𝑗1𝑡superscriptsuperscriptsubscript𝑗𝑀𝑡subscript𝑦𝑗superscriptsubscript𝜃0𝑡1formulae-sequencefor-all𝑗𝐵𝑡𝑘delimited-[]𝑀\nabla_{{h_{j}^{k}}^{(t)}}\mathcal{L}_{\mathrm{VFL}}(\theta_{0}^{(t+1)})=% \nabla_{{h_{j}^{k}}^{(t)}}\ell([{h_{j}^{1}}^{(t)},\ldots,{h_{j}^{M}}^{(t)}],y_% {j};\theta_{0}^{(t+1)}),\forall j\in B(t),k\in[M]∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_VFL end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) = ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_ℓ ( [ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , … , italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ] , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) , ∀ italic_j ∈ italic_B ( italic_t ) , italic_k ∈ [ italic_M ] (16)

Server sends gradients {hjk(t)VFL(θ0(t+1))}jB(t)subscriptsubscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡subscriptVFLsuperscriptsubscript𝜃0𝑡1𝑗𝐵𝑡\{\nabla_{{h_{j}^{k}}^{(t)}}\mathcal{L}_{\mathrm{VFL}}(\theta_{0}^{(t+1)})\}_{% j\in B(t)}{ ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_VFL end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT to each client k,k[M]𝑘for-all𝑘delimited-[]𝑀k,\forall k\in[M]italic_k , ∀ italic_k ∈ [ italic_M ].

(4) Client updates local model parameters θksubscriptnormal-θnormal-k\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Finally, every client k𝑘kitalic_k locally updates the model parameters θksubscript𝜃𝑘\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT according to the VFL objective in Eq. 1 as follows:

θk(t+1)θk(t)ηkθk(t)VFL(θ0(t+1)),k[M]formulae-sequencesuperscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝜃𝑘𝑡superscript𝜂𝑘subscriptsuperscriptsubscript𝜃𝑘𝑡subscriptVFLsuperscriptsubscript𝜃0𝑡1for-all𝑘delimited-[]𝑀\theta_{k}^{(t+1)}\leftarrow\theta_{k}^{(t)}-\eta^{k}\nabla_{\theta_{k}^{(t)}}% \mathcal{L}_{\mathrm{VFL}}(\theta_{0}^{(t+1)}),\forall k\in[M]italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ← italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_VFL end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) , ∀ italic_k ∈ [ italic_M ] (17)

where ηksuperscript𝜂𝑘\eta^{k}italic_η start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT is the local learning rate for client k𝑘kitalic_k, and

θk(t)VFL(θ0(t+1))=1Nj=1Nθk(t)hjk(t)hjk(t)VFL(θ0(t+1))+βθk(t)(θk(t))subscriptsuperscriptsubscript𝜃𝑘𝑡subscriptVFLsuperscriptsubscript𝜃0𝑡11𝑁superscriptsubscript𝑗1𝑁subscriptsuperscriptsubscript𝜃𝑘𝑡superscriptsuperscriptsubscript𝑗𝑘𝑡subscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡subscriptVFLsuperscriptsubscript𝜃0𝑡1𝛽subscriptsuperscriptsubscript𝜃𝑘𝑡superscriptsubscript𝜃𝑘𝑡\nabla_{\theta_{k}^{(t)}}\mathcal{L}_{\mathrm{VFL}}(\theta_{0}^{(t+1)})=\frac{% 1}{N}\sum_{j=1}^{N}\nabla_{\theta_{k}^{(t)}}{{h_{j}^{k}}^{(t)}}\nabla_{{h_{j}^% {k}}^{(t)}}\mathcal{L}_{\mathrm{VFL}}(\theta_{0}^{(t+1)})+\beta\nabla_{\theta_% {k}^{(t)}}\mathcal{R}(\theta_{k}^{(t)})∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_VFL end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_VFL end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) + italic_β ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_R ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) (18)

These four steps of Split Learning are summarized in Algorithm 2.

1:  Input:number of communication rounds T𝑇Titalic_T, number of clients M𝑀Mitalic_M, number of training samples N𝑁Nitalic_N, batch size b𝑏bitalic_b , input features {{xj1}j=1N,{xj2}j=1N,,{xjM}j=1N}superscriptsubscriptsuperscriptsubscript𝑥𝑗1𝑗1𝑁superscriptsubscriptsuperscriptsubscript𝑥𝑗2𝑗1𝑁superscriptsubscriptsuperscriptsubscript𝑥𝑗𝑀𝑗1𝑁\{\{x_{j}^{1}\}_{j=1}^{N},\{x_{j}^{2}\}_{j=1}^{N},\ldots,\{x_{j}^{M}\}_{j=1}^{% N}\}{ { italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT , { italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT , … , { italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT }, the labels {yj}j=1Nsuperscriptsubscriptsubscript𝑦𝑗𝑗1𝑁\{y_{j}\}_{j=1}^{N}{ italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT, local model {θk}k=1Msuperscriptsubscriptsubscript𝜃𝑘𝑘1𝑀\{\theta_{k}\}_{k=1}^{M}{ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT; linear heads {Wk}k=1Msuperscriptsubscriptsubscript𝑊𝑘𝑘1𝑀\{W_{k}\}_{k=1}^{M}{ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT; server learning rate η𝜂\etaitalic_η; client learning rate {ηk}k=1Msuperscriptsubscriptsuperscript𝜂𝑘𝑘1𝑀\{\eta^{k}\}_{k=1}^{M}{ italic_η start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT;
2:  for communication round t[T]𝑡delimited-[]𝑇t\in[T]italic_t ∈ [ italic_T ] do
3:     Server samples a set of data indices B(t)𝐵𝑡B(t)italic_B ( italic_t ) with |B(t)|=b𝐵𝑡𝑏|B(t)|=b| italic_B ( italic_t ) | = italic_b
4:     for client k[M]𝑘delimited-[]𝑀k\in[M]italic_k ∈ [ italic_M ] do
5:        generates a local training batch {xjk}jB(t)subscriptsuperscriptsubscript𝑥𝑗𝑘𝑗𝐵𝑡\{x_{j}^{k}\}_{j\in B(t)}{ italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT
6:        computes local embeddings hjk(t)f(xjk;θk),jB(t)formulae-sequencesuperscriptsuperscriptsubscript𝑗𝑘𝑡𝑓superscriptsubscript𝑥𝑗𝑘subscript𝜃𝑘for-all𝑗𝐵𝑡{h_{j}^{k}}^{(t)}\leftarrow f(x_{j}^{k};\theta_{k}),\forall j\in B(t)italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ← italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , ∀ italic_j ∈ italic_B ( italic_t )
7:        sends local embeddings {hjk(t)}jB(t)subscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡𝑗𝐵𝑡\{{h_{j}^{k}}^{(t)}\}_{j\in B(t)}{ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT to the server
8:     Server updates server model θ0(t+1)superscriptsubscript𝜃0𝑡1\theta_{0}^{(t+1)}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT by Eq. 14
9:     Server computes gradients w.r.t embeddings hjk(t)VFL(θ0(t+1))subscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡subscriptVFLsuperscriptsubscript𝜃0𝑡1\nabla_{{h_{j}^{k}}^{(t)}}\mathcal{L}_{\mathrm{VFL}}(\theta_{0}^{(t+1)})∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_VFL end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) by Eq. 16 ,jB(t),\forall j\in B(t), ∀ italic_j ∈ italic_B ( italic_t )
10:     Server sends gradients {hjk(t)VFL(θ0(t+1))}jB(t)subscriptsubscriptsuperscriptsuperscriptsubscript𝑗𝑘𝑡subscriptVFLsuperscriptsubscript𝜃0𝑡1𝑗𝐵𝑡\{\nabla_{{h_{j}^{k}}^{(t)}}\mathcal{L}_{\mathrm{VFL}}(\theta_{0}^{(t+1)})\}_{% j\in B(t)}{ ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_VFL end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT to each client k,k[M]𝑘for-all𝑘delimited-[]𝑀k,\forall k\in[M]italic_k , ∀ italic_k ∈ [ italic_M ]
11:     for client k[M]𝑘delimited-[]𝑀k\in[M]italic_k ∈ [ italic_M ] do
12:        updates local model θk(t+1)superscriptsubscript𝜃𝑘𝑡1\theta_{k}^{(t+1)}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT by Eq. 17
Algorithm 2 Split Learning [70]

-A2 VIMADMM-J

At each communication round t𝑡titalic_t, the server samples a set of data indices, B(t)𝐵𝑡B(t)italic_B ( italic_t ), with batch size |B(t)|=b𝐵𝑡𝑏|B(t)|=b| italic_B ( italic_t ) | = italic_b. Then we describe the key steps of VIMADMM-J (Algorithm 3) as follows:

(1) Communication from client to server. Each client k𝑘kitalic_k sends a batch of local logits {ojk(t)}jB(t)subscriptsuperscriptsuperscriptsubscript𝑜𝑗𝑘𝑡𝑗𝐵𝑡\{{o_{j}^{k}}^{(t)}\}_{j\in B(t)}{ italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT to the server, where ojk(t)=f(xjk;θk(t))Wk(t),jB(t)formulae-sequencesuperscriptsuperscriptsubscript𝑜𝑗𝑘𝑡𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡for-all𝑗𝐵𝑡{o_{j}^{k}}^{(t)}=f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t)},\forall j\in B(t)italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , ∀ italic_j ∈ italic_B ( italic_t )

(2) Sever updates auxiliary variables {zj}subscriptnormal-znormal-j\{z_{j}\}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT }. After receiving the local logits from all clients, the server updates the auxiliary variable for each sample j𝑗jitalic_j as:

zj(t)=argminzj(zj,yj)λj(t1)zj+ρ2k=1Mojk(t)zj2,jB(t)formulae-sequencesuperscriptsubscript𝑧𝑗𝑡subscript𝑧𝑗argminsubscript𝑧𝑗subscript𝑦𝑗superscriptsuperscriptsubscript𝜆𝑗𝑡1topsubscript𝑧𝑗𝜌2superscriptnormsuperscriptsubscript𝑘1𝑀superscriptsuperscriptsubscript𝑜𝑗𝑘𝑡subscript𝑧𝑗2for-all𝑗𝐵𝑡z_{j}^{(t)}=\underset{z_{j}}{\operatorname{argmin}}\quad\ell(z_{j},y_{j})-{% \lambda_{j}^{(t-1)}}^{\top}z_{j}+\frac{\rho}{2}\left\|\sum_{k=1}^{M}{o_{j}^{k}% }^{(t)}-z_{j}\right\|^{2},\forall j\in B(t)italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = start_UNDERACCENT italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_UNDERACCENT start_ARG roman_argmin end_ARG roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) - italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + divide start_ARG italic_ρ end_ARG start_ARG 2 end_ARG ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , ∀ italic_j ∈ italic_B ( italic_t ) (19)

Since the optimization problem in Eq. 19 is convex and differentiable with respect to zjsubscript𝑧𝑗z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, we use the L-BFGS-B algorithm [85] to solve the minimization problem.

(3) Sever updates dual variables {λj}subscriptnormal-λnormal-j\{\lambda_{j}\}{ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT }. After the updates in Eq. 19, the server updates the dual variable for each sample j𝑗jitalic_j as:

λj(t)=λj(t1)+ρ(k=1Mojk(t)zj(t)),jB(t)formulae-sequencesubscriptsuperscript𝜆𝑡𝑗subscriptsuperscript𝜆𝑡1𝑗𝜌superscriptsubscript𝑘1𝑀superscriptsuperscriptsubscript𝑜𝑗𝑘𝑡subscriptsuperscript𝑧𝑡𝑗for-all𝑗𝐵𝑡\lambda^{(t)}_{j}=\lambda^{(t-1)}_{j}+\rho\left(\sum_{k=1}^{M}{o_{j}^{k}}^{(t)% }-z^{(t)}_{j}\right),\forall j\in B(t)italic_λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = italic_λ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_ρ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) , ∀ italic_j ∈ italic_B ( italic_t ) (20)

(4) Communication from server to client. After the updates in Eq. 20, we define a residual variable sjk(t+1)superscriptsuperscriptsubscript𝑠𝑗𝑘𝑡1{s_{j}^{k}}^{(t+1)}italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT for each sample j𝑗jitalic_j of k𝑘kitalic_k-th client, which provides supervision for updating local model:

sjk(t)zj(t)i[M],ikoji(t)superscriptsuperscriptsubscript𝑠𝑗𝑘𝑡superscriptsubscript𝑧𝑗𝑡subscriptformulae-sequence𝑖delimited-[]𝑀𝑖𝑘superscriptsuperscriptsubscript𝑜𝑗𝑖𝑡{s_{j}^{k}}^{(t)}\triangleq{z_{j}}^{(t)}-\sum_{i\in[M],i\neq k}{o_{j}^{i}}^{(t)}italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ≜ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_M ] , italic_i ≠ italic_k end_POSTSUBSCRIPT italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT (21)

The server sends the dual variables {λj(t)}jB(t)subscriptsubscriptsuperscript𝜆𝑡𝑗𝑗𝐵𝑡\{\lambda^{(t)}_{j}\}_{j\in B(t)}{ italic_λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT and the residual variables {sjk(t)}jB(t)subscriptsuperscriptsuperscriptsubscript𝑠𝑗𝑘𝑡𝑗𝐵𝑡\{{s_{j}^{k}}^{(t)}\}_{j\in B(t)}{ italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT of all samples to each client k𝑘kitalic_k.

(5) Client updates linear head Wksubscriptnormal-Wnormal-kW_{k}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and local model θksubscriptnormal-θnormal-k\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT alternatively. The linear head of each client is locally updated as:

Wk(t+1)=argminWkβ(Wk)+1bjB(t)λj(t)f(xjk;θk(t))Wk+jB(t)ρ2bsjk(t)f(xjk;θk(t))WkF2,k[M]formulae-sequencesuperscriptsubscript𝑊𝑘𝑡1subscript𝑊𝑘argmin𝛽subscript𝑊𝑘1𝑏subscript𝑗𝐵𝑡superscriptsuperscriptsubscript𝜆𝑗𝑡top𝑓subscript𝑥subscript𝑗𝑘superscriptsubscript𝜃𝑘𝑡subscript𝑊𝑘subscript𝑗𝐵𝑡𝜌2𝑏superscriptsubscriptnormsuperscriptsuperscriptsubscript𝑠𝑗𝑘𝑡𝑓subscript𝑥subscript𝑗𝑘superscriptsubscript𝜃𝑘𝑡subscript𝑊𝑘𝐹2for-all𝑘delimited-[]𝑀W_{k}^{(t+1)}=\underset{W_{k}}{\operatorname{argmin}}\quad\beta\mathcal{R}(W_{% k})+\frac{1}{b}\sum_{j\in B(t)}{\lambda_{j}^{(t)}}^{\top}f(x_{j_{k}};\theta_{k% }^{(t)})W_{k}+\sum_{j\in B(t)}\frac{\rho}{2b}\left\|{s_{j}^{k}}^{(t)}-f(x_{j_{% k}};\theta_{k}^{(t)})W_{k}\right\|_{F}^{2},\forall k\in[M]italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = start_UNDERACCENT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_UNDERACCENT start_ARG roman_argmin end_ARG italic_β caligraphic_R ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT divide start_ARG italic_ρ end_ARG start_ARG 2 italic_b end_ARG ∥ italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_f ( italic_x start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , ∀ italic_k ∈ [ italic_M ] (22)

Each client updates the local model parameters θksubscript𝜃𝑘\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT as follows:

θk(t+1)=argminθkβ(θk)+1bjB(t)λj(t)f(xjk;θk)Wk(t+1)+jB(t)ρ2bsjk(t)f(xjk;θk)Wk(t+1)F2.superscriptsubscript𝜃𝑘𝑡1subscript𝜃𝑘argmin𝛽subscript𝜃𝑘1𝑏subscript𝑗𝐵𝑡superscriptsuperscriptsubscript𝜆𝑗𝑡top𝑓subscript𝑥subscript𝑗𝑘subscript𝜃𝑘superscriptsubscript𝑊𝑘𝑡1subscript𝑗𝐵𝑡𝜌2𝑏superscriptsubscriptnormsuperscriptsuperscriptsubscript𝑠𝑗𝑘𝑡𝑓subscript𝑥subscript𝑗𝑘subscript𝜃𝑘superscriptsubscript𝑊𝑘𝑡1𝐹2\theta_{k}^{(t+1)}=\underset{\theta_{k}}{\operatorname{argmin}}\quad\beta% \mathcal{R}(\theta_{k})+\frac{1}{b}\sum_{j\in B(t)}{\lambda_{j}^{(t)}}^{\top}f% (x_{j_{k}};\theta_{k}){W_{k}^{(t+1)}}+\sum_{j\in B(t)}\frac{\rho}{2b}\left\|{s% _{j}^{k}}^{(t)}-f(x_{j_{k}};\theta_{k}){W_{k}^{(t+1)}}\right\|_{F}^{2}.italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = start_UNDERACCENT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_UNDERACCENT start_ARG roman_argmin end_ARG italic_β caligraphic_R ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT divide start_ARG italic_ρ end_ARG start_ARG 2 italic_b end_ARG ∥ italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_f ( italic_x start_POSTSUBSCRIPT italic_j start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (23)

Due to the nonconvexity of the loss function of DNN, we use τ𝜏\tauitalic_τ local steps of SGD to update Wksubscript𝑊𝑘W_{k}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and θksubscript𝜃𝑘\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT alternatively at each round with the objective of Eq. 22 and Eq. 23. Specifically, at each local step, we first update Wksubscript𝑊𝑘W_{k}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and then update θksubscript𝜃𝑘\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT.

These five steps of VIMADMM-J are summarized in Algorithm 3.

Algorithm 3 VIMADMM-J (with differentially privacy)
1:  Input:number of communication rounds T𝑇Titalic_T, number of clients M𝑀Mitalic_M, number of training samples N𝑁Nitalic_N, batch size b𝑏bitalic_b , input features {{xj1}j=1N,{xj2}j=1N,,{xjM}j=1N}superscriptsubscriptsuperscriptsubscript𝑥𝑗1𝑗1𝑁superscriptsubscriptsuperscriptsubscript𝑥𝑗2𝑗1𝑁superscriptsubscriptsuperscriptsubscript𝑥𝑗𝑀𝑗1𝑁\{\{x_{j}^{1}\}_{j=1}^{N},\{x_{j}^{2}\}_{j=1}^{N},\ldots,\{x_{j}^{M}\}_{j=1}^{% N}\}{ { italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT , { italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT , … , { italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT }, the labels {yj}j=1Nsuperscriptsubscriptsubscript𝑦𝑗𝑗1𝑁\{y_{j}\}_{j=1}^{N}{ italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT, local model {θk}k=1Msuperscriptsubscriptsubscript𝜃𝑘𝑘1𝑀\{\theta_{k}\}_{k=1}^{M}{ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT; linear heads {Wk}k=1Msuperscriptsubscriptsubscript𝑊𝑘𝑘1𝑀\{W_{k}\}_{k=1}^{M}{ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT; auxiliary variables {zj}j=1Nsuperscriptsubscriptsubscript𝑧𝑗𝑗1𝑁\{z_{j}\}_{j=1}^{N}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT; dual variables {λj}j=1Nsuperscriptsubscriptsubscript𝜆𝑗𝑗1𝑁\{\lambda_{j}\}_{j=1}^{N}{ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT; noise parameter σ𝜎\sigmaitalic_σ, clipping constant C𝐶Citalic_C
2:  for communication round t[T]𝑡delimited-[]𝑇t\in[T]italic_t ∈ [ italic_T ] do
3:     Server samples a set of data indices B(t)𝐵𝑡B(t)italic_B ( italic_t ) with |B(t)|=bs𝐵𝑡subscript𝑏𝑠|B(t)|=b_{s}| italic_B ( italic_t ) | = italic_b start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT
4:     for client k[M]𝑘delimited-[]𝑀k\in[M]italic_k ∈ [ italic_M ] do
5:        generates a local training batch {xjk}jB(t)subscriptsuperscriptsubscript𝑥𝑗𝑘𝑗𝐵𝑡\{x_{j}^{k}\}_{j\in B(t)}{ italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT
6:        computes local logits ojk(t)=f(xjk;θk(t))Wk(t),jB(t)formulae-sequencesuperscriptsuperscriptsubscript𝑜𝑗𝑘𝑡𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡for-all𝑗𝐵𝑡{o_{j}^{k}}^{(t)}=f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t)},\forall j\in B(t)italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT = italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , ∀ italic_j ∈ italic_B ( italic_t )
7:         clips and perturbs local logit matrix \medmath{ojk(t)}jB(t)𝙲𝚕𝚒𝚙({ojk(t)}jB(t),C)+𝒩(0,σ2C2)\medmathsubscriptsuperscriptsuperscriptsubscript𝑜𝑗𝑘𝑡𝑗𝐵𝑡𝙲𝚕𝚒𝚙subscriptsuperscriptsuperscriptsubscript𝑜𝑗𝑘𝑡𝑗𝐵𝑡𝐶𝒩0superscript𝜎2superscript𝐶2\medmath{\{{o_{j}^{k}}^{(t)}\}_{j\in B(t)}\leftarrow\mathtt{Clip}\left(\{{o_{j% }^{k}}^{(t)}\}_{j\in B(t)},C\right)+\mathcal{N}\left(0,\sigma^{2}C^{2}\right)}{ italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT ← typewriter_Clip ( { italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT , italic_C ) + caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
8:        sends local logits {ojk(t)}jB(t)subscriptsuperscriptsuperscriptsubscript𝑜𝑗𝑘𝑡𝑗𝐵𝑡\{{o_{j}^{k}}^{(t)}\}_{j\in B(t)}{ italic_o start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT to the server
9:     Server updates auxiliary variables zj(t)superscriptsubscript𝑧𝑗𝑡z_{j}^{(t)}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT by Eq. 19,jB(t),\forall j\in B(t), ∀ italic_j ∈ italic_B ( italic_t )
10:     Server updates dual variables λj(t)subscriptsuperscript𝜆𝑡𝑗\lambda^{(t)}_{j}italic_λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT by Eq. 20 ,jB(t),\forall j\in B(t), ∀ italic_j ∈ italic_B ( italic_t )
11:     Server computes residual variables sjk(t)superscriptsuperscriptsubscript𝑠𝑗𝑘𝑡{s_{j}^{k}}^{(t)}italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT by Eq. 21 ,jB(t),k[M],\forall j\in B(t),k\in[M], ∀ italic_j ∈ italic_B ( italic_t ) , italic_k ∈ [ italic_M ]
12:     Server sends {λj(t)}jB(t)subscriptsubscriptsuperscript𝜆𝑡𝑗𝑗𝐵𝑡\{\lambda^{(t)}_{j}\}_{j\in B(t)}{ italic_λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT , {sjk(t)}jB(t)subscriptsuperscriptsuperscriptsubscript𝑠𝑗𝑘𝑡𝑗𝐵𝑡\{{s_{j}^{k}}^{(t)}\}_{j\in B(t)}{ italic_s start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_j ∈ italic_B ( italic_t ) end_POSTSUBSCRIPT to each client k,k[M]𝑘for-all𝑘delimited-[]𝑀k,\forall k\in[M]italic_k , ∀ italic_k ∈ [ italic_M ]
13:     for client k[M]𝑘delimited-[]𝑀k\in[M]italic_k ∈ [ italic_M ] do
14:        for local step e[τ]𝑒delimited-[]𝜏e\in[\tau]italic_e ∈ [ italic_τ ] do
15:           updates local linear head Wk(t+1)superscriptsubscript𝑊𝑘𝑡1W_{k}^{(t+1)}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT by Eq. 22 with SGD
16:           updates local model θk(t+1)superscriptsubscript𝜃𝑘𝑡1\theta_{k}^{(t+1)}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT by Eq. 23 with SGD

-B Convergence Guarantees

-B1 Additional Notations and Supporting Lemmas

To help theoretical analysis, we denote the objective functions in Eq. 5, Section III-B , Section III-B as

h(zj)subscript𝑧𝑗\displaystyle h(z_{j})italic_h ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) =(zj)λj(t)zj+ρ2k=1Mf(xjk;θk(t+1))Wk(t+1)zjF2absentsubscript𝑧𝑗superscriptsuperscriptsubscript𝜆𝑗𝑡topsubscript𝑧𝑗𝜌2superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1subscript𝑧𝑗𝐹2\displaystyle=\ell(z_{j})-{\lambda_{j}^{(t)}}^{\top}z_{j}+\frac{\rho}{2}\left% \|\sum_{k=1}^{M}{f(x_{j}^{k};\theta_{k}^{(t+1)})}W_{k}^{(t+1)}-z_{j}\right\|_{% F}^{2}= roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) - italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + divide start_ARG italic_ρ end_ARG start_ARG 2 end_ARG ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (24)
gk(Wk)subscript𝑔𝑘subscript𝑊𝑘\displaystyle g_{k}(W_{k})italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) =βkk(Wk)+1Nj[N]λj(t)f(xjk;θk(t))Wkabsentsubscript𝛽𝑘subscript𝑘subscript𝑊𝑘1𝑁subscript𝑗delimited-[]𝑁superscriptsuperscriptsubscript𝜆𝑗𝑡top𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡subscript𝑊𝑘\displaystyle=\beta_{k}\mathcal{R}_{k}(W_{k})+\frac{1}{N}\smashoperator[]{\sum% _{\begin{subarray}{c}j\in[N]\end{subarray}}^{}}{\lambda_{j}^{(t)}}^{\top}{f(x_% {j}^{k};\theta_{k}^{(t)})}W_{k}= italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG start_SUMOP SUBSCRIPTOP ∑ start_ARG start_ARG start_ROW start_CELL italic_j ∈ [ italic_N ] end_CELL end_ROW end_ARG end_ARG end_SUMOP italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT
+ρ2Nj[N]i[M],ikf(xji;θi(t))Wi(t)+f(xjk;θk(t))Wkzj(t)F2𝜌2𝑁subscript𝑗delimited-[]𝑁superscriptsubscriptnormsubscript𝑖delimited-[]𝑀𝑖𝑘𝑓superscriptsubscript𝑥𝑗𝑖superscriptsubscript𝜃𝑖𝑡superscriptsubscript𝑊𝑖𝑡𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡subscript𝑊𝑘superscriptsubscript𝑧𝑗𝑡𝐹2\displaystyle\quad+\frac{\rho}{2N}\sum_{j\in[N]}\left\|\sum_{\begin{subarray}{% c}i\in[M],\\ i\neq k\end{subarray}}f(x_{j}^{i};\theta_{i}^{(t)}){W_{i}}^{(t)}+{f(x_{j}^{k};% \theta_{k}^{(t)})}W_{k}-{z_{j}}^{(t)}\right\|_{F}^{2}+ divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j ∈ [ italic_N ] end_POSTSUBSCRIPT ∥ ∑ start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_i ∈ [ italic_M ] , end_CELL end_ROW start_ROW start_CELL italic_i ≠ italic_k end_CELL end_ROW end_ARG end_POSTSUBSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
qk(θk)subscript𝑞𝑘subscript𝜃𝑘\displaystyle q_{k}(\theta_{k})italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) =βkk(θk)+1Nj[N]λj(t)f(xjk;θk)Wk(t+1)absentsubscript𝛽𝑘subscript𝑘subscript𝜃𝑘1𝑁subscript𝑗delimited-[]𝑁superscriptsuperscriptsubscript𝜆𝑗𝑡top𝑓superscriptsubscript𝑥𝑗𝑘subscript𝜃𝑘superscriptsubscript𝑊𝑘𝑡1\displaystyle=\beta_{k}\mathcal{R}_{k}(\theta_{k})+\frac{1}{N}\sum_{j\in[N]}{% \lambda_{j}^{(t)}}^{\top}f(x_{j}^{k};\theta_{k}){W_{k}^{(t+1)}}= italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j ∈ [ italic_N ] end_POSTSUBSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT
+ρ2Nj[N]i[M],ikf(xji;θi(t))Wi(t+1)+f(xjk;θk)Wk(t+1)zj(t)F2𝜌2𝑁subscript𝑗delimited-[]𝑁superscriptsubscriptnormsubscript𝑖delimited-[]𝑀𝑖𝑘𝑓superscriptsubscript𝑥𝑗𝑖superscriptsubscript𝜃𝑖𝑡superscriptsubscript𝑊𝑖𝑡1𝑓superscriptsubscript𝑥𝑗𝑘subscript𝜃𝑘superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑧𝑗𝑡𝐹2\displaystyle\quad+\frac{\rho}{2N}\sum_{j\in[N]}\left\|\sum_{\begin{subarray}{% c}i\in[M],\\ i\neq k\end{subarray}}{f(x_{j}^{i};\theta_{i}^{(t)})}{W_{i}}^{(t+1)}+f(x_{j}^{% k};\theta_{k}){W_{k}^{(t+1)}}-{z_{j}}^{(t)}\right\|_{F}^{2}+ divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j ∈ [ italic_N ] end_POSTSUBSCRIPT ∥ ∑ start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_i ∈ [ italic_M ] , end_CELL end_ROW start_ROW start_CELL italic_i ≠ italic_k end_CELL end_ROW end_ARG end_POSTSUBSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT + italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

Before delving into the main proofs, we introduce the bellow supporting lemmas.

Lemma 1.
(zj(t))=λj(t)superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡\nabla\ell(z_{j}^{(t)})=\lambda_{j}^{(t)}∇ roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) = italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT (25)
Proof.

According to the optimality of zj(t)superscriptsubscript𝑧𝑗𝑡z_{j}^{(t)}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT Eq. 5

(zj(t))λj(t1)ρ(k=1Mf(xjk;θk(t))Wk(t)zj(t))=0,jB(t)formulae-sequencesuperscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡1𝜌superscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡0for-all𝑗𝐵𝑡\displaystyle\nabla\ell(z_{j}^{(t)})-{\lambda_{j}^{(t-1)}}-\rho\left(\sum_{k=1% }^{M}{f(x_{j}^{k};\theta_{k}^{(t)})}W_{k}^{(t)}-z_{j}^{(t)}\right)=0,\forall j% \in B(t)∇ roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT - italic_ρ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) = 0 , ∀ italic_j ∈ italic_B ( italic_t ) (26)

then invoke Eq. 6 λj(t)=λj(t1)+ρ(k=1Mf(xjk;θk(t))Wk(t)zj(t))subscriptsuperscript𝜆𝑡𝑗subscriptsuperscript𝜆𝑡1𝑗𝜌superscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡subscriptsuperscript𝑧𝑡𝑗\lambda^{(t)}_{j}=\lambda^{(t-1)}_{j}+\rho\left(\sum_{k=1}^{M}{f(x_{j}^{k};% \theta_{k}^{(t)})}W_{k}^{(t)}-z^{(t)}_{j}\right)italic_λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = italic_λ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_ρ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ), so we have (zj(t))=λj(t)superscriptsubscript𝑧𝑗𝑡subscriptsuperscript𝜆𝑡𝑗\nabla\ell(z_{j}^{(t)})=\lambda^{(t)}_{j}∇ roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) = italic_λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. ∎

Lemma 2.
λj(t)λj(t1)Lzj(t)zj(t1)normsuperscriptsubscript𝜆𝑗𝑡superscriptsubscript𝜆𝑗𝑡1𝐿normsuperscriptsubscript𝑧𝑗𝑡superscriptsubscript𝑧𝑗𝑡1\|\lambda_{j}^{(t)}-\lambda_{j}^{(t-1)}\|\leq L\|z_{j}^{(t)}-z_{j}^{(t-1)}\|∥ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT ∥ ≤ italic_L ∥ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT ∥ (27)
Proof.

According to 1 and Lemma 1, we have

λj(t)λj(t1)=(zj(t))(zj(t1))Lzj(t)zj(t1)normsuperscriptsubscript𝜆𝑗𝑡superscriptsubscript𝜆𝑗𝑡1normsuperscriptsubscript𝑧𝑗𝑡superscriptsubscript𝑧𝑗𝑡1𝐿normsuperscriptsubscript𝑧𝑗𝑡superscriptsubscript𝑧𝑗𝑡1\displaystyle\|\lambda_{j}^{(t)}-\lambda_{j}^{(t-1)}\|=\|\nabla\ell(z_{j}^{(t)% })-\nabla\ell(z_{j}^{(t-1)})\|\leq L\|z_{j}^{(t)}-z_{j}^{(t-1)}\|∥ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT ∥ = ∥ ∇ roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - ∇ roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT ) ∥ ≤ italic_L ∥ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT ∥ (28)

Lemma 3.

[35, Lemma 3]

(m=1Mxmt+1z2m=1Mxmtz2)superscriptnormsuperscriptsubscript𝑚1𝑀superscriptsubscript𝑥𝑚𝑡1𝑧2superscriptnormsuperscriptsubscript𝑚1𝑀superscriptsubscript𝑥𝑚𝑡𝑧2\displaystyle\left(\left\|\sum_{m=1}^{M}x_{m}^{t+1}-z\right\|^{2}-\left\|\sum_% {m=1}^{M}x_{m}^{t}-z\right\|^{2}\right)( ∥ ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT - italic_z ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - ∥ ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT - italic_z ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (29)
m=1M(k=1kmMxkt+xmt+1z2m=1Mxmtz2)+m=1Mxmt+1xmt2absentsuperscriptsubscript𝑚1𝑀superscriptnormsuperscriptsubscript𝑘1𝑘𝑚𝑀superscriptsubscript𝑥𝑘𝑡superscriptsubscript𝑥𝑚𝑡1𝑧2superscriptnormsuperscriptsubscript𝑚1𝑀superscriptsubscript𝑥𝑚𝑡𝑧2superscriptsubscript𝑚1𝑀superscriptnormsuperscriptsubscript𝑥𝑚𝑡1superscriptsubscript𝑥𝑚𝑡2\displaystyle\leq\sum_{m=1}^{M}\left(\left\|\sum_{\begin{subarray}{c}k=1\\ k\neq m\end{subarray}}^{M}x_{k}^{t}+x_{m}^{t+1}-z\right\|^{2}-\left\|\sum_{m=1% }^{M}x_{m}^{t}-z\right\|^{2}\right)+\sum_{m=1}^{M}\left\|x_{m}^{t+1}-x_{m}^{t}% \right\|^{2}≤ ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ( ∥ ∑ start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_k = 1 end_CELL end_ROW start_ROW start_CELL italic_k ≠ italic_m end_CELL end_ROW end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT + italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT - italic_z ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - ∥ ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT - italic_z ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∥ italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT - italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (32)

-B2 Proofs for Theorem 1

We restate our assumptions here in Theorem 1:

Assumption 1.

(z;)𝑧\ell(z;\cdot)roman_ℓ ( italic_z ; ⋅ ) is L𝐿Litalic_L-Lipschitz smooth w.r.t z𝑧zitalic_z.

Assumption 2.

ADMMsubscriptADMM\mathcal{L}_{\mathrm{ADMM}}caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT is strongly convex w.r.t z𝑧zitalic_z, W𝑊Witalic_W, θ𝜃\thetaitalic_θ with constant μzsubscript𝜇𝑧\mu_{z}italic_μ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT, μWsubscript𝜇𝑊\mu_{W}italic_μ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT, μθsubscript𝜇𝜃\mu_{\theta}italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT respectively.

Assumption 3.

The norm of Wksubscript𝑊𝑘W_{k}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is bounded WkσWnormsubscript𝑊𝑘subscript𝜎𝑊\|W_{k}\|\leq\sigma_{W}∥ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ ≤ italic_σ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT. The local model f(;θ)𝑓𝜃f(\cdot;\theta)italic_f ( ⋅ ; italic_θ ) has bounded gradient f(;θ)Lθnorm𝑓𝜃subscript𝐿𝜃\|\nabla f(\cdot;\theta)\|\leq L_{\theta}∥ ∇ italic_f ( ⋅ ; italic_θ ) ∥ ≤ italic_L start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and bounded output norm f(;θ)σθnorm𝑓𝜃subscript𝜎𝜃\|f(\cdot;\theta)\|\leq\sigma_{\theta}∥ italic_f ( ⋅ ; italic_θ ) ∥ ≤ italic_σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT.

Assumption 4.

The original objective function 𝚅𝙸𝙼subscript𝚅𝙸𝙼\mathcal{L}_{\mathrm{\texttt{VIM}}}caligraphic_L start_POSTSUBSCRIPT VIM end_POSTSUBSCRIPT is bounded from below over ΘΘ\Thetaroman_Θ and 𝒲𝒲\mathcal{W}caligraphic_W, that is e¯:=min{θk}Θ,{Wk}𝒲𝚅𝙸𝙼({Wk}k=1M,{θk}k=1M)>.assign¯𝑒subscriptformulae-sequencesubscript𝜃𝑘Θsubscript𝑊𝑘𝒲subscript𝚅𝙸𝙼superscriptsubscriptsubscript𝑊𝑘𝑘1𝑀superscriptsubscriptsubscript𝜃𝑘𝑘1𝑀\underline{e}:=\min_{\{\theta_{k}\}\in\Theta,\{W_{k}\}\in\mathcal{W}}\mathcal{% L}_{\mathrm{\texttt{VIM}}}(\{W_{k}\}_{k=1}^{M},\{\theta_{k}\}_{k=1}^{M})>-\infty.under¯ start_ARG italic_e end_ARG := roman_min start_POSTSUBSCRIPT { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } ∈ roman_Θ , { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } ∈ caligraphic_W end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT VIM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ) > - ∞ .

-B3 Proofs for Theorem 1 Part (A)

We decompose Theorem 1 part (A) into the below two lemmas and prove them one-by-one.

Lemma 4.

Let 1 to 3 hold, and there exists a penalty parameter ρ𝜌\rhoitalic_ρ satisfying

max{L,2L2μz}<ρ<min{μθLθ2σW2,μWσθ2}𝐿2superscript𝐿2subscript𝜇𝑧𝜌subscript𝜇𝜃superscriptsubscript𝐿𝜃2superscriptsubscript𝜎𝑊2subscript𝜇𝑊superscriptsubscript𝜎𝜃2\max\{L,\frac{2L^{2}}{\mu_{z}}\}<\rho<\min\{\frac{\mu_{\theta}}{L_{\theta}^{2}% \sigma_{W}^{2}},\frac{\mu_{W}}{\sigma_{\theta}^{2}}\}roman_max { italic_L , divide start_ARG 2 italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_μ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_ARG } < italic_ρ < roman_min { divide start_ARG italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_ARG start_ARG italic_L start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG , divide start_ARG italic_μ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG } (33)

then ADMMsubscriptnormal-ADMM\mathcal{L}_{\mathrm{ADMM}}caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT is monotonically decreasing:

ADMM({Wk(t+1)},{θk(t+1)},{zj(t+1)},{λj(t+1)})ADMM({Wk(t)},{θk(t)},{zj(t)},{λj(t)})<0.subscriptADMMsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝜆𝑗𝑡1subscriptADMMsuperscriptsubscript𝑊𝑘𝑡superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡0\displaystyle\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{\theta_{k}^{(t+1)% }\},\{z_{j}^{(t+1)}\},\{\lambda_{j}^{(t+1)}\})-\mathcal{L}_{\mathrm{ADMM}}(\{W% _{k}^{(t)}\},\{\theta_{k}^{(t)}\},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\})<0.caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } ) - caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } ) < 0 . (34)
Lemma 5.

Let 1 to 4 hold, then the following limit exists and ADMMsubscriptnormal-ADMM\mathcal{L}_{\mathrm{ADMM}}caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT is lower bounded by e defined in 4:

limtADMM({Wk(t)},{θk(t)},{zj(t)},{λj(t)})e¯.subscript𝑡subscriptADMMsuperscriptsubscript𝑊𝑘𝑡superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡¯𝑒\displaystyle\lim_{t\rightarrow\infty}\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t)% }\},\{\theta_{k}^{(t)}\},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\})\geq\underline{% e}.roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } ) ≥ under¯ start_ARG italic_e end_ARG . (35)

We first present the proof for the monotonically decreasing property of ADMMsubscriptADMM\mathcal{L}_{\mathrm{ADMM}}caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT in Lemma 4.

Proof for Lemma 4.
ADMM({Wk(t+1)},{θk(t+1)},{zj(t+1)},{λj(t+1)})ADMM({Wk(t)},{θk(t)},{zj(t)},{λj(t)})subscriptADMMsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝜆𝑗𝑡1subscriptADMMsuperscriptsubscript𝑊𝑘𝑡superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡\displaystyle\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{\theta_{k}^{(t+1)% }\},\{z_{j}^{(t+1)}\},\{\lambda_{j}^{(t+1)}\})-\mathcal{L}_{\mathrm{ADMM}}(\{W% _{k}^{(t)}\},\{\theta_{k}^{(t)}\},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\})caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } ) - caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } ) (36)
=\displaystyle== ADMM({Wk(t+1)},{θk(t+1)},{zj(t+1)},{λj(t+1)})ADMM({Wk(t+1)},{θk(t+1)},{zj(t+1)},{λj(t)})T1subscriptsubscriptADMMsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝜆𝑗𝑡1subscriptADMMsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝜆𝑗𝑡subscript𝑇1\displaystyle\underbrace{\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{% \theta_{k}^{(t+1)}\},\{z_{j}^{(t+1)}\},\{\lambda_{j}^{(t+1)}\})-\mathcal{L}_{% \mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{\theta_{k}^{(t+1)}\},\{z_{j}^{(t+1)}\},\{% \lambda_{j}^{(t)}\})}_{T_{1}}under⏟ start_ARG caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } ) - caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } ) end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT
+ADMM({Wk(t+1)},{θk(t+1)},{zj(t+1)},{λj(t)})ADMM({Wk(t+1)},{θk(t+1)},{zj(t)},{λj(t)})T2subscriptsubscriptADMMsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝜆𝑗𝑡subscriptADMMsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡subscript𝑇2\displaystyle+\underbrace{\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{% \theta_{k}^{(t+1)}\},\{z_{j}^{(t+1)}\},\{\lambda_{j}^{(t)}\})-\mathcal{L}_{% \mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{\theta_{k}^{(t+1)}\},\{z_{j}^{(t)}\},\{% \lambda_{j}^{(t)}\})}_{T_{2}}+ under⏟ start_ARG caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } ) - caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } ) end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT
+ADMM({Wk(t+1)},{θk(t+1)},{zj(t)},{λj(t)}ADMM({Wk(t+1)},{θk(t)},{zj(t)},{λj(t)})T3\displaystyle+\underbrace{\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{% \theta_{k}^{(t+1)}\},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\}-\mathcal{L}_{% \mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{\theta_{k}^{(t)}\},\{z_{j}^{(t)}\},\{% \lambda_{j}^{(t)}\})}_{T_{3}}+ under⏟ start_ARG caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } - caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } ) end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_POSTSUBSCRIPT
+ADMM({Wk(t+1)},{θk(t)},{zj(t)},{λj(t)}ADMM({Wk(t)},{θk(t)},{zj(t)},{λj(t)})T4\displaystyle+\underbrace{\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{% \theta_{k}^{(t)}\},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\}-\mathcal{L}_{\mathrm{% ADMM}}(\{W_{k}^{(t)}\},\{\theta_{k}^{(t)}\},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)% }\})}_{T_{4}}+ under⏟ start_ARG caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } - caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } ) end_ARG start_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT end_POSTSUBSCRIPT

Recall ADMM objective function

ADMM=1Nj=1N(zj,yj)+k=1Mβkk(θk)+k=1Mβkk(Wk)subscriptADMM1𝑁superscriptsubscript𝑗1𝑁subscript𝑧𝑗subscript𝑦𝑗superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘subscript𝜃𝑘superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘subscript𝑊𝑘\displaystyle\mathcal{L}_{\mathrm{ADMM}}=\frac{1}{N}\sum_{j=1}^{N}\ell(z_{j},y% _{j})+\sum_{k=1}^{M}\beta_{k}\mathcal{R}_{k}(\theta_{k})+\sum_{k=1}^{M}\beta_{% k}\mathcal{R}_{k}(W_{k})caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT )
+1Nj=1Nλj(k=1Mf(xjk;θk)Wkzj)+ρ2Nj=1Nk=1Mf(xjk;θk)WkzjF21𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝜆𝑗topsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘subscript𝜃𝑘subscript𝑊𝑘subscript𝑧𝑗𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘subscript𝜃𝑘subscript𝑊𝑘subscript𝑧𝑗𝐹2\displaystyle+\frac{1}{N}\sum_{j=1}^{N}\lambda_{j}^{\top}\left(\sum_{k=1}^{M}f% (x_{j}^{k};\theta_{k})W_{k}-z_{j}\right)+\frac{\rho}{2N}\sum_{j=1}^{N}\left\|% \sum_{k=1}^{M}f(x_{j}^{k};\theta_{k})W_{k}-z_{j}\right\|_{F}^{2}+ divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

Then we have

T1subscript𝑇1\displaystyle T_{1}italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT =ADMM({Wk(t+1)},{θk(t+1)},{zj(t+1)},{λj(t+1)})ADMM({Wk(t+1)},{θk(t+1)},{zj(t+1)},{λj(t)})absentsubscriptADMMsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝜆𝑗𝑡1subscriptADMMsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝜆𝑗𝑡\displaystyle=\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{\theta_{k}^{(t+1% )}\},\{z_{j}^{(t+1)}\},\{\lambda_{j}^{(t+1)}\})-\mathcal{L}_{\mathrm{ADMM}}(\{% W_{k}^{(t+1)}\},\{\theta_{k}^{(t+1)}\},\{z_{j}^{(t+1)}\},\{\lambda_{j}^{(t)}\})= caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } ) - caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } )
=1Nj=1N(λj(t+1)λj(t))(k=1Mf(xjk;θk(t+1))Wk(t+1)zj(t+1))absent1𝑁superscriptsubscript𝑗1𝑁superscriptsuperscriptsubscript𝜆𝑗𝑡1superscriptsubscript𝜆𝑗𝑡topsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑧𝑗𝑡1\displaystyle=\frac{1}{N}\sum_{j=1}^{N}(\lambda_{j}^{(t+1)}-\lambda_{j}^{(t)})% ^{\top}\left(\sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{(t+1)})W_{k}^{(t+1)}-z_{j}^% {(t+1)}\right)= divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ( italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT )
=(a)1Nj=1N1ρλj(t+1)λj(t)2𝑎1𝑁superscriptsubscript𝑗1𝑁1𝜌superscriptnormsuperscriptsubscript𝜆𝑗𝑡1superscriptsubscript𝜆𝑗𝑡2\displaystyle\overset{(a)}{=}\frac{1}{N}\sum_{j=1}^{N}\frac{1}{\rho}\|\lambda_% {j}^{(t+1)}-\lambda_{j}^{(t)}\|^{2}start_OVERACCENT ( italic_a ) end_OVERACCENT start_ARG = end_ARG divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG italic_ρ end_ARG ∥ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
(b)j=1NL2ρNzj(t+1)zj(t)2𝑏superscriptsubscript𝑗1𝑁superscript𝐿2𝜌𝑁superscriptnormsuperscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝑧𝑗𝑡2\displaystyle\overset{(b)}{\leq}\sum_{j=1}^{N}\frac{L^{2}}{\rho N}\|z_{j}^{(t+% 1)}-z_{j}^{(t)}\|^{2}start_OVERACCENT ( italic_b ) end_OVERACCENT start_ARG ≤ end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_ρ italic_N end_ARG ∥ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

where (a) we use the Eq. 6 that 1ρ(λj(t)λj(t1))=(k=1Mf(xjk;θk(t))Wk(t)zj(t))1𝜌subscriptsuperscript𝜆𝑡𝑗subscriptsuperscript𝜆𝑡1𝑗superscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡subscriptsuperscript𝑧𝑡𝑗\frac{1}{\rho}(\lambda^{(t)}_{j}-\lambda^{(t-1)}_{j})=\left(\sum_{k=1}^{M}{f(x% _{j}^{k};\theta_{k}^{(t)})}W_{k}^{(t)}-z^{(t)}_{j}\right)divide start_ARG 1 end_ARG start_ARG italic_ρ end_ARG ( italic_λ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_λ start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ); (b) we use Lemma 2.

T2subscript𝑇2\displaystyle T_{2}italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT =ADMM({Wk(t+1)},{θk(t+1)},{zj(t+1)},{λj(t)})ADMM({Wk(t+1)},{θk(t+1)},{zj(t)},{λj(t)})absentsubscriptADMMsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝜆𝑗𝑡subscriptADMMsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡\displaystyle=\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{\theta_{k}^{(t+1% )}\},\{z_{j}^{(t+1)}\},\{\lambda_{j}^{(t)}\})-\mathcal{L}_{\mathrm{ADMM}}(\{W_% {k}^{(t+1)}\},\{\theta_{k}^{(t+1)}\},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\})= caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } ) - caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } )
=1Nj=1N((zj(t+1))(zj(t)))1Nj=1Nλj(t)(zj(t+1)zj(t))absent1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝑧𝑗𝑡1𝑁superscriptsubscript𝑗1𝑁superscriptsuperscriptsubscript𝜆𝑗𝑡topsuperscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝑧𝑗𝑡\displaystyle=\frac{1}{N}\sum_{j=1}^{N}\left(\ell(z_{j}^{(t+1)})-\ell(z_{j}^{(% t)})\right)-\frac{1}{N}\sum_{j=1}^{N}{\lambda_{j}^{(t)}}^{\top}\left(z_{j}^{(t% +1)}-z_{j}^{(t)}\right)= divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ( roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) - roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) - divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT )
+\displaystyle++ ρ2Nj=1N(k=1Mf(xjk;θk(t+1))Wk(t+1)zj(t+1)F2k=1Mf(xjk;θk(t+1))Wk(t+1)zj(t)F2)𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑧𝑗𝑡1𝐹2superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑧𝑗𝑡𝐹2\displaystyle\frac{\rho}{2N}\sum_{j=1}^{N}\left(\left\|\sum_{k=1}^{M}f(x_{j}^{% k};\theta_{k}^{(t+1)})W_{k}^{(t+1)}-z_{j}^{(t+1)}\right\|_{F}^{2}-\left\|\sum_% {k=1}^{M}f(x_{j}^{k};\theta_{k}^{(t+1)})W_{k}^{(t+1)}-z_{j}^{(t)}\right\|_{F}^% {2}\right)divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ( ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
=1Nj=1N(h(zj(t+1))h(zj(t)))absent1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝑧𝑗𝑡\displaystyle=\frac{1}{N}\sum_{j=1}^{N}\left(h(z_{j}^{(t+1)})-h(z_{j}^{(t)})\right)= divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ( italic_h ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) - italic_h ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) )
(a)1Nj=1N(h(zj(t+1)),zj(t+1)zj(t)μz2zj(t+1)zj(t)2)𝑎1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝑧𝑗𝑡subscript𝜇𝑧2superscriptnormsuperscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝑧𝑗𝑡2\displaystyle\overset{(a)}{\leq}\frac{1}{N}\sum_{j=1}^{N}\left(\left\langle% \nabla h(z_{j}^{(t+1)}),z_{j}^{(t+1)}-z_{j}^{(t)}\right\rangle-\frac{\mu_{z}}{% 2}\|z_{j}^{(t+1)}-z_{j}^{(t)}\|^{2}\right)start_OVERACCENT ( italic_a ) end_OVERACCENT start_ARG ≤ end_ARG divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ( ⟨ ∇ italic_h ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) , italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ⟩ - divide start_ARG italic_μ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∥ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
=(b)μz2Nj=1Nzj(t+1)zj(t)2𝑏subscript𝜇𝑧2𝑁superscriptsubscript𝑗1𝑁superscriptnormsuperscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝑧𝑗𝑡2\displaystyle\overset{(b)}{=}-\frac{\mu_{z}}{2N}\sum_{j=1}^{N}\|z_{j}^{(t+1)}-% z_{j}^{(t)}\|^{2}start_OVERACCENT ( italic_b ) end_OVERACCENT start_ARG = end_ARG - divide start_ARG italic_μ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∥ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

where (a) strong convex of hhitalic_h 2 , (b) optimality of z update at Eq. 5.

T3=subscript𝑇3absent\displaystyle T_{3}=italic_T start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = ADMM({Wk(t+1)},{θk(t+1)},{zj(t)},{λj(t)}ADMM({Wk(t+1)},{θk(t)},{zj(t)},{λj(t)})\displaystyle\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{\theta_{k}^{(t+1)% }\},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\}-\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^% {(t+1)}\},\{\theta_{k}^{(t)}\},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\})caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } - caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } )
=\displaystyle== k=1Mβk(k(θk(t+1))k(θk(t)))+1Nj=1Nλj(t)(k=1M(f(xjk;θk(t+1))Wk(t+1)f(xjk;θk(t))Wk(t+1)))superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘superscriptsubscript𝜃𝑘𝑡1subscript𝑘superscriptsubscript𝜃𝑘𝑡1𝑁superscriptsubscript𝑗1𝑁superscriptsuperscriptsubscript𝜆𝑗𝑡topsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1\displaystyle\sum_{k=1}^{M}\beta_{k}\left(\mathcal{R}_{k}(\theta_{k}^{(t+1)})-% \mathcal{R}_{k}(\theta_{k}^{(t)})\right)+\frac{1}{N}\sum_{j=1}^{N}{\lambda_{j}% ^{(t)}}^{\top}\left(\sum_{k=1}^{M}\left(f(x_{j}^{k};\theta_{k}^{(t+1)})W_{k}^{% (t+1)}-f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t+1)}\right)\right)∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) - caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ( italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) )
+ρ2Nj=1N(k=1Mf(xjk;θk(t+1))Wk(t+1)zj(t)F2k=1Mf(xjk;θk(t))Wk(t+1)zj(t)F2)𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑧𝑗𝑡𝐹2superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑧𝑗𝑡𝐹2\displaystyle+\frac{\rho}{2N}\sum_{j=1}^{N}\left(\left\|\sum_{k=1}^{M}f(x_{j}^% {k};\theta_{k}^{(t+1)})W_{k}^{(t+1)}-z_{j}^{(t)}\right\|_{F}^{2}-\left\|\sum_{% k=1}^{M}f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t+1)}-z_{j}^{(t)}\right\|_{F}^{2}\right)+ divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ( ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
(a)𝑎\displaystyle\overset{(a)}{\leq}start_OVERACCENT ( italic_a ) end_OVERACCENT start_ARG ≤ end_ARG k=1Mβk(k(θk(t+1))k(θk(t)))+1Nj=1Nλj(t)(k=1M(f(xjk;θk(t+1))Wk(t+1)f(xjk;θk(t))Wk(t+1)))superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘superscriptsubscript𝜃𝑘𝑡1subscript𝑘superscriptsubscript𝜃𝑘𝑡1𝑁superscriptsubscript𝑗1𝑁superscriptsuperscriptsubscript𝜆𝑗𝑡topsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1\displaystyle\sum_{k=1}^{M}\beta_{k}\left(\mathcal{R}_{k}(\theta_{k}^{(t+1)})-% \mathcal{R}_{k}(\theta_{k}^{(t)})\right)+\frac{1}{N}\sum_{j=1}^{N}{\lambda_{j}% ^{(t)}}^{\top}\left(\sum_{k=1}^{M}\left(f(x_{j}^{k};\theta_{k}^{(t+1)})W_{k}^{% (t+1)}-f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t+1)}\right)\right)∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) - caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ( italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) )
+ρ2Nj=1Nk=1M(i[M],ikf(xji;θi(t))Wi(t+1)+f(xjk;θk(t+1))Wk(t+1)zj(t)F2k=1Mf(xjk;θk(t))Wk(t+1)zj(t)F2)𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝑀superscriptsubscriptnormsubscript𝑖delimited-[]𝑀𝑖𝑘𝑓superscriptsubscript𝑥𝑗𝑖superscriptsubscript𝜃𝑖𝑡superscriptsubscript𝑊𝑖𝑡1𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑧𝑗𝑡𝐹2superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑧𝑗𝑡𝐹2\displaystyle+\frac{\rho}{2N}\sum_{j=1}^{N}\sum_{k=1}^{M}\left(\left\|\sum_{% \begin{subarray}{c}i\in[M],\\ i\neq k\end{subarray}}{f(x_{j}^{i};\theta_{i}^{(t)})}{W_{i}}^{(t+1)}+f(x_{j}^{% k};\theta_{k}^{(t+1)}){W_{k}^{(t+1)}}-{z_{j}}^{(t)}\right\|_{F}^{2}-\left\|% \sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t+1)}-z_{j}^{(t)}\right\|_{% F}^{2}\right)+ divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ( ∥ ∑ start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_i ∈ [ italic_M ] , end_CELL end_ROW start_ROW start_CELL italic_i ≠ italic_k end_CELL end_ROW end_ARG end_POSTSUBSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT + italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
+ρ2Nj=1Nk=1Mf(xjk;θk(t+1))Wk(t+1)f(xjk;θk(t))Wk(t+1)F2𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝑀superscriptsubscriptnorm𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1𝐹2\displaystyle+\frac{\rho}{2N}\sum_{j=1}^{N}\sum_{k=1}^{M}\left\|f(x_{j}^{k};% \theta_{k}^{(t+1)})W_{k}^{(t+1)}-f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t+1)}% \right\|_{F}^{2}+ divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∥ italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
=\displaystyle== k=1M(qk(θk(t+1))qk(θk(t)))+ρ2Nj=1Nk=1Mf(xjk;θk(t+1))Wk(t+1)f(xjk;θk(t))Wk(t+1)F2superscriptsubscript𝑘1𝑀subscript𝑞𝑘superscriptsubscript𝜃𝑘𝑡1subscript𝑞𝑘superscriptsubscript𝜃𝑘𝑡𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝑀superscriptsubscriptnorm𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1𝐹2\displaystyle\sum_{k=1}^{M}\left(q_{k}(\theta_{k}^{(t+1)})-q_{k}(\theta_{k}^{(% t)})\right)+\frac{\rho}{2N}\sum_{j=1}^{N}\sum_{k=1}^{M}\left\|f(x_{j}^{k};% \theta_{k}^{(t+1)})W_{k}^{(t+1)}-f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t+1)}% \right\|_{F}^{2}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ( italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) - italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) + divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∥ italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
(b)𝑏\displaystyle\overset{(b)}{\leq}start_OVERACCENT ( italic_b ) end_OVERACCENT start_ARG ≤ end_ARG k=1M(qk(θk(t+1)),θk(t+1)θk(t)μθ2θk(t+1)θk(t)2)superscriptsubscript𝑘1𝑀subscript𝑞𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝜃𝑘𝑡subscript𝜇𝜃2superscriptnormsuperscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝜃𝑘𝑡2\displaystyle\sum_{k=1}^{M}\left(\left\langle\nabla q_{k}(\theta_{k}^{(t+1)}),% \theta_{k}^{(t+1)}-\theta_{k}^{(t)}\right\rangle-\frac{\mu_{\theta}}{2}\|% \theta_{k}^{(t+1)}-\theta_{k}^{(t)}\|^{2}\right)∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ( ⟨ ∇ italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) , italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ⟩ - divide start_ARG italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∥ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
+ρ2Nj=1Nk=1Mf(xjk;θk(t+1))Wk(t+1)f(xjk;θk(t))Wk(t+1)F2𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝑀superscriptsubscriptnorm𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1𝐹2\displaystyle\quad+\frac{\rho}{2N}\sum_{j=1}^{N}\sum_{k=1}^{M}\left\|f(x_{j}^{% k};\theta_{k}^{(t+1)})W_{k}^{(t+1)}-f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t+1)}% \right\|_{F}^{2}+ divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∥ italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
=(c)𝑐\displaystyle\overset{(c)}{=}start_OVERACCENT ( italic_c ) end_OVERACCENT start_ARG = end_ARG k=1Mμθ2θk(t+1)θk(t)2+ρ2Nj=1Nk=1Mf(xjk;θk(t+1))f(xjk;θk(t))2Wk(t+1)2superscriptsubscript𝑘1𝑀subscript𝜇𝜃2superscriptnormsuperscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝜃𝑘𝑡2𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝑀superscriptnorm𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡2superscriptnormsuperscriptsubscript𝑊𝑘𝑡12\displaystyle\sum_{k=1}^{M}-\frac{\mu_{\theta}}{2}\|\theta_{k}^{(t+1)}-\theta_% {k}^{(t)}\|^{2}+\frac{\rho}{2N}\sum_{j=1}^{N}\sum_{k=1}^{M}\left\|f(x_{j}^{k};% \theta_{k}^{(t+1)})-f(x_{j}^{k};\theta_{k}^{(t)})\right\|^{2}\|W_{k}^{(t+1)}\|% ^{2}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT - divide start_ARG italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∥ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∥ italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) - italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
(d)𝑑\displaystyle\overset{(d)}{\leq}start_OVERACCENT ( italic_d ) end_OVERACCENT start_ARG ≤ end_ARG k=1Mμθ2θk(t+1)θk(t)2+ρLθ22k=1Mθk(t+1)θk(t)2Wk(t+1)2superscriptsubscript𝑘1𝑀subscript𝜇𝜃2superscriptnormsuperscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝜃𝑘𝑡2𝜌superscriptsubscript𝐿𝜃22superscriptsubscript𝑘1𝑀superscriptnormsuperscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝜃𝑘𝑡2superscriptnormsuperscriptsubscript𝑊𝑘𝑡12\displaystyle\sum_{k=1}^{M}-\frac{\mu_{\theta}}{2}\left\|\theta_{k}^{(t+1)}-% \theta_{k}^{(t)}\right\|^{2}+\frac{\rho L_{\theta}^{2}}{2}\sum_{k=1}^{M}\left% \|\theta_{k}^{(t+1)}-\theta_{k}^{(t)}\right\|^{2}\|W_{k}^{(t+1)}\|^{2}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT - divide start_ARG italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∥ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG italic_ρ italic_L start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∥ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
(e)𝑒\displaystyle\overset{(e)}{\leq}start_OVERACCENT ( italic_e ) end_OVERACCENT start_ARG ≤ end_ARG k=1Mμθ+ρLθ2σW22θk(t+1)θk(t)2superscriptsubscript𝑘1𝑀subscript𝜇𝜃𝜌superscriptsubscript𝐿𝜃2superscriptsubscript𝜎𝑊22superscriptnormsuperscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝜃𝑘𝑡2\displaystyle\sum_{k=1}^{M}\frac{-\mu_{\theta}+\rho L_{\theta}^{2}\sigma_{W}^{% 2}}{2}\left\|\theta_{k}^{(t+1)}-\theta_{k}^{(t)}\right\|^{2}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT divide start_ARG - italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT + italic_ρ italic_L start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ∥ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

where (a) is due to Lemma 3, (b) is due to the strong convex of qksubscript𝑞𝑘q_{k}italic_q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT 2, (c) is due to the optimality of θksubscript𝜃𝑘\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT in Section III-B, (d) is due to the Lipschitz continuity of local model f(θ)𝑓𝜃f(\theta)italic_f ( italic_θ ) 3 ( bounded gradient implies Lipschitz continuity), and (e) is due to the upper bound of the linear weights in 3.

T4=subscript𝑇4absent\displaystyle T_{4}=italic_T start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT = ADMM({Wk(t+1)},{θk(t)},{zj(t)},{λj(t)}ADMM({Wk(t)},{θk(t)},{zj(t)},{λj(t)})\displaystyle\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{\theta_{k}^{(t)}% \},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\}-\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{% (t)}\},\{\theta_{k}^{(t)}\},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\})caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } - caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } )
=\displaystyle== k=1Mβk(k(Wk(t+1))k(Wk(t)))+1Nj=1Nλj(t)(k=1Mf(xjk;θk(t))(Wk(t+1)Wk(t)))superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘superscriptsubscript𝑊𝑘𝑡1subscript𝑘superscriptsubscript𝑊𝑘𝑡1𝑁superscriptsubscript𝑗1𝑁superscriptsuperscriptsubscript𝜆𝑗𝑡topsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑊𝑘𝑡\displaystyle\sum_{k=1}^{M}\beta_{k}\left(\mathcal{R}_{k}(W_{k}^{(t+1)})-% \mathcal{R}_{k}(W_{k}^{(t)})\right)+\frac{1}{N}\sum_{j=1}^{N}{\lambda_{j}^{(t)% }}^{\top}\left(\sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{(t)})\left(W_{k}^{(t+1)}-% W_{k}^{(t)}\right)\right)∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) - caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) )
+ρ2Nj=1N(k=1Mf(xjk;θk(t))Wk(t+1)zj(t)F2k=1Mf(xjk;θk(t))Wk(t)zj(t)F2)𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑧𝑗𝑡𝐹2superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡𝐹2\displaystyle+\frac{\rho}{2N}\sum_{j=1}^{N}\left(\left\|\sum_{k=1}^{M}f(x_{j}^% {k};\theta_{k}^{(t)})W_{k}^{(t+1)}-z_{j}^{(t)}\right\|_{F}^{2}-\left\|\sum_{k=% 1}^{M}f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t)}-z_{j}^{(t)}\right\|_{F}^{2}\right)+ divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ( ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
(a)𝑎\displaystyle\overset{(a)}{\leq}start_OVERACCENT ( italic_a ) end_OVERACCENT start_ARG ≤ end_ARG k=1Mβk(k(Wk(t+1))k(Wk(t)))+1Nj=1Nλj(t)(k=1Mf(xjk;θk(t))(Wk(t+1)Wk(t)))superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘superscriptsubscript𝑊𝑘𝑡1subscript𝑘superscriptsubscript𝑊𝑘𝑡1𝑁superscriptsubscript𝑗1𝑁superscriptsuperscriptsubscript𝜆𝑗𝑡topsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑊𝑘𝑡\displaystyle\sum_{k=1}^{M}\beta_{k}\left(\mathcal{R}_{k}(W_{k}^{(t+1)})-% \mathcal{R}_{k}(W_{k}^{(t)})\right)+\frac{1}{N}\sum_{j=1}^{N}{\lambda_{j}^{(t)% }}^{\top}\left(\sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{(t)})\left(W_{k}^{(t+1)}-% W_{k}^{(t)}\right)\right)∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) - caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) )
+ρ2Nj=1Nk=1M(i[M],ikf(xji;θi(t))Wi(t)+f(xjk;θk(t))Wk(t+1)zj(t)F2k=1Mf(xjk;θk(t))Wk(t)zj(t)F2)𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝑀superscriptsubscriptnormsubscript𝑖delimited-[]𝑀𝑖𝑘𝑓superscriptsubscript𝑥𝑗𝑖superscriptsubscript𝜃𝑖𝑡superscriptsubscript𝑊𝑖𝑡𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑧𝑗𝑡𝐹2superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡𝐹2\displaystyle+\frac{\rho}{2N}\sum_{j=1}^{N}\sum_{k=1}^{M}\left(\left\|\sum_{% \begin{subarray}{c}i\in[M],\\ i\neq k\end{subarray}}f(x_{j}^{i};\theta_{i}^{(t)}){W_{i}}^{(t)}+{f(x_{j}^{k};% \theta_{k}^{(t)})}W_{k}^{(t+1)}-{z}_{j}^{(t)}\right\|_{F}^{2}-\left\|\sum_{k=1% }^{M}f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t)}-z_{j}^{(t)}\right\|_{F}^{2}\right)+ divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ( ∥ ∑ start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_i ∈ [ italic_M ] , end_CELL end_ROW start_ROW start_CELL italic_i ≠ italic_k end_CELL end_ROW end_ARG end_POSTSUBSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
+ρ2Nj=1Nk=1Mf(xjk;θk(t))Wk(t+1)f(xjk;θk(t))Wk(t)F2𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝑀superscriptsubscriptnorm𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡𝐹2\displaystyle+\frac{\rho}{2N}\sum_{j=1}^{N}\sum_{k=1}^{M}\left\|{f(x_{j}^{k};% \theta_{k}^{(t)})}W_{k}^{(t+1)}-{f(x_{j}^{k};\theta_{k}^{(t)})}W_{k}^{(t)}% \right\|_{F}^{2}+ divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∥ italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
=\displaystyle== k=1M(gk(Wk(t+1))gk(Wk(t)))+ρ2Nj=1Nk=1Mf(xjk;θk(t))Wk(t+1)f(xjk;θk(t))Wk(t)F2superscriptsubscript𝑘1𝑀subscript𝑔𝑘superscriptsubscript𝑊𝑘𝑡1subscript𝑔𝑘superscriptsubscript𝑊𝑘𝑡𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝑀superscriptsubscriptnorm𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡𝐹2\displaystyle\sum_{k=1}^{M}\left(g_{k}(W_{k}^{(t+1)})-g_{k}(W_{k}^{(t)})\right% )+\frac{\rho}{2N}\sum_{j=1}^{N}\sum_{k=1}^{M}\left\|{f(x_{j}^{k};\theta_{k}^{(% t)})}W_{k}^{(t+1)}-{f(x_{j}^{k};\theta_{k}^{(t)})}W_{k}^{(t)}\right\|_{F}^{2}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ( italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) - italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) + divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∥ italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
(b)𝑏\displaystyle\overset{(b)}{\leq}start_OVERACCENT ( italic_b ) end_OVERACCENT start_ARG ≤ end_ARG k=1M(gk(Wk(t+1)),Wk(t+1)Wk(t)μW2Wk(t+1)Wk(t)2)superscriptsubscript𝑘1𝑀subscript𝑔𝑘superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑊𝑘𝑡subscript𝜇𝑊2superscriptnormsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑊𝑘𝑡2\displaystyle\sum_{k=1}^{M}\left(\left\langle\nabla g_{k}(W_{k}^{(t+1)}),W_{k}% ^{(t+1)}-W_{k}^{(t)}\right\rangle-\frac{\mu_{W}}{2}\|W_{k}^{(t+1)}-W_{k}^{(t)}% \|^{2}\right)∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ( ⟨ ∇ italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) , italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ⟩ - divide start_ARG italic_μ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∥ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
+ρ2Nj=1Nk=1Mf(xjk;θk(t))Wk(t+1)f(xjk;θk(t))Wk(t)F2𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝑀superscriptsubscriptnorm𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡𝐹2\displaystyle+\frac{\rho}{2N}\sum_{j=1}^{N}\sum_{k=1}^{M}\left\|{f(x_{j}^{k};% \theta_{k}^{(t)})}W_{k}^{(t+1)}-{f(x_{j}^{k};\theta_{k}^{(t)})}W_{k}^{(t)}% \right\|_{F}^{2}+ divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∥ italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
(c)𝑐\displaystyle\overset{(c)}{\leq}start_OVERACCENT ( italic_c ) end_OVERACCENT start_ARG ≤ end_ARG k=1MμW2Wk(t+1)Wk(t)2+ρN2k=1MWk(t+1)Wk(t)2f(xjk;θk(t))2superscriptsubscript𝑘1𝑀subscript𝜇𝑊2superscriptnormsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑊𝑘𝑡2𝜌𝑁2superscriptsubscript𝑘1𝑀superscriptnormsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑊𝑘𝑡2superscriptnorm𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡2\displaystyle\sum_{k=1}^{M}-\frac{\mu_{W}}{2}\left\|W_{k}^{(t+1)}-W_{k}^{(t)}% \right\|^{2}+\frac{\rho N}{2}\sum_{k=1}^{M}\left\|W_{k}^{(t+1)}-W_{k}^{(t)}% \right\|^{2}\|{f(x_{j}^{k};\theta_{k}^{(t)})}\|^{2}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT - divide start_ARG italic_μ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∥ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG italic_ρ italic_N end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∥ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
(d)𝑑\displaystyle\overset{(d)}{\leq}start_OVERACCENT ( italic_d ) end_OVERACCENT start_ARG ≤ end_ARG k=1MμW+ρNσθ22Wk(t+1)Wk(t)2superscriptsubscript𝑘1𝑀subscript𝜇𝑊𝜌𝑁superscriptsubscript𝜎𝜃22superscriptnormsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑊𝑘𝑡2\displaystyle\sum_{k=1}^{M}\frac{-\mu_{W}+\rho N\sigma_{\theta}^{2}}{2}\left\|% W_{k}^{(t+1)}-W_{k}^{(t)}\right\|^{2}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT divide start_ARG - italic_μ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT + italic_ρ italic_N italic_σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ∥ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

where (a) is due to Lemma 3, (b) is due to strong convex of gksubscript𝑔𝑘g_{k}italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT 2, (c) is because of the optimality of Wksubscript𝑊𝑘W_{k}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT in Section III-B and (d) is due to upper bound of the model outputs in 3.

Combining the above bounds for T1,T2,T3,T4subscript𝑇1subscript𝑇2subscript𝑇3subscript𝑇4T_{1},T_{2},T_{3},T_{4}italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_T start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_T start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT together and recall the condition for ρ𝜌\rhoitalic_ρ in Eq. 33, we have

ADMM({Wk(t+1)},{θk(t+1)},{zj(t+1)},{λj(t+1)})ADMM({Wk(t)},{θk(t)},{zj(t)},{λj(t)})subscriptADMMsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝜆𝑗𝑡1subscriptADMMsuperscriptsubscript𝑊𝑘𝑡superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡\displaystyle\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{\theta_{k}^{(t+1)% }\},\{z_{j}^{(t+1)}\},\{\lambda_{j}^{(t+1)}\})-\mathcal{L}_{\mathrm{ADMM}}(\{W% _{k}^{(t)}\},\{\theta_{k}^{(t)}\},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\})caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } ) - caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } )
=\displaystyle== T1+T2+T3+T4subscript𝑇1subscript𝑇2subscript𝑇3subscript𝑇4\displaystyle T_{1}+T_{2}+T_{3}+T_{4}italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_T start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT + italic_T start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT
\displaystyle\leq 1Nj=1N(μz2+L2ρ)zj(t+1)zj(t)2+k=1Mμθ+ρLθ2σW22θk(t+1)θk(t)2+k=1MμW+ρσθ22Wk(t+1)Wk(t)21𝑁superscriptsubscript𝑗1𝑁subscript𝜇𝑧2superscript𝐿2𝜌superscriptnormsuperscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝑧𝑗𝑡2superscriptsubscript𝑘1𝑀subscript𝜇𝜃𝜌superscriptsubscript𝐿𝜃2superscriptsubscript𝜎𝑊22superscriptnormsuperscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝜃𝑘𝑡2superscriptsubscript𝑘1𝑀subscript𝜇𝑊𝜌superscriptsubscript𝜎𝜃22superscriptnormsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑊𝑘𝑡2\displaystyle\frac{1}{N}\sum_{j=1}^{N}\left(-\frac{\mu_{z}}{2}+\frac{L^{2}}{% \rho}\right)\|z_{j}^{(t+1)}-z_{j}^{(t)}\|^{2}+\sum_{k=1}^{M}\frac{-\mu_{\theta% }+\rho L_{\theta}^{2}\sigma_{W}^{2}}{2}\left\|\theta_{k}^{(t+1)}-\theta_{k}^{(% t)}\right\|^{2}+\sum_{k=1}^{M}\frac{-\mu_{W}+\rho\sigma_{\theta}^{2}}{2}\left% \|W_{k}^{(t+1)}-W_{k}^{(t)}\right\|^{2}divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ( - divide start_ARG italic_μ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG + divide start_ARG italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_ρ end_ARG ) ∥ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT divide start_ARG - italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT + italic_ρ italic_L start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ∥ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT divide start_ARG - italic_μ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT + italic_ρ italic_σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ∥ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
<\displaystyle<< 00\displaystyle 0

Thus, proved. ∎

Then we provide the proof for the lower-bounded property of ADMMsubscriptADMM\mathcal{L}_{\mathrm{ADMM}}caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT for Lemma 5.

Proof for Lemma 5.
ADMM({Wk(t)},{θk(t)},{zj(t)},{λj(t)})subscriptADMMsuperscriptsubscript𝑊𝑘𝑡superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡\displaystyle\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t)}\},\{\theta_{k}^{(t)}\},% \{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\})caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } )
=\displaystyle== 1Nj=1N(zj(t),yj)+k=1Mβkk(θk(t))+k=1Mβkk(Wk(t))1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑧𝑗𝑡subscript𝑦𝑗superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘superscriptsubscript𝑊𝑘𝑡\displaystyle\frac{1}{N}\sum_{j=1}^{N}\ell(z_{j}^{(t)},y_{j})+\sum_{k=1}^{M}% \beta_{k}\mathcal{R}_{k}(\theta_{k}^{(t)})+\sum_{k=1}^{M}\beta_{k}\mathcal{R}_% {k}(W_{k}^{(t)})divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT )
+1Nj=1Nλj(t)(k=1Mf(xjk;θk(t))Wk(t)zj(t))+ρ2Nj=1Nk=1Mf(xjk;θk(t))Wk(t)zj(t)F21𝑁superscriptsubscript𝑗1𝑁superscriptsuperscriptsubscript𝜆𝑗𝑡topsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡𝐹2\displaystyle+\frac{1}{N}\sum_{j=1}^{N}{\lambda_{j}^{(t)}}^{\top}\left(\sum_{k% =1}^{M}f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t)}-z_{j}^{(t)}\right)+\frac{\rho}% {2N}\sum_{j=1}^{N}\left\|\sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t)% }-z_{j}^{(t)}\right\|_{F}^{2}+ divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) + divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
=(a)𝑎\displaystyle\overset{(a)}{=}start_OVERACCENT ( italic_a ) end_OVERACCENT start_ARG = end_ARG 1Nj=1N(zj(t),yj)+k=1Mβkk(θk(t))+k=1Mβkk(Wk(t))1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑧𝑗𝑡subscript𝑦𝑗superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘superscriptsubscript𝑊𝑘𝑡\displaystyle\frac{1}{N}\sum_{j=1}^{N}\ell(z_{j}^{(t)},y_{j})+\sum_{k=1}^{M}% \beta_{k}\mathcal{R}_{k}(\theta_{k}^{(t)})+\sum_{k=1}^{M}\beta_{k}\mathcal{R}_% {k}(W_{k}^{(t)})divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT )
+1Nj=1N(zj(t))(k=1Mf(xjk;θk(t))Wk(t)zj(t))+ρ2Nj=1Nk=1Mf(xjk;θk(t))Wk(t)zj(t)F21𝑁superscriptsubscript𝑗1𝑁superscriptsuperscriptsubscript𝑧𝑗𝑡topsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡𝜌2𝑁superscriptsubscript𝑗1𝑁superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡𝐹2\displaystyle+\frac{1}{N}\sum_{j=1}^{N}{\nabla\ell(z_{j}^{(t)})}^{\top}\left(% \sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t)}-z_{j}^{(t)}\right)+% \frac{\rho}{2N}\sum_{j=1}^{N}\left\|\sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{(t)}% )W_{k}^{(t)}-z_{j}^{(t)}\right\|_{F}^{2}+ divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∇ roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) + divide start_ARG italic_ρ end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
(b)𝑏\displaystyle\overset{(b)}{\geq}start_OVERACCENT ( italic_b ) end_OVERACCENT start_ARG ≥ end_ARG 1Nj=1N(k=1Mf(xjk;θk(t))Wk(t),yj)+k=1Mβkk(θk(t))+k=1Mβkk(Wk(t))1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡subscript𝑦𝑗superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘superscriptsubscript𝑊𝑘𝑡\displaystyle\frac{1}{N}\sum_{j=1}^{N}\ell\left(\sum_{k=1}^{M}f(x_{j}^{k};% \theta_{k}^{(t)})W_{k}^{(t)},y_{j}\right)+\sum_{k=1}^{M}\beta_{k}\mathcal{R}_{% k}(\theta_{k}^{(t)})+\sum_{k=1}^{M}\beta_{k}\mathcal{R}_{k}(W_{k}^{(t)})divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_ℓ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT )
+ρL2Nj=1Nk=1Mf(xjk;θk(t))Wk(t)zj(t)F2𝜌𝐿2𝑁superscriptsubscript𝑗1𝑁superscriptsubscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡𝐹2\displaystyle+\frac{\rho-L}{2N}\sum_{j=1}^{N}\left\|\sum_{k=1}^{M}f(x_{j}^{k};% \theta_{k}^{(t)})W_{k}^{(t)}-z_{j}^{(t)}\right\|_{F}^{2}+ divide start_ARG italic_ρ - italic_L end_ARG start_ARG 2 italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
(c)𝑐\displaystyle\overset{(c)}{\geq}start_OVERACCENT ( italic_c ) end_OVERACCENT start_ARG ≥ end_ARG 1Nj=1N(k=1Mf(xjk;θk(t))Wk(t),yj)+k=1Mβkk(θk(t))+k=1Mβkk(Wk(t))1𝑁superscriptsubscript𝑗1𝑁superscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡subscript𝑦𝑗superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑘1𝑀subscript𝛽𝑘subscript𝑘superscriptsubscript𝑊𝑘𝑡\displaystyle\frac{1}{N}\sum_{j=1}^{N}\ell\left(\sum_{k=1}^{M}f(x_{j}^{k};% \theta_{k}^{(t)})W_{k}^{(t)},y_{j}\right)+\sum_{k=1}^{M}\beta_{k}\mathcal{R}_{% k}(\theta_{k}^{(t)})+\sum_{k=1}^{M}\beta_{k}\mathcal{R}_{k}(W_{k}^{(t)})divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_ℓ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT )
=\displaystyle== 𝚅𝙸𝙼({Wk(t)},{θk(t)})subscript𝚅𝙸𝙼superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝜃𝑘𝑡\displaystyle\mathcal{L}_{\mathrm{\texttt{VIM}}}(\{W_{k}^{(t)}\},\{\theta_{k}^% {(t)}\})caligraphic_L start_POSTSUBSCRIPT VIM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } )
\displaystyle\geq e¯¯𝑒\displaystyle\underline{e}under¯ start_ARG italic_e end_ARG

where (a) is due to Lemma 1; (b) is due to Lipschitz continuity of gradient of \ellroman_ℓ in 1 that

(k=1Mf(xjk;θk(t))Wk(t))(zj(t))L2k=1Mf(xjk;θk(t))Wk(t)zj(t)(zj(t)),(k=1Mf(xjk;θk(t))Wk(t)zj(t))superscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡𝐿2normsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡\displaystyle\ell\left(\sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t)}% \right)-\ell\left(z_{j}^{(t)}\right)-\frac{L}{2}\left\|\sum_{k=1}^{M}f(x_{j}^{% k};\theta_{k}^{(t)})W_{k}^{(t)}-z_{j}^{(t)}\right\|\leq\left\langle\nabla\ell(% z_{j}^{(t)}),\left(\sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{(t)})W_{k}^{(t)}-z_{j% }^{(t)}\right)\right\rangleroman_ℓ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - divide start_ARG italic_L end_ARG start_ARG 2 end_ARG ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ ≤ ⟨ ∇ roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ⟩

and (c) is due to ρL𝜌𝐿\rho\geq Litalic_ρ ≥ italic_L from the condition Eq. 33.

The result show that ADMM({Wk(t)},{θk(t)},{zj(t)},{λj(t)})subscriptADMMsuperscriptsubscript𝑊𝑘𝑡superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t)}\},\{\theta_{k}^{(t)}\},\{z_{j}^{(t)}% \},\{\lambda_{j}^{(t)}\})caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } ) is lower bounded. Thus, proved. ∎

Proof for Theorem 1 (A).

Combining Lemma 4 and Lemma 5, we show that ADMM({Wk(t)},{θk(t)},{zj(t)},{λj(t)})subscriptADMMsuperscriptsubscript𝑊𝑘𝑡superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t)}\},\{\theta_{k}^{(t)}\},\{z_{j}^{(t)}% \},\{\lambda_{j}^{(t)}\})caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } ) is monotonically decreasing and is convergent. This completes the proof. ∎

-B4 Proofs for Theorem 1 Part (B)

Proofs for Theorem 1 (B).

Lemma 4 implies that

ADMM({Wk(t+1)},{θk(t+1)},{zj(t+1)},{λj(t+1)})ADMM({Wk(t)},{θk(t)},{zj(t)},{λj(t)})subscriptADMMsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝜆𝑗𝑡1subscriptADMMsuperscriptsubscript𝑊𝑘𝑡superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡\displaystyle\mathcal{L}_{\mathrm{ADMM}}(\{W_{k}^{(t+1)}\},\{\theta_{k}^{(t+1)% }\},\{z_{j}^{(t+1)}\},\{\lambda_{j}^{(t+1)}\})-\mathcal{L}_{\mathrm{ADMM}}(\{W% _{k}^{(t)}\},\{\theta_{k}^{(t)}\},\{z_{j}^{(t)}\},\{\lambda_{j}^{(t)}\})caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } ) - caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT ( { italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } , { italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT } )
\displaystyle\leq 1Nj=1N(μz2+L2ρ)zj(t+1)zj(t)2+k=1Mμθ+ρLθ2σW22θk(t+1)θk(t)21𝑁superscriptsubscript𝑗1𝑁subscript𝜇𝑧2superscript𝐿2𝜌superscriptnormsuperscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝑧𝑗𝑡2superscriptsubscript𝑘1𝑀subscript𝜇𝜃𝜌superscriptsubscript𝐿𝜃2superscriptsubscript𝜎𝑊22superscriptnormsuperscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝜃𝑘𝑡2\displaystyle\frac{1}{N}\sum_{j=1}^{N}\left(-\frac{\mu_{z}}{2}+\frac{L^{2}}{% \rho}\right)\|z_{j}^{(t+1)}-z_{j}^{(t)}\|^{2}+\sum_{k=1}^{M}\frac{-\mu_{\theta% }+\rho L_{\theta}^{2}\sigma_{W}^{2}}{2}\left\|\theta_{k}^{(t+1)}-\theta_{k}^{(% t)}\right\|^{2}divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ( - divide start_ARG italic_μ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG + divide start_ARG italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_ρ end_ARG ) ∥ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT divide start_ARG - italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT + italic_ρ italic_L start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ∥ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
+k=1MμW+ρσθ22Wk(t+1)Wk(t)2superscriptsubscript𝑘1𝑀subscript𝜇𝑊𝜌superscriptsubscript𝜎𝜃22superscriptnormsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑊𝑘𝑡2\displaystyle\quad+\sum_{k=1}^{M}\frac{-\mu_{W}+\rho\sigma_{\theta}^{2}}{2}% \left\|W_{k}^{(t+1)}-W_{k}^{(t)}\right\|^{2}+ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT divide start_ARG - italic_μ start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT + italic_ρ italic_σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ∥ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

Using the fact that ADMMsubscriptADMM\mathcal{L}_{\mathrm{ADMM}}caligraphic_L start_POSTSUBSCRIPT roman_ADMM end_POSTSUBSCRIPT is monotonically decreasing and lower-bounded (in Lemma 5) as well as the bounds for ρ𝜌\rhoitalic_ρ in Eq. 33, we have j[N],k[M]formulae-sequencefor-all𝑗delimited-[]𝑁𝑘delimited-[]𝑀\forall j\in[N],k\in[M]∀ italic_j ∈ [ italic_N ] , italic_k ∈ [ italic_M ],

limtzj(t+1)zj(t)20,limtθk(t+1)θk(t)20,limtWk(t+1)Wk(t)20.formulae-sequencesubscript𝑡superscriptnormsuperscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝑧𝑗𝑡20formulae-sequencesubscript𝑡superscriptnormsuperscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝜃𝑘𝑡20subscript𝑡superscriptnormsuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑊𝑘𝑡20\displaystyle\lim_{t\rightarrow\infty}\left\|z_{j}^{(t+1)}-z_{j}^{(t)}\right\|% ^{2}\rightarrow 0,\lim_{t\rightarrow\infty}\left\|\theta_{k}^{(t+1)}-\theta_{k% }^{(t)}\right\|^{2}\rightarrow 0,\lim_{t\rightarrow\infty}\left\|W_{k}^{(t+1)}% -W_{k}^{(t)}\right\|^{2}\rightarrow 0.roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT ∥ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT → 0 , roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT ∥ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT → 0 , roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT ∥ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT → 0 . (37)

By Lemma 2 that λj(t+1)λj(t)Lzj(t+1)zj(t)normsuperscriptsubscript𝜆𝑗𝑡1superscriptsubscript𝜆𝑗𝑡𝐿normsuperscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝑧𝑗𝑡\|\lambda_{j}^{(t+1)}-\lambda_{j}^{(t)}\|\leq L\|z_{j}^{(t+1)}-z_{j}^{(t)}\|∥ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ ≤ italic_L ∥ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥, we further obtain

limtλj(t+1)λj(t)20,j[N]formulae-sequencesubscript𝑡superscriptnormsuperscriptsubscript𝜆𝑗𝑡1superscriptsubscript𝜆𝑗𝑡20for-all𝑗delimited-[]𝑁\displaystyle\lim_{t\rightarrow\infty}\left\|\lambda_{j}^{(t+1)}-\lambda_{j}^{% (t)}\right\|^{2}\rightarrow 0,\forall j\in[N]roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT ∥ italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT → 0 , ∀ italic_j ∈ [ italic_N ] (38)

In light of the dual update step of Algorithm 1, Eq. 38 implies that

limtk=1Mf(xjk;θk(t+1))Wk(t+1)zj(t+1)20,j[N]formulae-sequencesubscript𝑡superscriptnormsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑧𝑗𝑡120for-all𝑗delimited-[]𝑁\displaystyle\lim_{t\rightarrow\infty}\left\|\sum_{k=1}^{M}f(x_{j}^{k};\theta_% {k}^{(t+1)})W_{k}^{(t+1)}-z_{j}^{(t+1)}\right\|^{2}\rightarrow 0,\forall j\in[N]roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT ∥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT → 0 , ∀ italic_j ∈ [ italic_N ] (39)

Using the limit points, we have Wk(t+1)Wk*,θk(t+1)θk*,zj(t+1)zj*,λj(t+1)λj*formulae-sequencesuperscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑊𝑘formulae-sequencesuperscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝜃𝑘formulae-sequencesuperscriptsubscript𝑧𝑗𝑡1superscriptsubscript𝑧𝑗superscriptsubscript𝜆𝑗𝑡1superscriptsubscript𝜆𝑗W_{k}^{(t+1)}\rightarrow W_{k}^{*},\theta_{k}^{(t+1)}\rightarrow\theta_{k}^{*}% ,z_{j}^{(t+1)}\rightarrow z_{j}^{*},\lambda_{j}^{(t+1)}\rightarrow\lambda_{j}^% {*}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT → italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT → italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT → italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT → italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT.

Based on Eq. 39, we have

k=1Mf(xjk;θk*)Wk*=zj*superscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘superscriptsubscript𝑊𝑘superscriptsubscript𝑧𝑗\displaystyle\sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{*})W_{k}^{*}=z_{j}^{*}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT = italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT (40)

Then, we examine the optimality condition for the {Wk(t+1)}superscriptsubscript𝑊𝑘𝑡1\{W_{k}^{(t+1)}\}{ italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } subproblems at iteration t+1𝑡1t+1italic_t + 1:

0=0absent\displaystyle 0=0 = βkk(Wk(t+1))+1Nj=1Nλj(t)f(xjk;θk(t))subscript𝛽𝑘subscript𝑘superscriptsubscript𝑊𝑘𝑡11𝑁superscriptsubscript𝑗1𝑁superscriptsuperscriptsubscript𝜆𝑗𝑡top𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡\displaystyle\beta_{k}\nabla\mathcal{R}_{k}(W_{k}^{(t+1)})+\frac{1}{N}\sum_{j=% 1}^{N}{\lambda_{j}^{(t)}}^{\top}f(x_{j}^{k};\theta_{k}^{(t)})italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∇ caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT )
+j=1NρN(i[M],ikf(xji;θi(t))Wi(t)+f(xjk;θk(t))Wk(t+1)zj(t))f(xjk;θk(t))superscriptsubscript𝑗1𝑁𝜌𝑁subscriptformulae-sequence𝑖delimited-[]𝑀𝑖𝑘𝑓superscriptsubscript𝑥𝑗𝑖superscriptsubscript𝜃𝑖𝑡superscriptsubscript𝑊𝑖𝑡𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑧𝑗𝑡𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡\displaystyle+\sum\limits_{j=1}^{N}\frac{\rho}{N}\left(\sum\limits_{i\in[M],i% \neq k}f(x_{j}^{i};\theta_{i}^{(t)}){W_{i}}^{(t)}+f(x_{j}^{k};\theta_{k}^{(t)}% )W_{k}^{(t+1)}-{z_{j}}^{(t)}\right)f(x_{j}^{k};\theta_{k}^{(t)})+ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG italic_ρ end_ARG start_ARG italic_N end_ARG ( ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_M ] , italic_i ≠ italic_k end_POSTSUBSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT + italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) (41)

According to Eq. 37 and Eq. 39, we have

0=0absent\displaystyle 0=0 = βkk(Wk*)+1Nj=1Nλj*f(xjk;θk*)subscript𝛽𝑘subscript𝑘superscriptsubscript𝑊𝑘1𝑁superscriptsubscript𝑗1𝑁superscriptsuperscriptsubscript𝜆𝑗top𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘\displaystyle\beta_{k}\nabla\mathcal{R}_{k}(W_{k}^{*})+\frac{1}{N}\sum_{j=1}^{% N}{\lambda_{j}^{*}}^{\top}f(x_{j}^{k};\theta_{k}^{*})italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∇ caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) (42)

Similarly, the optimality condition for the {θk(t+1)}superscriptsubscript𝜃𝑘𝑡1\{\theta_{k}^{(t+1)}\}{ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } subproblems at iteration t+1𝑡1t+1italic_t + 1 indicates that:

0=0absent\displaystyle 0=0 = βkk(θk(t+1))+1Nj=1Nλj(t)f(xjk;θk(t+1))Wk(t+1)subscript𝛽𝑘subscript𝑘superscriptsubscript𝜃𝑘𝑡11𝑁superscriptsubscript𝑗1𝑁superscriptsuperscriptsubscript𝜆𝑗𝑡top𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1\displaystyle\beta_{k}\nabla\mathcal{R}_{k}(\theta_{k}^{(t+1)})+\frac{1}{N}% \sum_{j=1}^{N}{\lambda_{j}^{(t)}}^{\top}\nabla f(x_{j}^{k};\theta_{k}^{(t+1)})% W_{k}^{(t+1)}italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∇ caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT
+j=1NρN(i[M],ikf(xji;θi(t))Wi(t+1)+f(xjk;θk(t))Wk(t+1)zj(t))f(xjk;θk(t+1))Wk(t+1)superscriptsubscript𝑗1𝑁𝜌𝑁subscriptformulae-sequence𝑖delimited-[]𝑀𝑖𝑘𝑓superscriptsubscript𝑥𝑗𝑖superscriptsubscript𝜃𝑖𝑡superscriptsubscript𝑊𝑖𝑡1𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡1superscriptsubscript𝑧𝑗𝑡𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡1superscriptsubscript𝑊𝑘𝑡1\displaystyle+\sum\limits_{j=1}^{N}\frac{\rho}{N}\left(\sum\limits_{i\in[M],i% \neq k}f(x_{j}^{i};\theta_{i}^{(t)}){W_{i}}^{(t+1)}+f(x_{j}^{k};\theta_{k}^{(t% )})W_{k}^{(t+1)}-{z_{j}}^{(t)}\right)\nabla f(x_{j}^{k};\theta_{k}^{(t+1)})W_{% k}^{(t+1)}+ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG italic_ρ end_ARG start_ARG italic_N end_ARG ( ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_M ] , italic_i ≠ italic_k end_POSTSUBSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT + italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ∇ italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT (43)

According to Eq. 37 and Eq. 39, we have

0=0absent\displaystyle 0=0 = βkk(θk*)+1Nj=1Nλj*f(xjk;θk*)Wk*subscript𝛽𝑘subscript𝑘superscriptsubscript𝜃𝑘1𝑁superscriptsubscript𝑗1𝑁superscriptsuperscriptsubscript𝜆𝑗top𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘superscriptsubscript𝑊𝑘\displaystyle\beta_{k}\nabla\mathcal{R}_{k}(\theta_{k}^{*})+\frac{1}{N}\sum_{j% =1}^{N}{\lambda_{j}^{*}}^{\top}\nabla f(x_{j}^{k};\theta_{k}^{*})W_{k}^{*}italic_β start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∇ caligraphic_R start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT (44)

Based on the optimality condition for the {zj(t+1)}superscriptsubscript𝑧𝑗𝑡1\{z_{j}^{(t+1)}\}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT } subproblems at iteration t+1𝑡1t+1italic_t + 1, we have

0=(zj(t))λj(t1)ρ(k=1Mf(xjk;θk(t))Wk(t)zj(t)),jB(t)formulae-sequence0superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡1𝜌superscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡for-all𝑗𝐵𝑡\displaystyle 0=\nabla\ell(z_{j}^{(t)})-{\lambda_{j}^{(t-1)}}-\rho\left(\sum_{% k=1}^{M}{f(x_{j}^{k};\theta_{k}^{(t)})}W_{k}^{(t)}-z_{j}^{(t)}\right),\forall j% \in B(t)0 = ∇ roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT - italic_ρ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , ∀ italic_j ∈ italic_B ( italic_t ) (45)

Based on the strongly convexity w.r.t zjsubscript𝑧𝑗z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT in 2, there exists a subgradient η(t)(zj(t))superscript𝜂𝑡superscriptsubscript𝑧𝑗𝑡\eta^{(t)}\in\partial\ell\left(z_{j}^{(t)}\right)italic_η start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∈ ∂ roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) such that

zzj(t),η(t)(λj(t1)+ρ(k=1Mf(xjk;θk(t))Wk(t)zj(t)))0,z𝑧superscriptsubscript𝑧𝑗𝑡superscript𝜂𝑡superscriptsubscript𝜆𝑗𝑡1𝜌superscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡0for-all𝑧\displaystyle\left\langle z-z_{j}^{(t)},\eta^{(t)}-\left(\lambda_{j}^{(t-1)}+% \rho\left(\sum_{k=1}^{M}{f(x_{j}^{k};\theta_{k}^{(t)})}W_{k}^{(t)}-z_{j}^{(t)}% \right)\right)\right\rangle\geq 0,\quad\forall z⟨ italic_z - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - ( italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT + italic_ρ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ⟩ ≥ 0 , ∀ italic_z (46)

It implies that

(z;yj)(zj(t);yj)+zzj(t),(λj(t1)+ρ(k=1Mf(xjk;θk(t))Wk(t)zj(t)))0,z𝑧subscript𝑦𝑗superscriptsubscript𝑧𝑗𝑡subscript𝑦𝑗𝑧superscriptsubscript𝑧𝑗𝑡superscriptsubscript𝜆𝑗𝑡1𝜌superscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘𝑡superscriptsubscript𝑊𝑘𝑡superscriptsubscript𝑧𝑗𝑡0for-all𝑧\displaystyle\ell\left(z;y_{j}\right)-\ell\left(z_{j}^{(t)};y_{j}\right)+\left% \langle z-z_{j}^{(t)},-\left(\lambda_{j}^{(t-1)}+\rho\left(\sum_{k=1}^{M}{f(x_% {j}^{k};\theta_{k}^{(t)})}W_{k}^{(t)}-z_{j}^{(t)}\right)\right)\right\rangle% \geq 0,\forall zroman_ℓ ( italic_z ; italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) - roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ; italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + ⟨ italic_z - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT , - ( italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t - 1 ) end_POSTSUPERSCRIPT + italic_ρ ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ⟩ ≥ 0 , ∀ italic_z (47)

Taking the limits for Eq. 47 and using the results in Eq. 37 Eq. 38 Eq. 39, we have

(z;yj)(zj*;yj)+zzj*,λj*0,z𝑧subscript𝑦𝑗superscriptsubscript𝑧𝑗subscript𝑦𝑗𝑧superscriptsubscript𝑧𝑗superscriptsubscript𝜆𝑗0for-all𝑧\displaystyle\ell\left(z;y_{j}\right)-\ell\left(z_{j}^{*};y_{j}\right)+\left% \langle z-z_{j}^{*},-\lambda_{j}^{*}\right\rangle\geq 0,\forall zroman_ℓ ( italic_z ; italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) - roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ; italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + ⟨ italic_z - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , - italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ⟩ ≥ 0 , ∀ italic_z (48)

That is:

(z;yj)+λj*(k=1Mf(xjk;θk*)Wk*z)(zj*;yj)+λj*(k=1Mf(xjk;θk*)Wk*zj*)𝑧subscript𝑦𝑗superscriptsuperscriptsubscript𝜆𝑗topsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘superscriptsubscript𝑊𝑘𝑧superscriptsubscript𝑧𝑗subscript𝑦𝑗superscriptsuperscriptsubscript𝜆𝑗topsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘superscriptsubscript𝑊𝑘superscriptsubscript𝑧𝑗\displaystyle\ell\left(z;y_{j}\right)+{\lambda_{j}^{*}}^{\top}\left(\sum_{k=1}% ^{M}f(x_{j}^{k};\theta_{k}^{*})W_{k}^{*}-z\right)\geq\ell\left(z_{j}^{*};y_{j}% \right)+{\lambda_{j}^{*}}^{\top}\left(\sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{*}% )W_{k}^{*}-z_{j}^{*}\right)roman_ℓ ( italic_z ; italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT - italic_z ) ≥ roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ; italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT )

It implies that

zj*argminzj(zj;yj)+λj*(k=1Mf(xjk;θk*)Wk*zj)superscriptsubscript𝑧𝑗subscriptsubscript𝑧𝑗subscript𝑧𝑗subscript𝑦𝑗superscriptsuperscriptsubscript𝜆𝑗topsuperscriptsubscript𝑘1𝑀𝑓superscriptsubscript𝑥𝑗𝑘superscriptsubscript𝜃𝑘superscriptsubscript𝑊𝑘subscript𝑧𝑗\displaystyle z_{j}^{*}\in\arg\min_{z_{j}}\ell\left(z_{j};y_{j}\right)+{% \lambda_{j}^{*}}^{\top}\left(\sum_{k=1}^{M}f(x_{j}^{k};\theta_{k}^{*})W_{k}^{*% }-z_{j}\right)italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ∈ roman_arg roman_min start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_ℓ ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ; italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) + italic_λ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT - italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT )

This completes the proof.

-C Privacy Guarantees

-C1 Preliminaries

We utilize Rényi Differential Privacy (RDP) to perform the privacy analysis since it supports a tighter composition of privacy budget [56] than the moments accounting technique [1] for Differential Privacy (DP).

We start by introducing the definition of RDP as a generalization of DP, which leverages the α𝛼\alphaitalic_α-Rényi divergence between the output distributions of two neighboring datasets. The definition of the neighboring dataset in this work follows the client-level differentially private FL framework [54]. The neighboring datasets would differ in all samples associated with a single client, that is, one user is present or absent in the VFL global dataset. (Definition 2)

Definition 3.

(Rényi Differential Privacy [56]) We say that a mechanism \mathcal{M}caligraphic_M is (α,ϵ)𝛼italic-ϵ(\alpha,\epsilon)( italic_α , italic_ϵ )-RDP with order α(1,)𝛼1\alpha\in(1,\infty)italic_α ∈ ( 1 , ∞ ) if for all neighboring datasets D,D𝐷superscript𝐷D,D^{\prime}italic_D , italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT

Dα((D)(D)):=1α1log𝔼θ(D)[((D)(θ)(D)(θ))α]ϵassignsubscript𝐷𝛼conditional𝐷superscript𝐷1𝛼1subscript𝔼similar-to𝜃superscript𝐷delimited-[]superscript𝐷𝜃superscript𝐷𝜃𝛼italic-ϵD_{\alpha}\left(\mathcal{M}(D)\|\mathcal{M}\left(D^{\prime}\right)\right):=% \frac{1}{\alpha-1}\log\mathbb{E}_{\theta\sim\mathcal{M}\left(D^{\prime}\right)% }\left[\left(\frac{\mathcal{M}(D)(\theta)}{\mathcal{M}\left(D^{\prime}\right)(% \theta)}\right)^{\alpha}\right]\leq\epsilonitalic_D start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( caligraphic_M ( italic_D ) ∥ caligraphic_M ( italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) := divide start_ARG 1 end_ARG start_ARG italic_α - 1 end_ARG roman_log blackboard_E start_POSTSUBSCRIPT italic_θ ∼ caligraphic_M ( italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT [ ( divide start_ARG caligraphic_M ( italic_D ) ( italic_θ ) end_ARG start_ARG caligraphic_M ( italic_D start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ( italic_θ ) end_ARG ) start_POSTSUPERSCRIPT italic_α end_POSTSUPERSCRIPT ] ≤ italic_ϵ (49)

RDP guarantee can be converted to DP guarantee as follows:

Theorem 3.

(RDP to (ϵ,δ)italic-ϵ𝛿(\epsilon,\delta)( italic_ϵ , italic_δ )-DP Conversion  [3]) 666This theorem is tighter than the original RDP paper [56], and it is adopted in the official implementation of the PyTorch Opacus library. If f𝑓fitalic_f is an (α,ϵ)𝛼italic-ϵ(\alpha,\epsilon)( italic_α , italic_ϵ )-RDP mechanism, it also satisfies (ϵ+logα1αlogδ+logαα1,δ)italic-ϵ𝛼1𝛼𝛿𝛼𝛼1𝛿(\epsilon+\log\frac{\alpha-1}{\alpha}-\frac{\log\delta+\log\alpha}{\alpha-1},\delta)( italic_ϵ + roman_log divide start_ARG italic_α - 1 end_ARG start_ARG italic_α end_ARG - divide start_ARG roman_log italic_δ + roman_log italic_α end_ARG start_ARG italic_α - 1 end_ARG , italic_δ )-differential privacy for any 0<δ<10𝛿10<\delta<10 < italic_δ < 1.

Here, we highlight three key properties that are relevant to our analyses.

Theorem 4.

(RDP Composition [56]) Let f:𝒟1normal-:𝑓maps-to𝒟subscript1f:\mathcal{D}\mapsto\mathcal{R}_{1}italic_f : caligraphic_D ↦ caligraphic_R start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT be (α,ϵ1)𝛼subscriptitalic-ϵ1(\alpha,\epsilon_{1})( italic_α , italic_ϵ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )-RDP and g:1×𝒟2normal-:𝑔maps-tosubscript1𝒟subscript2g:\mathcal{R}_{1}\times\mathcal{D}\mapsto\mathcal{R}_{2}italic_g : caligraphic_R start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT × caligraphic_D ↦ caligraphic_R start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT be (α,ϵ2)𝛼subscriptitalic-ϵ2(\alpha,\epsilon_{2})( italic_α , italic_ϵ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )-RDP, then the mechanism defined as (X,Y)𝑋𝑌(X,Y)( italic_X , italic_Y ), where Xf(D)similar-to𝑋𝑓𝐷X\sim f(D)italic_X ∼ italic_f ( italic_D ) and Yg(X,D)similar-to𝑌𝑔𝑋𝐷Y\sim g(X,D)italic_Y ∼ italic_g ( italic_X , italic_D ), satisfies (α,ϵ1+ϵ2)𝛼subscriptitalic-ϵ1subscriptitalic-ϵ2(\alpha,\epsilon_{1}+\epsilon_{2})( italic_α , italic_ϵ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_ϵ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )-RDP.

Theorem 5.

(RDP Guarantee for Gaussian Mechanism [56]) If f𝑓fitalic_f is a real-valued function, the Gaussian Mechanism for approximating f𝑓fitalic_f is defined as 𝐆σf(D)=f(D)+𝒩(0,σ2)subscript𝐆𝜎𝑓𝐷𝑓𝐷𝒩0superscript𝜎2\mathbf{G}_{\sigma}f(D)=f(D)+\mathcal{N}\left(0,\sigma^{2}\right)bold_G start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT italic_f ( italic_D ) = italic_f ( italic_D ) + caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ). If f𝑓fitalic_f has 2subscriptnormal-ℓ2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT sensitivity 1, then the Gaussian Mechanism 𝐆σfsubscript𝐆𝜎𝑓\mathbf{G}_{\sigma}fbold_G start_POSTSUBSCRIPT italic_σ end_POSTSUBSCRIPT italic_f satisfies (α,α/(2σ2))𝛼𝛼2superscript𝜎2(\alpha,\alpha/(2\sigma^{2}))( italic_α , italic_α / ( 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) )-RDP.

-C2 Proof of Theorem 2

We aim to protect the local training data of each client under client-level (ϵ,δ)italic-ϵ𝛿(\epsilon,\delta)( italic_ϵ , italic_δ )-DP guarantee (Definition 2) during VFL training. Let X𝑋Xitalic_X be the VFL global dataset, i.e., the union of local feature sets X1,,XMsubscript𝑋1subscript𝑋𝑀X_{1},\ldots,X_{M}italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT from all M𝑀Mitalic_M clients. We denote the output of client k𝑘kitalic_k as a matrix 𝒜ksubscript𝒜𝑘\mathcal{A}_{k}caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, where each row is the embedding or logit of one local training sample. With a loss of generality, we consider the embedding matrix 𝒜k=[h1k,,hNk]subscript𝒜𝑘superscriptsuperscriptsubscript1𝑘superscriptsubscript𝑁𝑘top\mathcal{A}_{k}=[h_{1}^{k},\ldots,h_{N}^{k}]^{\top}caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = [ italic_h start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT , … , italic_h start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT as local output, and our analysis directly applies to the logit matrix 𝒜k=[o1k,,oNk]subscript𝒜𝑘superscriptsuperscriptsubscript𝑜1𝑘superscriptsubscript𝑜𝑁𝑘top\mathcal{A}_{k}=[o_{1}^{k},\ldots,o_{N}^{k}]^{\top}caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = [ italic_o start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT , … , italic_o start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. The local outputs from all clients can be concatenated as a global embedding matrix 𝒜𝒜\mathcal{A}caligraphic_A:

𝒜=[𝒜1,𝒜2,,𝒜M]𝒜subscript𝒜1subscript𝒜2subscript𝒜𝑀\mathcal{A}=[\mathcal{A}_{1},\mathcal{A}_{2},\dots,\mathcal{A}_{M}]caligraphic_A = [ caligraphic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , caligraphic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , caligraphic_A start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT ] (50)

For our algorithms (Algorithm 1, Algorithm 3) that sample a mini-batch of data with data indices B(t)𝐵𝑡B(t)italic_B ( italic_t ) at each round t𝑡titalic_t for clients to compute their embeddings, we view the corresponding 𝒜ksubscript𝒜𝑘\mathcal{A}_{k}caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT for each client k𝑘kitalic_k as:

𝒜k(t)[j]=hjk(t)ifjB(t),formulae-sequencesuperscriptsubscript𝒜𝑘𝑡delimited-[]𝑗superscriptsuperscriptsubscript𝑗𝑘𝑡if𝑗𝐵𝑡\displaystyle\mathcal{A}_{k}^{(t)}[j]={h_{j}^{k}}^{(t)}\quad\text{if}\quad j% \in B(t),caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT [ italic_j ] = italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT if italic_j ∈ italic_B ( italic_t ) , (51)
𝒜k(t)[j]=0ifjB(t).formulae-sequencesuperscriptsubscript𝒜𝑘𝑡delimited-[]𝑗0if𝑗𝐵𝑡\displaystyle\mathcal{A}_{k}^{(t)}[j]=0\quad\quad\quad\text{if}\quad j\notin B% (t).caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT [ italic_j ] = 0 if italic_j ∉ italic_B ( italic_t ) . (52)

Here we can fill in the rows of the output matrix for non-sampled indices (i.e., jB(t)𝑗𝐵𝑡j\notin B(t)italic_j ∉ italic_B ( italic_t )) as all zeros for privacy analysis.

We will first analyze the privacy cost for one communication round (omitting the superscript t𝑡titalic_t here) and then accumulate the privacy costs over T𝑇Titalic_T rounds via the DP composition theorem.

We define a function \mathcal{H}caligraphic_H that outputs a global embedding matrix consisting of clipped local embedding matrices for FL global dataset X𝑋Xitalic_X as:

(X)=[𝒜1^,,𝒜M^],where 𝒜k^=𝙲𝚕𝚒𝚙(𝒜k,C),k[M].formulae-sequence𝑋^subscript𝒜1^subscript𝒜𝑀formulae-sequencewhere ^subscript𝒜𝑘𝙲𝚕𝚒𝚙subscript𝒜𝑘𝐶for-all𝑘delimited-[]𝑀\mathcal{H}(X)=[\hat{\mathcal{A}_{1}},\ldots,\hat{\mathcal{A}_{M}}],\text{% where\quad}\hat{\mathcal{A}_{k}}=\mathtt{Clip}(\mathcal{A}_{k},C),\forall k\in% [M].caligraphic_H ( italic_X ) = [ over^ start_ARG caligraphic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG , … , over^ start_ARG caligraphic_A start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT end_ARG ] , where over^ start_ARG caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG = typewriter_Clip ( caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_C ) , ∀ italic_k ∈ [ italic_M ] . (53)
Lemma 6.

For any neighboring datasets X,X𝑋𝑋X,Xitalic_X , italic_X differing by all samples associated by one single client, the 2subscriptnormal-ℓ2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT sensitivity for \mathcal{H}caligraphic_H is C𝐶Citalic_C.

Proof.

WLOG, the neighboring dataset Xsuperscript𝑋X^{\prime}italic_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT differs the first client from X𝑋Xitalic_X, i.e., X={X1,X2,,XM}superscript𝑋superscriptsubscript𝑋1subscript𝑋2subscript𝑋𝑀X^{\prime}=\{{X_{1}}^{\prime},X_{2},\ldots,X_{M}\}italic_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = { italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_X start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT }. Therefore the global embedding matrix (X)𝑋\mathcal{H}(X)caligraphic_H ( italic_X ) and (X)superscript𝑋\mathcal{H}(X^{\prime})caligraphic_H ( italic_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) only differ by the clipped local embedding matrix from the first client (𝒜1^^subscript𝒜1\hat{\mathcal{A}_{1}}over^ start_ARG caligraphic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG and 𝒜1^superscript^subscript𝒜1\hat{\mathcal{A}_{1}}^{\prime}over^ start_ARG caligraphic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT). Then, the 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT sensitivity for \mathcal{H}caligraphic_H is bounded as follows:

maxX,X(X)(X)2=𝒜1^𝒜1^F2C.subscript𝑋superscript𝑋subscriptnorm𝑋superscript𝑋2superscriptsubscriptnorm^subscript𝒜1superscript^subscript𝒜1𝐹2𝐶\max_{X,X^{\prime}}\|\mathcal{H}(X)-\mathcal{H}(X^{\prime})\|_{2}=\sqrt{{% \color[rgb]{0,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,0}% \pgfsys@color@gray@stroke{0}\pgfsys@color@gray@fill{0}\|\hat{\mathcal{A}_{1}}-% \hat{\mathcal{A}_{1}}^{\prime}\|_{F}^{2}}}\leq C.roman_max start_POSTSUBSCRIPT italic_X , italic_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ caligraphic_H ( italic_X ) - caligraphic_H ( italic_X start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = square-root start_ARG ∥ over^ start_ARG caligraphic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG - over^ start_ARG caligraphic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ≤ italic_C . (54)

Then, we define our Gaussian mechanism 𝐆σCsubscript𝐆𝜎𝐶\mathbf{G}_{\sigma C}\mathcal{H}bold_G start_POSTSUBSCRIPT italic_σ italic_C end_POSTSUBSCRIPT caligraphic_H, which outputs a global matrix consisting of noise-perturbed local embedding matrices for VFL global dataset X𝑋Xitalic_X:

𝐆σC(X)=[𝒜1~,,𝒜M~],where 𝒜k~=𝙲𝚕𝚒𝚙(𝒜k,C)+𝒩(0,σ2C2),k[M].formulae-sequencesubscript𝐆𝜎𝐶𝑋~subscript𝒜1~subscript𝒜𝑀formulae-sequencewhere ~subscript𝒜𝑘𝙲𝚕𝚒𝚙subscript𝒜𝑘𝐶𝒩0superscript𝜎2superscript𝐶2for-all𝑘delimited-[]𝑀\mathbf{G}_{\sigma C}\mathcal{H}(X)=[\widetilde{\mathcal{A}_{1}},\ldots,% \widetilde{\mathcal{A}_{M}}],\text{where\quad}\widetilde{\mathcal{A}_{k}}=% \mathtt{Clip}({\mathcal{A}_{k}},C)+\mathcal{N}\left(0,\sigma^{2}C^{2}\right),% \forall k\in[M].bold_G start_POSTSUBSCRIPT italic_σ italic_C end_POSTSUBSCRIPT caligraphic_H ( italic_X ) = [ over~ start_ARG caligraphic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG , … , over~ start_ARG caligraphic_A start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT end_ARG ] , where over~ start_ARG caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG = typewriter_Clip ( caligraphic_A start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_C ) + caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) , ∀ italic_k ∈ [ italic_M ] . (55)
Lemma 7.

Given the function \mathcal{H}caligraphic_H with 2subscriptnormal-ℓ2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT sensitivity C𝐶Citalic_C, Gaussian standard deviation σ2C2superscript𝜎2superscript𝐶2\sigma^{2}C^{2}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_C start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , the Gaussian mechanism with 𝐆σCsubscript𝐆𝜎𝐶\mathbf{G}_{\sigma C}\mathcal{H}bold_G start_POSTSUBSCRIPT italic_σ italic_C end_POSTSUBSCRIPT caligraphic_H satisfies client-level (α,α/(2σ2))𝛼𝛼2superscript𝜎2(\alpha,\alpha/(2\sigma^{2}))( italic_α , italic_α / ( 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) )-RDP.

Proof.

The 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT sensitivity for the function \mathcal{H}caligraphic_H is C𝐶Citalic_C by Lemma 6. The Gaussian standard deviation for the noise-perturbed embedding is σC𝜎𝐶\sigma Citalic_σ italic_C, which is proportional to the clipping constant C𝐶Citalic_C. Combining it with Theorem 5 yields the conclusion that 𝐆σCsubscript𝐆𝜎𝐶\mathbf{G}_{\sigma C}\mathcal{H}bold_G start_POSTSUBSCRIPT italic_σ italic_C end_POSTSUBSCRIPT caligraphic_H guarantees client-level (α,α/(2σ2))𝛼𝛼2superscript𝜎2(\alpha,\alpha/(2\sigma^{2}))( italic_α , italic_α / ( 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) )-RDP. ∎

We note that the training process in the server does not access the raw data Xksubscript𝑋𝑘X_{k}italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, thus it does not increase the privacy budget and the whole algorithm in one round satisfies RDP by the post-processing property of RDP. For algorithms with T𝑇Titalic_T communication rounds, we use the RDP Composition theorem (Theorem 4) to accumulate the privacy budge over T𝑇Titalic_T rounds, and convert the RDP guarantee into DP guarantee (Theorem 3).

Finally, we recall Theorem 2 and provide the formal proof. See 2

Proof.

At each communication round, according to Lemma 7, 𝐆σCsubscript𝐆𝜎𝐶\mathbf{G}_{\sigma C}\mathcal{H}bold_G start_POSTSUBSCRIPT italic_σ italic_C end_POSTSUBSCRIPT caligraphic_H satisfies client-level (α,α2σ2)𝛼𝛼2superscript𝜎2(\alpha,\frac{\alpha}{2\sigma^{2}})( italic_α , divide start_ARG italic_α end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG )-RDP. Due to the post-processing property of RDP, after server training, our DP algorithms (i.e., DP versions of Algorithm 1,  3) with one round still satisfy client-level (α,ϵ(α))𝛼superscriptitalic-ϵ𝛼(\alpha,\epsilon^{\prime}(\alpha))( italic_α , italic_ϵ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_α ) )-RDP. Based on RDP Composition theorem (Theorem 4), our DP algorithms with T𝑇Titalic_T communication rounds satisfy client-level (α,Tα2σ2)𝛼𝑇𝛼2superscript𝜎2(\alpha,\frac{T\alpha}{2\sigma^{2}})( italic_α , divide start_ARG italic_T italic_α end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG )-RDP. Based on the connection between RDP and DP in Theorem 3, our DP algorithms with T𝑇Titalic_T communication rounds also satisfy client-level (Tα2σ2+logα1αlogδ+logαα1,δ)𝑇𝛼2superscript𝜎2𝛼1𝛼𝛿𝛼𝛼1𝛿(\frac{T\alpha}{2\sigma^{2}}+\log\frac{\alpha-1}{\alpha}-\frac{\log\delta+\log% \alpha}{\alpha-1},\delta)( divide start_ARG italic_T italic_α end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + roman_log divide start_ARG italic_α - 1 end_ARG start_ARG italic_α end_ARG - divide start_ARG roman_log italic_δ + roman_log italic_α end_ARG start_ARG italic_α - 1 end_ARG , italic_δ )-DP.

-D Experimental Details and Additional Results

-D1 Platform

We simulate the vertical federated learning setup (1 server and N clients) on a Linux machine with AMD Ryzen Threadripper 3990X 64-Core CPUs and 4 NVIDIA GeForce RTX 3090 GPUs. The algorithms are implemented by PyTorch [58]. Please see the submitted code for full details. We run each experiment 3 times with different random seeds.

-D2 Hyperparameters

We detail our hyperparameter tuning protocol and the hyperparameter values here. For all VFL training experiments, we use the SGD optimizer with learning rate η𝜂\etaitalic_η for the server’s model, and the SGD optimizer with momentum 0.9 and learning rate η𝜂\etaitalic_η for client k𝑘kitalic_k’s local model. The regularization weight β𝛽\betaitalic_β is set to 0.005. The embedding dimension dfsubscript𝑑𝑓d_{f}italic_d start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT is set to 60, and batch size b𝑏bitalic_b is set to 1024 for all datasets.

Vanilla VFL training

For Vanilla VFL training experiments, we tune learning rates by performing a grid search separately for all methods over {0.05,0.1,0.3,0.5,0.8}0.050.10.30.50.8\{0.05,0.1,0.3,0.5,0.8\}{ 0.05 , 0.1 , 0.3 , 0.5 , 0.8 } on MNIST, {0.003,0.005,0.008,0.01,0.05,0.1}0.0030.0050.0080.010.050.1\{0.003,0.005,0.008,0.01,0.05,0.1\}{ 0.003 , 0.005 , 0.008 , 0.01 , 0.05 , 0.1 } on CIFAR, {0.05,0.1,0.5}0.050.10.5\{0.05,0.1,0.5\}{ 0.05 , 0.1 , 0.5 } on NUS-WIDE, {0.0005,0.005,0.01,0.05,0.1}0.00050.0050.010.050.1\{0.0005,0.005,0.01,0.05,0.1\}{ 0.0005 , 0.005 , 0.01 , 0.05 , 0.1 } on ModelNet40. For ADMM-based methods, we tune penalty factor ρ𝜌\rhoitalic_ρ with a search grid {0.5,1,2}0.512\{0.5,1,2\}{ 0.5 , 1 , 2 } on all datasets.

Differentially private VFL training

We leverage the PyTorch Differential Privacy library Opacus 777https://github.com/pytorch/opacus to calculate the privacy budgets ϵitalic-ϵ\epsilonitalic_ϵ. In all experiments, δ=1e5𝛿1𝑒5\delta=1e-5italic_δ = 1 italic_e - 5. For each privacy budget ϵitalic-ϵ\epsilonitalic_ϵ, we perform a grid search for the combination of hyperparameters (including noise scale σ𝜎\sigmaitalic_σ, clipping threshold C𝐶Citalic_C, and learning rate η𝜂\etaitalic_η) for all methods for a fair comparison. The noise scale is tuned from σ𝜎\sigmaitalic_σ {2,3,5,8,10,30,50,70}235810305070\{2,3,5,8,10,30,50,70\}{ 2 , 3 , 5 , 8 , 10 , 30 , 50 , 70 } on all datasets. C𝐶Citalic_C is tuned from {0.0005,0.001,0.005,0.01,0.1,1}0.00050.0010.0050.010.11\{0.0005,0.001,0.005,0.01,0.1,1\}{ 0.0005 , 0.001 , 0.005 , 0.01 , 0.1 , 1 } and η𝜂\etaitalic_η is tuned from {0.05,0.3,0.5,1}0.050.30.51\{0.05,0.3,0.5,1\}{ 0.05 , 0.3 , 0.5 , 1 } for MNIST; C𝐶Citalic_C is tuned from {0.01,0.05,0.1,0.5,1}0.010.050.10.51\{0.01,0.05,0.1,0.5,1\}{ 0.01 , 0.05 , 0.1 , 0.5 , 1 } and η𝜂\etaitalic_η is tuned from {0.005,0.01,0.05,0.1,0.5,1}0.0050.010.050.10.51\{0.005,0.01,0.05,0.1,0.5,1\}{ 0.005 , 0.01 , 0.05 , 0.1 , 0.5 , 1 } for CIFAR; C𝐶Citalic_C is tuned from {0.001,0.005,0.01,0.05,0.1}0.0010.0050.010.050.1\{0.001,0.005,0.01,0.05,0.1\}{ 0.001 , 0.005 , 0.01 , 0.05 , 0.1 } and η𝜂\etaitalic_η is tuned from {0.05,0.1,0.3,0.5,1}0.050.10.30.51\{0.05,0.1,0.3,0.5,1\}{ 0.05 , 0.1 , 0.3 , 0.5 , 1 } for NUS-WIDE; C𝐶Citalic_C is tuned from {0.01,0.05,0.1,0.5,1}0.010.050.10.51\{0.01,0.05,0.1,0.5,1\}{ 0.01 , 0.05 , 0.1 , 0.5 , 1 } and η𝜂\etaitalic_η is tuned from {0.05,0.1,0.5}0.050.10.5\{0.05,0.1,0.5\}{ 0.05 , 0.1 , 0.5 } for ModelNet40. We use the best hyper-parameters to start 3 runs with different random seeds and report the average results for each method.

Client-level explainability

In the experiments of client importance validation via noisy test client, for each time, we perturb the features of all test samples at one client by adding Gaussian noise sampled from 𝒩(0,σ¯2)𝒩0superscript¯𝜎2\mathcal{N}\left(0,{\bar{\sigma}}^{2}\right)caligraphic_N ( 0 , over¯ start_ARG italic_σ end_ARG start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) to its features. In order to observe the difference in test accuracy between important clients and unimportant clients, we set σ¯¯𝜎\bar{\sigma}over¯ start_ARG italic_σ end_ARG to 10 for MNIST, 1 for CIFAR and NUS-WIDE, and 3 for ModelNet40.

In the experiments of client denoising, we construct one noisy client (i.e., client 7, 5, 2, 3 for MNIST, CIFAR, NUS-WIDE, ModelNet40 respectively) by adding Gaussian noise sampled from 𝒩(0,σ~2)𝒩0superscript~𝜎2\mathcal{N}\left(0,{\widetilde{\sigma}}^{2}\right)caligraphic_N ( 0 , over~ start_ARG italic_σ end_ARG start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) to all its training samples and test samples. We set σ~~𝜎\widetilde{\sigma}over~ start_ARG italic_σ end_ARG to 1 for MNIST, NUS-WIDE and ModelNet40, and 3 for CIFAR.

-D3 Additional Evaluation Results

Effect of ρ𝜌\rhoitalic_ρ

Here we report the test accuracy of VIMADMM with different penalty factor ρ𝜌\rhoitalic_ρ in Figure 5, which show that VIMADMM is not sensitive to ρ𝜌\rhoitalic_ρ on four datasets.






MNIST CIFAR NUS-WIDE ModelNet40
Refer to caption Refer to caption Refer to caption Refer to caption
Figure 5: Performance of VIMADMM with different penalty factor ρ𝜌\rhoitalic_ρ on four datasets. VIMADMM is not sensitive to ρ𝜌\rhoitalic_ρ from 0.5 to 2.
Results for a large number of clients.

We evaluate baselines and our methods under 100 clients on MNIST by allowing the agents to obtain overlapped features, and the results show that our methods still outperform baselines. Specifically, we divide the features into 100 overlapped subsets for 100 clients so that each client has 14 pixels. The results in Table XI show that VIM methods (i.e., VIMADMM, VIMADMM-J) have higher accuracy than baselines in both w/ and w/o model splitting settings.

TABLE XI: Performance of Vanilla VFL when M=100𝑀100M=100italic_M = 100 on MNIST
W/ model splitting W/o model splitting
VAFL Split Learning VIMADMM FDML VIMADMM-J
95.38 95.45 95.77 95.85 95.96
Results for client subsampling under client-level DP

We extended our study under client-level DP by incorporating a subsampling mechanism into the VIMADMM framework to save the privacy budget. Specifically, during each communication round, the server receives the local embeddings from p%percent𝑝p\%italic_p % clients, corresponding to a participation rate of p%percent𝑝p\%italic_p %. To address missing local embeddings, the server leverages historical local embeddings from other clients to complete their absent local embeddings.

We calculate the privacy budget ϵisubscriptitalic-ϵ𝑖\epsilon_{i}italic_ϵ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for client following our Theorem 2, where T𝑇Titalic_T denotes the number of communication rounds that client i𝑖iitalic_i uploads local embeddings, instead of the total number of communication rounds as in our original algorithm. Due to the non-overlapping nature of local data among clients, the concatenated output matrix from all clients satisfies the maxiϵisubscript𝑖subscriptitalic-ϵ𝑖\max_{i}\epsilon_{i}roman_max start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_ϵ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT client-level DP guarantee according to DP parallel composition.

As shown in Table XII, there is a utility loss when p%=25%,50%percent𝑝percent25percent50p\%=25\%,50\%italic_p % = 25 % , 50 % compared to p%=100%percent𝑝percent100p\%=100\%italic_p % = 100 % under DP and non-DP settings on MNIST and CIFAR. This discrepancy demonstrates the necessity of aggregating local outputs from all clients during training to achieve optimal utility in vertical federated learning.

TABLE XII: The utility of VIMADMM with different client subsampling ratio under client-level DP.
MNIST CIFAR
25% 50% 100% 25% 50% 100%
ϵ=italic-ϵ\epsilon=\inftyitalic_ϵ = ∞ 94.20 94.52 97.13 65.78 66.10 75.25
ϵ=8italic-ϵ8\epsilon=8italic_ϵ = 8 85.98 87.22 92.35 62.01 62.86 73.83
ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1 85.51 86.62 91.09 46.53 46.69 61.65
Additional results on client denoising

Table XIII presents the test accuracy of VAFL, Split Learning, and VIMADMM at different epochs (communication rounds) on different datasets under one noisy client. Note that each epoch consists of N/b𝑁𝑏N/bitalic_N / italic_b communication rounds. Table XIII shows that under the noisy training scenario, VIMADMM consistently outperform Split Learning and VAFL with faster convergence and higher test accuracy, which indicates the effectiveness of VIM’s multiple linear heads in client denoising.

TABLE XIII: Test accuracy under one noisy client whose training local features and test local features are perturbed by Gaussian noise.
Method Test accuracy @ epoch (communication round)
MNIST CIFAR NUS-WIDE ModelNet40
2 (106) 5 (265) 10 (530) 2 (88) 5 (220) 10 (440) 2 (106) 5 (265) 10 (530) 2 (18) 5 (45) 10 (90)
VAFL 91.0791.0791.0791.07±plus-or-minus{}\pm{}± 0.170.170.170.17\endcollectcell 94.3694.3694.3694.36±plus-or-minus{}\pm{}± 0.160.160.160.16\endcollectcell 95.5995.5995.5995.59±plus-or-minus{}\pm{}± 0.110.110.110.11\endcollectcell 28.8328.8328.8328.83±plus-or-minus{}\pm{}± 1.041.041.041.04\endcollectcell 38.7738.7738.7738.77±plus-or-minus{}\pm{}± 0.390.390.390.39\endcollectcell 46.9846.9846.9846.98±plus-or-minus{}\pm{}± 0.700.700.700.70\endcollectcell 51.8851.8851.8851.88±plus-or-minus{}\pm{}± 0.720.720.720.72\endcollectcell 77.6877.6877.6877.68±plus-or-minus{}\pm{}± 0.740.740.740.74\endcollectcell 85.3185.3185.3185.31±plus-or-minus{}\pm{}± 0.150.150.150.15\endcollectcell 43.2343.2343.2343.23±plus-or-minus{}\pm{}± 3.073.073.073.07\endcollectcell 80.1380.1380.1380.13±plus-or-minus{}\pm{}± 1.101.101.101.10\endcollectcell 89.5689.5689.5689.56±plus-or-minus{}\pm{}± 0.410.410.410.41\endcollectcell
Split Learning 95.0495.0495.0495.04±plus-or-minus{}\pm{}± 0.140.140.140.14\endcollectcell 96.0196.0196.0196.01±plus-or-minus{}\pm{}± 0.030.030.030.03\endcollectcell 96.4396.4396.4396.43±plus-or-minus{}\pm{}± 0.080.080.080.08\endcollectcell 42.7542.7542.7542.75±plus-or-minus{}\pm{}± 0.130.130.130.13\endcollectcell 50.0650.0650.0650.06±plus-or-minus{}\pm{}± 0.180.180.180.18\endcollectcell 55.5355.5355.5355.53±plus-or-minus{}\pm{}± 0.370.370.370.37\endcollectcell 85.3585.3585.3585.35±plus-or-minus{}\pm{}± 0.240.240.240.24\endcollectcell 86.4286.4286.4286.42±plus-or-minus{}\pm{}± 0.240.240.240.24\endcollectcell 87.1487.1487.1487.14±plus-or-minus{}\pm{}± 0.290.290.290.29\endcollectcell 77.9477.9477.9477.94±plus-or-minus{}\pm{}± 1.001.001.001.00\endcollectcell 88.7488.7488.7488.74±plus-or-minus{}\pm{}± 0.070.070.070.07\endcollectcell 89.6989.6989.6989.69±plus-or-minus{}\pm{}± 0.420.420.420.42\endcollectcell
VIMADMM 96.2296.2296.2296.22±plus-or-minus{}\pm{}± 0.070.070.070.07\endcollectcell 96.6096.6096.6096.60±plus-or-minus{}\pm{}± 0.040.040.040.04\endcollectcell 96.8296.8296.8296.82±plus-or-minus{}\pm{}± 0.070.070.070.07\endcollectcell 67.0867.0867.0867.08±plus-or-minus{}\pm{}± 0.430.430.430.43\endcollectcell 70.7070.7070.7070.70±plus-or-minus{}\pm{}± 0.340.340.340.34\endcollectcell 71.7671.7671.7671.76±plus-or-minus{}\pm{}± 0.140.140.140.14\endcollectcell 86.3886.3886.3886.38±plus-or-minus{}\pm{}± 0.200.200.200.20\endcollectcell 87.0087.0087.0087.00±plus-or-minus{}\pm{}± 0.270.270.270.27\endcollectcell 87.1887.1887.1887.18±plus-or-minus{}\pm{}± 0.140.140.140.14\endcollectcell 90.0590.0590.0590.05±plus-or-minus{}\pm{}± 0.380.380.380.38\endcollectcell 90.7190.7190.7190.71±plus-or-minus{}\pm{}± 0.310.310.310.31\endcollectcell 90.5990.5990.5990.59±plus-or-minus{}\pm{}± 0.050.050.050.05\endcollectcell
Reference accuracy for SOTA model and reference model in the centralized setting

we included the comparisons in Table XIV, which shows the reference accuracy for SoTA models and a simple reference model in a centralized setting on four datasets. Specifically,

  • SoTA models: These models may employ different model architectures and training methods compared to our approach. They serve as a virtual upper bound as the highest achievable accuracy on each dataset.

  • Reference model in the centralized setting. This reference model has the same model size as one local model coupled with a server model.

For instance, on the MNIST dataset, the latest SoTA method in a centralized setting achieves an accuracy of 99.87%, while a basic reference model (comprising one local model followed by a server model) reaches 98.19%. In comparison, our VFL model (consisting of M local models followed by a server model) demonstrates a comparable accuracy of 97.13%.

Furthermore, on datasets such as NUS-WIDE and ModelNet, our VFL model even surpasses the accuracy of the reference model in a centralized setting. This is attributable to the significantly higher number of model parameters in VFL. For example, in ModelNet, we utilized four ResNet-18 feature extractors as local models for four different clients, allowing for a more nuanced understanding and representation of the data.

TABLE XIV: Accuracy for SOTA model and reference model in the centralized setting.
MNIST CIFAR NUS-WIDE (5 classes) ModelNet40 (2D multi-views)
SOTA method in centralized setting (virtual upper bound) 99.87 [9] 99.50 [18] 88.7 [47] 96.6 [73]
Reference model in centralized setting (e.g., one local model followed by server model) 98.19 77.61 87.71 88.96
VIMADMM Model (e.g., M𝑀Mitalic_M local models followed by server model) 97.13 75.25 88.51 91.32
VIMADMM on larger models

VIMADMM can scale well to large model such as ResNet-18 as shown in the experiments on ModelNet40. Leveraging the larger models as feature extractors for clients, VIMADMM can produce higher-quality local embeddings, which are also crucial for learning accurate linear heads on the server side. Here we also report the results of VIMADMM on CIFAR with CNN, ResNet-18, and ResNet-34, which are 75.25%, 81.35%, and 82.58%, respectively. It shows that a larger model can lead to higher accuracy for VIMADMM, validating its scalability and efficiency.

VIMADMM on non-image tasks

VIMADMM can be adapted to non-image tasks, such as datasets with both text and image modalities. For example, in the NUS-WIDE dataset, which encompasses both text and image features as local datasets, VIMADMM achieves state-of-the-art results as shown in Figure 1. This adaptability is due to VIMADMM’s flexible design, which can handle heterogeneous input data types via different feature extractors (e.g., local models) in the clients, and then aggregate heterogeneous local embeddings via multiple linear heads in the server. We believe these results underscore VIMADMM’s potential in a broader range of applications beyond image tasks.

VIMADMM under long-tail datasets

Long-tail datasets are characterized by a significant imbalance, where minority classes have far fewer samples than majority ones. This horizontal imbalance is distinct from the challenges addressed by vertical federated learning, where the same sample (whether it belongs to a majority or minority class) is vertically split across multiple clients. This means that in vertical federated learning, minority class samples are still be evenly distributed among clients. We conduct additional experiments on long-tail data.

We create long-tail training datasets following [71] with an imbalance factor of 10 (i.e., the ratio of samples in the head to tail class). Specifically, for MNIST, this resulted in class sample sizes of [6000, 4645, 3596, 2784, 2156, 1669, 1292, 1000, 774, 600] for 10 classes. For CIFAR, the class sample sizes are [5000, 3871, 2997, 2320, 1796, 1391, 1077, 834, 645, 500] across the 10 classes.

We compared the VIMADMM model, which consists M𝑀Mitalic_M local models followed by a server model, with a reference model in a centralized setting. This reference model has the same model size as one local model coupled with a server model.

We show the results in Table V, demonstrating that our VIMADMM is still effective on challenging long-tail training datasets, yielding results comparable to those of the reference model in a centralized setting. Moreover, the long-tail version of MNIST dataset does not significantly impact the accuracy of VIMADMM compared to the original MNIST dataset.

-E Discussion

DP generative model in VFL

An alternative way to achieve DP in VFL is to locally train a DP generative model for each client, and then send the DP generative models to the server, which takes only one communication round. However, we identified several key challenges that make it less suitable for our VFL context:

  • Mismatch of Synthetic Partial Features Across Clients. In VFL, there are M𝑀Mitalic_M clients holding different subsets of features for the same training samples (denoted as xj1,xj2,,xjMsuperscriptsubscript𝑥𝑗1superscriptsubscript𝑥𝑗2superscriptsubscript𝑥𝑗𝑀x_{j}^{1},x_{j}^{2},\ldots,x_{j}^{M}italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT for sample j𝑗jitalic_j). A generative model, due to its stochastic nature, would generate synthetic partial features without correspondence to a specific training sample (e.g., the parietal features generated from a local generative model would adhere to the local data distribution, but do not correspond to a particular original sample j𝑗jitalic_j). This lack of correspondence means that the server cannot effectively concatenate the synthetic partial features into a cohesive “global” dataset for training. This is a limitation of the generative model in VFL.

  • Quality Concerns Due to Partial Features. Given that each client only has partial features (see Figure 3 row 1 for the visualization of raw local features), a generative model trained locally (without FL) might yield lower-quality data. This is particularly problematic in cases where partial features are not informative (e.g., clients having only background pixels in image datasets like MNIST/CIFAR). The state-of-the-art accuracy for synthetic data in centralized learning with sample-level DP on MNIST is around 97.6% under ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1 and 98.2% under ϵ=10italic-ϵ10\epsilon=10italic_ϵ = 10 [37]. Since the partial features in VFL are less informative than the full features in centralized learning, the DP genetive model in VFL would lead to lower accuracy. On the other hand, our VIMADMM DP learning algorithm already achieves a promising accuracy that is close to the state-of-the-art: 91.35% under ϵ=1italic-ϵ1\epsilon=1italic_ϵ = 1 and 92.35% under ϵ=8italic-ϵ8\epsilon=8italic_ϵ = 8 in VFL under client-level DP, which is a stricter privacy notion than sample-level DP.

  • Scalability Issues of DP Generative Model with High-Dimensional Data. DP generative models often struggle with high-dimensional datasets [37, 72]. For instance, their performance on datasets like CIFAR10 is limited [72], posing a challenge for more complex datasets like ModelNet40. Additionally, the generation of multi-modal data (e.g., text and image features in NUS-WIDE) remains an unresolved challenge. In our method, as we are not training generative models, the high dimensionality of data will not pose a significant challenge to VIMADMM.

  • Communication Overhead with Generative Model. The model size of a standard DCGAN [61] implemented in PyTorch 888https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html is 13.65 MB. As it only takes one round for communication, the communication costs for each client would be 13.65 MB. In comparison, VIMADMM demonstrates similar communication costs on certain datasets. For example, on the ModelNet40 dataset, VIMADMM achieves an accuracy of 89% with a total communication cost of 11.32 MB (See Figure 1 and Table IV for details). This efficiency stems from the transmission of local embeddings and ADMM-related variables, which collectively have a smaller size than the number of parameters within a deep neural network like a DCGAN generator. This trend becomes even more evident if we use a larger generative model than DCGAN, which is a common direction in current generative AI advancements. Moreover, the high-quality local embeddings (e.g., from a pretrained ResNet-18 as local model) and multiple local updates at each communication round (enabled by ADMM) significantly aid convergence. Consequently, a relatively small number of communication rounds (approximately 10) is required to reach an accuracy of 89% on ModelNet40.

    We remain open to future research exploring the feasibility and optimization of the local training of DP generative models in VFL settings.

Linear head and client importance

In Section VI-C, we utilize the norm of weights in the linear head, learned by the server from local embeddings, to determine the importance of the corresponding client (and its local features). Our approach is in line with existing methods such as LIME [63], SHAP [49], and others [29] that utilize model weights to determine feature importance. In our model, we follow existing work by assuming feature independence  [63, 49, 29] to simplify the interpretation of weights in terms of feature importance.

Challenges of ADMM algorithm design in VFL

There are several key challenges of designing ADMM algorithm in VFL for distributed optimization:

  • how to ensure the consensus among clients and form it as a constrained optimization problem (e.g., from Eq. III-B to Eq. III-B).

  • how to decompose the optimization problem into small sub-problems that can be solved in parallel by ADMM.

For the first challenge, although ADMM is flexible to introduce auxiliary variables and thus formulate a constrained optimization problem in HFL, it raises new challenges in VFL. For example, the ADMM-based methods in HFL [23, 22, 38, 81] usually use the global model as the auxiliary variable and enforce the consistency between the global model and each local model. However, VFL communicates embeddings, and it is not feasible to enforce local embeddings from different clients to be the same as they provide unique information from different aspects. Therefore, in this paper, we introduce the auxiliary variable zjsubscript𝑧𝑗z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT for each sample j𝑗jitalic_j and construct the constraint between zjsubscript𝑧𝑗z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and server’s output k=1MhjkWksuperscriptsubscript𝑘1𝑀superscriptsubscript𝑗𝑘subscript𝑊𝑘\sum_{k=1}^{M}h_{j}^{k}W_{k}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT (i.e., the logits), which enables the optimization for each Wksubscript𝑊𝑘W_{k}italic_W start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT by ADMM.

For the second challenge, we propose the bi-level optimization for server’s model and clients’ models to train DNNs for VFL with model splitting, while the existing ADMM-based method in VFL [35] only considers logistic regression with linear models in client-side, which does not apply to DNNs. The initial attempt we made is to decompose the optimization for server’s linear heads by ADMM while still using chain rule of SGD to update local models, which does not exhibit much superiority over pure SGD-based methods. Later, we decompose the optimization for both server’s linear heads and local models by ADMM, leading to our current algorithm VIMADMM that enables multiple local updates for clients at each communication round and achieves significantly better performance, as we show in Sec. VI-A.

Limitations

Directly deploying VFL algorithms without stopping criteria or regularization techniques may lead to the over-fitting phenomenon, as in many other algorithms. Based on our experiments, we find that over-fitting is a common problem of VFL algorithms due to a large number of model parameters from all clients in the whole VFL system. Compared to centralized learning or horizontal FL, the prediction for one data sample in VFL involves M𝑀Mitalic_M times model parameters, which corresponds to M𝑀Mitalic_M partitions of input features. To prevent over-fitting, we use regularizers to constrain the complexity of models and adopt standard stopping criteria, i.e., stop training when the model converges or the validation accuracy starts to drop more than 2%percent22\%2 %.