Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

Model-Based Deep Learning

Download as pdf or txt
Download as pdf or txt
You are on page 1of 35

Model-Based Deep Learning

This article reviews leading strategies for designing systems whose operation combines
domain knowledge and data via model-based deep learning in a tutorial fashion.
By N IR S HLEZINGER , Member IEEE, JAY W HANG , Y ONINA C. E LDAR , Fellow IEEE,
AND A LEXANDROS G. D IMAKIS , Fellow IEEE

ABSTRACT | Signal processing, communications, and con- signal processing and machine learning that incorporate the
trol have traditionally relied on classical statistical modeling advantages of both domains.
techniques. Such model-based methods utilize mathemati-
KEYWORDS | Deep learning; model-based machine learning;
cal formulations that represent the underlying physics, prior
signal processing.
information, and additional domain knowledge. Simple classi-
cal models are useful but sensitive to inaccuracies and may
lead to poor performance when real systems display complex I. I N T R O D U C T I O N
or dynamic behavior. On the other hand, purely data-driven Traditional signal processing is dominated by algorithms
approaches that are model-agnostic are becoming increasingly that are based on simple mathematical models that are
popular as datasets become abundant and the power of mod- hand-designed from domain knowledge. Such knowledge
ern deep learning pipelines increases. Deep neural networks can come from statistical models based on measurements
(DNNs) use generic architectures that learn to operate from and an understanding of the underlying physics or from
data and demonstrate excellent performance, especially for the fixed deterministic representation of the particular
supervised problems. However, DNNs typically require mas- problem at hand. These domain-knowledge-based process-
sive amounts of data and immense computational resources, ing algorithms, which we refer to henceforth as model-
limiting their applicability for some scenarios. In this article, based methods, carry out inference based on knowledge
we present the leading approaches for studying and design- of the underlying model relating the observations at hand
ing model-based deep learning systems. These are methods and the desired information. Model-based methods do
that combine principled mathematical models with data-driven not rely on data to learn their mapping though data are
systems to benefit from the advantages of both approaches. often used to estimate a small number of parameters.
Such model-based deep learning methods exploit both partial Fundamental techniques, such as the Kalman filter and
domain knowledge, via mathematical structures designed for message-passing algorithms, belong to the class of model-
specific problems, and learning from limited data. Among the based methods. Classical statistical models rely on simpli-
applications detailed in our examples for model-based deep fying assumptions (e.g., linear systems, and Gaussian and
learning are compressed sensing, digital communications, and independent noises) that make inference tractable, under-
tracking in state-space models. Our aim is to facilitate the standable, and computationally efficient. On the other
design and study of future systems at the intersection of hand, simple models frequently fail to represent nuances
of high-dimensional complex data and dynamic variations.
The incredible success of deep learning, e.g., on
Manuscript received 3 November 2021; revised 7 September 2022; vision [1], [2], and challenging games, such as Go [3] and
accepted 13 February 2023. Date of publication 1 March 2023; date of current
version 17 May 2023. (Corresponding author: Nir Shlezinger.)
Starcraft [4], has initiated a general data-driven mindset.
Nir Shlezinger is with the School of Electrical and Computer Engineering (ECE), It is currently prevalent to replace simple principled mod-
Ben-Gurion University of the Negev, Be’er Sheva 8410501, Israel (e-mail:
nirshl@bgu.ac.il).
els with purely data-driven pipelines, trained with mas-
Jay Whang is with the Department of Computer Science (CS), The University of sive labeled datasets. In particular, deep neural networks
Texas at Austin, Austin, TX 78712 USA (e-mail: jaywhang@cs.utexas.edu).
(DNNs) can be trained in a supervised way end-to-end
Yonina C. Eldar is with the Faculty of Mathematics and Computer Science (CS),
Weizmann Institute of Science, Rehovot 7632706, Israel (e-mail: to map inputs to predictions. The benefits of data-driven
yonina@weizmann.ac.il).
methods over model-based approaches are twofold: First,
Alexandros G. Dimakis is with the Department of Electrical and Computer
Engineering (ECE), The University of Texas at Austin, Austin, TX 78712 USA purely data-driven techniques do not rely on analytical
(e-mail: dimakis@austin.utexas.edu). approximations and, thus, can operate in scenarios where
Digital Object Identifier 10.1109/JPROC.2023.3247480 analytical models are not known. Second, for complex

0018-9219 © 2023 IEEE. Personal use is permitted, but republication/redistribution requires IEEE permission.
See https://www.ieee.org/publications/rights/index.html for more information.

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 465


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

systems, data-driven algorithms are able to recover fea-


tures from observed data that are needed to carry out
inference [5]. This is sometimes difficult to achieve ana-
lytically, even when complex models are perfectly known.
The computational burden of training and utilizing
highly parametrized DNNs, as well as the fact that massive
datasets are typically required to train such DNNs to learn
a desirable mapping, may constitute major drawbacks in
various signal processing, communications, and control
applications. This is particularly relevant for hardware-
limited devices, such as mobile phones, unmanned aerial
Fig. 1. Division of model-based deep learning techniques into
vehicles, and Internet-of-Things (IOT) systems, which are
categories and subcategories.
often limited in their ability to utilize highly parametrized
DNNs [6] and require adapting to dynamic conditions.
Furthermore, DNNs are commonly utilized as black boxes;
understanding how their predictions are obtained and carried out by a model-based algorithm whose operation
characterizing confidence intervals tends to be quite chal- is augmented with deep learning tools. This integration
lenging. As a result, deep learning does not yet offer of model-agnostic deep learning tools allows one to use
the interpretability, flexibility, versatility, and reliability of model-based inference algorithms while having access only
model-based methods [7]. to partial domain knowledge. Based on this division,
The limitations associated with model-based methods we provide concrete guidelines for studying, designing,
and black-box deep learning systems gave rise to a mul- and comparing model-based deep learning systems. An
titude of techniques for combining signal processing and illustration of the proposed division into categories and
machine learning to benefit from both approaches. These subcategories is depicted in Fig. 1.
methods are application-driven and are, thus, designed We begin by discussing the high-level concepts of model-
and studied in light of a specific task. For example, based, data-driven, and hybrid schemes. Since we focus
the combination of DNNs and model-based compressed on DNNs as the current leading data-driven technique,
sensing (CS) recovery algorithms was shown to facili- we briefly review basic concepts in deep learning, ensuring
tate sparse recovery [8], [9] and enable CS beyond the that the tutorial is accessible to readers without a back-
domain of sparse signals [10], [11]; deep learning was ground in deep learning. We then elaborate on the fun-
used to empower regularized optimization methods [12], damental strategies for combining model-based methods
[13], while model-based optimization contributed to the with deep learning. For each such strategy, we present a
design of DNNs for such tasks [14]. Digital communi- few concrete implementation approaches in a systematic
cation receivers are used DNNs to learn to carry out manner, including established approaches, such as deep
and enhance symbol detection and decoding algorithms unfolding, which was originally proposed in 2010 by
in a data-driven manner [15], [16], [17], while symbol Gregor and LeCun [8], as well as recently proposed
recovery methods enabled the design of model-aware deep model-based deep learning paradigms, such as DNN-aided
receivers [18], [19], [20], [21]. The proliferation of hybrid inference [22] and neural augmentation [23]. For each
model-based/data-driven systems, each designed for a approach, we formulate system design guidelines for a
unique task, motivates establishing a concrete systematic given problem, provide detailed examples from the recent
framework for combining domain knowledge in the form literature, and discuss its properties and use-cases. Each
of model-based methods and deep learning, which is the of our detailed examples focuses on a different application
focus of this article. in signal processing, communications, and control, demon-
In this article, we review leading strategies for designing strating the breadth and the wide variety of applications
systems whose operation combines domain knowledge and that can benefit from such hybrid designs. We conclude
data via model-based deep learning in a tutorial fashion. this article with a summary and a qualitative comparison
To that aim, we present a unified framework for studying of model-based deep learning approaches, along with a
hybrid model-based/data-driven systems, without focusing description of some future research topics and challenges.
on a specific application, while being geared toward fam- We aim to encourage future researchers and practitioners
ilies of problems typically studied in the signal processing with a signal processing background to study and design
literature. The proposed framework divides systems com- model-based deep learning.
bining model-based signal processing and deep learning This overview article focuses on strategies for design-
into two main strategies: The first category includes DNNs ing architectures whose operation combines deep learning
whose architecture is specialized to the specific problem with model-based methods, as illustrated in Fig. 1. These
using model-based methods, referred to here as model- strategies can also be integrated into existing mechanisms
aided networks. The second one, which we refer to as DNN- for incorporating model-based domain knowledge in the
aided inference, consists of techniques in which inference is selection of the task for which data-driven systems are

466 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

applied, as well as in the generation and manipulation of Ex,s∼px,s {l(f, x, s)}, where px,s is the underlying statisti-
the data. An example of a family of such mechanisms for cal model relating the input and the label. The goal of both
using model-based knowledge in the selection of the appli- model-based methods and data-driven schemes is to design
cation and the data is the learning-to-optimize framework, the inference rule f (·) to minimize the risk for a given
which is the focus of growing attention in the context of problem. The main difference between these strategies is
wireless networks’ design [24], [25], [26]; this framework what information is utilized to tune f (·).
advocates the usage of pretrained DNNs for realizing fast
solvers for complex optimization problems, which rely on
objectives and constraints formulated based on domain B. Model-Based Methods
knowledge, along with the usage of model-based gener-
ated data for off-line training. An additional related family Model-based algorithms, also referred to as hand-
is that of channel autoencoders, which integrates mathe- designed methods [31], set their inference rule, i.e., tune f
matical modeling of random communication channels as in (1) to minimize the risk function, based on domain
layers of deep autoencoders to design channel codes [27], knowledge. The term domain knowledge typically refers to
[28] and compression mechanisms [29]. prior knowledge of the underlying statistics relating the
The rest of this article is organized as follows. Section II input x and the label s.
discusses the concepts of model-based methods compared In particular, an analytical mathematical expression
to data-driven schemes and how they give rise to the describing the underlying model, i.e., px,s , is required.
model-based deep learning paradigm. Section III reviews Model-based algorithms can provably implement the
some basics of deep learning. The main strategies for risk-minimizing inference mapping, e.g., the maximum a
designing model-based deep learning systems, i.e., model- posteriori probability (MAP) rule. While computing the
aided networks and DNN-aided inference, are detailed risk-minimizing rule is often computationally prohibitive,
in Sections IV and V, respectively. Finally, we provide a various model-based methods approximate this rule at
summary and discuss some future research challenges in controllable complexity and, in some cases, also provably
Section VI. approach its performance. This is typically achieved using
iterative methods comprised of multiple stages, where
II. M O D E L - B A S E D V E R S U S each stage involves generic mathematical manipulations
D ATA - D R I V E N I N F E R E N C E and model-specific computations.
We begin by reviewing the main conceptual differences Model-based methods do not rely on data to learn their
between model-based and data-driven inference. To that mapping, as illustrated in the right part of Fig. 2, though
aim, we first present a mathematical formulation of a data are often used to estimate unknown model param-
generic inference problem. Then, we discuss how this eters. In practice, accurate knowledge of the statistical
problem is tackled from a purely model-based perspec- model relating the observations and the desired informa-
tive and from a purely data-driven one, where, for the tion is typically unavailable, and thus, applying such tech-
latter, we focus on deep learning as a family of generic niques commonly requires imposing some assumptions on
data-driven approaches. We then formulate the notion the underlying statistics, which, in some cases, reflects the
of model-based deep learning based on these distinct actual behavior, but may also constitute a crude approx-
strategies. imation of the true dynamics. In the presence of inac-
curate model knowledge, either as a result of estimation
A. Inference Systems errors or due to enforcing a model, which does not fully
capture the environment, the performance of model-based
The term inference refers to the ability to conclude based
techniques tends to degrade. This limits the applicability
on evidence and reasoning. While this generic definition
of model-based schemes in scenarios where, e.g., px,s is
can refer to a broad range of tasks, we focus, in our
unknown, costly to estimate accurately, or too complex to
description, on systems that estimate or make predictions
express analytically.
based on a set of observed variables. In this wide family of
problems, the system is required to map an input variable
x ∈ X into a prediction of a label variable s ∈ S , denoted
ŝ, where X and S are referred to as the input space and C. Data-Driven Schemes
the label space, respectively. An inference rule can, thus,
Data-driven systems learn their mapping from data. In a
be expressed as
supervised setting, data are comprised of a training set
f : X 7→ S (1)
consisting of nt pairs of inputs and their corresponding
and the space of inference mappings is denoted by F . labels, denoted {(xt , st )}n
t=1 . Data-driven schemes do not
t

We use l(·) to denote a cost measure defined over F × have access to the underlying distribution and, thus, can-
X × S , dictated by the specific task [30, Ch. 2]. The not compute the risk function. As a result, the inference
fidelity of an inference mapping is measured by the risk mapping is typically tuned based on an empirical risk
function, also known as the generalization error, given by function, referred to henceforth as loss function, which, for

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 467


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 2. Illustration of model-based versus data-driven inference. The red arrows correspond to the computation performed before the
particular inference data are received.

an inference mapping f , is given by based on recurrent neural networks (RNNs) [33] or atten-
tion mechanisms [34] are often preferred. Alternatively,
1 X
nt in the presence of spatial patterns, one may utilize convo-
L(f ) = l(f, xt , st ). (2) lutional layers [35]. An additional method to incorporate
nt
t=1
domain knowledge into a black box DNN is by preprocess-
ing the input via, e.g., feature extraction.
Since one can usually form an inference rule, which The generic nature of data-driven strategies induces
minimizes the empirical loss (2) by memorizing the data, some drawbacks. Broadly speaking, learning a large num-
i.e., overfit, data-driven schemes often restrict the domain ber of parameters requires a massive dataset to train
of feasible inference rules [30, Ch. 2]. A leading strategy on. Even when a sufficiently large dataset is avail-
in data-driven systems, upon which deep learning is based, able, the resulting training procedure is typically lengthy
is to assume some highly expressive generic parametric and involves a high computational burden. Finally, the
model on the mapping in (1) while incorporating opti- black-box nature of the resulting mapping implies that
mization mechanisms to avoid overfitting and allow the data-driven systems in general lack interpretability, making
resulting system to infer reliably with new data samples. it difficult to provide performance guarantees and insights
In such cases, the inference rule is dictated by a set of into the system operation.
parameters denoted θ , and thus, the system mapping is
written as fθ .
The conventional application of deep learning imple- D. Model-Based Deep Learning
ments fθ using a DNN architecture, where θ represents Completely separating existing literature into
the weights of the network. Such highly parametrized net- model-based versus data-driven is a subjective and
works can effectively approximate any Borel measurable debatable task. Instead, we focus on some approaches
mapping as follows from the universal approximation the- that clearly lie in the middle ground to give a useful
orem [32, Ch. 6.4.1]. Therefore, by properly tuning their overview of the landscape. The considered families of
parameters using sufficient training data, as we elaborate methods incorporate domain knowledge in the form of an
on in Section III, one should be able to obtain the desirable established model-based algorithm, which is suitable for
inference rule. the problem at hand, while combining capabilities to learn
Unlike model-based algorithms, which are specifically from data via deep learning techniques.
tailored to a given scenario, purely data-driven methods Model-based deep learning schemes tune their mapping
are model-agnostic, as illustrated in the left part of Fig. 2. of the input x based on both data, e.g., a labeled training
The unique characteristics of the specific scenario are set {(xt , st )}n
t=1 , as well as some domain knowledge,
t

encapsulated in the learned weights. The parametrized such as partial knowledge of the underlying distribution
inference rule, e.g., the DNN mapping, is generic and can px,s . Such hybrid data-driven model-aware systems can
be applied to a broad range of different problems. While typically learn their mappings from smaller training sets
standard DNN structures are highly model-agnostic and compared to purely model-agnostic DNNs and commonly
are commonly treated as black boxes, one can still incorpo- operate without full accurate knowledge of the underlying
rate some level of domain knowledge in the selection of the model upon which model-based methods are based.
specific network architecture. For instance, when the input Most existing techniques for implementing inference
is known to exhibit temporal correlation, architectures rules in a hybrid model-based/data-driven fashion are

468 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

designed for a specific application, i.e., to solve a given function will not overfit and be able to generalize, i.e., infer
problem rather than formulate a systematic methodol- reliably from new data samples. Since the optimization in
ogy. Nonetheless, one can identify a common rationale (3) is carried out over θ , we write the loss as L(θ) for
for categorizing existing schemes in a systematic manner brevity.
that is not tailored to a specific scenario. In particular, The above formulation naturally gives rise to three
model-based deep learning techniques can be divided into fundamental components of deep learning: the DNN archi-
two main strategies, as illustrated in Fig. 2. These strate- tecture that defines the function class F , the task-specific
gies may each be further specialized to various different loss function L(θ), and the optimizer that dictates how
tasks, as we show in the sequel. The first of the two, to search for the optimal fθ within F . Therefore, our
which we refer to as model-aided networks, utilizes DNNs review of the basics of deep learning commences with a
for inference; however, rather than using conventional description of the fundamental architecture and optimizer
DNN architectures, here, a specific DNN tailored for the components in Section III-A. We then present several repre-
problem at hand is designed by following the operation sentative tasks along with their corresponding typical loss
of suitable model-based methods. The second strategy, functions in Section III-B.
which we call DNN-aided inference systems, uses conven-
tional model-based methods for inference; however, unlike A. Deep Learning Preliminaries
purely model-based schemes, here, specific parts of the
The formulation of the parametric empirical risk in (3)
model-based algorithm are augmented with deep learning
is not unique to deep leaning and is in fact common
tools, allowing the resulting system to implement the algo-
to numerous machine learning schemes. The strength of
rithm while learning to overcome partial or mismatched
deep learning, i.e., its ability to learn accurate complex
domain knowledge from data.
mappings from large datasets, is due to its use of DNNs
The systematic categorization of model-based deep
to enable a highly expressive family of function classes F ,
learning methodologies can facilitate the study and design
along with dedicated optimization algorithms for tuning
of future techniques in different and diverse application
the parameters from data. In the following, we discuss the
areas. One may also propose schemes that combine aspects
high-level notion of DNNs, followed by a description of
from both categories, building upon the understanding of
how they are optimized.
the characteristics and gains of each approach, discussed
in the sequel. Since both strategies rely on deep learning 1) Neural Network Architecture: DNNs implement para-
tools, we first provide a brief overview of key concepts metric functions comprised of a sequence of differentiable
in deep learning in Section III, after which we elaborate transformations called layers, whose composition maps the
on model-aided networks and DNN-aided inference in input to a desired output. Specifically, a DNN fθ consisting
Sections IV and V, respectively. of k layers {h1 , . . . , hk } maps the input x to the output
ŝ = fθ (x) = hk ◦ · · · ◦ h1 (x), where ◦ denotes the function
III. B A S I C S O F D E E P L E A R N I N G composition. Since each layer hi is itself a parametric
Here, we cover the basics of deep learning required to function, the parameters’ set of the entire network fθ is the
understand the DNN-based components in the model- union of all of its layers’ parameters, and thus, fθ denotes
based/data-driven approaches discussed later. Our aim is a DNN with parameters θ . The architecture of a DNN refers
to equip the reader with the necessary foundations upon to the specification of its layers {hi }ki=1 .
which our formulations of model-based deep learning sys- A generic formulation that captures various
tems are presented. parametrized layers is that of an affine transformation,
As discussed in Section II-C, in deep learning, the target i.e., h(x) = W x + b whose parameters are {W , b}.
mapping is constrained to take the form of a parametrized For instance, in fully connected (FC) layers, also referred
function fθ : X → S . In particular, the inference mapping to as dense layers, one can optimize {W , b} to take
belongs to a fixed family of functions F specified by a any value. Another extremely common affine transform
predefined DNN architecture, which is represented by a layer is convolutional layers. Such layers apply a set of
specific choice of the parameter vector θ . Once the function discrete convolutional kernels to signals that are possibly
class F and the loss function L are defined, where the comprised of multiple channels, e.g., tensors. The vector
latter is dictated by the training data (2) while possibly representation of their output can be written as an affine
including some regularization on θ , one may attempt to mapping of the form W x + b, where x is the vectorization
find the function, which minimizes the loss within F , i.e., of the input, and W is constrained to represent multiple
channels of discrete convolutions [32, Ch. 9]. These
θ ∗ = arg min L(fθ ). (3) convolutional neural networks (CNNs) are known to
fθ ∈F yield a highly parameter-efficient mapping that captures
important invariances such as translational invariance in
A common challenge in optimizing based on (3) is to image data.
guarantee that the inference mapping learned using the While many commonly used layers are affine, DNNs rely
data-based loss function rather than the model-based risk on the inclusion of nonlinear layers. If all the layers of a

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 469


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

DNN were affine, the composition of all such layers would As discussed in Section II-C, the goal is to recover a
also be affine, and thus, the resulting network would only mapping fθ that minimizes the risk function, i.e., the
represent affine functions. For this reason, layers in a DNN generalization error. This is done by optimizing the DNN
are interleaved with activation functions, which are simple mapping fθ using the data-based empirical loss function
nonlinear functions applied to each dimension of the input L(θ) (2). This setting encompasses a wide range of prob-
separately. Activations are often fixed, i.e., their mapping lems, including regression, classification, and structured
is not parametric and is thus not optimized in the learning prediction, through a judicious choice of the loss function.
process. Some notable examples of widely used activation In the following, we review commonly used loss functions
functions include the rectified linear unit (ReLU) defined for classification and regression tasks.
as ReLU(x) = max{x, 0} and the sigmoid σ(x) = (1 +
a) Classification: Perhaps, one of the most widely
exp(−x))−1 .
known success stories of DNNs, classification (image clas-
2) Choice of Optimizer: Given a DNN architecture and sification in particular), has remained a core benchmark
a loss function L(θ), finding a globally optimal θ that since the introduction of AlexNet [38]. In this setting,
minimizes L is a hopelessly intractable task, especially at we are given a training set {(xt , st )}n t=1 containing input-
t

the scale of millions of parameters or more. Fortunately, label pairs, where each xt is a fixed-size input, e.g.,
the recent success of deep learning has demonstrated that an image, and st is the one-hot encoding of the class. Such
gradient-based optimization methods work surprisingly one-hot encoding of class c can be viewed as a probability
well despite their inability to find global optima. The vector for a K -way categorical distribution, with K = |S|,
simplest such method is gradient descent, which iteratively with all probability mass placed on class c.
updates the parameters The DNN mapping fθ for this task is appropriately
designed to map an input xt to the probability vector
θ q+1 = θ q − ηq ∇θ L(θ q ) (4) ŝt ≜ f (xt ) = ⟨ŝt,1 , . . . , ŝt,K ⟩, where ŝt,c denotes the
cth component of ŝt . This parametrization allows for the
model to return a soft decision in the form of a categorical
where ηq is the step size that may change as a function of
distribution over the classes.
the step count q . Since the gradient ∇θ L(θ q ) is often too
A natural choice of loss function for this setting is the
costly to compute over the entire training data, it is esti-
cross-entropy loss, defined as
mated from a small number of randomly chosen samples
(i.e., a minibatch). The resulting optimization method is nt K
called minibatch stochastic gradient descent and belongs to 1 XX
LCE (θ) = st,c (− log ŝt,c ). (5)
the family of stochastic first-order optimizers. nt
t=1 c=1
Stochastic first-order optimization techniques are
well-suited for training DNNs because their memory usage For a sufficiently large set of independent identically dis-
grows only linearly with the number of parameters, and tributed (i.i.d.) training pairs, the empirical cross-entropy
they avoid the need to process the entire training data loss approaches the expected cross-entropy measure,
at each step of optimization. Over the years, numerous which is minimized when the DNN output matches the true
variations of stochastic gradient descent have been pro- conditional distribution ps|x . Consequently, minimizing the
posed. Many modern optimizers, such as RMSProp [36] cross-entropy loss encourages the DNN output to match
and Adam [37], use statistics from previous parameter the ground-truth label, and its mapping closely approaches
updates to adaptively adjust the step size for each param- the true underlying posterior distribution when properly
eter separately (i.e., for each dimension of θ ). trained.
The formulation of the cross-entropy loss (5) implicitly
B. Common Deep Learning Tasks assumes that the DNN returns a valid probability vector,
PK
As detailed above, the data-driven nature of deep learn- i.e., ŝt,c ≥ 0 and c=1 ŝt,c = 1. However, there is no
ing is encapsulated in the dependence of the loss func- guarantee that this will be the case, especially at the begin-
tion on the training data. Thus, the loss function not ning of training when the parameters of the DNN are more
only implicitly defines the task of the resulting system or less randomly initialized. To guarantee that the DNN
but also dictates what kind of data is required. Based mapping yields a valid probability distribution, classifiers
on the requirements placed on the training data, prob- typically employ the softmax function (e.g., on top of the
lems in deep learning largely fall under three different output layer), given by
categories: supervised, semisupervised, and unsupervised. * +
Here, we define each category and list some example tasks exp(x1 ) exp(xd )
Softmax(x) = , . . . , Pd
and their typical loss functions. Pd
exp(xi ) exp(xi )
i=1 i=1
1) Supervised Learning: In supervised learning, the
training data consist of a set of input-label pairs where xi is the ith entry of x. Due to the exponentiation
{(xt , st )}n
t=1 , where each pair takes values in X × S .
t
followed by normalization, the output of the softmax

470 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

function is guaranteed to be a valid probability vector. For simplicity, in the following, we describe the original
In practice, one can compute the softmax function of the GAN formulation of [39]. Here, Dφ : X → [0, 1] is a binary
network outputs when evaluating the loss function, rather classifier trained to distinguish real examples xt from the
than using a dedicated output layer. fake examples generated by Gθ , and the GAN loss function
b) Regression: Another task where DNNs have been is the minmax loss.
successfully applied is regression, where one attempts to The loss is optimized in an alternating fashion by tun-
predict continuous variables instead of categorical ones. ning the discriminator φ to minimize LD (·) for a given
Here, the labels {st } in the training data represent some generator θ , followed by a corresponding optimization of
continuous value, e.g., in R or some specified range [a, b]. the generator based on its loss LG (·). These loss measures
Similar to the usage of softmax layers for classifica- are given by
tion, an appropriate final activation function σ is needed,
nt
depending on the range of the variable of interest. For −1 X 
LD (φ|θ) = log Dφ (xt ) + log 1−Dφ Gθ (z t )
example, when regressing on a strictly positive value, 2nt
t=1
a common choice is σ(x) = exp(x) or the softplus activa- nt
−1 X
tion σ(x) = log(1+exp(x)) so that the range of the network

LG (θ|φ) = log log Dφ Gθ (z t ) .
nt
fθ is constrained to be the positive reals. When the output t=1

is to be limited to an interval [a, b], then one may use the


mapping σ(x) = a + (b − a)(1 + tanh(x))/2. Here, the latent variables {z t } are drawn from its known
Arguably, the most common loss function for regression prior distribution for each minibatch.
tasks is the empirical mean square error (MSE), i.e., Among currently available deep generative models,
GANs achieve the best sample quality at an unprece-
1 X
nt dented resolution. For example, the current state-of-the-art
LMSE (θ) = (st − ŝt )2 . (6) model StyleGAN2 [45] is able to generate high-resolution
nt
t=1
(1024 × 1024) images that are nearly indistinguishable
from real photographs to a human observer. That said,
2) Unsupervised Learning: In unsupervised learning, GANs do come with several disadvantages as well. The
we are only given a set of examples {xt }n t=1 without
t
adversarial training procedure is known to be unstable,
labels. Since there is no label to predict, unsupervised and many tricks are necessary for practice to train a large
learning algorithms are often used to discover interesting GAN. Also, because GANs do not offer any probabilistic
patterns present in the given data. Common tasks in this interpretation, it is difficult to objectively evaluate the
setting include clustering, anomaly detection, generative quality of a GAN.
modeling, and compression.
a) Generative models: One goal in unsupervised learn- b) Autoencoders: Another well-studied task in unsu-
ing of a generative model is to train a generator network pervised learning is the training of an autoencoder, which
Gθ (z) such that the latent variables z , which follow a sim- has many uses such as dimensionality reduction and repre-
ple distribution, such as standard Gaussian, are mapped sentation learning. An autoencoder consists of two neural
into samples obeying a distribution similar to that of the networks: an encoder fenc : X 7→ Z and a decoder fdec :
training data [32, Ch. 20]. For instance, one can train Z 7→ X , where Z is some predefined latent space. The
a generative model to map Gaussian vectors into images primary goal of an autoencoder is to reconstruct a signal x
of human faces. A popular type of DNN-based generative from itself by mapping it through fdec ◦ fenc .
model that tries to achieve this goal is generative adver- The task of autoencoding may seem pointless at first;
sarial network (GAN) [39], which has shown remarkable indeed, one can trivially recover x by setting Z = X
success in many domains. and fenc , fdec to be identity functions. The interesting case
GANs learn the generative model by employing a dis- is when one imposes constraints that limit the ability of
criminator network Dφ to assess the generated samples, the network to learn the identity mapping [32, Ch. 14].
thus avoiding the need to mathematically handcraft a loss One way to achieve this is to form an undercomplete
measure quantifying their quality. The parameters {θ, φ} autoencoder, where the latent space Z is restricted to be
of the two networks are learned via adversarial training, lower dimensional than X , e.g., X = Rn and Z = Rm
where θ and φ are updated in an alternating manner. The for some m < n. This constraint forces the encoder to
two networks Gθ and Dφ “compete” against each other to map its input into a more compact representation while
achieve opposite goals: Gθ tries to fool the discriminator, retaining enough information so that the reconstruction is
whereas Dφ tries to reliably distinguish real examples from as close to the original input as possible. Additional mech-
the fake ones made by the generator. anisms for preventing an autoencoder from learning the
Various methods have been proposed to train generative identity mapping include imposing a regularizing term on
models in this adversarial fashion, including, e.g., the the latent representation, as done in sparse autoencoders
Wasserstein GAN [40], [41], the least-squares GAN [42], and contractive autoencoders, or alternatively, by distort-
the Hinge GAN [43], and the relativistic average GAN [44]. ing the input to the system, as carried out by denoising

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 471


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

autoencoders [32, Ch. 14.2]. A common metric used to diagram algorithmic representation, which neural building
measure the quality of reconstruction is the MSE loss. blocks rely upon, as presented in Section IV-B. The dedi-
Under this setting, we obtain the following loss function cated neural network is then formulated as a discrimina-
for training: tive architecture [54], [55] whose trainable parameters,
intermediate mathematical manipulations, and intercon-
1 X
nt nections follow the operations of the model-based algo-
LMSE (fenc , fdec ) = ∥xt −fdec (fenc (xt ))∥22 . (7) rithm, as illustrated in Fig. 3.
nt
t=1
In the following, we describe these methodologies in a
systematic manner. In particular, our presentation of each
3) Semisupervised Learning: Semisupervised learning approach commences with a high-level description and
lies in the middle ground between the above two cate- generic design outline, followed by one or two concrete
gories, where one typically has access to a large amount of model-based deep learning examples from the literature,
unlabeled data and a small set of labeled data. The goal is and concludes with a summarizing discussion. For each
to leverage the unlabeled data to improve performance on example, we first detail the system model and model-based
some supervised tasks to be trained on the labeled data. algorithm from which it originates. Then, we describe
As labeling data is often a very costly process, semisu- the hybrid model-based/data-driven system by detailing
pervised learning provides a way to quickly learn desired its architecture and training procedure, and present some
inference rules without having to label all of the available representative quantitative results.
unlabeled data points.
Various approaches have been proposed in the literature
A. Deep Unfolding
to utilize unlabeled data for a supervised task; see the
detailed survey [46]. One such common technique is to Deep unfolding [56], also referred to as deep unrolling,
guess the missing labels while integrating dedicated mech- converts an iterative algorithm into a DNN by designing
anisms to boost confidence [47]. This can be achieved by, each layer to resemble a single iteration. Deep unfolding
e.g., applying the DNN to various augmentations of the was originally proposed by Greger and LeCun [8], where
unlabeled data [48] while combining multiple regulariza- a deep architecture was designed to learn to carry out
tion terms for encouraging consistency and low-entropy the iterative soft thresholding algorithm (ISTA) for sparse
of the guessed labels [49], as well as training a teacher recovery. Deep unfolded networks have since been applied
DNN using the available labeled data to produce guessed in various applications in image denoising [57], [58],
labels [50]. sparse recovery [9], [31], [59], dictionary learning [51],
[60], communications [18], [19], [61], [62], [63], [64],
ultrasound [65], and superresolution [66], [67], [68].
IV. M O D E L - A I D E D N E T W O R K S A recent review can be found in [7].
Model-aided networks implement model-based deep
learning by using model-aware algorithms to design deep 1) Design Outline: The application of deep unfolding
architectures. Broadly speaking, model-aided networks to design a model-aided deep network is based on the
implement the inference system using a DNN, similar following steps.
to conventional deep learning. Nonetheless, instead of 1) Identify an iterative optimization algorithm that is
applying generic off-the-shelf DNNs, the rationale here useful for the problem at hand. For instance, recov-
is to tailor the architecture specifically for the scenario ering a sparse vector from its noisy projections can be
of interest, based on a suitable model-based method. tackled using ISTA, unfolded into LISTA in [8].
By converting a model-based algorithm into a model-aided 2) Fix a number of iterations in the optimization algo-
network, which learns its mapping from data, one typically rithm.
achieves improved inference speed, as well as overcome 3) Design the layers to imitate the free parameters of
partial or mismatched domain knowledge. In particular, each iteration in a trainable fashion.
model-aided networks can learn missing model parame- 4) Train the overall resulting network end-to-end.
ters, such as channel matrices [19], dictionaries [51], and The selection of the free parameters to learn in the third
noise covariances [52], as part of the learning procedure. step determines the resulting trainable architecture. One
Alternatively, it can be used to learn a surrogate model can set these parameters to be the hyperparameters of the
for which the resulting inference rule best matches the iterative optimizer (such as step size), thus leveraging data
training data [53]. to automatically determine parameters typically selected
Model-aided networks obtain dedicated DNN architec- by hand [53]. Alternatively, the architecture may be
tures by identifying structures in a model-based algorithm designed to learn the parameters of the objective optimized
that one would have utilized for the problem given full in each iteration, thus achieving a more abstract family
domain knowledge and sufficient computational resources. of inference rules compared with the original iterative
Such structures can be given in the form of an iterative algorithm, or even convert the operation of each iteration
representation of the model-based algorithm, as exploited into a trainable neural architecture. We next demonstrate
by deep unfolding detailed in Section IV-A, or via a block how this rationale is translated into concrete architectures,

472 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 3. Model-aided DNN illustration: (a) model-based algorithm comprised of a series of model-aware computations and generic
mathematical steps and (b) DNN whose architecture and interconnections are designed based on the model-based algorithm. Here, data can
be used to train the overall architecture end-to-end, typically requiring the intermediate mathematical steps to be either differentiable or
well-approximated by a differentiable mapping.

using two examples: the first is the DetNet system of [18] N × 1 observations x, which are related via
that unfolds projected gradient descent optimization; the
second is the unfolded dictionary learning for Poisson x = Hs + w. (8)
image denoising proposed in [51].

where H is a known deterministic N × K channel matrix


2) Example 1: Deep Unfolded Projected Gradient Descent: and w consists of N i.i.d. Gaussian random variables (RVs).
Projected gradient descent is a simple and common itera- For our presentation, we consider the case in which the
tive algorithm for tackling constrained optimization. While entries of s are symbols generated from a binary phase
the projected gradient descent method is quite generic shift keying (BPSK) constellation in a uniform i.i.d. man-
and can be applied in a broad range of constrained ner, i.e., S = {±1}K . In this case, the MAP rule given
optimization setups, in the following, we focus on its an observation x becomes the minimum distance estimate,
implementation for symbol detection in linear memoryless given by
multiple-input–multiple-output (MIMO) Gaussian chan-
ŝ = arg min ∥x − Hs∥2 . (9)
nels. In such cases, where the constraint follows from s∈{±1}K
the discrete nature of digital communication symbols, the
iterative projected gradient descent gives rise to the DetNet b) Projected gradient descent: While directly solving
architecture proposed in [18] via deep unfolding. (9) involves an exhaustive search over the 2K possible
symbol combinations, it can be tackled with affordable
a) System model: Consider the problem of symbol computational complexity using the iterative projected
detection in linear memoryless MIMO Gaussian channels. gradient descent algorithm. This method, whose derivation
The task is to recover the K -dimensional vector s from the is detailed in Appendix VI-A, is summarized as Algorithm 1,

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 473


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

where PS (·) denotes the projection operator into S , which, layers, given by
for BPSK constellations, is the elementwise sign function.
nt Q
1 XX
L(θ) = log(q)∥st − ŝq (xt ; θ)∥2 (11)
Algorithm 1 Projected Gradient Descent for Sys- nt
t=1 q=1
tem Model (8)
Init: Fix step-size η > 0. Set initial guess ŝ0 where ŝq (xt ; θ) is the output of the q th layer of Det-
1 for q = 0, 1, . . . do Net with parameters θ and input xt . This loss measure
2 Update   accounts for the interpretable nature of the unfolded net-
ŝq+1 = PS ŝq − ηH T x + ηH T Hŝq . work, in which the output of each layer is a further refined
estimate of s.
3 end
Quantitative Results: The experiments reported in [18]
Output: Estimate ŝ = sq .
indicate that, when provided sufficient training examples,
DetNet outperforms leading MIMO detection algorithms
based on approximate message passing and semidefinite
c) Unfolded DetNet: DetNet unfolds the projected gra-
relaxation. It is also noted in [18] that the unfolded
dient descent iterations, repeated until convergence in
network requires an order of magnitude fewer layers
Algorithm 1, into a DNN, which learns to carry out this
compared to the number of iterations required by the
optimization procedure from data. To formulate DetNet,
model-based optimizer to converge. This gain is shown
we first fix a number of iterations Q. Next, a DNN with
to be translated into reduced run time during inference,
Q layers is designed, where each layer imitates a single
particularly when processing batches of data in parallel.
iteration of Algorithm 1 in a trainable manner.
In particular, it is reported in [18, Tbl. 1] that DetNet suc-
Architecture: DetNet builds upon the observation that
cessfully detects a batch of 1000 channel outputs in a 60 ×
the update rule in Step 2 of Algorithm 1 consists of two
30 static MIMO channel at run time, which is three times
stages: gradient descent computation, i.e., gradient step
faster than that required by approximate message passing
ŝq − ηH T x + ηH T Hŝq ; and projection, namely, applying
to converge, and over 80 times faster than semidefinite
PS (·). Therefore, each unfolded iteration is represented
relaxation.
as two sublayers. The first sublayer learns to compute the
gradient descent stage by treating the step size as a learned 3) Example 2: Deep Unfolded Dictionary Learning: DetNet
parameter and applying an FC layer with ReLU activation exemplifies how deep unfolding can be used to realize
to the obtained value. For iteration index q , this results in rapid implementations of exhaustive optimization algo-
rithms that typically require a very large amount of iter-
ations to converge. However, DetNet requires full domain
   
z q = ReLU W 1,q (I +δ2,q H T H)ŝq−1 −δ1,q H T x +b1,q
knowledge, i.e., it assumes that the system model fol-
lows (8), and the channel parameters H are known.
in which {W 1,q , b1,q , δ1,q , δ2,q } are learnable parameters. An additional benefit of deep unfolding is its ability to learn
The second sublayer learns the projection operator by missing model parameters along with the overall optimiza-
approximating the sign operation with a soft sign activa- tion procedure, as we illustrate in the following example
tion preceded by an FC layer, leading to proposed in [51], which focuses on dictionary learning for
Poisson image denoising. Similar examples where channel
ŝq = soft sign (W 2,q z q + b2,q ) . (10) knowledge is not required in deep unfolding can be found
in, e.g., [19], [57], and [64].
a) System model: Consider the problem of recon-
Here, the learnable parameters are {W 2,q , b2,q }. The
structing an image µ ∈ RN from its noisy measure-
resulting deep network is depicted in Fig. 4, in which
ments x ∈ RN . The image is corrupted by Poisson noise,
the output after Q iterations, denoted ŝQ , is used as
namely, px|µ is a multivariate Poisson distribution with
the estimated symbol vector by taking the sign of each
mutually independent entries and mean µ. Furthermore,
element.
it is assumed that, for the clean image µ, it holds that
Training: Let θ = {(W 1,q , W 2,q , b1,q , b2,q , δ1,q , δ2,q )}Q
q=1
log(µ) (taken elementwise) can be written as
be the trainable parameters of DetNet.1 To tune θ , the
overall network is trained end-to-end to minimize the
empirical weighted ℓ2 norm loss over its intermediate log(µ) = Hs. (12)

1 The formulation of DetNet in [18] includes an additional sublayer in In (12), H , referred to as the dictionary, is an unknown
each iteration intended to further lift its input into higher dimensions and block-Toeplitz matrix (representing a convolutional dictio-
introduce additional trainable parameters, as well as reweighing of the nary), while s is an unknown sparse vector.
outputs of subsequent layers. As these operations do not follow directly
from the unfolding projected gradient descent, they are not included in b) Proximal gradient mapping: The recovery of the
the description here. clean image µ is tackled by alternating optimization [69].

474 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 4. DetNet illustration. Parameters in red fonts are learned in training, while those in blue fonts are externally provided.

In each iteration, one first recovers s for a fixed H , after (DCEA) architecture proposed in [51] unfolds the proximal
which s is set to be fixed and H is estimated. The resulting gradient iterations in Step 5 of Algorithm 2. By doing so,
iterative algorithm, whose detailed derivation is given in it avoids the need to learn the dictionary H by alternating
Appendix VI-B, is summarized as Algorithm 2. Here, η > 0 optimization, as it is implicitly learned from data in the
is the step size; 1 is the all-ones vector; b is a threshold training procedure of the unfolded network.
parameter; and Tb is the soft-thresholding operator, also Architecture: DCEA treats the two-step convolutional
referred to as the shrinkage operator, applied elementwise sparse coding problem as an autoencoder, where the
and is given by Tb (x) = sign(x) max{|x| − b, 0}. Further- encoder computes the sparse vector s by unfolding Q
more, the optimization variable H in Step 2 is constrained proximal gradient iterations as in Step 5 of Algorithm 2.
to be block-Toeplitz. The decoder then converts ŝ produced by the encoder into
a recovered clean image µ̂.
Algorithm 2 Alternating Image Recovery and In particular, Tolooshams et al. [51] proposed two
Dictionary Learning for System Model (12) implementations of DCEA. The first, referred to as DCEA-C,
Init: Fix step-size η > 0. Set initial guess s0 directly implements Q proximal gradient iterations fol-
1 for l = 0, 1, . . . do lowed by the decoding step, which computes µ̂, where
2 Update both the encoder and the decoder use the same value of the
Ĥ l = arg min 1T exp (Hsl ) − xT Hsl . dictionary matrix H . This is replaced with a convolutional
H layer and is learned via end-to-end training along with the
3 Set ŝ0 = sl . thresholding parameters, bypassing the need to explicitly
4 for q = 0, 1, . . . do recover the dictionary for each image, as in Step 2 of Algo-
5 Update rithm 2. The second implementation, referred to as DCEA-
UC, decouples the convolution kernels of the encoder and
ŝq+1 = Tb ŝq +ηH T x − exp (Hŝq ) the decoder, and lets the encoder carry out Q iterations of

.
the form
 
6 end ŝq+1 = Tb ŝq + ηW T2 (x − exp (W 1 ŝq )) . (13)
7 Set sl+1 = ŝq .
8 end
Here, W 1 and W 2 are convolutional kernels that are
Output: Estimate clean image via
 not constrained to be equal to H used by the decoder.2
µ̂ = exp Ĥ l sl .
2 The architecture proposed in [51] is applicable for various

c) Deep convolutional exponential-family autoencoder: exponential-family noise signals. Particularly, for Poisson noise, an addi-
tional exponential linear unit was applied to x − exp (W 1 ŝq ), which
The hybrid model-based/data-driven architecture enti- was empirically shown to improve the convergence properties of the
tled deep convolutional exponential-family autoencoder network.

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 475


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 5. DCEA illustration. Parameters in red fonts are learned in training, while those in blue fonts are externally provided.

An illustration of the resulting architecture is depicted in DNN design, which resembles an iterative optimization
Fig. 5. algorithm. Compared to conventional DNNs, unfolded
Training: The parameters of DCEA are θ = {H, b} networks are typically interpretable, tend to have a smaller
for DCEA-C and θ = {W 1 , W 2 , H, b} for DCEA-UC. The number of parameters, and can, thus, be trained more
vector b ∈ RC is comprised of the thresholding parameters quickly [7], [61]. A key advantage of deep unfolding over
used at each channel. When applied for Poisson image model-based optimization is inference speed. For instance,
denoising, DCEA is trained in a supervised manner using unfolding projected gradient descent iterations into
the MSE loss, namely, a set of nt clean images {µt }n t=1 are
t
DetNet allows inferring with much fewer layers compared
nt
used along with their Poisson noisy version {xt }t=1 . By let- to the number of iterations required by the model-based
ting fθ (·) denote the resulting mapping of the unfolded algorithm to converge. Similar observations have been
network, the loss function is formulated as made in various unfolded algorithms [58], [66].
One of the key properties of unfolded networks is their
nt reliance on knowledge of the model describing the setup
1 X
L(θ) = ∥µt − fθ (xt )∥2 . (14) (though not necessarily on its parameters). For example,
nt
t=1
one must know that the image is corrupted by Poisson
noise to formulate the iterative procedure in Algorithm 2
Quantitative Results: The experimental results reported
unfolded into DCEA or that the observations obey a linear
in [51] evaluated the ability of the unfolded DCEA-C and
Gaussian model to unfold the projected gradient descent
DCEA-UC in recovering images corrupted with different
iterations into DetNet. However, the parameters of this
levels of Poisson noise. An example of an image denoised
model, e.g., the matrix H in (8) and (12), can be either
by the unfolded system is depicted in Fig. 6. It was
provided based on domain knowledge, as done in DetNet,
noted in [51] that the proposed approach allows achiev-
or alternatively, learned in the training procedure, as car-
ing similar and even improved results to those of purely
ried out by DCEA. The model awareness of deep unfolding
data-driven techniques based on black-box CNNs [70].
has its advantages and drawbacks. When the model is
However, the fact that the denoising system is obtained
accurately known, deep unfolding essentially incorporates
by unfolding the model-based optimizer in Step 5 of Algo-
it into the DNN architecture, as opposed to conventional
rithm 2 allows this performance to be achieved while utiliz-
black-box DNNs which must learn it from data. However,
ing 3%–10% of the overall number of trainable parameters
this approach does not exploit the model-agnostic nature
as those used by the conventional CNN.
of deep learning and, thus, may lead to degraded per-
4) Discussion: Deep unfolding incorporates formance when the true relationship between the mea-
model-based domain knowledge to obtain a dedicated surements and the desired quantities deviates from the

476 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 6. Illustration of an image corrupted by different levels of Poisson noise and the resulting denoised images produced by the unfolded
DCEA-C and DCEA-UC. Figure reproduced from [51] with authors’ permission.

model assumed in the design. Nonetheless, training an a learned fashion but also in the identification of the
unfolded network designed with a mismatched model specific task of each block, as well as the ability to convert
using data corresponding to the true underlying scenario known statistical relationships, such as causal graphs into
typically yields more accurate inference compared to the dedicated DNN architectures.
model-based iterative algorithm with the same model mis-
1) Design Outline: The application of neural building
match, as the unfolded network can learn to compensate
blocks to design a model-aided deep network is based on
for this mismatch [64].
the following steps.
1) Identify an algorithm or a flowchart structure that is
B. Neural Building Blocks useful for the problem at hand and can be decom-
posed into multiple building blocks.
The neural building block is an alternative approach
2) Identify which of these building blocks should be
to design model-aided networks, which can be treated
learned from data and what is their concrete task.
as a generalization of deep unfolding. It is based on
3) Design a dedicated neural network for each building
representing a model-based algorithm, or alternatively
block capable of learning to carry out its specific task.
prior knowledge of an underlying statistical model, as an
4) Train the overall resulting network, either in an end-
interconnection of distinct building blocks. Neural build-
to-end fashion or by training each building block
ing blocks implement a DNN comprised of multiple sub-
network individually.
networks. Each module learns to carry out the specific
We next demonstrate how one can design a model-aided
computations of the different building blocks constituting
network comprised of neural building blocks. Our example
the model-based algorithm, as done in [16], [71], [72],
focuses on symbol detection in flat MIMO channels, where
and [73], or to capture a known statistical relationship,
we consider the data-driven implementation of the itera-
as in [74].
tive soft interference cancellation (SIC) scheme of [75],
Neural building blocks are designed for scenarios that
which is the DeepSIC algorithm proposed in [16].
are tackled using algorithms with flow diagram repre-
sentations, which can be captured as a sequential and 2) Example 3: DeepSIC for MIMO Detection: Iterative
parallel interconnection of building blocks. In particular, SIC [75] is an MIMO detection method suitable for lin-
deep unfolding can be obtained as a special case of ear Gaussian channels, i.e., the same channel models as
neural building blocks, where the original algorithm is that described in the example of DetNet in Section IV-A.
an iterative optimizer, such that the building blocks are DeepSIC is a hybrid model-based/data-driven implemen-
interconnected in a sequential fashion and implemented tation of the iterative SIC scheme [16]. However, unlike
using a single layer. However, the generalization of neural its model-based counterpart and alternative deep MIMO
building blocks compared to deep unfolding is not encap- receivers [18], [19], [61], DeepSIC is not particularly
sulated merely in its ability to implement nonsequential tailored for linear Gaussian channels and can be utilized
interconnections between algorithmic building blocks in in various flat MIMO channels. We formulate DeepSIC by

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 477


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

first reviewing the model-based iterative SIC and present a set of neural building blocks, thus circumventing these
DeepSIC as its data-driven implementation. limitations of its model-based counterpart.
a) Iterative soft interference cancellation: The iterative Architecture: The iterative SIC algorithm can be viewed
SIC algorithm proposed in [75] is an MIMO detection as a set of interconnected basic building blocks, each
method that combines multistage interference cancella- implementing the two stages of interference cancellation
tion with soft decisions. The detector operates iteratively, and soft decoding, as illustrated in Fig. 7(a). While the
where, in each iteration, an estimate of the conditional block diagram in Fig. 7(a) is ignorant of the underly-
probability mass function (PMF) of sk , which is the kth ing channel model, the basic building blocks are model-
entry of s, given the observed x, is generated for every dependent. Although each of these basic building blocks
symbol k ∈ {1, 2, . . . , K} := K. Each PMF, which is an consists of two sequential procedures that are completely
(q)
|S| × 1 vector denoted p̂k at the q th iteration, is com- channel-model-based, the purpose of these computations
puted using the corresponding estimates of the interfer- is to carry out a classification task. In particular, the kth
ing symbols {sl }l̸=k obtained in the previous iteration. building block of the q th iteration, k ∈ K, produces p̂(q)k ,
Iteratively repeating this procedure refines the PMF esti- which is an estimate of the conditional PMF of sk given
(q−1)
mates, allowing to accurately recover each symbol from x based on {p̂l }l̸=k . Such computations are naturally
the output of the last iteration. This iterative procedure implemented by classification DNNs, e.g., FC networks
is illustrated in Fig. 7(a) and summarized as Algorithm 3, with a softmax output layer. Embedding these conditional
whose derivation is detailed in Appendix VI-C. Algorithm 3 PMF computations into the iterative SIC block diagram in
is detailed for linear Gaussian models as in (8), assuming Fig. 7(a) yields the overall receiver architecture depicted
2
that the noise w has variance σw . We use hl to denote in Fig. 7(b).
the lth column of H , while N (µ, Σ) is the Gaussian A major advantage of using classification DNNs as the
distribution with mean µ and covariance Σ. basic building blocks in Fig. 7(b) stems from their ability
to accurately compute conditional distributions in complex
nonlinear setups without requiring a priori knowledge of
Algorithm 3 Iterative SIC for System Model (8)
the channel model and its parameters. Consequently, when
(0)
Init: Set initial PMFs guess {p̂k }Kk=1 these building blocks are trained to properly implement
1 for q = 0, 1, . . . do their classification task, the receiver essentially realizes
(q)
2 For each k ∈ K, compute expected values ek iterative SIC for arbitrary channel models in a data-driven
(q)
and variance vk from p̂k .
(q) fashion.
3 Interference cancellation: For each k ∈ K Training: In order for DeepSIC to reliably implement
compute symbol detection, its building block classification DNNs
must be properly trained. Two possible training approaches
(q+1)
X (q) are considered based on a labeled set of nt samples
zk =x− hl el . {(st , xt )}n
t=1 .
t

l̸=k
1) End-to-end training: The first approach jointly trains
the entire network, i.e., all the building block DNNs.
4 Soft decoding: For each k ∈ K, estimate Since the output of the deep network is the set of
(q+1)
p̂k
(q+1)
as the PMF of sk given z k , PMFs {p̂(Q) K
k }k=1 , the sum cross-entropy loss is used.
assuming that Let θ be the network parameters and p̂(Q) k (x, α; θ) be
the entry of p̂(Q)
k corresponding to s k = α when the
(q+1)
 X (q)  input to the network parameterized by θ is x. The
zk 2
|sk ∼ N hk sk , σw IK + vl hl hTl . sum cross-entropy loss is
l̸=k

nt K
1 XX (Q) 
5 end L(θ) = − log p̂k xt , (st )k ; θ . (15)
nt
Output: Estimate ŝ by setting each ŝk as the t=1 k=1

symbol maximizing the estimated PMF


(q)
p̂k . Training the interconnection of DNNs in Fig. 7(b) end-
to-end based on (15) jointly updates the coefficients
of all the K · Q building block DNNs. For a large
b) DeepSIC: Iterative SIC is specifically designed for number of symbols, i.e., large K , training so many
linear channels of the form (8). In particular, the inter- parameters simultaneously is expected to require a
ference cancellation Step 3 of Algorithm 3 requires the large labeled set.
contribution of the interfering symbols to be additive. Fur- 2) Sequential training: The fact that DeepSIC is imple-
thermore, it requires accurate complete knowledge of the mented as an interconnection of neural building
underlying statistical model, i.e., of (8). DeepSIC propsoed blocks implies that each block can be trained with
in [16] learns to implement the iterative SIC from data as a reduced number of training samples. Specifically,

478 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 7. Iterative SIC illustration: (a) model-based method and (b) DeepSIC.

the goal of each building block DNN does not depend compares DeepSIC to the model-based iterative SIC and
on the iteration index: The kth building block of the the data-driven DetNet [18]. Fig. 8(b) considers a Poisson
q th iteration outputs a soft estimate of sk for each channel, where x is related to s via a multivariate Poisson
iteration q . Therefore, each building block DNN can distribution, for which schemes requiring a linear Gaussian
be trained individually by minimizing the conven- model, such as the iterative SIC algorithm, are not suitable.
tional cross-entropy loss. To formulate this objective, The ability to use DNNs as neural building blocks to carry
let θ (q)
k represent the parameters of the kth DNN at out their model-based algorithmic counterparts in a robust
(q−1)
iteration q and write p̂(q)
(q)
k (x, {p̂l }l̸=k , α; θ k ) as and model-agnostic fashion is demonstrated in Fig. 8.
(q)
the entry of p̂k corresponding to sk = α when the In particular, it is demonstrated that DeepSIC approaches
DNN parameters are θ (q) k and its inputs are x and the SER values of the iterative SIC algorithm in linear
(q−1)
{p̂l }l̸=k . The cross-entropy loss is Gaussian channels while notably outperforming it in the
presence of model mismatch, as well as when applied
−1 X
nt in non-Gaussian setups. It is also observed in Fig. 8(a)
(q)  (q) (q−1) (q) 
L θk = log p̂k x̃t , {p̂t,l }l̸=k , (s̃t )k ; θ k that the resulting architecture of DeepSIC can be trained
nt
t=1
(16) with smaller datasets compared to alternative data-driven
where {p̂(q− 1)
} represent the estimated PMFs associ- receivers, such as DetNet.
t,l
ated with xi computed at the previous iteration. The
problem with training each DNN individually is that 3) Discussion: The main rationale in designing DNNs
1)
the soft estimates {p̂(q−t,l } are not provided as part as interconnected neural building blocks is to facilitate
of the training set. This challenge can be tackled by learned inference by preserving the structured operation
training the DNNs corresponding to each layer in a of a model-based algorithm applicable to the problem
sequential manner, where, for each layer, the outputs at hand given full domain knowledge. As discussed ear-
of the trained previous iterations are used as the soft lier, this approach can be treated as an extension of
estimates fed as training samples. deep unfolding, allowing to exploit additional structures
beyond a sequential iterative operation. The generalization
Quantitative Results: Two experimental studies of Deep- of deep unfolding into a set of learned building blocks
SIC taken from [16] are depicted in Fig. 8. These results opens additional possibilities in designing model-aided
compare the symbol error rate (SER) achieved by DeepSIC, networks.
which learns to carry out Q = 5 SIC iterations from nt = First, the treatment of the model-based algorithm as a
5000 labeled samples. In particular, Fig. 8(a) considers a set of building blocks with concrete tasks allows a DNN
Gaussian channel of the form (8) with K = N = 32, result- architecture designed to comply with this structure not
ing in MAP detection being computationally infeasible, and only to learn to carry out the original model-based method

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 479


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 8. Experimental results from [16] of DeepSIC compared to the model-based iterative SIC, the model-based MAP (when feasible), and
the data-driven DetNet of [18] (when applicable). Perfect CSI implies that the system is trained and tested using samples from the same
channel, while, under CSI uncertainty, they are trained using samples from a set of different channels. (a) 32 × 32 Gaussian channel.
(b) 4 × 4 Poisson channel.

from data but also to robustify it and enable its application V. D N N - A I D E D I N F E R E N C E


in diverse new scenarios. This follows since the block dia- DNN-aided inference is a family of model-based deep
gram structure of the algorithm may be ignorant of the spe- learning algorithms in which DNNs are incorporated into
cific underlying statistical model and only relies upon a set model-based methods. As opposed to model-aided net-
of generic assumptions, e.g., that the entries of the desired works discussed in Section IV, where the resultant system
vector s are mutually independent. Consequently, replac- is a deep network whose architecture imitates the oper-
ing these building blocks with dedicated DNNs allows to ation of a model-based algorithm, here, the inference is
exploit their model-agnostic nature, and thus, the original carried out using a traditional model-based method, while
algorithm can now be learned to be carried out in complex some of the intermediate computations are augmented by
environments. For instance, DeepSIC can be applied to DNNs. The main motivation of DNN-aided inference is to
nonlinear channels, owing to the implementation of the exploit the established benefits of model-based methods,
building blocks of the iterative SIC algorithm using generic in terms of performance, complexity, and suitability for the
DNNs, while the model-based algorithm is limited to setups problem at hand. Deep learning is incorporated to mitigate
of the form (8). sensitivity to inaccurate model knowledge, facilitate oper-
In addition, the division into building blocks gives rise ation in complex environments, and enable application in
to the possibility to train each block separately. The main new domains. An illustration of a DNN-aided inference
advantage in doing so is that a smaller training set is system is depicted in Fig. 9.
expected to be required, though, in the horizon of a DNN-aided inference is particularly suitable for sce-
sufficiently large amount of training, end-to-end training narios in which one only has access to partial domain
is likely to yield a more accurate model as its parameters knowledge. In such cases, the available domain knowledge
are jointly optimized. For example, in DeepSIC, sequential dictates the algorithm utilized, while the part that is not
training uses the nt input–output pairs to train each DNN available or is too complex to model analytically is tackled
individually. Compared to the end-to-end training that using deep learning. We divide our description of DNN-
utilizes the training samples to learn the complete set of aided inference schemes into three main families of meth-
parameters, which can be quite large, sequential training ods: The first, referred to as structure-agnostic DNN-aided
uses the same dataset to learn a significantly smaller num- inference detailed in Section V-A, utilizes deep learning
ber of parameters, reduced by a factor of K · Q, multiple to capture structures in the underlying data distribution,
times. This indicates that the ability to train the blocks e.g., to represent the domain of natural images. This
individually is expected to require much fewer training DNN is then utilized by model-based methods, allow-
samples, at the cost of a longer learning procedure for ing them to operate in a manner, which is invariant to
a given training set, due to its sequential operation, and these structures. The family of structure-oriented DNN-
possible performance degradation as the building blocks aided inference schemes, as detailed in Section V-B, utilizes
are not jointly trained. In addition, training each block sep- model-based algorithms to exploit a known tractable statis-
arately facilitates adding and removing blocks when such tical structure, such as an underlying Markovian behavior
operations are required in order to adapt the inference of the considered signals. In such methods, deep learn-
rule. ing is incorporated into the structure-aware algorithm,

480 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 9. DNN-aided inference illustration: (a) model-based algorithm comprised of multiple iterations with intermediate model-based
computations and (b) data-driven implementation of the algorithm, where the specific model-based computations are replaced with
dedicated learned deep models. Here, one can possibly use data to train the internal DNNs individually or to train the overall inference
mapping end-to-end as a discriminative learning model [54], [55], typically requiring the intermediate mathematical steps to be either
differentiable or well-approximated by a differentiable mapping.

thereby capturing the remaining portions of the underlying Traditionally, the prior knowledge encapsulating the
model and mitigating sensitivity to uncertainty. Finally, structure and properties of the underlying signal is rep-
in Section V-C, we discuss neural augmentation methods resented by a handcrafted regularization term or con-
that are tailored to robustify model-based processing in the straint incorporated into the optimization objective. For
presence of inaccurate knowledge of the parameters of the example, a common model-based strategy used in various
underlying model. Here, the inference is carried out using inverse problems is to impose sparsity in some given
a model-based algorithm based on its available domain dictionary, which facilitates CS-based optimization. Deep
knowledge, while a deep learning system operating in learning brings forth the possibility to avoid such explicit
parallel is utilized to compensate for errors induced by constraints, thereby mitigating the detrimental effects of
model inaccuracy. Our description of these methodologies crude, handcrafted approximation of the true underlying
in Sections V-A–V-C follows the same systematic form structure of the signal while enabling optimization with
used in Section IV, where each approach is detailed by a implicit data-driven regularization. This can be imple-
high-level description, design outline, one or two concrete mented by incorporating deep denoisers as learned prox-
examples, and a summarizing discussion. imal mappings in iterative optimization, as carried out by
plug-and-play networks3 [13], [14], [80], [81], [82], [83],
A. Structure-Agnostic DNN-Aided Inference [84], [85]. DNN-based priors can also be used to enable,
e.g., CS beyond the domain of sparse signals [10], [11].
The first family of DNN-aided inference utilizes deep
learning to implicitly learn structures and statistical prop- 1) Design Outline: Designing structure-agnostic DNN-
erties of the signal of interest, in a manner that is amenable aided systems can be carried out via the following steps.
to model-based optimization. These inference systems are 1) Identify a suitable optimization procedure, given the
particularly relevant for various inverse problems in signal domain knowledge for the signal of interest.
processing, including denoising, sparse recovery, deconvo- 2) The specific parts of the optimization procedure,
lution, and superresolution [76]. Tackling such problems which rely on complicated and possibly analytically
typically involves imposing some structure on the signal intractable domain knowledge, are replaced with a
domain. This prior knowledge is then incorporated into a DNN.
model-based optimization procedure, such as alternating
direction method of multipliers (ADMM) [77], fast iter-
3 The term plug-and-play typically refers to the usage of an image
ative shrinkage and thresholding algorithm [78], and
denoiser as proximal mapping in regularized optimization [80]. As this
primal-dual splitting [79], which recover the desired signal approach can also utilize model-based denoisers, we use the term plug-
with provable performance guarantees. and-play networks for such methods with DNN-based denoisers.

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 481


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

3) The integrated data-driven module can either be Pretraining: To implement deep generative priors, one
trained separately from the inference system, possibly first has to train a generative network G to map a latent
in an unsupervised manner as in [10], or alterna- vector z into a signal s, which lies in the domain of inter-
tively, the complete inference system is trained end- est. A major advantage of employing a DNN-based prior in
to-end [12]. this setting is that generator networks are agnostic to how
We next demonstrate how these steps are carried out they are used and can be pretrained and reused for multiple
in two examples: CS over complicated domains, where downstream tasks. The pretraining, thus, follows the stan-
deep generative networks are used for capturing the signal dard unsupervised training procedure, as discussed, e.g.,
domain [10]; and plug-and-play networks, which augment in Section III-B for GANs.
ADMM with a DNN to bypass the need to express a In particular, the work [10] trained a deep convolutional
proximal mapping. GAN [86] on the CelebA dataset [87] to represent 64 ×
64 color images of human faces, as well as a variational
2) Example 4: Compressed Sensing Using Generative Mod- autoencoder (VAE) [88] for representing handwritten dig-
els: CS refers to the task of recovering some unknown its in 28 × 28 grayscale form based on the MNIST
signal from (possibly noisy) lower dimensional observa- dataset [89].
tions. The mapping that transforms the input signal into Architecture: Once a pretrained generator network G is
the observations is known as the forward operator. In our available, it can be incorporated as an alternative prior for
example, we focus on the setting where the forward opera- the inverse model in (17). The key intuition behind this
tor is a particular linear function that is known at the time approach is that the range of G should only contain plausi-
of signal recovery. ble signals. Thus, one can replace the handcrafted sparsity
The main challenge in CS is that there could be (poten- prior with a data-driven DNN prior G by constraining our
tially infinitely) many signals that agree with the given signal recovery to the range of G.
observations. Since such a problem is underdetermined, One natural way to impose this constraint is to perform
it is necessary to make some sort of structural assumptions the optimization in the latent space to find z whose image
on the unknown signal to identify the most plausible one. G(z) matches the observations. This is carried out by
A classic assumption is that the signal is sparse on some minimizing the following loss function in the latent space
known basis. of G:
a) System model: We consider the problem of noisy
CS, where we wish to reconstruct an unknown N - L(z) = ∥HG(z) − x∥22 . (19)
dimensional signal s∗ from the following observations:

Because the above loss function involves a highly non-


x = Hs∗ + w (17) convex function G, there is no closed-form solution or
guarantee for this optimization problem. However, the
where H is an M × N matrix, modeled as random Gaus- loss function is differentiable with respect to z , so it can
sian matrix with entries H ij ∼ N (0, 1/M ), with M < N , be tackled using conventional gradient-based optimization
and w is an M × 1 noise vector. techniques. Once a suitable latent z is found, the signal is
b) Sparsity-based CS: We next focus on a particular recovered as G(z).
technique as a representative example of model-based CS. In practice, Bora et al. [10] report that incorporating
We rely here on the assumption that s∗ is sparse and seek an ℓ2 regularizer on z helps. This is possibly due to the
to recover s∗ from x by solving the ℓ1 relaxed LASSO Gaussian prior assumption for the latent variable, as the
objective density of z is proportional to exp(−∥z∥22 ). Therefore,
minimizing ∥z∥22 is equivalent to maximizing the density of
LLASSO (s) ≜ ∥Hs − x∥22 + λ∥s∥1 . (18) z under the Gaussian prior. This has the effect of avoiding
images that are extremely unlikely under the Gaussian
prior even if it matches the observation well. The final loss
While the derivation above assumes that s∗ is sparse, the includes this regularization term
LASSO objective can also be used when s∗ is sparse in
some dictionary B , e.g., in the wavelet (WVT) domain,
and the detailed formulation is given in Appendix VI-D. LCS (z) = ∥HG(z) − x∥22 + λ∥z∥22 (20)
c) DNN-aided compressed sensing: In a data-driven
approach, we aim to replace the sparsity prior with a where λ is a regularization coefficient.
learned DNN. The following description is based on [10], In summary, DNN-aided CS replaces the constrained
which is proposed to use a deep generative prior. Specif- optimization over the complex input signal with tractable
ically, we replace the explicit sparsity assumption on true optimization over the latent variable z , which follows a
signal s∗ , with a requirement that it lies in the range of known simple distribution. This is achieved using a pre-
a pretrained generator network G : Rl → RN (e.g., the trained DNN-based prior G to map it into the domain of
generator network of a GAN). interest. Inference is performed by minimizing LCS in the

482 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 10. High-level overview of CS with a DNN-based prior. The


generator network G is pretrained to map Gaussian latent variables
to plausible signals in the target domain. Then, signal recovery is
done by finding a point in the range of G that minimizes
reconstruction error via gradient-based optimization over the latent
variable.

Fig. 12. Visualization of the recovered signals from noisy CS on the


CelebA dataset. Reproduced from [10] with the authors’ permission.
latent space of G. An illustration of the system operation is
depicted in Fig. 10.
Quantitative Results: To showcase the efficacy of the
data-driven prior at capturing complex high-dimensional
signal domains, we present the evaluation of its perfor- VAE and VAE + R EG) show notable performance gain
mance, as reported in [10]. The baseline model used for compared to the sparsity prior for a small number of
comparison is based on directly solving the LASSO loss measurements. Implicitly imposing a sparsity prior via the
(18). For CelebA, we formulate the LASSO objective in the LASSO objective outperforms the deep generative priors
discrete cosine transform (DCT) and the WVT basis, and as the number of observations approaches the dimension
minimize it via coordinate descent. of the signal. One explanation for this behavior is that
The first task is the recovery of handwritten digit images the pretrained generator G does not perfectly model the
from low-dimensional projections corrupted by additive MNIST digit distribution and may not actually contain the
Gaussian noise. The reconstruction error is evaluated ground-truth signal in its range. As such, its reconstruction
for various numbers of observations M . The results are error may never be exactly zero regardless of how many
depicted in Fig. 11. observations are given. The LASSO objective, on the other
We clearly see the benefit of using a data-driven deep hand, does not suffer from this issue and is able to make
prior in Fig. 11, where the VAE-based methods (labeled use of the extra observations available.
The ability of deep generative priors to facilitate recov-
ery from compressed measurements is also observed in
Fig. 12, which qualitatively evaluates GAN-based CS recov-
ery on the CelebA dataset. This experiment uses M =
500 noisy measurements (out of N = 12 288 total dimen-
sions). As shown in Fig. 12, in this low-measurement
regime, the data-driven prior again provides much more
reasonable samples.

3) Example 5: Plug-and-Play Networks for Image Restora-


tion: The above example of DNN-aided CS allows carrying
out regularized optimization over complex domains while
using deep learning to avoid regularizing explicitly. This
is achieved via deep priors, where the domain of inter-
est is captured by a generative network. An alternative
strategy, referred to as plug-and-play networks, applies
deep denoisers as learned proximal mappings. Namely,
instead of using DNNs to evaluate the regularized objective
Fig. 11. Experimental result for noisy CS on the MNIST dataset. as in [10], one uses DNNs to carry out an optimization
Reproduced from [10] with the authors’ permission. procedure, which relies on this objective without having

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 483


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

to express the desired signal domain. In the following, However, the proximal mapping in Step 3 of Algorithm 4
we exemplify the application of plug-and-play networks for is invariant of the task and the data. In particular, it is
image restoration using ADMM optimization [80]. the solution to the problem of MAP denoising ŝq+1 + uq ,
a) System model: We again consider the linear inverse assuming that the noise-free signal has prior ϕ(·) and
problem formulated in (17) in which the additive noise the noise is Gaussian with variance α. Now, denoisers
w is comprised of i.i.d. mutually independent Gaussian are common DNN models and are known to operate
2
entries with zero mean and variance σw . However, unlike reliably on signal domains with intractable priors (e.g.,
the setup considered in the previous example, the sensing natural images) [81]. One can, thus, implement ADMM
matrix H is not assumed to be random and can be any optimization without having to specify the prior ϕ(·) by
fixed matrix dictated by the underlying setup. replacing Step 3 of Algorithm 4 with a DNN denoiser [80],
The recovery of the desired signal s can be obtained via as illustrated in Fig. 13. Specifically, the proximal mapping
the MAP rule, which is given by is replaced with a DNN-based denoiser fθ such that

ŝ = arg min − log p(s|x) v q+1 = fθ (ŝq+1 + uq ; αq ) (23)


s

= arg min − log p(x|s) − log p(s)


s where αq denotes the noise level to which the denoiser
1 is tuned. This noise level can either be fixed to represent
= arg min ∥x − Hs∥2 + ϕ(s) (21)
s 2 that used during training, or alternatively, one can use
flexible DNN-based denoiser in which, e.g., the noise level
where ϕ(s) is a regularization term which equals is provided as an additional input [90].
2
−σw log p(s), with possibly some additive constant that Quantitative Results: As an illustrative example of the
does not affect the minimization in (21). quantitative gains on plug-and-play networks, we consider
b) Alternating direction method of multipliers: The reg- the setup of cardiac magnetic resonance imaging image
ularized optimization problem that stems from the MAP reconstruction reported in [80]. The proximal mapping
rule in (21) can be solved using ADMM [77]. ADMM here is replaced with a five-layer CNN with residual con-
introduces two auxiliary variables, denoted v and u, nection operating on spatiotemporal volumetric patches.
and is given by the iterative procedure in Algorithm 4, The CNN is trained offline to denoise clean images manu-
whose derivation is detailed in Appendix VI-E. In Step 2, ally corrupted by Gaussian noise. The experimental results
we defined f (v) ≜ (1/2)∥x − Hv∥2 , while the proximal reported in Fig. 14 demonstrate that the introduction of
mapping of some function g(·) used in Steps 2 and 3 is deep denoisers notably improves both the performance
defined as and the convergence rate of the iterative optimizer com-
pared to utilizing model-based approaches for approximat-
 
1 ing the proximal mapping.
proxg (v) := arg min g(z) + ∥z − v∥22 . (22)
z 2
4) Discussion: Using deep learning to strengthen
regularized optimization builds upon the model-agnostic
The ADMM algorithm is illustrated in Fig. 13(a).
nature of DNNs. Traditional optimization methods rely on
mathematical expressions to capture the structure of the
Algorithm 4 ADMM solution that one is looking for, inevitably inducing model
Init: Fix α > 0. Initialize u(0) , v (0) randomly mismatch in domains that are extremely challenging
1 for q = 0, 1, . . . do to describe analytically. The ability of deep learning to
2 Update ŝq+1 = proxαf (v q + uq ). learn complex mappings without relying on domain
3 Update v q+1 = proxαϕ (sq+1 + uq ). knowledge is exploited here to bypass the need for explicit
4 Update uq+1 = uq + (ŝq+1 − v q+1 ). regularization.
5 end The need to learn to capture the domain of interest
Output: Estimate ŝ = ŝq . facilitates using pretrained networks, thus reducing the
dependency on massive amounts of labeled data. For
instance, deep generative priors use DNN architectures
c) Plug-and-play ADMM: The key challenge in imple- that are trained in an unsupervised manner and, thus,
menting the ADMM iterations stems from the computation rely only on unlabeled data, e.g., natural images. Such
of the proximal mapping in Step 3. In particular, while unlabeled samples are typically more accessible and easy
one can evaluate Step 2 in the closed form, as shown in to aggregate compared to labeled data, e.g., tagged natural
Appendix VI-E, computing Step 3 of Algorithm 4 requires images. One can often utilize off-the-shelf pretrained DNNs
explicit knowledge of the prior ϕ(·), which is often not when such a network exists for domains related to the
available. Furthermore, even when one has a good approx- ones over which optimization is carried out, with possible
imation of ϕ(·), computing the proximal mapping in Step 3 adjustments to account for the subtleties of the problem by
may still be extremely challenging to carry out analytically. transfer learning.

484 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 13. Illustration of (a) ADMM algorithm compared to (b) plug-and-play ADMM network.

Finally, while our description of DNN-aided regularized exploit an underlying statistical structure while integrat-
optimization relies on model-based iterative optimizers, ing DNNs to enable operation without additional explicit
which utilize a deep learning module, one can also incor- characterization of this model. The types of structures
porate deep learning into the optimization procedure. For exploited in the literature can come in the form of an
instance, the iterative optimization steps can be unfolded a priori known factorizable distribution, such as causality
into a DNN, as in, e.g., [12]. This approach allows bene- and finite memory in communication channels [15], [22],
fiting from both the ability of deep learning to implicitly [91]; it can follow from an established approximation of
represent complex domains, as well as the inference speed the statistical behavior, such as modeling of images as
reduction of deep unfolding along with its robustness to conditional random fields [92], [93], [94]; follow from
uncertainty and errors in the model parameters assumed to physical knowledge of the system operation [95], [96],
be known. Nonetheless, the fact that the iterative optimiza- [97]; or arise due to the distributed nature of the problem,
tion must be learned from data in addition to the structure as in [98].
of the domain of interest implies that larger amounts of The main advantage in accounting for such statistical
labeled data are required to train the system compared to structures stems from the availability of various model-
using the model-based optimizer. based methods, tailored specifically to exploit these struc-
tures to facilitate accurate inference at reduced complexity.
Many of these algorithms, such as the Kalman filter and
B. Structure-Oriented DNN-Aided Inference its variants [99, Ch. 7], which build upon an underlying
The family of structure-oriented DNN-aided inference state-space structure, or the Viterbi algorithm [100], which
algorithms utilizes model-based methods designed to exploits the presence of a hidden Markov model, can
be represented as special cases of the broad family of
factor graph methods. Consequently, our main example
used for describing structure-oriented DNN-aided infer-
ence focuses on the implementation of message passing
over data-driven factor graphs.
1) Design Outline: Structure-oriented DNN-aided
algorithms utilize deep learning not for the overall
inference task but for robustifying and relaxing the
model-dependence of established model-based inference
algorithms designed specifically for the structure induced
by the specific problem being solved. The design of
such DNN-aided hybrid inference systems consists of the
following steps.
1) A proper inference algorithm is chosen based on
the available knowledge of the underlying statistical
Fig. 14. Normalized MSE versus iteration for the recovery of structure. The domain knowledge is encapsulated in
cardiac MRI images. Here, plug-and-play networks using a CNN
the selection of the algorithm, which is learned from
denoiser (PnP-CNN) are compared to the model-based strategies of
data.
computing the proximal mapping by imposing as prior sparsity in
the undecimated WVT domain (PnP-UWT), as well as CS with a 2) Once a model-based algorithm is selected, we iden-
similar constraint (CS-UWT) and with total-variation prior (CS-TV). tify its model-specific computations and replace them
Figure reproduced from [80] with authors’ permission. with dedicated compact DNNs.

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 485


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

3) The resulting DNNs are either trained individually, Algorithm 5 SP Algorithm for System Model (24)
or the overall system can be trained in an end-to-end Init: Fix an initial forward message →

µ s0 (s) = 1 and a final backward
manner. message ← − (s) ≡ 1.
µ st
1 for i = t − 1, t − 2, . . . , 1 do
We next demonstrate how these steps are translated 2 For each si ∈ S l , compute backward message
in a hybrid model-based/data-driven algorithm, using the
example of learned factor graph inference for Markovian ←
− (s ) = f (xi+1 , si+1 , si )←

X
µ si i µ si+1 (si+1 ).
sequences proposed in [91] and [101]. si+1

2) Example 6: Learned Factor Graphs: Factor graph meth-


ods, such as the sum-product (SP) algorithm, exploit the 3 end
4 for i = 1, 2, . . . , t do
factorization of a joint distribution to efficiently compute a
5 For each si ∈ S l , compute forward message
desired quantity [102]. The application of the SP algorithm
for distributions that can be represented as noncyclic factor →
− f (xi , si , si−1 )→

X
µ si (si ) = µ si−1 (si−1 ).
graphs, such as Markovian models, allows computing the si−1
MAP rule, an operation whose burden typically grows
exponentially with the label space dimensionality, with
6 Estimate
complexity that only grows linearly with it. While the
following description focuses on Markovian stationary time


X
sequences, it can be extended to various forms of factoriz- ŝi = arg max µ si−1 (si−1 )f (xi , [si−l+1 , . . . , si ], si−1 )
si ∈S
si−1 ∈S l
able distributions.
a) System model: We consider the recovery of a time 7 ×←
− ([s
µ si i−l+1 , . . . , si ]).

series {si } taking values in a finite set S from an observed


sequence {xi } taking values in a set X . The subscript i 8 end
denotes the time index. The joint distribution of {si } and Output: ŝt = [ŝ1 , . . . , ŝt ]T

{xi } obeys an lth-order Markovian stationary model, l ≥ 1.


Consequently, when the initial state {si }0i=−l is given, the structure of the factor graph while using deep learning to
joint distribution of x = [x1 , . . . , xt ]T and s = [s1 , . . . , st ]T compute the function nodes without having to explicitly
satisfies specify their computations. Finally, it carries out the SP
t
Y method for inference over the resulting learned factor
p xi |sii−l p si |si−1
 
p(x, s) = i−l (24) graph.
i=1
Architecture: For Markovian relationships, the structure
for any fixed sequence length t > 0, where we write sji ≜ of the factor graph is illustrated in Fig. 15(a) regardless
[si , si+1 , . . . , sj ]T for i < j . of the specific statistical model. Furthermore, the station-
b) Sum-product algorithm: When the joint distribu- arity assumption implies that the complete factor graph is
tion of s and x is a priori known and can be computed, encapsulated in the single function f (·) (26) regardless of
the inference rule that minimizes the error probability for the block size t. Building upon this insight, DNNs can be
each time instance is the MAP detector utilized to learn the mapping carried out at the function
node separately from the inference task. The resulting
ŝi (x) = arg max p(si |x) (25) learned stationary factor graph is then used to recover {si }
si ∈S by message passing, as illustrated in Fig. 15(b). As learning
a single function node is expected to be a simpler task
for each i ∈ {1, . . . , t} ≜ T . This rule can be efficiently compared to learning the overall inference method for
approached when (24) holds using the SP algorithm [102]. recovering s from x, this approach allows using relatively
The SP algorithm represents the joint distribution (24) and compact DNNs, which can be learned from a relatively
computes the posterior distribution by message passing small dataset.
over this graph, as illustrated in Fig. 15(a). The resulting Training: In order to learn a stationary factor graph from
procedure, as detailed further in Appendix VI-F, is summa- samples, one must only learn its function node, which here
1
rized as Algorithm 5, where we define si ≜ sii−l+1 ∈ S l , boils down to learning p(xi |sii−l ) and p(si |si− i−l ) by (26).
1
and the function Since S is finite, the transition probability p(si |si− i−l ) can
be learned via a histogram.
f (xi , si , si−1 ) ≜ p (xi |si , si−1 ) p (si |si−1 ) . (26) For learning the distribution p(xi |sii−l ), it is noted that

−1
Algorithm 5 approaches the MAP detector in (25) with p(xi |si ) = p (si |xi ) p (xi ) p(si ) . (27)
complexity that only grows linearly with t.
c) Learned factor graphs: Learned factor graphs A parametric estimate of p (si |xi ), denoted P̂θ (si |xi ),
enable learning to implement MAP detection from labeled is obtained for each si ∈ S l+1 by training classification
data. It utilizes partial domain knowledge to determine the networks with softmax output layers to minimize the cross-

486 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 15. Illustration of the SP method for Markovian sequences using (a) true factor graph and (b) learned factor graph.

entropy loss. As the SP mapping is invariant to scaling Markovian structure allows achieving improved perfor-
f (xi , si , si−1 ) with some factor, which does not depend on mance compared to utilizing black-box DNN architectures,
the si , si−1 , one can set p (xi ) ≡ 1 in (27) and use the such as the sliding bidirectional RNN detector, with limited
result to obtain a scaled value of the function node, which, datasets for training.
as discussed above, does not affect the inference mapping.
Quantitative Results: As a numerical example of learned 3) Discussion: The integration of deep learning into
factor graphs for Markovian models, we consider a sce- structure-oriented model-based algorithms allows to
nario of symbol detection over causal stationary communi- exploit the model-agnostic nature of DNNs while explic-
cation channels with finite memory, reproduced from [91]. itly accounting for available structural domain knowl-
Fig. 16 depicts the numerically evaluated SER achieved edge. Consequently, structure-oriented DNN-aided infer-
by applying the SP algorithm over a factor graph learned ence is most suitable for setups in which structured
from nt = 5000 labeled samples for channels with mem- domain knowledge naturally follows from established
ory l = 4. The results are compared to the performance models, while the subtleties of the complete statistical
of model-based SP, which requires complete knowledge knowledge may be challenging to accurately capture ana-
of the underlying statistical model, as well as the slid- lytically. Such structural knowledge is often present in
ing bidirectional RNN detector proposed in [103] for various problems in signal processing and communica-
such setups, which utilizes a conventional DNN archi- tions. For instance, modeling communication channels
tecture that does not explicitly account for the Marko- as causal finite-memory systems, as assumed in the
vian structure. Fig. 16(a) considers a Gaussian chan- above quantitative example, is a well-established repre-
nel, while, in Fig. 16(b), the conditional distribution sentation of many physical channels. The availability of
p(xi |sii−l ) represents a Poisson distribution. Fig. 16 demon- established structures in signal processing-related setups
strates the ability of learned factor graphs to enable makes structure-oriented DNN-aided inference a candidate
accurate message-passing inference in a data-driven man- approach to facilitate inference in such scenarios in a man-
ner, as the performance achieved using learned factor ner, which is ignorant of the possibly intractable subtleties
graphs approaches that of the SP algorithm, which oper- of the problem, by learning to account for them implicitly
ates with full knowledge of the underlying statistical from data.
model. The numerical results also demonstrate that com- The fact that DNNs are used to learn an intermediate
bining model-agnostic DNNs with model-aware inference computation rather than the complete predication rule
notably improves robustness to model uncertainty com- facilitates the usage of relatively compact DNNs. This
pared to applying SP with the inaccurate model. Further- property can be exploited to implement learned inference
more, it also observed that explicitly accounting for the on computationally limited devices, as was done in [97] for

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 487


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 16. Experimental results from [91] of learned factor graphs (learned FG) compared to the model-based SP algorithm and the
data-driven sliding bidirectional RNN (SBRNN) of [103]. Perfect CSI implies that the system is trained and tested using samples from the
same channel, while, under CSI uncertainty, they are trained using samples from a set of different channels. (a) Gaussian channel.
(b) Poisson channel.

DNN-aided velocity tracking in autonomous racing cars. computations [21], [23], [105], [106]. An illustration of
An additional consequence is that the resulting system can this approach is depicted in Fig. 17.
be trained using scarce datasets. One can exploit the fact The main advantage of utilizing an external DNN for
that the system can be trained using small training sets correcting internal computations stems from its ability to
to, e.g., enable online adaptation to temporal variations in notably improve the robustness of model-based methods to
the statistical model based on some feedback on the cor- inaccurate knowledge of the underlying model parameters.
rectness of the inference rule. This property was exploited Since the model-based algorithm is individually imple-
in [104] to facilitate online training of DNN-aided receivers mented, one must posses the complete domain knowledge
in coded communications. it requires, and thus, the external correction DNN allows
A DNN integrated into a structure-oriented model-based the resulting system to overcome inaccuracies in this
inference method can be either trained individually, i.e., domain knowledge by learning to correct them from data.
independently of the inference task, or in an end-to-end Furthermore, the learned correction term incorporated
fashion. The first approach typically requires less training by neural augmentation can improve the performance
data, and the resulting trained DNN can be combined with of model-based algorithms in scenarios where they are
various inference algorithms. For instance, the learned suboptimal, as detailed in the example in the sequel.
function node used to carry out SP inference in the above
example can also be integrated into the Viterbi algorithm, 1) Design Outline: The design of neural-augmented
as done in [15]. Alternatively, the learned modules can inference systems is comprised of the following steps.
be tuned end-to-end by formulating their objective as that 1) Choose a suitable iterative optimization algorithm for
of the overall inference algorithm and backpropagating the problem of interest, and identify the informa-
through the model-based computations (see [94]). Learn- tion exchanged between the iterations, along with
ing in an end-to-end fashion facilitates overcoming inaccu- the intermediate computations used to produce this
racies in the assumed structures, possibly by incorporating information.
learned methods to replace the generic computations of 2) The information exchanged between the iterations is
the model-based algorithm, at the cost of requiring larger updated with a correction term learned by a DNN.
volumes of data for training purposes. The DNN is designed to combine the same quantities
used by the model-based algorithm, only in a learned
C. Neural Augmentation fashion.
3) The overall hybrid model-based/data-driven system
The DNN-aided inference strategies detailed in Sec-
is trained in an end-to-end fashion, where one can
tions V-A and V-B utilize model-based algorithms to carry
consider not only the algorithm outputs in the loss
out inference while replacing explicit domain-specific com-
function but also the intermediate outputs of the
putations with dedicated DNNs. An alternative approach,
internal iterations.
referred to as neural augmentation, utilizes the complete
model-based algorithm for inference, i.e., without embed- We next demonstrate how these steps are carried out
ding deep learning into its components, while using an in order to augment Kalman smoothing, as proposed
external DNN for correcting some of its intermediate in [105].

488 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 17. Neural augmentation illustration.

2) Example 7: Neural-Augmented Kalman Smooth- smoothing to approach the minimal MSE estimate involves
ing: The DNN-aided Kalman smoother proposed in [105] applying gradient descent optimization on the joint log-
implements state estimation in environments character- likelihood function, i.e., by iterating over
ized by state-space models. Here, neural augmentation not
only robustifies the smoother in the presence of inaccurate
 
s(q+1) = s(q) + η∇s(q) log p x, s(q) (29)
model knowledge but also improves its performance in
nonlinear setups, where variants of the Kalman algorithm,
such as the extended Kalman (E-Kalman) method, may be where η > 0 is a step size. Leveraging the state-space
suboptimal [99, Ch. 7]. model (28), one can implement gradient descent itera-
a) System model: Consider a linear Gaussian state- tions as message passing, via the procedure summarized
space model. Here, one is interested in recovering a in Algorithm 6, whose detailed formulation is given in
sequence of t state RVs {si }ti=1 taking values in a continu- Appendix VI-G.
ous set from an observed sequence {xi }ti=1 . The observa-
tions are related to the desired state sequence via

Algorithm 6 Smoothing via Iterative Gradient


xi = Hsi + r i (28a) Descent
Init: Fix step-size η > 0. Set initial guess ŝ(0)
while the state transition takes the form 1 for q = 0, 1, . . . do
2 for i = 1, . . . , t do
si = F si−1 + wi . (28b) 3 Compute messages
 
(q) (q) (q)
In (28), r i and wi obey an i.i.d. zero-mean Gaussian µS i−1 →S i = −W −1 si − F si−1 ,
distributions with covariance R and W , respectively, while (q)

(q) (q)

H and F are known linear mappings. µS i+1 →S i = F T W −1 si+1 − F si ,
We focus on scenarios where the state-space model
 
(q) (q)
µX i →S i = H T R−1 xi − Hsi .
in (28) that is available to the inference system is an inac-
curate approximation of the true underlying dynamics. For
such scenarios, one can apply Kalman smoothing, which
is known to achieve minimal MSE recovery when (28) 4 Update gradient step via
holds, while introducing a neural augmentation correction

term [105]. (q+1)
ŝi
(q) (q)
= ŝi +η µS i−1 →S i
b) Kalman smoothing: The Kalman smoother com- 
(q) (q)
putes the minimal MSE estimate of each si given a real- 5 + µS i+1 →S i +µX i →S i .
ization of x = [x1 , . . . , xt ]T . Its procedure is comprised
of forward and backward message passing, exploiting the
6 end
Markovian structure of the state-space model to operate at
7 end
complexity, which only grows linearly with t. In particular,
by writing s = [s1 , . . . , st ]T , one way to implement such
Output: Estimate ŝ = ŝ(q) .

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 489


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Fig. 18. Neural augmented Kalman smoother illustration. Blocks marked with Z−1 represent a single iteration delay.

c) Neural-augmented Kalman smoothing: The gradient resulting in the update equation (29) replaced with
descent formulation in (29) is evaluated by the messages
in Step 3 of Algorihtm 6, which, in turn, rely on accurate
   
s(q+1) = s(q) + η ∇s(q) log p x, s(q) + ϵ(q+1) . (31)
knowledge of the state-space model (28). To facilitate
operation with inaccurate model knowledge due to, e.g.,
(28) being a linear approximation of a nonlinear setup, The overall architecture is illustrated in Fig. 18.
one can introduce neural augmentation to learn to correct Training: Let θ be the parameters of the GNN in Fig. 18.
inaccurate computations of the log-likelihood gradients. The hybrid system is trained end-to-end to minimize the
This is achieved by using an external DNN to map the empirical weighted ℓ2 norm loss over its intermediate
messages in Step 3 into a correction term, denoted ϵ(q+1) . layers, where the contribution of each iteration to the
Architecture: The learned mapping of the messages (28) overall loss increases as the iterative procedure progresses.
into a correction term operates in the form of a graph In particular, letting {(st , xt )}n
t=1 be the training set, the
t

neural network (GNN) [107]. This is implemented by loss function used to train the neural-augmented Kalman
maintaining an internal node variable for each variable in smoother is given by
Step 3 of Algorithm 6, denoted h(q)
(q)
si for each si and hxi
for each xi , as well as internal message variables m(q)
V n →S i
nt Q
1 XX q
for each message computed by the model-based Algo- L(θ) = ∥st − ŝq (xt ; θ)∥2 (32)
nt Q
rithm 6. The node variables h(q)
si are updated along with the
t=1 q=1

model-based smoothing algorithm iterations as estimates


of their corresponding variables, while the variables hxi where ŝq (xt ; θ) is the estimate produced by the q th itera-
are obtained once from x via a neural network. The GNN tion, i.e., via (31), with parameters θ and input xt .
then maps the messages produced by the model-based Quantitative Results: The experiment whose results are
Kalman smoother into its internal messages via a neural depicted in Fig. 19 considers a nonlinear state-space
network fe (·), which operates on the corresponding node model described by the Lorenz attractor equations, which
variables, i.e., describe atmospheric convection via continuous-time dif-
ferential equations. The state-space model is approximated
as a discrete-time linear one by replacing the dynamics
 
(q) (q)
mV n →S i = fe h(q) (q)
v n , hsi , µV n →S i (30)
with their j th order Taylor series.
Fig. 19 demonstrates the ability of neural augmentation
where h(q)
xn ≡ hxn for each q . These messages are then com- to improve model-based inference. It is observed that intro-
bined and forwarded into a gated recurrent unit (GRU), ducing the DNN-based correction term allows the system
which produces the refined estimate of the node variables to learn to overcome the model inaccuracy and achieve
(q+1)
{hsi } based on their corresponding messages (30). an error, which decreases with the amount of available
(q+1)
Finally, each updated node variable hsi is mapped into training data. It is also observed that the hybrid approach
(q+1)
its corresponding error term ϵi via a fourth neural of combining model-based inference and deep learning
network, denoted fd (·). enables accurate inference with notably reduced volumes
(q+1)
The correction terms {ϵi } aggregated into the vector of training data, as the individual application of the GNN
ϵ(q+1) are used to update the log-likelihood gradients, for state estimation, which does not explicitly account

490 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

that it requires 10–20 times fewer samples compared to


that required by the individual GNN to achieve similar MSE
results.

VI. C O N C L U S I O N A N D F U T U R E C H A L -
LENGES
In this article, we presented a mapping of methods
for combining domain knowledge and data-driven infer-
ence via model-based deep learning in a tutorial manner.
We noted that hybrid model-based/data-driven systems
can be categorized into model-aided networks, which uti-
lize model-based algorithms to design DNN architectures
and DNN-aided inference, where deep learning is inte-
grated into traditional model-based methods. We detailed
representative design approaches for each strategy in a
Fig. 19. MSE versus dataset size for the neural-augmented Kalman systematic manner, along with design guidelines and con-
smoother (Hybrid) compared to the model-based E-Kalman smoother crete examples. To conclude this overview, we first sum-
and a solely data-driven GNN for various linearizations of marize the key advantages of model-based deep learning
state-space models (represented by the index j). Figure reproduced
in Section VI-A. Then, we present guidelines for selecting
from [105] with authors’ permission.
a design approach for a given application in Section VI-B,
intended to facilitate the derivation of future hybrid
data-driven/model-based systems. Finally, we review some
future research challenges in Section VI-C.
for the available domain knowledge, requires much more
training data to achieve similar accuracy as that of the
neural-augmented Kalman smoother. A. Advantages of Model-Based Deep Learning
3) Discussion: Neural augmentation implements hybrid The combination of traditional handcrafted algorithms
model-based/data-driven inference by utilizing two indi- with emerging data-driven tools via model-based deep
vidual modules—a model-based algorithm and a DNN— learning brings forth several key advantages. Compared
with each capable of inferring on its own. The rationale to purely model-based schemes, the integration of deep
here is to benefit from both approaches by interleaving learning facilitates inference in complex environments,
the iterative operation of the modules, and specifically where accurately capturing the underlying model in a
by utilizing the data-driven component to learn to cor- closed-form mathematical expression may be infeasible.
rect the model-based algorithm, but rather than produce For instance, incorporating DNN-based implicit regular-
individual estimates. This approach, thus, conceptually ization was shown to enable CS beyond its traditional
differs from the DNN-aided inference strategies discussed domain of sparse signals, as discussed in Section V-A, while
in Sections V-A and V-B, where a DNN is integrated into a the implementation of the SIC method as an intercon-
model-based algorithm. nection of neural building blocks enables its operation in
The fact that neural augmentation utilizes individual nonlinear setups, as demonstrated in Section IV-B. The
model-based and data-driven modules reflects its require- model-agnostic nature of deep learning also allows hybrid
ments and use cases. First, one must possess full domain model-based/data-driven inference to achieve improved
knowledge, or at least an approximation of the true resiliency to model uncertainty compared to inferring
model, in order to implement model-based inference. For solely based on domain knowledge. For example, aug-
instance, the neural-augmented Kalman smoother requires menting model-based Kalman smoothing with a GNN was
full knowledge of the state-space model (28), or at least shown in Section V-C to notably improve its performance
an approximation of this analytical closed-form model as when the state-space model does not fully reflect the
used in the quantitative example, in order to compute true dynamics, while the usage of learned factor graphs
the exchanged messages in Algorithm 6. In addition, the for SP inference was demonstrated to result in improved
presence of an individual DNN module implies that rela- robustness to model uncertainty in Section V-B. Finally,
tively large amounts of data are required in order to train the fact that hybrid systems learn to carry out part of
it. Nonetheless, the fact that this DNN only produces a their inference based on data allows inferring with a
correction term, which is interleaved with the model-based reduced delay compared to the corresponding fully model-
algorithm operation, implies that the amount of train- based methods, as demonstrated by deep unfolding in
ing data required to achieve a given accuracy is notably Section IV-A.
smaller compared to that required when using solely the Compared to utilizing conventional DNN architectures
DNN for inference. For instance, the quantitative example for inference, the incorporation of domain knowledge via a
of the neural augmented Kalman smoother demonstrates hybrid model-based/data-driven design results in systems

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 491


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

that are tailored to the problem at hand. As a result, model- Step 3 (Implementation Challenges): Having identified
based deep learning systems require notably fewer data in a suitable model-based algorithm, the selection of the
order to learn an accurate mapping, as demonstrated in approach to combine it with deep learning should be
the comparison of learned factor graphs and the sliding based on the understanding of its main implementation
bidirectional RNN system in the quantitative example in challenges. Some representative issues and their relation-
Section V-B, as well as the comparison between the neural ship with the recommended model-based deep learning
augmented Kalman smoother and the GNN state estimator approaches include the following.
in the corresponding example in Section V-C. This prop-
erty of model-based deep learning systems enables quick 1) Missing domain knowledge—model-based deep
adaptation to variations in the underlying statistical model, learning—can implement the model-based inference
as shown in [104]. Finally, a system combining DNNs algorithm when parts of the underlying model are
with model-based inference often provides the ability to unknown, or alternatively, too complex to be captured
analyze its resulting predictions, yielding interpretability analytically, by harnessing the model-agnostic nature
and confidence which are commonly challenging to obtain of deep learning. In this case, the selection of
with conventional black-box deep learning. the implementation approach depends on the
format of the identified model-based algorithm.
When it builds upon some known structures via,
B. Choosing a Model-Based Deep Learning e.g., message-passing-based inference, structure-
Strategy oriented DNN-aided inference detailed in Section V-B
The aforementioned gains of model-based deep learning can be most suitable as means of integrating
are shared at some level by all the different approaches DNNs to enable operation with missing domain
presented in Sections IV and V. However, each strat- knowledge. Similarly, when the missing domain
egy is focused on exploiting a different advantage of knowledge can be represented as some complex
hybrid model-based/data-driven inference, particularly in search domain, or alternatively, an unknown and
the context of signal processing-oriented applications. Con- possibly intractable regularization term, structure-
sequently, to complement the mapping of model-based agnostic DNN-aided inference detailed in Section V-A
deep learning strategies and facilitate the implementation can typically facilitate optimization with implicitly
of future application-specific hybrid systems, we next enlist learned regularizers. Finally, when the algorithm
the main considerations that one should take into account can be represented as an interconnection of
when seeking to combine model-based methods with data- model-dependent building blocks, one can maintain
driven tools for a given problem. the overall flow of the algorithm while operating in
Step 1 (Domain Knowledge and Data Characterization): a model-agnostic manner via neural building blocks,
First, one must ensure the availability of the two key as discussed in Section IV-B.
ingredients in model-based deep learning, i.e., domain 2) Inaccurate domain knowledge—model-based
knowledge and data. The former corresponds to what is algorithms—is typically sensitive to inaccurate
known a priori about the problem at hand, in terms of knowledge of the underlying model and its
statistical models and established assumptions, as well as parameters. In such cases, where one has access
what is unknown, or is based on some approximation that to a complete description of the underlying model
is likely to be inaccurate. The latter addresses the amount up to some uncertainty, model-based deep learning
of labeled and unlabeled samples that one possesses in can robustify the model-based algorithm and learn
advance for the considered problem, as well as whether to achieve improved accuracy. A candidate approach
or not they reflect the scenario in which the system is to robustify model-based processing is by adding a
requested to infer in practice. learned correction term via neural augmentation,
Step 2 (Identifying a Model-Based Method): Based on the as detailed in Section V-C. Alternatively, when the
available domain knowledge, the next step is to identify model-based algorithm takes an iterative form,
a suitable model-based algorithm for the problem. This improved resiliency can be obtained by unfolding the
choice should rely on the portion of the domain knowl- algorithm into a DNN, as discussed in Section IV-A,
edge, which is available, and not on what is unknown, as well as use robust optimization in unfolding [108].
as the latter can be compensated for by the integration 3) Inference speed—model-based deep learning—can
of deep learning tools. This stage must also consider the learn to implement iterative inference algorithms,
requirements of the inference system in terms of perfor- which typically require a large number of itera-
mance, complexity, and real-time operation, as these are tions to converge, with reduced inference speed.
encapsulated in the selection of the algorithm. This is achieved by designing model-aided networks,
The identification of a model-based algorithm, com- typically via deep unfolding (see Section IV-A) or
bined with the availability of domain knowledge and data, neural building blocks (see Section IV-B). The fact
should also indicate whether model-based deep learning that model-aided networks learn their iterative com-
mechanisms are required for the application of interest. putations from data allows the resulting system

492 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

to infer reliably with a much smaller number of in Sections IV and V systematically commences with a
iteration-equivalent layers compared to the iterations model-based algorithm, which is then augmented into a
required by the model-based algorithm. Alternatively, data-aided design via deep learning techniques, one can
when the delaying aspect is an internal lengthy com- also envision algorithms in which model-based algorithms
putation, one can improve run time by replacing it are utilized to improve upon an existing DNN architecture.
with a fixed run-time DNNs via DNN-aided inference, Alternatively, one can leverage model-based techniques
as shown in, e.g., [109]. to propose interpretable DNN architectures that follow
The aforementioned implementation challenges consti- traditional model-based methods to account for domain
tute only a partial list of the considerations that one should knowledge.
account for when selecting a model-based deep learning 3) Collaborative Model-Based Deep Learning: The increas-
design approach. Additional considerations include com- ing demands for accessible and personalized artificial intel-
putational capabilities during both training and inference; ligence give rise to the need to operate DNNs on edge
the need to handle variations in the statistical model, devices such as smartphones, sensors, and autonomous
which, in turn, translate to a possible requirement to peri- cars [6]. The limited computational and data resources
odically re-train the system; and the quantity and the type of edge devices make model-based deep learning strate-
of available data. Nonetheless, the above division provides gies particularly attractive for edge intelligence. Latency
systematic guidelines that one can utilize and possibly considerations and privacy constraints for mobile and
extend when seeking to implement an inference system sensitive data are further driving research in distributed
relying on both data and domain knowledge. Finally, training (e.g., through the framework of federated learn-
we note that some of the detailed model-based deep ing [111], [112]) and collaborative inference [113]. Com-
learning strategies can be combined, and thus, one can bining model-based structures with federated learning and
select more than a single design approach. For instance, distributed inference remains as interesting research direc-
one can interleave DNN-aided inference via implicitly tions.
learned regularization and/or priors, with deep unfolding
of the iterative optimization algorithm, as discussed in 4) Unexplored Applications: The increasing interest in
Section V-A. hybrid model-based/data-driven deep learning methods is
motivated by the need for robustness and structural under-
standing. Applications falling under the broad family of
C. Future Research Directions signal processing, communications, and control problems
We end by discussing a few representative unexplored are natural candidates to benefit due to the proliferation
research aspects of model-based deep learning. of established model-based algorithms. We believe that
model-based deep learning can contribute to the develop-
1) Performance Guarantees: One of the key strengths
ment of technologies such as IOT networks, autonomous
of model-based algorithms is their established theoret-
systems, and wireless communications.
ical performance guarantees. In particular, the analyt-
ical tractability of model-based methods implies that
APPENDIX
one can quantify their expected performance as a func-
A. Detailed Formulation of Project Gradient
tion of the parameters of underlying statistical or deter-
Descent (Example 1, Section IV)
ministic models. For conventional deep learning, such
performance guarantees are very challenging to charac- Projected gradient descent iteratively refines its estimate
terize, and deeper theoretical understanding is a crucial by taking a gradient step with respect to the unconstrained
missing component. The combination of deep learning objective, followed by projection into the constrained set of
with model-based structure increases interpretability, thus the optimization variable. For the system model in (8), this
possibly leading to theoretical guarantees. Theoretical operation at iteration index q + 1 is obtained recursively as
guarantees improve the reliability of hybrid model-based/ !
data-driven systems, as well as improve performance. For ∂∥x − Hs∥2
ŝq+1 = PS ŝq − η
example, some preliminary theoretical results were identi- ∂s s=ŝq
fied for specific model-based deep learning methods, such  
= PS ŝq − ηH T x + ηH T Hŝq (A.1)
as the convergence analysis of the unfolded LISTA in [110]
and plug-and-play networks in [82].
where η is the step size and ŝ0 is an initial guess.
2) Deep Learning Algorithms: Improving model inter-
pretability and incorporating human knowledge are crucial
for artificial intelligence development. Model-based deep B. Detailed Formulation of Proximal Gradient
learning can constitute a systematic framework to incorpo- Method (Example 2, Section IV)
rate domain knowledge into data-driven systems and can, The recovery of the clean image µ, which can be rep-
thus, give rise to new forms of deep learning algorithms. resented using a convolutional dictionary from the noisy
For instance, while our description of the methodologies observations x, can be formulated as a convolutional

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 493


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

sparse coding problem kth user and the q th iteration, the interference cancellation
stage first computes the expected values and variances

ŝ, Ĥ = arg min − log px|µ (x|µ = Hs)+λ∥s∥1 of {sl }l̸=k based on the estimated PMF {p̂l(q−1) }l̸=k . The
s,H contribution of the interfering symbols from x is then
= arg min 1T exp (Hs)−xT Hs+λ∥s∥1 (B.1) canceled by replacing them with {el(q−1) } and subtracting
s,Ĥ their resulting term. Letting hl be the lth column of H , the
interference canceled channel output is given by
where the dictionary optimization variable is constrained
to be block-Toeplitz. The clean image is then obtained as (q)
X (q−1)
z k = x− hl el . (C.1)
  l̸=k

µ̂ = exp Ĥŝ (B.2)


Substituting the channel output x into (C.1), the realiza-
where 1 is the all ones vector and λ is a regularizing term tion of the interference canceled z k(q) is obtained.
that controls the degree of sparsity, boosted by the usage To implement soft decoding, it is assumed that z (q) k =
(q) (q)
of the ℓ1 norm. hk sk + w̃k , where the interference plus noise term w̃k
Algorithm 2 tackles (B.1) via alternating optimization, obeys a zero-mean Gaussian distribution, independent of
(q) 2 (q−1)
hl hTl ,
P
where the update equations at the iteration of index l are sk , with covariance Σk = σw IK + l̸=k vl
2
given by where σw is the noise variance. Combining this assumption
with (C.1), while writing the set of possible symbols as S =
|S| (q)
{αj }j=1 , the conditional distribution of z k given sk = αj
ŝl+1 = arg min 1T exp (Hs) − xT Hs + λ∥s∥1 is multivariate Gaussian with mean hk αj and covariance
s
(q)
s.t. H = Ĥ l (B.3) Σk . The conditional PMF of sk given x is approximated
from the conditional distribution of z k(q) given sk via Bayes
theorem, assuming that the marginal PMF of each sk is
and uniform over S ; this estimated conditional distribution is
computed as
Ĥ l+1 = arg min 1T exp (Hs) − xT Hs
H  
(q)
s.t. s = ŝl+1 . (B.4) p̂k
j
  T  −1 
(q) (q) (q)
exp − 12 z k −hk αj Σk z k −hk αj
The ℓ1 regularized optimization problem (B.3) can be
=   T  −1  .
tackled for a given H and index l via proximal gradient P 1 (q) (q) (q)
exp − 2 z k −hk αj ′ Σk z k −hk αj ′
descent iterations. This optimizer involves multiple itera- αj ′ ∈S
tions, indexed q = 0, 1, 2, . . . , of the form
  After the final iteration, the symbols are decoded by maxi-
ŝq+1 = Tb ŝq + ηH T (x − exp (Hŝq )) . (B.5) mizing the estimated PMFs for each k ∈ K, i.e., via
 
(Q)
The threshold parameter b is dictated by the regularization ŝk = αĵ , ĵ = arg max p̂k (C.2)
j j
parameter λ.

and the overall estimate is set to ŝ = [ŝ1 , . . . , ŝK ].


C. Detailed Formulation of Iterative Soft
Interference Cancellation (Example 3, Section IV)
To formulate the iterative SIC algorithm, we consider
the Gaussian MIMO channel in (8). Each iteration of the D. Detailed Formulation of Sparsity-Based CS
iterative SIC algorithm indexed q generates K distribution (Example 4, Section V)
vectors over the set of possible symbols S . The PMFs
are denoted by the vectors p̂(q) of size |S| × 1, where Consider the case where s∗ is sparse in some dictionary
k
k ∈ K. These vectors are computed from the observed B , e.g., in the WVT domain, such that s∗ = Bc∗ , where
x and the distribution vectors obtained at the previous ∥c∗ ∥0 = l with l ≪ N . In this case, the goal is to find
iteration, {p̂(q− 1) K (q)
}k=1 . The entries of p̂k are estimates of the sparsest c such that s = Bc agrees with the noisy
k
the distribution of sk for each possible symbol in S , given observations
the observed x and assuming that the interfering symbols
(q−1)
{sl }l̸=k are distributed via {p̂l }l̸=k . Every iteration minimize ∥c∥0
consists of two steps, carried out in parallel for each user:
interference cancellation and soft decoding. Focusing on the s.t. ∥HBc − x∥2 ≤ ϵ

494 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

where ϵ is a noise threshold. Since one can define H̃ := Using (26), the joint distribution p(x, s) in (24) can be
HB , we, henceforth, focus on the setting where B is the written as
identity matrix, and the optimization variable of the above
t
ℓ0 norm optimization problem is s. Y
Although the above problem is NP-hard, Candès et al. p (x, s) = f (xi , si , si−1 ) . (F.1)
i=1
[114] and Donoho [115] showed that it suffices to mini-
mize the ℓ1 relaxed LASSO objective in (18). The formu-
The factorizable expression of the joint distribution
lation (18) is convex, and for Gaussian A with l = ∥s∗ ∥0
(F.1) implies that it can be represented as a factor graph
and M = Θ(l log(N/l)), the unique minimizer of LLASSO is 1
with t function nodes {f (xi , si , si−1 )}, in which {si }t−
equal to s∗ with high probability. i=2
are edges, while the remaining variables are half-edges.
Using its factor graph representation, one can compute
the joint distribution of s and x by recursive message
E. Detailed Formulation of ADMM (Example 5,
passing along its factor graph, as illustrated in Fig. 15(a).
Section V)
In particular,
ADMM tackles the optimization problem in (21) by uti-
lizing variable splitting. Namely, it introduces an additional
µ sk (sk )f (xk+1 , sk+1 , sk )←
p(sk , sk+1 , x) = −
→ −
µ sk+1 (sk+1 )
auxiliary variable v in order to decouple the regularizer
ϕ(s) from the likelihood term ∥x − Hs∥2 . The resulting (F.2)
formulation of (21) is expressed as
where the forward path messages satisfy
1
ŝ = arg min min ∥x − Hs∥2 + ϕ(v) (E.1)
v 2 −
→ f (xi , si , si−1 )−

s
X
µ si (si ) = µ si−1 (si−1 ) (F.3)
s.t. v = s. (E.2) si−1

Problem (E.1) is then solved by formulating the aug- for i = 1, 2, . . . , k. Similarly, the backward messages are
mented Lagrangian (which introduces an additional opti-
mization variable u) and solving it in an alternating fash- ←
− (s ) =
µ
X
f (xi+1 , si+1 , si )←

µ
si i si+1 (si+1 ) (F.4)
ion. This results in the following update equations for the si+1
q th iteration [82]:
for i = t − 1, t − 2, . . . , k + 1.
α 1
ŝq+1 = arg min ∥x−Hs∥2 + ∥s−(v q −uq )∥2 (E.3a) The ability to compute the joint distribution in (F.2) via
s 2 2
message passing allows to obtain the MAP detector in (25)
1 with complexity that only grows linearly with t. This is
v q+1 = arg min αϕ(v) + ∥v − (ŝq+1 + uq )∥2 (E.3b)
v 2 achieved by noting that the MAP estimate satisfies
uq+1 = uq + (ŝq+1 − v q+1 ). (E.3c)


X
ŝi (x) = arg max µ si−1 (si−1 )f (xi , [si−l+1 , . . . , si ], si−1 )
si ∈S
si−1 ∈S l
Here, α > 0 is an optimization hyperparameter. Steps
(E.3a) and (E.3b) are the proximal mappings with respect ×←
− ([s
µ si i−l+1 , . . . , si ]) (F.5)
to the functions αϕ(·) and αf (·), respectively, with f (v) ≜
(1/2)∥x − Hv∥2 . Step (E.3c) represents a gradient ascent for each i ∈ T , where the summands can be computed
iteration. recursively, resulting in Algorithm 5. It is noted that,
For brevity, in Algorithm 4, we write (E.3a) as ŝq+1 = when the block size t is large, the messages may tend to
Proxαf (v q − uq ) and (E.3b) as v q+1 = Proxαϕ (ŝq+1 + uq ) zero and are, thus, commonly scaled [116], e.g., ← − (s)
µ si
in Algorithm 4. In particular, it is noted that (E.3a) equals ←

is replaced with γi µ si (s) for some scale factor, which
sq+1 = (αH T H + I)−1 (αH T x + (v q − uq )). does not depend on s and, thus, does not affect the
MAP rule.
F. Detailed Formulation of Sum-Product Method
(Example 6, Section V) G. Detailed Formulation of Iterative Kalman
Smoother (Example 7, Section V)
To formulate the SP method, the factorizable distribu-
tion (24) is first represented as a factor graph. To that The state-space model (28) implies that the joint distri-
aim, we recall the definitions of the vector variable si = bution of the state and observations satisfies
sii−l+1 ∈ S l and the function f (xi , si , si−1 ) in (26). Y
When si is a shifted version of si−1 , (26) coincides p (x, s) = p (x|s) p (s) = p(xt |st )p(st |st−1 ). (G.1)
1
with p xi |sii−l p si |si−
 
i−l and equals zero otherwise. t

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 495


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

Consequently, it holds that Therefore, the tth entry of the log-likelihood gradi-
∂ ent in (29), abbreviated henceforth as ∇t(q) , can be
log p (x, s) obtained as ∇t(q) = µS(q) (q) (q)
+ µS t+1 →S t + µX t →S t ,
∂st t−1 →S t
∂ X X where the summands, referred to as messages, are
= log p(xτ |sτ ) + log p(sτ |sτ −1 )
∂st given by
τ τ
∂ ∂
= log p(xt |st ) + log p(st |st−1 )  
µS t−1 →S t = −W −1 st − F st−1
(q) (q) (q)
∂st ∂st (G.3a)

+ log p(st+1 |st )
∂st  
µS t+1 →S t = F T W −1 st+1 − F st
(q) (q) (q)
∂ (G.3b)
= (xt − Hst )T R−1 (xt − Hst )
∂st
∂  
+ (st − F st−1 )T W −1 (st − F st−1 ) µX t →S t = H T R−1 xt − Hst
(q) (q)
. (G.3c)
∂st

+ (st+1 − F st )T W −1 (st+1 − F st )
∂st The iterative procedure in (29) is repeated until conver-
= H T R−1 (xt − Hst ) + −W −1 (st − F st−1 ) gence, as stated in Algorithm 6, and the resulting s(q) is
+ F T W −1 (st+1 − F st ) . (G.2) used as the estimate. ■

REFERENCES
[1] Y. LeCun, Y. Bengio, and G. Hinton, “Deep IEEE Trans. Wireless Commun., vol. 19, no. 5, deep learning,” in Proc. Int. Conf. Learn.
learning,” Nature, vol. 521, no. 7553, p. 436, pp. 3319–3331, May 2020. Represent., 2018.
Feb. 2015. [16] N. Shlezinger, R. Fu, and Y. C. Eldar, “DeepSIC: [29] M. B. Mashhadi, Q. Yang, and D. Gunduz,
[2] K. He, X. Zhang, S. Ren, and J. Sun, “Delving deep Deep soft interference cancellation for multiuser “Distributed deep convolutional compression for
into rectifiers: Surpassing human-level MIMO detection,” IEEE Trans. Wireless Commun., massive MIMO CSI feedback,” IEEE Trans. Wireless
performance on ImageNet classification,” in Proc. vol. 20, no. 2, pp. 1349–1362, Feb. 2021. Commun., vol. 20, no. 4, pp. 2621–2633,
IEEE Int. Conf. Comput. Vis. (ICCV), Dec. 2015, [17] E. Nachmani, E. Marciano, L. Lugosch, Apr. 2021.
pp. 1026–1034. W. J. Gross, D. Burshtein, and Y. Be’ery, “Deep [30] S. Shalev-Shwartz and S. Ben-David,
[3] D. Silver et al., “Mastering the game of go without learning methods for improved decoding of linear Understanding Machine Learning: From Theory to
human knowledge,” Nature, vol. 550, no. 7676, codes,” IEEE J. Sel. Topics Signal Process., vol. 12, Algorithms. Cambridge, U.K.: Cambridge Univ.
pp. 354–359, 2017. no. 1, pp. 119–131, Feb. 2018. Press, 2014.
[4] O. Vinyals et al., “Grandmaster level in StarCraft II [18] N. Samuel, T. Diskin, and A. Wiesel, “Learning to [31] C. Metzler, A. Mousavi, and R. Baraniuk, “Learned
using multi-agent reinforcement learning,” detect,” IEEE Trans. Signal Process., vol. 67, D-AMP: Principled neural network based
Nature, vol. 575, no. 7782, pp. 350–354, 2019. no. 10, pp. 2554–2564, May 2019. compressive image recovery,” in Proc. Adv. Neural
[5] Y. Bengio, “Learning deep architectures for AI,” [19] H. He, C.-K. Wen, S. Jin, and G. Y. Li, Inf. Process. Syst., 2017, pp. 1772–1783.
Found. Trends Mach. Learn., vol. 2, no. 1, “Model-driven deep learning for MIMO [32] I. Goodfellow, Y. Bengio, and A. Courville, Deep
pp. 1–127, 2009. detection,” IEEE Trans. Signal Process., vol. 68, Learning. Cambridge, MA, USA: MIT Press, 2016.
[6] J. Chen and X. Ran, “Deep learning with edge pp. 1702–1715, 2020. [33] S. Hochreiter and J. Schmidhuber, “Long
computing: A review,” Proc. IEEE, vol. 107, no. 8, [20] M. Khani, M. Alizadeh, J. Hoydis, and P. Fleming, short-term memory,” Neural Comput., vol. 9, no. 8,
pp. 1655–1674, Aug. 2019. “Adaptive neural signal detection for massive pp. 1735–1780, 1997.
[7] V. Monga, Y. Li, and Y. C. Eldar, “Algorithm MIMO,” IEEE Trans. Wireless Commun., vol. 19, [34] A. Vaswani et al., “Attention is all you need,” in
unrolling: Interpretable, efficient deep learning for no. 8, pp. 5635–5648, Aug. 2020. Proc. Adv. Neural Inf. Process. Syst., 2017,
signal and image processing,” IEEE Signal Process. [21] K. Pratik, B. D. Rao, and M. Welling, “RE-MIMO: pp. 5998–6008.
Mag., vol. 38, no. 2, pp. 18–44, Mar. 2021. Recurrent and permutation equivariant neural [35] Y. LeCun and Y. Bengio, “Convolutional networks
[8] K. Gregor and Y. LeCun, “Learning fast MIMO detection,” IEEE Trans. Signal Process., for images, speech, and time series,” in The
approximations of sparse coding,” in Proc. 27th vol. 69, pp. 459–473, 2021. Handbook of Brain Theory and Neural Networks,
Int. Conf. Mach. Learn., 2010, pp. 399–406. [22] N. Farsad, N. Shlezinger, A. J. Goldsmith, and vol. 3361, no. 10. MIT Press, 1995, p. 1995.
[9] S. Wu et al., “Learning a compressed sensing Y. C. Eldar, “Data-driven symbol detection via [36] T. Tieleman and G. Hinton, “Lecture
measurement matrix via gradient unrolling,” in model-based machine learning,” Commun. Inf. 6.5-RMSPROP: Divide the gradient by a running
Proc. Int. Conf. Mach. Learn., 2019, Syst., vol. 20, no. 3, pp. 283–317, 2020. average of its recent magnitude,” COURSERA,
pp. 6828–6839. [23] V. G. Satorras and M. Welling, “Neural enhanced Neural Netw. Mach. Learn., vol. 4, no. 2,
[10] A. Bora, A. Jalal, E. Price, and A. G. Dimakis, belief propagation on factor graphs,” in Proc. Int. pp. 26–31, 2012.
“Compressed sensing using generative models,” in Conf. Artif. Intell. Statist., 2021, pp. 685–693. [37] D. P. Kingma and J. Ba, “Adam: A method for
Proc. 34th Int. Conf. Mach. Learn. (JMLR), vol. 70, [24] A. Zappone, M. D. Renzo, M. Debbah, T. T. Lam, stochastic optimization,” in Proc. Int. Conf. Learn.
2017, pp. 537–546. and X. Qian, “Model-aided wireless artificial Represent., 2015.
[11] J. Whang, Q. Lei, and A. G. Dimakis, “Compressed intelligence: Embedding expert knowledge in [38] A. Krizhevsky, I. Sutskever, and G. E. Hinton,
sensing with invertible generative models and deep neural networks for wireless system “ImageNet classification with deep convolutional
dependent noise,” in Proc. Deep Learn. Inverse optimization,” IEEE Veh. Technol. Mag., vol. 14, neural networks,” Commun. ACM, vol. 60, no. 2,
Problems NeurIPS Workshop, 2021. no. 3, pp. 60–69, Sep. 2019. pp. 84–90, Jun. 2012.
[12] D. Gilton, G. Ongie, and R. Willett, “Neumann [25] A. Zappone, M. D. Renzo, and M. Debbah, [39] I. Goodfellow et al., “Generative adversarial nets,”
networks for linear inverse problems in imaging,” “Wireless networks design in the era of deep in Proc. Adv. Neural Inf. Process. Syst., 2014,
IEEE Trans. Comput. Imag., vol. 6, pp. 328–343, learning: Model-based, AI-based, or both?” IEEE pp. 2672–2680.
2020. Trans. Commun., vol. 67, no. 10, pp. 7331–7376, [40] M. Arjovsky, S. Chintala, and L. Bottou,
[13] S. V. Venkatakrishnan, C. A. Bouman, and Oct. 2019. “Wasserstein generative adversarial networks,” in
B. Wohlberg, “Plug-and-play priors for model [26] L. Liang, H. Ye, G. Yu, and G. Y. Li, Proc. Int. Conf. Mach. Learn., 2017, pp. 214–223.
based reconstruction,” in Proc. IEEE Global Conf. “Deep-learning-based wireless resource allocation [41] I. Gulrajani, F. Ahmed, M. Arjovsky, V. Dumoulin,
Signal Inf. Process., Dec. 2013, pp. 945–948. with application to vehicular networks,” Proc. and A. C. Courville, “Improved training of
[14] H. K. Aggarwal, M. P. Mani, and M. Jacob, “Modl: IEEE, vol. 108, no. 2, pp. 341–356, Feb. 2020. Wasserstein GANs,” in Proc. Adv. Neural Inf.
Model-based deep learning architecture for [27] T. O’Shea and J. Hoydis, “An introduction to deep Process. Syst., vol. 30, 2017, pp. 5767–5777.
inverse problems,” IEEE Trans. Med. Imag., vol. 38, learning for the physical layer,” IEEE Trans. Cogn. [42] X. Mao, Q. Li, H. Xie, R. Y. K. Lau, Z. Wang, and
no. 2, pp. 394–405, Feb. 2019. Commun. Netw., vol. 3, no. 4, pp. 563–575, S. P. Smolley, “Least squares generative
[15] N. Shlezinger, N. Farsad, Y. C. Eldar, and Dec. 2017. adversarial networks,” in Proc. IEEE Int. Conf.
A. J. Goldsmith, “ViterbiNet: A deep learning [28] H. Kim, Y. Jiang, R. Rana, S. Kannan, S. Oh, and Comput. Vis. (ICCV), Oct. 2017,
based Viterbi algorithm for symbol detection,” P. Viswanath, “Communication algorithms via pp. 2794–2802.

496 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

[43] J. H. Lim and J. C. Ye, “Geometric GAN,” 2017, approach,” IEEE Access, vol. 7, pp. 93326–93338, vol. 40, no. 1, pp. 120–145, 2011.
arXiv:1705.02894. 2019. [80] R. Ahmad et al., “Plug-and-play methods for
[44] A. Jolicoeur-Martineau, “The relativistic [63] Q. Hu, Y. Cai, Q. Shi, K. Xu, G. Yu, and Z. Ding, magnetic resonance imaging: Using denoisers for
discriminator: A key element missing from “Iterative algorithm induced deep-unfolding image recovery,” IEEE Signal Process. Mag.,
standard GAN,” in Proc. Int. Conf. Learn. neural networks: Precoding design for multiuser vol. 37, no. 1, pp. 105–116, Jan. 2020.
Represent., 2019. MIMO systems,” IEEE Trans. Wireless Commun., [81] K. Zhang, W. Zuo, S. Gu, and L. Zhang, “Learning
[45] T. Karras, S. Laine, M. Aittala, J. Hellsten, vol. 20, no. 2, pp. 1394–1410, Feb. 2021. deep CNN denoiser prior for image restoration,”
J. Lehtinen, and T. Aila, “Analyzing and improving [64] S. Khobahi, N. Shlezinger, M. Soltanalian, and in Proc. IEEE Conf. Comput. Vis. Pattern Recognit.
the image quality of StyleGAN,” in Proc. IEEE/CVF Y. C. Eldar, “LoRD-Net: Unfolded deep detection (CVPR), Jul. 2017, pp. 3929–3938.
Conf. Comput. Vis. Pattern Recognit. (CVPR), network with low-resolution receivers,” IEEE [82] E. Ryu, J. Liu, S. Wang, X. Chen, Z. Wang, and
Jun. 2020, pp. 8110–8119. Trans. Signal Process., vol. 69, pp. 5651–5664, W. Yin, “Plug-and-play methods provably converge
[46] J. E. Van Engelen and H. H. Hoos, “A survey on 2021. with properly trained denoisers,” in Proc. Int.
semi-supervised learning,” Mach. Learn., vol. 109, [65] M. Mischi, M. A. L. Bell, R. J. van Sloun, and Conf. Mach. Learn., 2019, pp. 5546–5557.
no. 2, pp. 373–440, 2020. Y. C. Eldar, “Deep learning in medical [83] S. Ono, “Primal-dual plug-and-play image
[47] D.-H. Lee, “Pseudo-label: The simple and efficient ultrasound—From image formation to image restoration,” IEEE Signal Process. Lett., vol. 24,
semi-supervised learning method for deep neural analysis,” IEEE Trans. Ultrason., Ferroelectr., Freq. no. 8, pp. 1108–1112, Aug. 2017.
networks,” in Proc. Workshop Challenges Represent. Control, vol. 67, no. 12, pp. 2477–2480, [84] U. S. Kamilov, H. Mansour, and B. Wohlberg,
Learn., ICML, 2013, vol. 3, no. 2, p. 896. Dec. 2020. “A plug-and-play priors approach for solving
[48] S. Laine and T. Aila, “Temporal ensembling for [66] G. Dardikman-Yoffe and Y. C. Eldar, “Learned nonlinear imaging inverse problems,” IEEE Signal
semi-supervised learning,” in Proc. Int. Conf. SPARCOM: Unfolded deep super-resolution Process. Lett., vol. 24, no. 12, pp. 1872–1876,
Learn. Represent., 2016. microscopy,” Opt. Exp., vol. 28, no. 19, Dec. 2017.
[49] D. Berthelot, N. Carlini, I. Goodfellow, pp. 4797–4812, 2020. [85] T. Meinhardt, M. Moeller, C. Hazirbas, and
N. Papernot, A. Oliver, and C. A. Raffel, [67] K. Zhang, L. Van Gool, and R. Timofte, “Deep D. Cremers, “Learning proximal operators: Using
“MixMatch: A holistic approach to unfolding network for image super-resolution,” in denoising networks for regularizing inverse
semi-supervised learning,” in Proc. Adv. Neural Inf. Proc. IEEE/CVF Conf. Comput. Vis. Pattern imaging problems,” in Proc. IEEE Int. Conf.
Process. Syst., 2019, pp. 5050–5060. Recognit. (CVPR), Jun. 2020, pp. 3217–3226. Comput. Vis. (ICCV), Oct. 2017, pp. 1781–1790.
[50] Q. Xie, M.-T. Luong, E. Hovy, and Q. V. Le, [68] Y. Huang, S. Li, L. Wang, and T. Tan, “Unfolding [86] A. Radford, L. Metz, and S. Chintala,
“Self-training with noisy student improves the alternating optimization for blind super “Unsupervised representation learning with deep
ImageNet classification,” in Proc. IEEE/CVF Conf. resolution,” in Proc. Adv. Neural Inf. Process. Syst., convolutional generative adversarial networks,”
Comput. Vis. Pattern Recognit. (CVPR), Jun. 2020, vol. 33, 2020, pp. 5632–5643. 2015, arXiv:1511.06434.
pp. 10687–10698. [69] A. Agarwal, A. Anandkumar, P. Jain, P. Netrapalli, [87] Z. Liu, P. Luo, X. Wang, and X. Tang, “Deep
[51] B. Tolooshams, A. H. Song, S. Temereanca, and and R. Tandon, “Learning sparsely used learning face attributes in the wild,” in Proc. IEEE
D. Ba, “Convolutional dictionary learning based overcomplete dictionaries via alternating Int. Conf. Comput. Vis. (ICCV), Dec. 2015,
auto-encoders for natural exponential-family minimization,” SIAM J. Optim., vol. 26, no. 4, pp. 3730–3738.
distributions,” in Proc. Int. Conf. Mach. Learn., pp. 2775–2799, 2016. [88] D. P. Kingma and M. Welling, “Auto-encoding
2020, pp. 9493–9503. [70] T. Remez, O. Litany, R. Giryes, and A. M. variational Bayes,” 2013, arXiv:1312.6114.
[52] L. Xu and R. Niu, “EKFNet: Learning system noise Bronstein, “Class-aware fully convolutional [89] Y. LeCun and C. Cortes. (2010). MNIST
statistics from measurement data,” in Proc. IEEE Gaussian and Poisson denoising,” IEEE Trans. Handwritten Digit Database. [Online]. Available:
Int. Conf. Acoust., Speech Signal Process. (ICASSP), Image Process., vol. 27, no. 11, pp. 5707–5722, http://yann.lecun.com/exdb/mnist/
Jun. 2021, pp. 4560–4564. Nov. 2018.
[90] K. Zhang, W. Zuo, and L. Zhang, “FFDNet: Toward
[53] N. Shlezinger, Y. C. Eldar, and S. P. Boyd, [71] J. Duan et al., “VS-Net: Variable splitting network a fast and flexible solution for CNN-based image
“Model-based deep learning: On the intersection for accelerated parallel MRI reconstruction,” in denoising,” IEEE Trans. Image Process., vol. 27,
of deep learning and optimization,” IEEE Access, Proc. Int. Conf. Med. Image Comput. no. 9, pp. 4608–4622, Sep. 2018.
vol. 10, pp. 115384–115398, 2022. Comput.-Assist. Intervent. Cham, Switzerland:
[91] N. Shlezinger, N. Farsad, Y. C. Eldar, and
[54] A. Ng and M. Jordan, “On discriminative vs. Springer, 2019, pp. 713–722.
A. J. Goldsmith, “Data-driven factor graphs for
generative classifiers: A comparison of logistic [72] J. P. Merkofer, G. Revach, N. Shlezinger, and deep symbol detection,” in Proc. IEEE Int. Symp.
regression and naive Bayes,” in Proc. Adv. Neural R. J. G. van Sloun, “Deep augmented music Inf. Theory (ISIT), Jun. 2020, pp. 2682–2687.
Inf. Process. Syst., vol. 14, 2001, pp. 841–848. algorithm for data-driven doa estimation,” in Proc.
[92] A. Arnab et al., “Conditional random fields meet
[55] N. Shlezinger and T. Routtenberg, “Discriminative IEEE Int. Conf. Acoust., Speech Signal Process.
deep neural networks for semantic segmentation:
and generative learning for linear estimation of (ICASSP), May 2022, pp. 3598–3602.
Combining probabilistic graphical models with
random signals [lecture notes],” 2022, [73] T. Van Luong, N. Shlezinger, C. Xu, T. M. Hoang, deep learning for structured prediction,” IEEE
arXiv:2206.04432. Y. C. Eldar, and L. Hanzo, “Deep learning based Signal Process. Mag., vol. 35, no. 1, pp. 37–52,
[56] J. R. Hershey, J. L. Roux, and F. Weninger, “Deep successive interference cancellation for the Jan. 2018.
unfolding: Model-based inspiration of novel deep non-orthogonal downlink,” IEEE Trans. Veh.
[93] S. Chandra and I. Kokkinos, “Fast, exact and
architectures,” 2014, arXiv:1409.2574. Technol., vol. 71, no. 11, pp. 11876–11888,
multi-scale inference for semantic image
Nov. 2022.
[57] Y. Li, M. Tofighi, J. Geng, V. Monga, and segmentation with deep Gaussian CRFs,” in Proc.
Y. C. Eldar, “Efficient and interpretable deep blind [74] M. Kocaoglu, C. Snyder, A. G. Dimakis, and Eur. Conf. Comput. Vis. Cham, Switzerland:
image deblurring via algorithm unrolling,” IEEE S. Vishwanath, “CausalGAN: Learning causal Springer, 2016, pp. 402–418.
Trans. Comput. Imag., vol. 6, pp. 666–681, implicit generative models with adversarial
[94] P. Knobelreiter, C. Sormann, A. Shekhovtsov,
2020. training,” in Proc. Int. Conf. Learn. Represent.,
F. Fraundorfer, and T. Pock, “Belief propagation
2018.
[58] O. Solomon et al., “Deep unfolded robust PCA reloaded: Learning BP-layers for labeling
with application to clutter suppression in [75] W.-J. Choi, K.-W. Cheong, and J. M. Cioffi, problems,” in Proc. IEEE/CVF Conf. Comput. Vis.
ultrasound,” IEEE Trans. Med. Imag., vol. 39, “Iterative soft interference cancellation for Pattern Recognit. (CVPR), Jun. 2020,
no. 4, pp. 1051–1063, Apr. 2020. multiple antenna systems,” in Proc. IEEE Wireless pp. 7900–7909.
Commun. Netw. Conf. Conf. Rec., Sep. 2000,
[59] Y. Cui, S. Li, and W. Zhang, “Jointly sparse signal [95] B. Luijten et al., “Adaptive ultrasound
pp. 304–309.
recovery and support recovery via deep learning beamforming using deep learning,” IEEE Trans.
with applications in MIMO-based grant-free [76] G. Ongie, A. Jalal, C. A. Metzler, R. G. Baraniuk, Med. Imag., vol. 39, no. 12, pp. 3967–3978,
random access,” IEEE J. Sel. Areas Commun., A. G. Dimakis, and R. Willett, “Deep learning Dec. 2020.
vol. 39, no. 3, pp. 788–803, Mar. 2021. techniques for inverse problems in imaging,” IEEE
[96] G. Revach, N. Shlezinger, X. Ni, A. L. Escoriza,
J. Sel. Areas Inf. Theory, vol. 1, no. 1, pp. 39–56,
[60] T. Chang, B. Tolooshams, and D. Ba, “RandNet: R. J. G. van Sloun, and Y. C. Eldar, “KalmanNet:
May 2020.
Deep learning with compressed measurements of Neural network aided Kalman filtering for
images,” in Proc. IEEE 29th Int. Workshop Mach. [77] S. Boyd, N. Parikh, and E. Chu, Distributed partially known dynamics,” IEEE Trans. Signal
Learn. Signal Process. (MLSP), Oct. 2019, pp. 1–6. Optimization and Statistical Learning via the Process., vol. 70, pp. 1532–1547, 2022.
Alternating Direction Method of Multipliers.
[61] A. Balatsoukas-Stimming and C. Studer, “Deep [97] A. L. Escoriza, G. Revach, N. Shlezinger, and
Norwell, MA, USA: Now Publishers, 2011.
unfolding for communications systems: A survey R. J. G. van Sloun, “Data-driven Kalman-based
and some new directions,” in Proc. IEEE Int. [78] A. Beck and M. Teboulle, “A fast iterative velocity estimation for autonomous racing,” in
Workshop Signal Process. Syst. (SiPS), Oct. 2019, shrinkage-thresholding algorithm for linear Proc. IEEE Int. Conf. Auto. Syst. (ICAS), Aug. 2021,
pp. 266–271. inverse problems,” SIAM J. Imag. Sci., vol. 2, pp. 1–5.
no. 1, pp. 183–202, Jan. 2009.
[62] S. Takabe, M. Imanishi, T. Wadayama, [98] H. Palangi, R. Ward, and L. Deng, “Distributed
R. Hayakawa, and K. Hayashi, “Trainable [79] A. Chambolle and T. Pock, “A first-order compressive sensing: A deep learning approach,”
projected gradient detector for massive primal-dual algorithm for convex problems with IEEE Trans. Signal Process., vol. 64, no. 17,
overloaded MIMO channels: Data-driven tuning applications to imaging,” J. Math. Imag. Vis., pp. 4504–4518, Sep. 2016.

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 497


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

[99] S. S. Haykin, Adaptive Filter Theory. London, U.K.: “Combining generative and discriminative models Neural Inf. Process. Syst., 2018, pp. 9079–9089.
Pearson, 2005. for hybrid inference,” in Proc. Adv. Neural Inf. [111] T. Li, A. K. Sahu, A. Talwalkar, and V. Smith,
[100] A. J. Viterbi, “Error bounds for convolutional Process. Syst., 2019, pp. 13802–13812. “Federated learning: Challenges, methods, and
codes and an asymptotically optimum decoding [106] F. Gao, J. Zhang, and Y. Zhang, “Neural enhanced future directions,” IEEE Signal Process. Mag.,
algorithm,” IEEE Trans. Inf. Theory, vol. IT-13, dynamic message passing,” in Proc. Int. Conf. Artif. vol. 37, no. 3, pp. 50–60, May 2020.
no. 2, pp. 260–269, Apr. 1967. Intell. Statist., 2022, pp. 10471–10482. [112] T. Gafni, N. Shlezinger, K. Cohen, Y. C. Eldar, and
[101] N. Shlezinger, N. Farsad, Y. C. Eldar, and [107] K. Yoon et al., “Inference in probabilistic graphical H. V. Poor, “Federated learning: A signal
A. J. Goldsmith, “Learned factor graphs for models by graph neural networks,” in Proc. 53rd processing perspective,” IEEE Signal Process. Mag.,
inference from stationary time sequences,” IEEE Asilomar Conf. Signals, Syst., Comput., Nov. 2019, vol. 39, no. 3, pp. 14–41, May 2022.
Trans. Signal Process., vol. 70, pp. 366–380, 2022. pp. 868–875. [113] N. Shlezinger and I. V. Bajic, “Collaborative
[102] F. R. Kschischang, B. J. Frey, and H.-A. Loeliger, [108] W. Pu, C. Zhou, Y. C. Eldar, and inference for AI-empowered IoT devices,” IEEE
“Factor graphs and the sum-product algorithm,” M. R. D. Rodrigues, “REST: Robust lEarned Internet Things Mag., vol. 5, no. 4, pp. 92–98,
IEEE Trans. Inf. Theory, vol. 47, no. 2, shrinkage-thresholding network taming inverse Dec. 2022.
pp. 498–519, Feb. 2001. problems with model mismatch,” in Proc. IEEE Int. [114] E. J. Candès, J. Romberg, and T. Tao, “Robust
[103] N. Farsad and A. Goldsmith, “Neural network Conf. Acoust., Speech Signal Process. (ICASSP), uncertainty principles: Exact signal reconstruction
detection of data sequences in communication Jun. 2021, pp. 2885–2889. from highly incomplete frequency information,”
systems,” IEEE Trans. Signal Process., vol. 66, [109] X. Ni, G. Revach, N. Shlezinger, R. J. G. van Sloun, IEEE Trans. Inf. Theory, vol. 52, no. 2,
no. 21, pp. 5663–5678, Nov. 2018. and Y. C. Eldar, “RTSNet: Deep learning aided pp. 489–509, Feb. 2006.
[104] T. Raviv, S. Park, O. Simeone, Y. C. Eldar, and Kalman smoothing,” in Proc. IEEE Int. Conf. [115] D. L. Donoho, “Compressed sensing,” IEEE Trans.
N. Shlezinger, “Online meta-learning for hybrid Acoust., Speech Signal Process. (ICASSP), Inf. Theory, vol. 52, no. 4, pp. 1289–1306,
model-based deep receivers,” IEEE Trans. Wireless May 2022, pp. 5902–5906. Apr. 2006.
Commun., early access, Feb. 8, 2023, doi: [110] X. Chen, J. Liu, Z. Wang, and W. Yin, “Theoretical [116] H.-A. Loeliger, “An introduction to factor graphs,”
10.1109/TWC.2023.3241841. linear convergence of unfolded ISTA and its IEEE Signal Process. Mag., vol. 21, no. 1,
[105] V. G. Satorras, Z. Akata, and M. Welling, practical weights and thresholds,” in Proc. Adv. pp. 28–41, Jan. 2004.

ABOUT THE AUTHORS


Nir Shlezinger (Member, IEEE) received Yonina C. Eldar (Fellow, IEEE) received the
the B.Sc., M.Sc., and Ph.D. degrees in elec- B.Sc. degree in physics and the B.Sc. degree
trical and computer engineering from the in electrical engineering from Tel Aviv Uni-
Ben-Gurion University of the Negev, Be’er versity (TAU), Tel Aviv, Israel, in 1995 and
Sheva, Israel, in 2011, 2013, and 2017, 1996, respectively, and the Ph.D. degree in
respectively. electrical engineering and computer science
From 2017 to 2019, he was a Post- from the Massachusetts Institute of Technol-
doctoral Researcher with the Technion— ogy (MIT), Cambridge, MA, USA, in 2002.
Israel Institute of Technology, Haifa, Israel. She was a Professor with the Department
From 2019 to 2020, he was a Postdoctoral Researcher with the of Electrical Engineering, Technion—Israel Institute of Technology,
Weizmann Institute of Science, Rehovot, Israel. He is currently an Haifa, Israel, where she held the Edwards Chair in Engineering. She
Assistant Professor with the School of Electrical and Computer Engi- was a Visiting Professor with Stanford University, Stanford, CA, USA.
neering, Ben-Gurion University of the Negev. His research interests She is currently a Professor with the Department of Mathematics
include communications, information theory, signal processing, and Computer Science, Weizmann Institute of Science, Rehovot,
and machine learning. Israel. She is also a Visiting Professor with MIT, a Visiting Scientist
Dr. Shlezinger was awarded the FGS Prize for outstanding with the Broad Institute, Cambridge, and an Adjunct Professor with
research achievements at the Weizmann Institute of Science. Duke University, Durham, NC, USA. She is the author of the book
Sampling Theory: Beyond Bandlimited Systems and a coauthor
of five other books published by Cambridge University Press. Her
research interests are in the broad areas of statistical signal pro-
cessing, sampling theory and compressed sensing, learning and
optimization methods, and their applications to biology, medical
imaging, and optics.
Dr. Eldar was a member of the IEEE Signal Processing Theory and
Methods and Bio Imaging Signal Processing Technical Committees.
She was a member of the Young Israel Academy of Science and
Humanities and the Israel Committee for Higher Education. She was
a Horev Fellow of the Leaders in Science and Technology Program at
the Technion and an Alon Fellow. She is a member of the IEEE Sen-
sor Array and Multichannel Technical Committee. She is a member
of the Israel Academy of Sciences and Humanities (elected in 2017)
and an EURASIP Fellow. She received many awards for excellence
in research and teaching, including the IEEE Signal Processing
Society Technical Achievement Award in 2013, the IEEE/AESS Fred
Jay Whang is currently working toward Nathanson Memorial Radar Award in 2014, and the IEEE Kiyo
the Ph.D. degree in computer science (CS) Tomiyasu Award in 2016. She received the Michael Bruno Memorial
at The University of Texas at Austin (UT Award from the Rothschild Foundation, the Weizmann Prize for
Austin), Austin, TX, USA, advised by Prof. Exact Sciences, the Wolf Foundation Krill Prize for Excellence in
Alex Dimakis. Scientific Research, the Henry Taub Prize for Excellence in Research
His research interests lie primarily in deep (twice), the Hershel Rich Innovation Award (three times), the Award
generative models and their applications. for Women with Distinguished Contributions, the Andre and Bella
Meyer Lectureship, the Career Development Chair at the Technion,

498 P ROCEEDINGS OF THE IEEE | Vol. 111, No. 5, May 2023


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.
Shlezinger et al.: Model-Based Deep Learning

the Muriel & David Jacknow Award for Excellence in Teaching, Alexandros G. Dimakis (Fellow, IEEE)
and the Technionb’s Award for Excellence in Teaching (two times). received the Diploma degree from the
She received several best paper awards and best demo awards National Technical University of Athens
together with her research students and colleagues, including the (NTU), Athens, Greece, and the Ph.D. degree
SIAM Outstanding Paper Prize, the UFFC Outstanding Paper Award, from the University of California at Berkeley
the Signal Processing Society Best Paper Award, and the IET Cir- (UC Berkeley), Berkeley, CA, USA, in 2008.
cuits, Devices and Systems Premium Award. She was selected as He is currently a Professor and the
one of the 50 most influential women in Israel and Asia. She is Co-Director of the National AI Institute on
a highly cited researcher. She was the co-chair and the technical the Foundations of Machine Learning, The
co-chair of several international conferences and workshops. She University of Texas at Austin (UT Austin), Austin, TX, USA.
serves on several other IEEE committees. In the past, she was a His research interests include information theory and machine
Signal Processing Society Distinguished Lecturer. She has served as learning.
an Associate Editor for the IEEE TRANSACTIONS ON SIGNAL PROCESSING, Dr. Dimakis is an IEEE Fellow for contributions to distributed
the EURASIP Journal on Advances in Signal Processing, the SIAM coding and learning. He received several awards, including the
Journal on Matrix Analysis and Applications, and the SIAM Journal James Massey Award, the NSF Career Award, the Google Research
on Imaging Sciences. She is the Editor-in-Chief of Foundations and Award, the UC Berkeley Eli Jury Dissertation Award, and several
Trends in Signal Processing. best paper awards. He has served as an Associate Editor for
several journals, including IEEE TRANSACTIONS ON INFORMATION THEORY,
and as the Area Chair of machine learning conferences (Advances
in Neural Information Processing Systems (NeurIPS), International
Conference on Machine Learning (ICML), and Association for the
Advancement of Artificial Intelligence (AAAI)).

Vol. 111, No. 5, May 2023 | P ROCEEDINGS OF THE IEEE 499


Authorized licensed use limited to: Peking University. Downloaded on November 27,2023 at 10:56:25 UTC from IEEE Xplore. Restrictions apply.

You might also like