Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
License: arXiv.org perpetual non-exclusive license
arXiv:2402.16821v1 [math.NA] 26 Feb 2024

Numerical analysis on Neural network projected schemes for approximating one dimensional Wasserstein Gradient flows

Xinzhe Zuo zxz@math.ucla.edu Department of Mathematics, University of California, Los Angeles, CA, 90095. Jiaxi Zhao jiaxi.zhao@u.nus.edu Department of Mathematics, National University of Singapore. Shu Liu shuliu@math.ucla.edu Department of Mathematics, University of California, Los Angeles, CA, 90095. Stanley Osher sjo@math.ucla.edu Department of Mathematics, University of California, Los Angeles, CA, 90095.  and  Wuchen Li wuchen@mailbox.sc.edu Department of Mathematics, University of South Carolina, Columbia, SC, 29208.
Abstract.

We provide a numerical analysis and computation of neural network projected schemes for approximating one dimensional Wasserstein gradient flows. We approximate the Lagrangian mapping functions of gradient flows by the class of two-layer neural network functions with ReLU (rectified linear unit) activation functions. The numerical scheme is based on a projected gradient method, namely the Wasserstein natural gradient, where the projection is constructed from the L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT mapping spaces onto the neural network parameterized mapping space. We establish theoretical guarantees for the performance of the neural projected dynamics. We derive a closed-form update for the scheme with well-posedness and explicit consistency guarantee for a particular choice of network structure. General truncation error analysis is also established on the basis of the projective nature of the dynamics. Numerical examples, including gradient drift Fokker-Planck equations, porous medium equations, and Keller-Segel models, verify the accuracy and effectiveness of the proposed neural projected algorithm.

Key words and phrases:
Optimal transport; Information Geometry; Natural gradient; Neural network functions; Convergence analysis.
Xinzhe Zuo and Jiaxi Zhao contributed equally. Jiaxi Zhao, Xinzhe Zuo, and Shu Liu are partially supported by AFOSR YIP award No. FA9550-23-1-008; Xinzhe Zuo, Shu Liu, and Stanley Osher are partially funded by AFOSR MURI FA9550-18-502 and ONR N00014-20-1-2787; Wuchen Li’s work is partially supported by AFOSR YIP award No. FA9550-23-1-008, NSF DMS-2245097, and NSF RTG: 2038080.

1. Introduction

Simulating gradient flows of free energies is a central problem in the computational physics of complex systems [8] and data science [1, 2]. In physics, gradient flows often arise from first-order principles, such as the Onsager principle [32]. The Onsager gradient flows are widely used in phase fields, chemistry, and biology modeling. In recent years, a particular type of Onsager gradient flow, known as Wasserstein gradient flow, has been widely studied in optimal transport communities [3, 33, 37]. It studies an infinite-dimensional pseudo-Riemannian metric in the probability distribution space known as the density manifold. The gradient flow in the Wasserstein space naturally captures the free energy dissipation properties. Depending on the choices of free energies, the Wasserstein gradient flow contains a vast class of differential equations, such as gradient drift Fokker-Planck equations, porous medium equations, and Keller-Segel models. These models are widely used in population dynamics and sampling-related optimization problems.

In recent years, machine learning has brought a class of new methods in computational physics, where free energies are identified with the loss functions [11, 30]. Meanwhile, computing Wasserstein gradient flows of loss functions in terms of samples also finds their various applications, such as generative artificial intelligence [4] and transport map-based sampling methods [35]. In these applications, one often relies on the Lagrangian mapping functions to describe the Wasserstein gradient flows and deep neural networks to approximate the mapping functions due to their high expressivity and adaptivity from the compositional structure. While empirical successes of this framework have been observed in various applications [35, 4], very few theoretical results exist to explain the underlying mechanism.

Moreover, projected dynamics in neural network space are widely used to approximate Wasserstein gradient flows [14, 25]. These dynamics restrict the space of probabilities onto a finite-dimensional subspace parameterized by neural network mapping functions. For this reason, we call it the neural projected gradient dynamics. This approach originates from the natural gradient method in information geometry [1] and extends the framework set by [20]. Some basic questions about its accuracy and efficiency remain: Even in one-dimensional space, how well do the neural projected dynamics approximate the Wasserstein gradient flow? What is the accuracy of the neural network approximation in Lagrangian mapping functions?

In this paper, we study the numerical analysis and computational neural network projected schemes for one-dimensional Wasserstein gradient flows. The main result is sketched below. By formulating gradient flows in Lagrangian coordinates, the proposed numerical scheme takes the form of a ‘preconditioned’ gradient descent, where the preconditioner is the metric tensor of the statistical manifold of the parameter space. Theoretically, we first provide the derivation of the analytic solution for the inverse neural mapping metric. It is based on a special class of the ReLU network in theorem 2. We use the analytic form of the projected gradient flow formula to prove the consistency of the numerical scheme. Then, we prove in theorem 3 that the numerical schemes derived from the neural projected dynamics are of first or second-order consistency for the general Wasserstein gradient directions. These include cases of the heat flow and the Fokker-Planck equation. Furthermore, viewing our neural network model as a moving mesh method, we show in proposition 7 that the mesh will not degenerate during the simulation.

In numerics, the advantages of the proposed method are twofold. First, using a two-layer neural network as our basis function, the proposed method can be regarded as a ‘moving-mesh’ method in Lagrangian coordinate, which demonstrates very promising performance even when the number of parameters of the neural network is very limited. In particular, our numerical examples can achieve an accuracy of 103superscript10310^{-3}10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT with less than 100 neurons. Second, using the Wasserstein gradient flow formulation, the proposed method is very easy to implement since it can make use of the automatic differentiation feature from popular machine learning libraries such as PyTorch.

Nowadays, the computation of Wasserstein gradient flows (WGFs) has attracted great interests from researchers in various communities such as mathematics, physics, statistics, and machine learning. Classical numerical methods [9] have been introduced to directly evaluate the probability density function. Recently, algorithms that approximate the Lagrangian mapping functions associated with WGFs have been invented. We refer the readers to [8] and references therein for related discussions. These treatments automatically preserve non-negativity and total mass. Together with the fast-developing deep learning techniques, they inspire a series of research on composing scalable, sampling-friendly computational methods for WGFs in higher-dimensional spaces [25, 27, 13, 16, 18]. Recently, deep learning-based algorithms for computing the Lagrangian coordinates of the Wasserstein Hamiltonian flows, or more generally mean field control problems, have also been introduced in [38, 28, 34].

Our treatment of projecting the WGFs onto the parameter space is also known as the natural gradient method, which are first introduced in [1] (w.r.t. Fisher-Rao metric) and [10] (w.r.t. Wasserstein metric). Here the projected matrix is often named information matrix, namely Fisher information matrix and Wasserstein information matrix, depending on the usage of metrics in probability space. This method recently finds its application in large-scale optimization problems [29]. In recent research [12, 7, 14], the authors aim to calculate general evolution equations by directly leveraging the neural network representation of the time-dependent solution. They endow the evolution of the equation in the functional space into the parameter space of the neural network to obtain a finite-dimensional ordinary differential equation, which can be readily integrated via the Runge-Kutta solvers. Numerical properties of the ReLU neural network families have been investigated in [15].

Compared to previous studies, we study the numerical analysis of neural network projected dynamics for approximating WGFs. In one-dimensional space, we provide the error analysis for the neural projected dynamics with a two-layer neural network. We numerically verify the proposed error analysis. In particular, we formulate a class of explicit schemes from the neural network projected dynamics. This study continues the study of the Wasserstein information matrix on neural network models; see related discussions in [22, 21, 25].

The paper is organized as follows. In Section 2, we briefly review the formulation of Wasserstein gradient flows of free energies in both Eulerian and Lagrangian coordinates. We formulate the projected Wasserstein gradient flows over neural network models in Section 3. In Section 4, we conduct the numerical analysis of the proposed neural projected dynamics in two-layer neural network functions. In Section 5, we verify the accuracy of the proposed algorithm with numerical examples in Fokker-Planck equations, porous medium equations, and Keller-Segel models.

2. Review of Wasserstein gradient flows and Lagrangian coordinates

In this section, we prepare the theoretical foundations of Wasserstein gradient flows with a focus on Lagrangian description (diffeomorphism mapping functions) and the associated microscopic particle dynamics. See details in [3, 37].

2.1. Wasserstein gradient flows

Suppose ΩΩ\Omegaroman_Ω is a domain in the Euclidean space dsuperscript𝑑\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. Denote the probability space

𝒫(Ω)={p()C:Ωp(x)𝑑x=1,p()0}.𝒫Ωconditional-set𝑝superscript𝐶formulae-sequencesubscriptΩ𝑝𝑥differential-d𝑥1𝑝0\mathcal{P}(\Omega)=\left\{p(\cdot)\in C^{\infty}:~{}\int_{\Omega}p(x)dx=1,% \quad p(\cdot)\geq 0\right\}.caligraphic_P ( roman_Ω ) = { italic_p ( ⋅ ) ∈ italic_C start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT : ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT italic_p ( italic_x ) italic_d italic_x = 1 , italic_p ( ⋅ ) ≥ 0 } .

Given an energy functional ():Ω:Ω\mathcal{F}(\cdot):\Omega\rightarrow\mathbb{R}caligraphic_F ( ⋅ ) : roman_Ω → blackboard_R, we consider the following evolution equation associated with ()\mathcal{F}(\cdot)caligraphic_F ( ⋅ ),

tp(t,x)=x(p(t,x)xδδp(p)),p(,0)=p0,formulae-sequencesubscript𝑡𝑝𝑡𝑥subscript𝑥𝑝𝑡𝑥subscript𝑥𝛿𝛿𝑝𝑝𝑝0subscript𝑝0\partial_{t}p(t,x)=\nabla_{x}\cdot(p(t,x)\nabla_{x}\frac{\delta}{\delta p}% \mathcal{F}(p)),\quad p(\cdot,0)=p_{0},∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) = ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( italic_p ( italic_t , italic_x ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_p end_ARG caligraphic_F ( italic_p ) ) , italic_p ( ⋅ , 0 ) = italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , (1)

with Neumann boundary condition p(t,x)xδδp(p)𝒏=0𝑝𝑡𝑥subscript𝑥𝛿𝛿𝑝𝑝𝒏0p(t,x)\nabla_{x}\frac{\delta}{\delta p}\mathcal{F}(p)\cdot\bm{n}=0italic_p ( italic_t , italic_x ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_p end_ARG caligraphic_F ( italic_p ) ⋅ bold_italic_n = 0 where 𝒏𝒏\bm{n}bold_italic_n is the outward pointing vector on boundary ΩΩ\partial\Omega∂ roman_Ω. δδp𝛿𝛿𝑝\frac{\delta}{\delta p}divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_p end_ARG is the L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT first variation operator w.r.t. density variable p𝑝pitalic_p. The mass of p(t,)𝑝𝑡p(t,\cdot)italic_p ( italic_t , ⋅ ) is conserved and always equals 1111. An important fact about (1) is that this equation can be treated as the gradient flow of \mathcal{F}caligraphic_F on 𝒫(Ω)𝒫Ω\mathcal{P}(\Omega)caligraphic_P ( roman_Ω ). To be more specific, by endowing the probability space 𝒫(Ω)𝒫Ω\mathcal{P}(\Omega)caligraphic_P ( roman_Ω ) with the L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT Wasserstein metric gWsubscript𝑔𝑊g_{W}italic_g start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT, we can view (𝒫(Ω),gW)𝒫Ωsubscript𝑔𝑊(\mathcal{P}(\Omega),g_{W})( caligraphic_P ( roman_Ω ) , italic_g start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ) as a Riemannian manifold, and (1) is the gradient flow on such manifold with respect to gWsubscript𝑔𝑊g_{W}italic_g start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT.

Let us briefly review several facts. We first define the metric gWsubscript𝑔𝑊g_{W}italic_g start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT at arbitrary p𝒫(Ω)𝑝𝒫Ωp\in\mathcal{P}(\Omega)italic_p ∈ caligraphic_P ( roman_Ω ), which is identified via the continuity equation (that is, tangent vectors) whose driving vector field belongs to the closure of all gradient fields xψ:Ωd:subscript𝑥𝜓Ωsuperscript𝑑\nabla_{x}\psi:\Omega\rightarrow\mathbb{R}^{d}∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_ψ : roman_Ω → blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT with ψC(Ω)𝜓superscript𝐶Ω\psi\in C^{\infty}(\Omega)italic_ψ ∈ italic_C start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( roman_Ω ) in L2(p)superscript𝐿2𝑝L^{2}(p)italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_p )-norm. Consider a smooth curve {pi(t,)}t(ϵ,ϵ)subscriptsubscript𝑝𝑖𝑡𝑡italic-ϵitalic-ϵ\{p_{i}(t,\cdot)\}_{t\in(-\epsilon,\epsilon)}{ italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t , ⋅ ) } start_POSTSUBSCRIPT italic_t ∈ ( - italic_ϵ , italic_ϵ ) end_POSTSUBSCRIPT (i=1,2𝑖12i=1,2italic_i = 1 , 2) passing through p𝑝pitalic_p at t=0𝑡0t=0italic_t = 0 on 𝒫(Ω)𝒫Ω\mathcal{P}(\Omega)caligraphic_P ( roman_Ω ). Suppose the probability evolution pi(t,)subscript𝑝𝑖𝑡p_{i}(t,\cdot)italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t , ⋅ ) is driven by the gradient field xψi()subscript𝑥subscript𝜓𝑖\nabla_{x}\psi_{i}(\cdot)∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_ψ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( ⋅ ) at t=0𝑡0t=0italic_t = 0, i.e., ψi()subscript𝜓𝑖\psi_{i}(\cdot)italic_ψ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( ⋅ ) solves

tpi(0,x)+x(pi(0,x)ψi(x))=0,i=1,2.formulae-sequencesubscript𝑡subscript𝑝𝑖0𝑥subscript𝑥subscript𝑝𝑖0𝑥subscript𝜓𝑖𝑥0𝑖12\partial_{t}p_{i}(0,x)+\nabla_{x}\cdot(p_{i}(0,x)\nabla\psi_{i}(x))=0,\quad i=% 1,2.∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( 0 , italic_x ) + ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( 0 , italic_x ) ∇ italic_ψ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_x ) ) = 0 , italic_i = 1 , 2 .

We define the L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT Wasserstein metric gW(,)subscript𝑔𝑊g_{W}(\cdot,\cdot)italic_g start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( ⋅ , ⋅ ) at p𝑝pitalic_p as a symmetric, positive-definite bilinear form,

gW(tp1(0,),tp2(0,))=Ωxψ(x)xψ2(x)p(x)𝑑x.subscript𝑔𝑊subscript𝑡subscript𝑝10subscript𝑡subscript𝑝20subscriptΩsubscript𝑥𝜓𝑥subscript𝑥subscript𝜓2𝑥𝑝𝑥differential-d𝑥g_{W}(\partial_{t}p_{1}(0,\cdot),\partial_{t}p_{2}(0,\cdot))=\int_{\Omega}% \nabla_{x}\psi(x)\cdot\nabla_{x}\psi_{2}(x)p(x)~{}dx.italic_g start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( ∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( 0 , ⋅ ) , ∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( 0 , ⋅ ) ) = ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_ψ ( italic_x ) ⋅ ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_ψ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_x ) italic_p ( italic_x ) italic_d italic_x .

Recall the definition of the gradient of a smooth function f𝑓fitalic_f on a Riemannian manifold (M,g)𝑀𝑔(M,g)( italic_M , italic_g ) as

g(gradf(x),x˙(0))=ddtf(x(t)),𝑔grad𝑓𝑥˙𝑥0𝑑𝑑𝑡𝑓𝑥𝑡g(\mathrm{grad}f(x),\dot{x}(0))=\frac{d}{dt}f(x(t)),italic_g ( roman_grad italic_f ( italic_x ) , over˙ start_ARG italic_x end_ARG ( 0 ) ) = divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG italic_f ( italic_x ( italic_t ) ) ,

for any smooth curves {x(t)}t(ϵ,ϵ)𝑥𝑡𝑡italic-ϵitalic-ϵ\{x(t)\}t\in(-\epsilon,\epsilon){ italic_x ( italic_t ) } italic_t ∈ ( - italic_ϵ , italic_ϵ ) passing through x𝑥xitalic_x at t=0𝑡0t=0italic_t = 0. Switching back to our case, for the functional \mathcal{F}caligraphic_F defined on (𝒫(Ω),gW)𝒫Ωsubscript𝑔𝑊(\mathcal{P}(\Omega),g_{W})( caligraphic_P ( roman_Ω ) , italic_g start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ), we define the gradient of \mathcal{F}caligraphic_F w.r.t. Wasserstein metric gWsubscript𝑔𝑊g_{W}italic_g start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT at p𝑝pitalic_p as

gW(gradW(p),tp(0,))=ddt(p(t,))|t=0.subscript𝑔𝑊subscriptgrad𝑊𝑝subscript𝑡𝑝0evaluated-at𝑑𝑑𝑡𝑝𝑡𝑡0g_{W}(\textrm{grad}_{W}\mathcal{F}(p),\partial_{t}p(0,\cdot))=\frac{d}{dt}% \mathcal{F}(p(t,\cdot))\Bigg{|}_{t=0}.italic_g start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ( grad start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT caligraphic_F ( italic_p ) , ∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( 0 , ⋅ ) ) = divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG caligraphic_F ( italic_p ( italic_t , ⋅ ) ) | start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT .

Here {p(t,)}t(ϵ,ϵ)subscript𝑝𝑡𝑡italic-ϵitalic-ϵ\{p(t,\cdot)\}_{t\in(-\epsilon,\epsilon)}{ italic_p ( italic_t , ⋅ ) } start_POSTSUBSCRIPT italic_t ∈ ( - italic_ϵ , italic_ϵ ) end_POSTSUBSCRIPT is arbitrary curve on 𝒫(Ω)𝒫Ω\mathcal{P}(\Omega)caligraphic_P ( roman_Ω ) with p(0,)=p()𝑝0𝑝p(0,\cdot)=p(\cdot)italic_p ( 0 , ⋅ ) = italic_p ( ⋅ ). Suppose p(t,)𝑝𝑡p(t,\cdot)italic_p ( italic_t , ⋅ ) is guided by the gradient field xψsubscript𝑥𝜓\nabla_{x}\psi∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_ψ at time t=0,𝑡0t=0,italic_t = 0 , Then the right-hand side can be computed as

ddt(p(t,))=Ωδ(p(0,))δp(x)tp(0,x)dx𝑑𝑑𝑡𝑝𝑡subscriptΩ𝛿𝑝0𝛿𝑝𝑥subscript𝑡𝑝0𝑥𝑑𝑥\displaystyle\frac{d}{dt}\mathcal{F}(p(t,\cdot))=\int_{\Omega}\frac{\delta% \mathcal{F}(p(0,\cdot))}{\delta p}(x)\partial_{t}p(0,x)~{}dxdivide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG caligraphic_F ( italic_p ( italic_t , ⋅ ) ) = ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT divide start_ARG italic_δ caligraphic_F ( italic_p ( 0 , ⋅ ) ) end_ARG start_ARG italic_δ italic_p end_ARG ( italic_x ) ∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( 0 , italic_x ) italic_d italic_x =Ωδ(p)δp(x)(x(p(x)xψ(x)))𝑑xabsentsubscriptΩ𝛿𝑝𝛿𝑝𝑥subscript𝑥𝑝𝑥subscript𝑥𝜓𝑥differential-d𝑥\displaystyle=\int_{\Omega}\frac{\delta\mathcal{F}(p)}{\delta p}(x)(-\nabla_{x% }\cdot(p(x)\nabla_{x}\psi(x)))~{}dx= ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT divide start_ARG italic_δ caligraphic_F ( italic_p ) end_ARG start_ARG italic_δ italic_p end_ARG ( italic_x ) ( - ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( italic_p ( italic_x ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_ψ ( italic_x ) ) ) italic_d italic_x
=Ωδ(p)δp(x)xψ(x)p(x)𝑑x.absentsubscriptΩ𝛿𝑝𝛿𝑝𝑥subscript𝑥𝜓𝑥𝑝𝑥differential-d𝑥\displaystyle=\int_{\Omega}\nabla\frac{\delta\mathcal{F}(p)}{\delta p}(x)\cdot% \nabla_{x}\psi(x)p(x)dx.= ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT ∇ divide start_ARG italic_δ caligraphic_F ( italic_p ) end_ARG start_ARG italic_δ italic_p end_ARG ( italic_x ) ⋅ ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_ψ ( italic_x ) italic_p ( italic_x ) italic_d italic_x .

Recall the definition of the metric gWsubscript𝑔𝑊g_{W}italic_g start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT, it is not difficult to verify that the gradient field associated with gradW(p)subscriptgrad𝑊𝑝\mathrm{grad}_{W}\mathcal{F}(p)roman_grad start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT caligraphic_F ( italic_p ) is xδδp(p)subscript𝑥𝛿𝛿𝑝𝑝\nabla_{x}\frac{\delta}{\delta p}\mathcal{F}(p)∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_p end_ARG caligraphic_F ( italic_p ). Thus,

gradW(p)=x(p(t,x)xδ(p)δp(x)),subscriptgrad𝑊𝑝subscript𝑥𝑝𝑡𝑥subscript𝑥𝛿𝑝𝛿𝑝𝑥\mathrm{grad}_{W}\mathcal{F}(p)=-\nabla_{x}\cdot(p(t,x)\nabla_{x}\frac{\delta% \mathcal{F}(p)}{\delta p}(x)),roman_grad start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT caligraphic_F ( italic_p ) = - ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( italic_p ( italic_t , italic_x ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT divide start_ARG italic_δ caligraphic_F ( italic_p ) end_ARG start_ARG italic_δ italic_p end_ARG ( italic_x ) ) ,

and the Wasserstein gradient flow tp=gradW(p)subscript𝑡𝑝subscriptgrad𝑊𝑝\partial_{t}p=-\mathrm{grad}_{W}\mathcal{F}(p)∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p = - roman_grad start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT caligraphic_F ( italic_p ) can be formulated as equation (1).

We provide several examples of WGFs. In these examples, we assume Ω=dΩsuperscript𝑑\Omega=\mathbb{R}^{d}roman_Ω = blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.

  • (Fokker-Planck equation) Consider

    (p)=ΩV(x)p(x)𝑑x+γΩp(x)logp(x)𝑑x.𝑝subscriptΩ𝑉𝑥𝑝𝑥differential-d𝑥𝛾subscriptΩ𝑝𝑥𝑝𝑥differential-d𝑥\displaystyle\mathcal{F}(p)=\int_{\Omega}V(x)p(x)dx+\gamma\int_{\Omega}p(x)% \log p(x)dx.caligraphic_F ( italic_p ) = ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT italic_V ( italic_x ) italic_p ( italic_x ) italic_d italic_x + italic_γ ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT italic_p ( italic_x ) roman_log italic_p ( italic_x ) italic_d italic_x .

    Then the Wasserstein gradient of \mathcal{F}caligraphic_F equals

    gradW(p)=x(p(x)x(V(x)+γ(logp(x)+1)))=(p(x)xV(x))γΔxp(x).subscriptgrad𝑊𝑝subscript𝑥𝑝𝑥subscript𝑥𝑉𝑥𝛾𝑝𝑥1𝑝𝑥subscript𝑥𝑉𝑥𝛾subscriptΔ𝑥𝑝𝑥\begin{split}\mathrm{grad}_{W}\mathcal{F}(p)=&-\nabla_{x}\cdot(p(x)\nabla_{x}(% V(x)+\gamma(\log p(x)+1)))\\ =&-\nabla\cdot(p(x)\nabla_{x}V(x))-\gamma\Delta_{x}p(x).\end{split}start_ROW start_CELL roman_grad start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT caligraphic_F ( italic_p ) = end_CELL start_CELL - ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( italic_p ( italic_x ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_V ( italic_x ) + italic_γ ( roman_log italic_p ( italic_x ) + 1 ) ) ) end_CELL end_ROW start_ROW start_CELL = end_CELL start_CELL - ∇ ⋅ ( italic_p ( italic_x ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_V ( italic_x ) ) - italic_γ roman_Δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_p ( italic_x ) . end_CELL end_ROW

    The corresponding WGF is the Fokker-Planck equation

    tp(t,x)=x(p(t,x)xV(x))+γΔxp(t,x).subscript𝑡𝑝𝑡𝑥subscript𝑥𝑝𝑡𝑥subscript𝑥𝑉𝑥𝛾subscriptΔ𝑥𝑝𝑡𝑥\partial_{t}p(t,x)=\nabla_{x}\cdot(p(t,x)\nabla_{x}V(x))+\gamma\Delta_{x}p(t,x).∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) = ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( italic_p ( italic_t , italic_x ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_V ( italic_x ) ) + italic_γ roman_Δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) . (2)
  • (Porous medium equation) Consider

    (p)=pmm1.𝑝superscript𝑝𝑚𝑚1\mathcal{F}(p)=\frac{p^{m}}{m-1}.caligraphic_F ( italic_p ) = divide start_ARG italic_p start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT end_ARG start_ARG italic_m - 1 end_ARG .

    One computes

    gradW(p)=x(p(t,x)x(mm1p(x)m1))=x(x(p(x)m))=Δxp(x)m.subscriptgrad𝑊𝑝subscript𝑥𝑝𝑡𝑥subscript𝑥𝑚𝑚1𝑝superscript𝑥𝑚1subscript𝑥subscript𝑥𝑝superscript𝑥𝑚subscriptΔ𝑥𝑝superscript𝑥𝑚\mathrm{grad}_{W}\mathcal{F}(p)=-\nabla_{x}\cdot(p(t,x)\nabla_{x}(\frac{m}{m-1% }p(x)^{m-1}))=-\nabla_{x}\cdot(\nabla_{x}(p(x)^{m}))=-\Delta_{x}p(x)^{m}.roman_grad start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT caligraphic_F ( italic_p ) = - ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( italic_p ( italic_t , italic_x ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( divide start_ARG italic_m end_ARG start_ARG italic_m - 1 end_ARG italic_p ( italic_x ) start_POSTSUPERSCRIPT italic_m - 1 end_POSTSUPERSCRIPT ) ) = - ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_p ( italic_x ) start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT ) ) = - roman_Δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_p ( italic_x ) start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT .

    Thus, the corresponding WGF yields the porous medium equation

    tp(t,x)=Δxp(t,x)m.subscript𝑡𝑝𝑡𝑥subscriptΔ𝑥𝑝superscript𝑡𝑥𝑚\partial_{t}p(t,x)=\Delta_{x}p(t,x)^{m}.∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) = roman_Δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT . (3)
  • (Keller-Segel equation) Another well-known WGF is by choosing \mathcal{F}caligraphic_F as the sum of the internal energy and the interaction energy

    (p)=ΩU(p(x))𝑑x+12Ω×ΩW(|xy|)p(x)p(y)𝑑x𝑑y,𝑝subscriptΩ𝑈𝑝𝑥differential-d𝑥12subscriptdouble-integralΩΩ𝑊𝑥𝑦𝑝𝑥𝑝𝑦differential-d𝑥differential-d𝑦\mathcal{F}(p)=\int_{\Omega}U(p(x))~{}dx+\frac{1}{2}\iint_{\Omega\times\Omega}% W(|x-y|)p(x)p(y)~{}dxdy,caligraphic_F ( italic_p ) = ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT italic_U ( italic_p ( italic_x ) ) italic_d italic_x + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∬ start_POSTSUBSCRIPT roman_Ω × roman_Ω end_POSTSUBSCRIPT italic_W ( | italic_x - italic_y | ) italic_p ( italic_x ) italic_p ( italic_y ) italic_d italic_x italic_d italic_y ,

    where U𝑈Uitalic_U is a certain smooth function defined on +subscript\mathbb{R}_{+}blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT, and W()C(+;)𝑊𝐶subscriptW(\cdot)\in C(\mathbb{R}_{+};\mathbb{R})italic_W ( ⋅ ) ∈ italic_C ( blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT ; blackboard_R ) is a kernel function.

    We calculate

    gradW(p)=x(p(x)x(U(p(x))+W*p(x))),subscriptgrad𝑊𝑝subscript𝑥𝑝𝑥subscript𝑥superscript𝑈𝑝𝑥𝑊𝑝𝑥\mathrm{grad}_{W}\mathcal{F}(p)=-\nabla_{x}\cdot(p(x)\nabla_{x}(U^{\prime}(p(x% ))+W*p(x))),roman_grad start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT caligraphic_F ( italic_p ) = - ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( italic_p ( italic_x ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_U start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_p ( italic_x ) ) + italic_W * italic_p ( italic_x ) ) ) ,

    where we denote the convolution W*p(x)=ΩW(|xy|)p(y)𝑑y𝑊𝑝𝑥subscriptΩ𝑊𝑥𝑦𝑝𝑦differential-d𝑦W*p(x)=\int_{\Omega}W(|x-y|)p(y)~{}dyitalic_W * italic_p ( italic_x ) = ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT italic_W ( | italic_x - italic_y | ) italic_p ( italic_y ) italic_d italic_y. The WGF associated with this functional is the Keller-Segel equation

    tp(t,x)=x(p(t,x)xU(p(t,x)))+x(p(t,x)x(W*pt(x))).subscript𝑡𝑝𝑡𝑥subscript𝑥𝑝𝑡𝑥subscript𝑥superscript𝑈𝑝𝑡𝑥subscript𝑥𝑝𝑡𝑥subscript𝑥𝑊subscript𝑝𝑡𝑥\partial_{t}p(t,x)=\nabla_{x}\cdot(p(t,x)\nabla_{x}U^{\prime}(p(t,x)))+\nabla_% {x}\cdot(p(t,x)\nabla_{x}(W*p_{t}(x))).∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) = ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( italic_p ( italic_t , italic_x ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_U start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_p ( italic_t , italic_x ) ) ) + ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( italic_p ( italic_t , italic_x ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_W * italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) ) ) . (4)

2.2. Lagrangian coordinates & Particle dynamics

Consider a mapping function T:ZΩ:𝑇𝑍ΩT\colon Z\rightarrow\Omegaitalic_T : italic_Z → roman_Ω. Here zZ𝑧𝑍z\in Zitalic_z ∈ italic_Z is an input space, ΩdΩsuperscript𝑑\Omega\subset\mathbb{R}^{d}roman_Ω ⊂ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is the domain on which WGF is defined. To alleviate our discussion, we assume Z=Ω𝑍ΩZ=\Omegaitalic_Z = roman_Ω. Let us further assume TC(Z,Ω)𝑇superscript𝐶𝑍ΩT\in C^{\infty}(Z,\Omega)italic_T ∈ italic_C start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_Z , roman_Ω ), and the Jacobian matrix DzT(z)subscript𝐷𝑧𝑇𝑧D_{z}T(z)italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_T ( italic_z ) is non-singular for all zZ𝑧𝑍z\in Zitalic_z ∈ italic_Z, i.e., det(DzT(z))0detsubscript𝐷𝑧𝑇𝑧0\mathrm{det}(D_{z}T(z))\neq 0roman_det ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_T ( italic_z ) ) ≠ 0 on Z𝑍Zitalic_Z. This also guarantees that T𝑇Titalic_T is injective. Given a smooth reference probability density pr𝒫(Z)subscript𝑝r𝒫𝑍p_{\mathrm{r}}\in\mathcal{P}(Z)italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ∈ caligraphic_P ( italic_Z ), we denote the pushforwarded probability density of prsubscript𝑝𝑟p_{r}italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT by T𝑇Titalic_T as

p=T#pr,𝑝subscript𝑇#subscript𝑝𝑟p=T_{\#}p_{r},italic_p = italic_T start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ,

where T#:𝒫(Z)𝒫(Ω):subscript𝑇#𝒫𝑍𝒫ΩT_{\#}:\mathcal{P}(Z)\rightarrow\mathcal{P}(\Omega)italic_T start_POSTSUBSCRIPT # end_POSTSUBSCRIPT : caligraphic_P ( italic_Z ) → caligraphic_P ( roman_Ω ) is the pushforward operator defined as

Ωf(x)T#pr(x)𝑑x=Zf(T(z))pr(z)𝑑z,for all fTL1(pr).formulae-sequencesubscriptΩ𝑓𝑥subscript𝑇#subscript𝑝𝑟𝑥differential-d𝑥subscript𝑍𝑓𝑇𝑧subscript𝑝𝑟𝑧differential-d𝑧for all 𝑓𝑇superscript𝐿1subscript𝑝𝑟\int_{\Omega}f(x)T_{\#}p_{r}(x)~{}dx=\int_{Z}f(T(z))p_{r}(z)~{}dz,\quad\textrm% {for all }f\circ T\in L^{1}(p_{r}).∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT italic_f ( italic_x ) italic_T start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ( italic_x ) italic_d italic_x = ∫ start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT italic_f ( italic_T ( italic_z ) ) italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z , for all italic_f ∘ italic_T ∈ italic_L start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ) .

The density function of p𝑝pitalic_p satisfies

p(T(z))det(DzT(z))=pr(z).zZ.i.e.,p(x)=prdet(DzT)T1(x)xΩ.p(T(z))\mathrm{det}(D_{z}T(z))=p_{\mathrm{r}}(z).\quad\forall~{}z\in Z.\quad% \textrm{i.e.,}~{}~{}p(x)=\frac{p_{r}}{\mathrm{det}(D_{z}T)}\circ T^{-1}(x)% \quad\forall~{}x\in\Omega.italic_p ( italic_T ( italic_z ) ) roman_det ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_T ( italic_z ) ) = italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) . ∀ italic_z ∈ italic_Z . i.e., italic_p ( italic_x ) = divide start_ARG italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT end_ARG start_ARG roman_det ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_T ) end_ARG ∘ italic_T start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_x ) ∀ italic_x ∈ roman_Ω . (5)

Such pushforward map T𝑇Titalic_T used for constructing probability distribution p𝑝pitalic_p is usually called the Lagrangian coordinate. We now imitate the derivation of the WGF to help formulate its counterpart under the Lagrangian coordinate.

We denote 𝒪𝒪\mathcal{O}caligraphic_O as the space of smooth, L2(pr)superscript𝐿2subscript𝑝𝑟L^{2}(p_{r})italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ) integrable pushforward maps with non-zeros Jacobian, i.e.,

𝒪={TC(Z,Ω):det(DzT)0,Z|T(z)|2pz(z)𝑑z<}.𝒪conditional-set𝑇superscript𝐶𝑍Ωformulae-sequencedetsubscript𝐷𝑧𝑇0subscript𝑍superscript𝑇𝑧2subscript𝑝𝑧𝑧differential-d𝑧\mathcal{O}=\left\{T\in C^{\infty}(Z,\Omega)~{}:~{}\mathrm{det}(D_{z}T)\neq 0,% ~{}~{}\int_{Z}|T(z)|^{2}p_{z}(z)~{}dz<\infty\right\}.caligraphic_O = { italic_T ∈ italic_C start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_Z , roman_Ω ) : roman_det ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_T ) ≠ 0 , ∫ start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT | italic_T ( italic_z ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z < ∞ } .

Then the pushforward operation #:𝒪𝒫(Ω):#𝒪𝒫Ω\#:\mathcal{O}\rightarrow\mathcal{P}(\Omega)# : caligraphic_O → caligraphic_P ( roman_Ω ) introduces a submersion from the space of pushforward maps (diffeomorphisms) to the space of probability densities.

In order to derive the Wasserstein gradient flows (WGFs) on the space 𝒪𝒪\mathcal{O}caligraphic_O of pushforward maps instead of the probability space 𝒫(Ω)𝒫Ω\mathcal{P}(\Omega)caligraphic_P ( roman_Ω ), we first build up certain metric ,\langle\cdot,\cdot\rangle⟨ ⋅ , ⋅ ⟩ on 𝒪𝒪\mathcal{O}caligraphic_O that corresponds to the Wasserstein metric gWsubscript𝑔𝑊g_{W}italic_g start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT. As illustrated in [33], gWsubscript𝑔𝑊g_{W}italic_g start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT is obtained by pulling back the L2(pr)superscript𝐿2subscript𝑝𝑟L^{2}(p_{r})italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ) norm on 𝒪𝒪\mathcal{O}caligraphic_O via submersion ##\##. Thus, a way of choosing the metric is

𝐮1,𝐮2=Z𝐮1(z)𝐮2(z)pr(z)𝑑z,𝐮1,𝐮2L2(pr)C(Z,Ω).formulae-sequencesubscript𝐮1subscript𝐮2subscript𝑍subscript𝐮1𝑧subscript𝐮2𝑧subscript𝑝𝑟𝑧differential-d𝑧for-allsubscript𝐮1subscript𝐮2superscript𝐿2subscript𝑝𝑟superscript𝐶𝑍Ω\langle\mathbf{u}_{1},\mathbf{u}_{2}\rangle=\int_{Z}\mathbf{u}_{1}(z)\cdot% \mathbf{u}_{2}(z)p_{r}(z)~{}dz,\quad\forall~{}\mathbf{u}_{1},\mathbf{u}_{2}\in L% ^{2}(p_{r})~{}\bigcap~{}C^{\infty}(Z,\Omega).⟨ bold_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_u start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⟩ = ∫ start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT bold_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_z ) ⋅ bold_u start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_z ) italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z , ∀ bold_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_u start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ) ⋂ italic_C start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_Z , roman_Ω ) .

Now for any smooth functional :𝒫(Ω):𝒫Ω\mathcal{F}:\mathcal{P}(\Omega)\rightarrow\mathbb{R}caligraphic_F : caligraphic_P ( roman_Ω ) → blackboard_R, the composition ##:𝒪:superscript##𝒪\mathcal{F}^{\#}\triangleq\mathcal{F}\circ\#:\mathcal{O}\rightarrow\mathbb{R}caligraphic_F start_POSTSUPERSCRIPT # end_POSTSUPERSCRIPT ≜ caligraphic_F ∘ # : caligraphic_O → blackboard_R defines its corresponding functional on 𝒪𝒪\mathcal{O}caligraphic_O. Follow similar arguments presented in 2.1, we compute the gradient of #superscript#\mathcal{F}^{\#}caligraphic_F start_POSTSUPERSCRIPT # end_POSTSUPERSCRIPT with respect to the metric ,\langle\cdot,\cdot\rangle⟨ ⋅ , ⋅ ⟩ as

grad,#(T)=1pr()δ#(T)δT().subscriptgradsuperscript#𝑇1subscript𝑝r𝛿superscript#𝑇𝛿𝑇\mathrm{grad}_{\langle\cdot,\cdot\rangle}\mathcal{F}^{\#}(T)=\frac{1}{p_{% \mathrm{r}}(\cdot)}\frac{\delta\mathcal{F}^{\#}(T)}{\delta T}(\cdot).roman_grad start_POSTSUBSCRIPT ⟨ ⋅ , ⋅ ⟩ end_POSTSUBSCRIPT caligraphic_F start_POSTSUPERSCRIPT # end_POSTSUPERSCRIPT ( italic_T ) = divide start_ARG 1 end_ARG start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( ⋅ ) end_ARG divide start_ARG italic_δ caligraphic_F start_POSTSUPERSCRIPT # end_POSTSUPERSCRIPT ( italic_T ) end_ARG start_ARG italic_δ italic_T end_ARG ( ⋅ ) .

Here, δδT𝛿𝛿𝑇\frac{\delta}{\delta T}divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_T end_ARG is the L2(m)superscript𝐿2𝑚L^{2}(m)italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) (m𝑚mitalic_m denotes the Lebesgue measure) first variational w.r.t. the pushforward map T𝑇Titalic_T.

Thus, the gradient flow of #superscript#\mathcal{F}^{\#}caligraphic_F start_POSTSUPERSCRIPT # end_POSTSUPERSCRIPT on 𝒪𝒪\mathcal{O}caligraphic_O is formulated as

tT(t,)=grad,#(T(t,))=1pr()δ#(T(t,))δT().subscript𝑡𝑇𝑡subscriptgradsuperscript#𝑇𝑡1subscript𝑝r𝛿superscript#𝑇𝑡𝛿𝑇\partial_{t}T(t,\cdot)=-\mathrm{grad}_{\langle\cdot,\cdot\rangle}\mathcal{F}^{% \#}(T(t,\cdot))=-\frac{1}{p_{\mathrm{r}}(\cdot)}\frac{\delta\mathcal{F}^{\#}(T% (t,\cdot))}{\delta T}(\cdot).∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_T ( italic_t , ⋅ ) = - roman_grad start_POSTSUBSCRIPT ⟨ ⋅ , ⋅ ⟩ end_POSTSUBSCRIPT caligraphic_F start_POSTSUPERSCRIPT # end_POSTSUPERSCRIPT ( italic_T ( italic_t , ⋅ ) ) = - divide start_ARG 1 end_ARG start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( ⋅ ) end_ARG divide start_ARG italic_δ caligraphic_F start_POSTSUPERSCRIPT # end_POSTSUPERSCRIPT ( italic_T ( italic_t , ⋅ ) ) end_ARG start_ARG italic_δ italic_T end_ARG ( ⋅ ) .

The variation δδT𝛿𝛿𝑇\frac{\delta}{\delta T}divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_T end_ARG is calculated as

δ#(T)δT(z)=(xδ(T#pr)δp)T(z)pr(z).𝛿superscript#𝑇𝛿𝑇𝑧subscript𝑥𝛿subscript𝑇#subscript𝑝𝑟𝛿𝑝𝑇𝑧subscript𝑝𝑟𝑧\frac{\delta\mathcal{F}^{\#}(T)}{\delta T}(z)=\left(\nabla_{x}\frac{\delta% \mathcal{F}(T_{\#}p_{r})}{\delta p}\right)\circ T(z)p_{r}(z).divide start_ARG italic_δ caligraphic_F start_POSTSUPERSCRIPT # end_POSTSUPERSCRIPT ( italic_T ) end_ARG start_ARG italic_δ italic_T end_ARG ( italic_z ) = ( ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT divide start_ARG italic_δ caligraphic_F ( italic_T start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ) end_ARG start_ARG italic_δ italic_p end_ARG ) ∘ italic_T ( italic_z ) italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ( italic_z ) .

The above equation can also be written as

tT(t,z)=(xδ(T(t,)#pr)δp)T(t,z).subscript𝑡𝑇𝑡𝑧subscript𝑥𝛿𝑇subscript𝑡#subscript𝑝𝑟𝛿𝑝𝑇𝑡𝑧\partial_{t}T(t,z)=-\left(\nabla_{x}\frac{\delta\mathcal{F}(T(t,\cdot)_{\#}p_{% r})}{\delta p}\right)\circ T(t,z).∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_T ( italic_t , italic_z ) = - ( ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT divide start_ARG italic_δ caligraphic_F ( italic_T ( italic_t , ⋅ ) start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ) end_ARG start_ARG italic_δ italic_p end_ARG ) ∘ italic_T ( italic_t , italic_z ) . (6)

If we denote p(t,)=T(t,)#pr𝑝𝑡𝑇subscript𝑡#subscript𝑝𝑟p(t,\cdot)=T(t,\cdot)_{\#}p_{r}italic_p ( italic_t , ⋅ ) = italic_T ( italic_t , ⋅ ) start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT, one can verify that p(t,)𝑝𝑡p(t,\cdot)italic_p ( italic_t , ⋅ ) exactly solves equation (1) for WGF with p0=T(0,)#prsubscript𝑝0𝑇subscript0#subscript𝑝𝑟p_{0}=T(0,\cdot)_{\#}p_{r}italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_T ( 0 , ⋅ ) start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT, which justifies the equivalence between the gradient flow (6) in Lagrangian coordinates (i.e., the map T(t,)𝑇𝑡T(t,\cdot)italic_T ( italic_t , ⋅ )) and the WGF (1) expressed by using Eulerian coordinate (i.e., the density function p(t,)𝑝𝑡p(t,\cdot)italic_p ( italic_t , ⋅ )).

Such gradient flow (6) on the space of diffeomorphisms also forms a microscopic picture of particle dynamics of the WGF (1). For any random reference sample zprsimilar-to𝑧subscript𝑝𝑟z\sim p_{r}italic_z ∼ italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT, by setting 𝐱t=T(t,z)subscript𝐱𝑡𝑇𝑡𝑧\mathbf{x}_{t}=T(t,z)bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_T ( italic_t , italic_z ), it is not hard to verify that 𝐱tsubscript𝐱𝑡\mathbf{x}_{t}bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT evolves w.r.t. the dynamic

d𝐱tdt=(xδδp(pt))(𝐱t),𝐱0=T(0,z),zpr.formulae-sequence𝑑subscript𝐱𝑡𝑑𝑡subscript𝑥𝛿𝛿𝑝subscript𝑝𝑡subscript𝐱𝑡formulae-sequencesubscript𝐱0𝑇0𝑧similar-to𝑧subscript𝑝𝑟\frac{d\mathbf{x}_{t}}{dt}=-\left(\nabla_{x}\frac{\delta}{\delta p}\mathcal{F}% (p_{t})\right)(\mathbf{x}_{t}),\quad\mathbf{x}_{0}=T(0,z),\quad z\sim p_{r}.divide start_ARG italic_d bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_d italic_t end_ARG = - ( ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_p end_ARG caligraphic_F ( italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) ( bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , bold_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_T ( 0 , italic_z ) , italic_z ∼ italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT . (7)

Here we denote pt=T(t,)#prsubscript𝑝𝑡𝑇subscript𝑡#subscript𝑝𝑟p_{t}=T(t,\cdot)_{\#}p_{r}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_T ( italic_t , ⋅ ) start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT. ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT can be equivalently treated as the probability density of the random particle 𝐱tsubscript𝐱𝑡\mathbf{x}_{t}bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. In this dynamic, the movement of a single agent 𝐱tsubscript𝐱𝑡\mathbf{x}_{t}bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is determined by the instant population density ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT evaluated at 𝐱tsubscript𝐱𝑡\mathbf{x}_{t}bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. Such an approach offers a microscopic and deterministic interpretation of various diffusive processes possessing WGF structures.

The aforementioned examples of WGF can be formulated as the gradient flows under Lagrangian coordinates (6) as well as the particle dynamics (7). We summarize this in the following Table 1. We assume T(0,)#pr=p0𝑇subscript0#subscript𝑝𝑟subscript𝑝0T(0,\cdot)_{\#}p_{r}=p_{0}italic_T ( 0 , ⋅ ) start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT = italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT as the initial condition for (6), and 𝐱0p0similar-tosubscript𝐱0subscript𝑝0\mathbf{x}_{0}\sim p_{0}bold_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∼ italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT as the initial distribution of the random particle 𝐱tsubscript𝐱𝑡\mathbf{x}_{t}bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT in (7). We denote pt=T(t,)#prsubscript𝑝𝑡𝑇subscript𝑡#subscript𝑝𝑟p_{t}=T(t,\cdot)_{\#}p_{r}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_T ( italic_t , ⋅ ) start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT in equation (6). Accordingly, we denote ptsubscript𝑝𝑡p_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT as the probability density of the stochastic particle 𝐱tsubscript𝐱𝑡\mathbf{x}_{t}bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT in the dynamic (7).

WGF
Gradient flow in Lagrangian coordinates
Particle dynamic
Fokker-Planck (2)
tT(t,z)=x(V+γlogpt)T(t,z)subscript𝑡𝑇𝑡𝑧subscript𝑥𝑉𝛾subscript𝑝𝑡𝑇𝑡𝑧\partial_{t}T(t,z)=-\nabla_{x}(V+\gamma\log p_{t})\circ T(t,z)∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_T ( italic_t , italic_z ) = - ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_V + italic_γ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∘ italic_T ( italic_t , italic_z )
d𝐱tdt=xV(𝐱t)γxlogpt(𝐱t)𝑑subscript𝐱𝑡𝑑𝑡subscript𝑥𝑉subscript𝐱𝑡𝛾subscript𝑥subscript𝑝𝑡subscript𝐱𝑡\frac{d\mathbf{x}_{t}}{dt}=-\nabla_{x}V(\mathbf{x}_{t})-\gamma\nabla_{x}\log p% _{t}(\mathbf{x}_{t})divide start_ARG italic_d bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_d italic_t end_ARG = - ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_V ( bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - italic_γ ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )
Porous-medium (3)
tT(t,z)=mm1pt(T(t,z))m1xptT(t,z)subscript𝑡𝑇𝑡𝑧𝑚𝑚1subscript𝑝𝑡superscript𝑇𝑡𝑧𝑚1subscript𝑥subscript𝑝𝑡𝑇𝑡𝑧\partial_{t}T(t,z)=-\frac{m}{m-1}p_{t}(T(t,z))^{m-1}\nabla_{x}p_{t}\circ T(t,z)∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_T ( italic_t , italic_z ) = - divide start_ARG italic_m end_ARG start_ARG italic_m - 1 end_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_T ( italic_t , italic_z ) ) start_POSTSUPERSCRIPT italic_m - 1 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∘ italic_T ( italic_t , italic_z )
d𝐱tdt=mm1pt(𝐱t)m1pt(𝐱t)𝑑subscript𝐱𝑡𝑑𝑡𝑚𝑚1subscript𝑝𝑡superscriptsubscript𝐱𝑡𝑚1subscript𝑝𝑡subscript𝐱𝑡\frac{d\mathbf{x}_{t}}{dt}=-\frac{m}{m-1}p_{t}(\mathbf{x}_{t})^{m-1}\nabla p_{% t}(\mathbf{x}_{t})divide start_ARG italic_d bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_d italic_t end_ARG = - divide start_ARG italic_m end_ARG start_ARG italic_m - 1 end_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_m - 1 end_POSTSUPERSCRIPT ∇ italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )
Keller-Segel (4)
tT(t,z)=x(U(pt)+W*pt)T(t,z)subscript𝑡𝑇𝑡𝑧subscript𝑥superscript𝑈subscript𝑝𝑡𝑊subscript𝑝𝑡𝑇𝑡𝑧\partial_{t}T(t,z)=-\nabla_{x}(U^{\prime}(p_{t})+W*p_{t})\circ T(t,z)∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_T ( italic_t , italic_z ) = - ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( italic_U start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + italic_W * italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∘ italic_T ( italic_t , italic_z )
d𝐱tdt=xU(pt(𝐱t))xW*pt(𝐱t)𝑑subscript𝐱𝑡𝑑𝑡subscript𝑥superscript𝑈subscript𝑝𝑡subscript𝐱𝑡subscript𝑥𝑊subscript𝑝𝑡subscript𝐱𝑡\frac{d\mathbf{x}_{t}}{dt}=-\nabla_{x}U^{\prime}(p_{t}(\mathbf{x}_{t}))-\nabla% _{x}W*p_{t}(\mathbf{x}_{t})divide start_ARG italic_d bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_d italic_t end_ARG = - ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_U start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) - ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_W * italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT )
Table 1. Gradient flows under Lagrangian coordinates & Particle dynamics associated with the WGFs.

3. Neural projected Wassersetin gradient flows and their algorithms

As discussed in Section 2, instead of the direct evaluation of the density function of the Wasserstein gradient flow, it suffices to compute the time-dependent Lagrangian mapping T(t,)𝑇𝑡T(t,\cdot)italic_T ( italic_t , ⋅ ). In this research, we approximate T(t,)𝑇𝑡T(t,\cdot)italic_T ( italic_t , ⋅ ) via neural networks parametrized by time-dependent parameter {θt}subscript𝜃𝑡\{\theta_{t}\}{ italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT }. The evolution of θtsubscript𝜃𝑡\theta_{t}italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is obtained by projecting the gradient flow (6) onto the parameter space ΘΘ\Thetaroman_Θ. In this section, we briefly review the basic definitions of neural network mapping functions. We next study a metric space for neural mapping functions and formulate several neural mapping dynamics for {θt}subscript𝜃𝑡\{\theta_{t}\}{ italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT }.

3.1. Neural network activation functions

We first provide the definition of a neural network mapping function. Consider a mapping function

f:Z×ΘΩ,:𝑓𝑍ΘΩf\colon Z\times\Theta\rightarrow\Omega,italic_f : italic_Z × roman_Θ → roman_Ω ,

where Zl𝑍superscript𝑙Z\subset\mathbb{R}^{l}italic_Z ⊂ blackboard_R start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT is the latent space, ΩdΩsuperscript𝑑\Omega\subset\mathbb{R}^{d}roman_Ω ⊂ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is the sample space and ΘDΘsuperscript𝐷\Theta\subset\mathbb{R}^{D}roman_Θ ⊂ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT is the parameter space. In this paper, we consider the following network structure

f(θ,z)=1Ni=1Naiσ(zbi),𝑓𝜃𝑧1𝑁superscriptsubscript𝑖1𝑁subscript𝑎𝑖𝜎𝑧subscript𝑏𝑖f(\theta,z)=\frac{1}{N}\sum_{i=1}^{N}a_{i}\sigma\Big{(}z-b_{i}\Big{)},italic_f ( italic_θ , italic_z ) = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_σ ( italic_z - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ,

where θ=(ai,bi)D𝜃subscript𝑎𝑖subscript𝑏𝑖superscript𝐷\theta=(a_{i},b_{i})\in\mathbb{R}^{D}italic_θ = ( italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, D=(l+1)N𝐷𝑙1𝑁D=(l+1)Nitalic_D = ( italic_l + 1 ) italic_N. Here N𝑁Nitalic_N is the number of hidden units (neurons). aisubscript𝑎𝑖a_{i}\in\mathbb{R}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R is the weight of unit i𝑖iitalic_i. bilsubscript𝑏𝑖superscript𝑙b_{i}\in\mathbb{R}^{l}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT is an offset (location variable). σ::𝜎\sigma\colon\mathbb{R}\rightarrow\mathbb{R}italic_σ : blackboard_R → blackboard_R is an activation function, which satisfies σ(0)=0𝜎00\sigma(0)=0italic_σ ( 0 ) = 0, 1σ(0)1𝜎01\in\partial\sigma(0)1 ∈ ∂ italic_σ ( 0 ). From now on, we assume that f𝑓fitalic_f is invertible, monotone, and is continuous w.r.t. both z𝑧zitalic_z and θ𝜃\thetaitalic_θ variables.

For example, let N=d=1𝑁𝑑1N=d=1italic_N = italic_d = 1, and b1=0subscript𝑏10b_{1}=0italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 0. Define a two layer neural network by

{tikzpicture}

The following neural network mapping functions have been widely used.

Example 1 (Linear).

Denote σ(x)=x𝜎𝑥𝑥\sigma(x)=xitalic_σ ( italic_x ) = italic_x. Consider

f(θ,z)=θz,θ+.formulae-sequence𝑓𝜃𝑧𝜃𝑧𝜃subscriptf(\theta,z)=\theta z,\quad\theta\in\mathbb{R}_{+}.italic_f ( italic_θ , italic_z ) = italic_θ italic_z , italic_θ ∈ blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT .
Example 2 (ReLU).

Denote σ(x)=max{x,0}𝜎𝑥𝑥0\sigma(x)=\max\{x,0\}italic_σ ( italic_x ) = roman_max { italic_x , 0 }. Consider

f(θ,z)=θmax{z,0},θ+.formulae-sequence𝑓𝜃𝑧𝜃𝑧0𝜃subscriptf(\theta,z)=\theta\max\{z,0\},\quad\theta\in\mathbb{R}_{+}.italic_f ( italic_θ , italic_z ) = italic_θ roman_max { italic_z , 0 } , italic_θ ∈ blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT .
Example 3 (Sigmoid).

Denote σ(x)=11+e2x𝜎𝑥11superscript𝑒2𝑥\sigma(x)=\frac{1}{1+e^{-2x}}italic_σ ( italic_x ) = divide start_ARG 1 end_ARG start_ARG 1 + italic_e start_POSTSUPERSCRIPT - 2 italic_x end_POSTSUPERSCRIPT end_ARG. Consider

f(θ,z)=θ1+e2z,θ+.formulae-sequence𝑓𝜃𝑧𝜃1superscript𝑒2𝑧𝜃subscriptf(\theta,z)=\frac{\theta}{1+e^{-2z}},\quad\theta\in\mathbb{R}_{+}.italic_f ( italic_θ , italic_z ) = divide start_ARG italic_θ end_ARG start_ARG 1 + italic_e start_POSTSUPERSCRIPT - 2 italic_z end_POSTSUPERSCRIPT end_ARG , italic_θ ∈ blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT .

In Section 4 (theoretical results) and Section 5 (numerical examples), we focus mainly on the case where l=d=1𝑙𝑑1l=d=1italic_l = italic_d = 1, D=2N𝐷2𝑁D=2Nitalic_D = 2 italic_N. And σ()𝜎\sigma(\cdot)italic_σ ( ⋅ ) is the ReLU activation function.

3.2. Neural mapping models and energies

In this subsection, we consider the following probability density functions generated by neural network mapping functions. We call them the neural mapping models.

Definition 1 (Neural mapping models).

Let us define a fixed input reference probability density pr𝒫(Z)={p(z)C(Z):Zpr(z)𝑑z=1,p(z)0}subscript𝑝normal-r𝒫𝑍conditional-set𝑝𝑧superscript𝐶𝑍formulae-sequencesubscript𝑍subscript𝑝normal-r𝑧differential-d𝑧1𝑝𝑧0p_{\mathrm{r}}\in\mathcal{P}(Z)=\Big{\{}p(z)\in C^{\infty}(Z)\colon\int_{Z}p_{% \mathrm{r}}(z)dz=1,~{}p(z)\geq 0\Big{\}}italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ∈ caligraphic_P ( italic_Z ) = { italic_p ( italic_z ) ∈ italic_C start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_Z ) : ∫ start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z = 1 , italic_p ( italic_z ) ≥ 0 }. Denote a probability density generated by a neural network mapping function by the pushforward operator:

p=fθ#pr𝒫(Ω),𝑝subscriptsubscript𝑓𝜃#subscript𝑝r𝒫Ωp={f_{\theta}}_{\#}p_{\mathrm{r}}\in\mathcal{P}(\Omega),italic_p = italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ∈ caligraphic_P ( roman_Ω ) ,

In other words, p𝑝pitalic_p satisfies the following Monge-Ampère equation by

p(f(θ,z))det(Dzf(θ,z))=pr(z),𝑝𝑓𝜃𝑧detsubscript𝐷𝑧𝑓𝜃𝑧subscript𝑝r𝑧p(f(\theta,z))\mathrm{det}(D_{z}f(\theta,z))=p_{\mathrm{r}}(z)\,,italic_p ( italic_f ( italic_θ , italic_z ) ) roman_det ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ) = italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) , (8)

where Dzf(θ,z)subscript𝐷𝑧𝑓𝜃𝑧D_{z}f(\theta,z)italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) is the Jacobian of the mapping function f(θ,z)𝑓𝜃𝑧f(\theta,z)italic_f ( italic_θ , italic_z ) w.r.t. variable z𝑧zitalic_z.

Definition 2 (Neural mapping energies).

Given an energy functional :𝒫(Ω)normal-:normal-→𝒫normal-Ω\mathcal{F}\colon\mathcal{P}(\Omega)\rightarrow\mathbb{R}caligraphic_F : caligraphic_P ( roman_Ω ) → blackboard_R, we can construct a neural mapping energy F:Θnormal-:𝐹normal-→normal-ΘF\colon\Theta\rightarrow\mathbb{R}italic_F : roman_Θ → blackboard_R by

F(θ)=(fθ#pr).𝐹𝜃subscriptsubscript𝑓𝜃#subscript𝑝rF(\theta)=\mathcal{F}({f_{\theta}}_{\#}p_{\mathrm{r}}).italic_F ( italic_θ ) = caligraphic_F ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) .

Many applications in machine learning and scientific computing can be cast into the following optimization problem

minθΘF(θ).subscript𝜃Θ𝐹𝜃\min_{\theta\in\Theta}F(\theta).roman_min start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT italic_F ( italic_θ ) .

Here, F𝐹Fitalic_F often measures the closeness between the neural mapping model and the target or data density distribution. Several concrete examples of neural mapping energies F𝐹Fitalic_F are given below. For simplicity of presentation, we often write the integration operator w.r.t. density prsubscript𝑝rp_{\mathrm{r}}italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT over domain Z𝑍Zitalic_Z by the expectation operator 𝔼zprsubscript𝔼similar-to𝑧subscript𝑝r\mathbb{E}_{z\sim p_{\mathrm{r}}}blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT. Later in Section 3.5, we provide several examples of the energy functional \mathcal{F}caligraphic_F including the potential, the interaction (E.g. maximum mean discrepancy ) and the internal (information entropy/divergence) functionals. They are commonly used in machine learning and optimal transport communities; see details in [3, Section 9].

To summarize, the neural mapping energies are functionals \mathcal{F}caligraphic_F written in terms of the mapping functions f(θ,z)𝑓𝜃𝑧f(\theta,z)italic_f ( italic_θ , italic_z ). This allows us to perform optimization on the finite dimensional space ΘΘ\Thetaroman_Θ instead of the infinite dimensional space 𝒫(Ω)𝒫Ω\mathcal{P}(\Omega)caligraphic_P ( roman_Ω ).

3.3. Neural mapping metric space

We next consider a mapping space parameterized by a neural mapping function f(θ,)𝑓𝜃f(\theta,\cdot)italic_f ( italic_θ , ⋅ ). We can measure the difference between two neural mapping functions by the L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT distance thanks to the following definition.

Definition 3 (Neural mapping distance).

Define a distance function DistW:Θ×Θnormal-:subscriptnormal-Distnormal-Wnormal-→normal-Θnormal-Θ\mathrm{Dist}_{\mathrm{W}}\colon\Theta\times\Theta\rightarrow\mathbb{R}roman_Dist start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT : roman_Θ × roman_Θ → blackboard_R as

DistW(fθ0#pr,fθ1#pr)2=Zf(θ0,z)f(θ1,z)2pr(z)𝑑z=m=1d𝔼zpr[fm(θ0,z)fm(θ1,z)2],subscriptDistWsuperscriptsubscriptsubscript𝑓superscript𝜃0#subscript𝑝rsubscriptsubscript𝑓superscript𝜃1#subscript𝑝r2subscript𝑍superscriptdelimited-∥∥𝑓superscript𝜃0𝑧𝑓superscript𝜃1𝑧2subscript𝑝r𝑧differential-d𝑧superscriptsubscript𝑚1𝑑subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscriptdelimited-∥∥subscript𝑓𝑚superscript𝜃0𝑧subscript𝑓𝑚superscript𝜃1𝑧2\begin{split}\mathrm{Dist}_{\mathrm{W}}({f_{\theta^{0}}}_{\#}p_{\mathrm{r}},{f% _{\theta^{1}}}_{\#}p_{\mathrm{r}})^{2}=&\int_{Z}\|f(\theta^{0},z)-f(\theta^{1}% ,z)\|^{2}p_{\mathrm{r}}(z)dz\\ =&\sum_{m=1}^{d}\mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}\|f_{m}(\theta^{0},z)-% f_{m}(\theta^{1},z)\|^{2}\Big{]},\end{split}start_ROW start_CELL roman_Dist start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT , italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = end_CELL start_CELL ∫ start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT ∥ italic_f ( italic_θ start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_z ) - italic_f ( italic_θ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , italic_z ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_CELL end_ROW start_ROW start_CELL = end_CELL start_CELL ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_θ start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , italic_z ) - italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_θ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , italic_z ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] , end_CELL end_ROW

where θ0superscript𝜃0\theta^{0}italic_θ start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT, θ1Θsuperscript𝜃1normal-Θ\theta^{1}\in\Thetaitalic_θ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∈ roman_Θ are two sets of neural network parameters and \|\cdot\|∥ ⋅ ∥ is the Euclidean norm in dsuperscript𝑑\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.

In the above definition, DistWsubscriptDistW\mathrm{Dist}_{\mathrm{W}}roman_Dist start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT represents a distance function for two given neural mapping functions f(θ0,)𝑓superscript𝜃0f(\theta^{0},\cdot)italic_f ( italic_θ start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , ⋅ ) and f(θ1,)𝑓superscript𝜃1f(\theta^{1},\cdot)italic_f ( italic_θ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , ⋅ ). In fact, the L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT distance between neural mapping functions induces a metric on neural network parameters. Similar Riemannian geometry for feed-forward neural networks is also studied in [31].

We next consider the Taylor expansion of the distance function. Let ΔθDΔ𝜃superscript𝐷\Delta\theta\in\mathbb{R}^{D}roman_Δ italic_θ ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT,

DistW(fθ+Δθ#pr,fθ#pr)2=m=1d𝔼zpr[fm(θ+Δθ,z)fm(θ,z)2]=m=1di=1Dj=1D𝔼zpr[θifm(θ,z)θjfm(θ,z)]ΔθiΔθj+o(Δθ2)=Δθ𝖳GW(θ)Δθ+o(Δθ2).subscriptDistWsuperscriptsubscriptsubscript𝑓𝜃Δ𝜃#subscript𝑝rsubscriptsubscript𝑓𝜃#subscript𝑝r2superscriptsubscript𝑚1𝑑subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscriptdelimited-∥∥subscript𝑓𝑚𝜃Δ𝜃𝑧subscript𝑓𝑚𝜃𝑧2superscriptsubscript𝑚1𝑑superscriptsubscript𝑖1𝐷superscriptsubscript𝑗1𝐷subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]subscriptsubscript𝜃𝑖subscript𝑓𝑚𝜃𝑧subscriptsubscript𝜃𝑗subscript𝑓𝑚𝜃𝑧Δsubscript𝜃𝑖Δsubscript𝜃𝑗𝑜superscriptdelimited-∥∥Δ𝜃2Δsuperscript𝜃𝖳subscript𝐺W𝜃Δ𝜃𝑜superscriptdelimited-∥∥Δ𝜃2\begin{split}&\mathrm{Dist}_{\mathrm{W}}({f_{\theta+\Delta\theta}}_{\#}p_{% \mathrm{r}},{f_{\theta}}_{\#}p_{\mathrm{r}})^{2}\\ =&\sum_{m=1}^{d}\mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}\|f_{m}(\theta+\Delta% \theta,z)-f_{m}(\theta,z)\|^{2}\Big{]}\\ =&\sum_{m=1}^{d}\sum_{i=1}^{D}\sum_{j=1}^{D}\mathbb{E}_{z\sim p_{\mathrm{r}}}% \Big{[}\partial_{\theta_{i}}f_{m}(\theta,z)\partial_{\theta_{j}}f_{m}(\theta,z% )\Big{]}\Delta\theta_{i}\Delta\theta_{j}+o(\|\Delta\theta\|^{2})\\ =&\Delta\theta^{\mathsf{T}}G_{\mathrm{W}}(\theta)\Delta\theta+o(\|\Delta\theta% \|^{2}).\end{split}start_ROW start_CELL end_CELL start_CELL roman_Dist start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ + roman_Δ italic_θ end_POSTSUBSCRIPT start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT , italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL = end_CELL start_CELL ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_θ + roman_Δ italic_θ , italic_z ) - italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_θ , italic_z ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] end_CELL end_ROW start_ROW start_CELL = end_CELL start_CELL ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_θ , italic_z ) ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_θ , italic_z ) ] roman_Δ italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_Δ italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_o ( ∥ roman_Δ italic_θ ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL = end_CELL start_CELL roman_Δ italic_θ start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) roman_Δ italic_θ + italic_o ( ∥ roman_Δ italic_θ ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) . end_CELL end_ROW

Here GWsubscript𝐺WG_{\mathrm{W}}italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT is a Gram-type matrix function. We summarize its definition below.

Definition 4 (Neural mapping metric).

Define a matrix function GW:ΘD×Dnormal-:subscript𝐺normal-Wnormal-→normal-Θsuperscript𝐷𝐷G_{\mathrm{W}}\colon\Theta\rightarrow\mathbb{R}^{D\times D}italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT : roman_Θ → blackboard_R start_POSTSUPERSCRIPT italic_D × italic_D end_POSTSUPERSCRIPT. Denote GW(θ)=(GW(θ)ij)1i,jDsubscript𝐺normal-W𝜃subscriptsubscript𝐺normal-Wsubscript𝜃𝑖𝑗formulae-sequence1𝑖𝑗𝐷G_{\mathrm{W}}(\theta)=(G_{\mathrm{W}}(\theta)_{ij})_{1\leq i,j\leq D}italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) = ( italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT 1 ≤ italic_i , italic_j ≤ italic_D end_POSTSUBSCRIPT, such that

GW(θ)ij=m=1d𝔼zpr[θifm(θ,z)θjfm(θ,z)].subscript𝐺Wsubscript𝜃𝑖𝑗superscriptsubscript𝑚1𝑑subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]subscriptsubscript𝜃𝑖subscript𝑓𝑚𝜃𝑧subscriptsubscript𝜃𝑗subscript𝑓𝑚𝜃𝑧G_{\mathrm{W}}(\theta)_{ij}=\sum_{m=1}^{d}\mathbb{E}_{z\sim p_{\mathrm{r}}}% \Big{[}\partial_{\theta_{i}}f_{m}(\theta,z)\partial_{\theta_{j}}f_{m}(\theta,z% )\Big{]}.italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_θ , italic_z ) ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_θ , italic_z ) ] .

We also write

GW(θ)=𝔼zpr[θf(θ,z)θf(θ,z)𝖳],subscript𝐺W𝜃subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]subscript𝜃𝑓𝜃𝑧subscript𝜃𝑓superscript𝜃𝑧𝖳G_{\mathrm{W}}(\theta)=\mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}\nabla_{\theta}% f(\theta,z)\nabla_{\theta}f(\theta,z)^{\mathsf{T}}\Big{]},italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) = blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT ] ,

where we denote θf(θ,z)=(θifm(θ,z))1iD,1mdD×dsubscriptnormal-∇𝜃𝑓𝜃𝑧subscriptsubscriptsubscript𝜃𝑖subscript𝑓𝑚𝜃𝑧formulae-sequence1𝑖𝐷1𝑚𝑑superscript𝐷𝑑\nabla_{\theta}f(\theta,z)=(\partial_{\theta_{i}}f_{m}(\theta,z))_{1\leq i\leq D% ,1\leq m\leq d}\in\mathbb{R}^{D\times d}∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) = ( ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_θ , italic_z ) ) start_POSTSUBSCRIPT 1 ≤ italic_i ≤ italic_D , 1 ≤ italic_m ≤ italic_d end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_d end_POSTSUPERSCRIPT.

From now on, we call (Θ,GW)Θsubscript𝐺W(\Theta,G_{\mathrm{W}})( roman_Θ , italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ) the neural mapping metric space. Here we always assume that GW(θ)subscript𝐺W𝜃G_{\mathrm{W}}(\theta)italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) is a positive definite matrix in D×Dsuperscript𝐷𝐷\mathbb{R}^{D\times D}blackboard_R start_POSTSUPERSCRIPT italic_D × italic_D end_POSTSUPERSCRIPT.

3.4. Neural mapping dynamics

In this subsection, we derive some analogies of Wasserstein gradient flows in the neural mapping metric space (Θ,GW)Θsubscript𝐺W(\Theta,G_{\mathrm{W}})( roman_Θ , italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ). Shortly, we apply them to define the neural mapping dynamics and compare them with their counterparts in L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT mapping metric space and L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT Wasserstein metric probability space. From now on, we assume that f𝑓fitalic_f is smooth w.r.t. parameter θ𝜃\thetaitalic_θ. This is not true for the ReLU activation function, which will be studied in detail in later sections.

The next proposition provides gradient operators of a function FC2(Θ;)𝐹superscript𝐶2ΘF\in C^{2}(\Theta;\mathbb{R})italic_F ∈ italic_C start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( roman_Θ ; blackboard_R ) in the neural mapping metric space (Θ,GW)Θsubscript𝐺W(\Theta,G_{\mathrm{W}})( roman_Θ , italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ).

Proposition 1 (Neural mapping gradient operators).

The gradient operator of F𝐹Fitalic_F in (Θ,GW)normal-Θsubscript𝐺normal-W(\Theta,G_{\mathrm{W}})( roman_Θ , italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ), gradWF(θ)=(gradWF(θ)k)k=1Dsubscriptnormal-gradnormal-W𝐹𝜃superscriptsubscriptsubscriptnormal-gradnormal-W𝐹subscript𝜃𝑘𝑘1𝐷\mathrm{grad}_{\mathrm{W}}F(\theta)=(\mathrm{grad}_{\mathrm{W}}F(\theta)_{k})_% {k=1}^{D}roman_grad start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT italic_F ( italic_θ ) = ( roman_grad start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT italic_F ( italic_θ ) start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, is given by

gradWF(θ)k=i=1DGW1(θ)kiθiF(θ).subscriptgradW𝐹subscript𝜃𝑘superscriptsubscript𝑖1𝐷subscriptsuperscript𝐺1Wsubscript𝜃𝑘𝑖subscriptsubscript𝜃𝑖𝐹𝜃\mathrm{grad}_{\mathrm{W}}F(\theta)_{k}=\sum_{i=1}^{D}G^{-1}_{\mathrm{W}}(% \theta)_{ki}\partial_{\theta_{i}}F(\theta).roman_grad start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT italic_F ( italic_θ ) start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) start_POSTSUBSCRIPT italic_k italic_i end_POSTSUBSCRIPT ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_F ( italic_θ ) .
Proof.

We briefly derive the gradient operator of F𝐹Fitalic_F in (Θ,GW)Θsubscript𝐺W(\Theta,G_{\mathrm{W}})( roman_Θ , italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ) below. Suppose θ(t)=θt𝜃𝑡subscript𝜃𝑡\theta(t)=\theta_{t}italic_θ ( italic_t ) = italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is a smooth curve passing through the point θ(0)=θ𝜃0𝜃\theta(0)=\thetaitalic_θ ( 0 ) = italic_θ. Consider a Taylor expansion of F(θt)𝐹subscript𝜃𝑡F(\theta_{t})italic_F ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) at t=0𝑡0t=0italic_t = 0 by

F(θt)=F(θ)+tddtF(θt)|t=0+o(t)=F(θ)+t(GW(θ)gradWF(θ),θ˙)+o(t),𝐹subscript𝜃𝑡𝐹𝜃evaluated-at𝑡𝑑𝑑𝑡𝐹subscript𝜃𝑡𝑡0𝑜𝑡𝐹𝜃𝑡subscript𝐺W𝜃subscriptgradW𝐹𝜃˙𝜃𝑜𝑡\begin{split}F(\theta_{t})=&F(\theta)+t\cdot\frac{d}{dt}F(\theta_{t})|_{t=0}+o% (t)\\ =&F(\theta)+t\cdot(G_{\mathrm{W}}(\theta)\cdot\mathrm{grad}_{\mathrm{W}}F(% \theta),\dot{\theta})+o(t),\end{split}start_ROW start_CELL italic_F ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = end_CELL start_CELL italic_F ( italic_θ ) + italic_t ⋅ divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG italic_F ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) | start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT + italic_o ( italic_t ) end_CELL end_ROW start_ROW start_CELL = end_CELL start_CELL italic_F ( italic_θ ) + italic_t ⋅ ( italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) ⋅ roman_grad start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT italic_F ( italic_θ ) , over˙ start_ARG italic_θ end_ARG ) + italic_o ( italic_t ) , end_CELL end_ROW (9)

where we denote ddtθt|t=0=θ˙evaluated-at𝑑𝑑𝑡subscript𝜃𝑡𝑡0˙𝜃\frac{d}{dt}\theta_{t}|_{t=0}=\dot{\theta}divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT = over˙ start_ARG italic_θ end_ARG. Comparing linear terms of t𝑡titalic_t in (9), we have

(GW(θ)gradWF(θ),θ˙)=ddtF(θt)|t=0=(θF(θ),θ˙),subscript𝐺W𝜃subscriptgradW𝐹𝜃˙𝜃evaluated-at𝑑𝑑𝑡𝐹subscript𝜃𝑡𝑡0subscript𝜃𝐹𝜃˙𝜃\begin{split}(G_{\mathrm{W}}(\theta)\cdot\mathrm{grad}_{\mathrm{W}}F(\theta),% \dot{\theta})=&\frac{d}{dt}F(\theta_{t})|_{t=0}\\ =&(\nabla_{\theta}F(\theta),\dot{\theta}),\end{split}start_ROW start_CELL ( italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) ⋅ roman_grad start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT italic_F ( italic_θ ) , over˙ start_ARG italic_θ end_ARG ) = end_CELL start_CELL divide start_ARG italic_d end_ARG start_ARG italic_d italic_t end_ARG italic_F ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) | start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL = end_CELL start_CELL ( ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_F ( italic_θ ) , over˙ start_ARG italic_θ end_ARG ) , end_CELL end_ROW

for any θ˙TθΘ=d˙𝜃subscript𝑇𝜃Θsuperscript𝑑\dot{\theta}\in T_{\theta}\Theta=\mathbb{R}^{d}over˙ start_ARG italic_θ end_ARG ∈ italic_T start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT roman_Θ = blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. Thus

gradWF(θ)=GW1(θ)θF(θ).subscriptgradW𝐹𝜃subscriptsuperscript𝐺1W𝜃subscript𝜃𝐹𝜃\mathrm{grad}_{\mathrm{W}}F(\theta)=G^{-1}_{\mathrm{W}}(\theta)\nabla_{\theta}% F(\theta).roman_grad start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT italic_F ( italic_θ ) = italic_G start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_F ( italic_θ ) .

We are ready to present the neural mapping gradient flow, which will be used for our first-order algorithm in neural mapping optimization problems.

Proposition 2 (Neural mapping gradient flows).

Consider an energy functional :𝒫(Ω)normal-:normal-→𝒫normal-Ω\mathcal{F}\colon\mathcal{P}(\Omega)\rightarrow\mathbb{R}caligraphic_F : caligraphic_P ( roman_Ω ) → blackboard_R. Then the gradient flow of function F(θ)=(fθ#pr)𝐹𝜃subscriptsubscript𝑓𝜃normal-#subscript𝑝normal-rF(\theta)=\mathcal{F}({f_{\theta}}_{\#}p_{\mathrm{r}})italic_F ( italic_θ ) = caligraphic_F ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) in (Θ,GW)normal-Θsubscript𝐺normal-W(\Theta,G_{\mathrm{W}})( roman_Θ , italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ) is given by

dθdt=gradWF(θ).𝑑𝜃𝑑𝑡subscriptgradW𝐹𝜃\frac{d\theta}{dt}=-\mathrm{grad}_{\mathrm{W}}F(\theta).divide start_ARG italic_d italic_θ end_ARG start_ARG italic_d italic_t end_ARG = - roman_grad start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT italic_F ( italic_θ ) . (10)

In particular,

dθidt=j=1Dm=1d(𝔼zpr[θf(θ,z)θf(θ,z)𝖳])ij1𝔼z~pr[xmδδp(p)(f(θ,z~))θjfm(θ,z~)],𝑑subscript𝜃𝑖𝑑𝑡superscriptsubscript𝑗1𝐷superscriptsubscript𝑚1𝑑superscriptsubscriptsubscript𝔼similar-to𝑧subscript𝑝rdelimited-[]subscript𝜃𝑓𝜃𝑧subscript𝜃𝑓superscript𝜃𝑧𝖳𝑖𝑗1subscript𝔼similar-to~𝑧subscript𝑝rdelimited-[]subscriptsubscript𝑥𝑚𝛿𝛿𝑝𝑝𝑓𝜃~𝑧subscriptsubscript𝜃𝑗subscript𝑓𝑚𝜃~𝑧\begin{split}\frac{d\theta_{i}}{dt}=&-\sum_{j=1}^{D}\sum_{m=1}^{d}\Big{(}% \mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}\nabla_{\theta}f(\theta,z)\nabla_{% \theta}f(\theta,z)^{\mathsf{T}}\Big{]}\Big{)}_{ij}^{-1}\cdot\\ &\hskip 56.9055pt\mathbb{E}_{\tilde{z}\sim p_{\mathrm{r}}}\Big{[}\nabla_{x_{m}% }\frac{\delta}{\delta p}\mathcal{F}(p)(f(\theta,\tilde{z}))\cdot\partial_{% \theta_{j}}f_{m}(\theta,\tilde{z})\Big{]},\end{split}start_ROW start_CELL divide start_ARG italic_d italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_d italic_t end_ARG = end_CELL start_CELL - ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ( blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT ] ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⋅ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL blackboard_E start_POSTSUBSCRIPT over~ start_ARG italic_z end_ARG ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_p end_ARG caligraphic_F ( italic_p ) ( italic_f ( italic_θ , over~ start_ARG italic_z end_ARG ) ) ⋅ ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_θ , over~ start_ARG italic_z end_ARG ) ] , end_CELL end_ROW

where δδp(x)𝛿𝛿𝑝𝑥\frac{\delta}{\delta p(x)}divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_p ( italic_x ) end_ARG is the L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT–first variation w.r.t. variable p(x)𝑝𝑥p(x)italic_p ( italic_x ), x=f(θ,z)𝑥𝑓𝜃𝑧x=f(\theta,z)italic_x = italic_f ( italic_θ , italic_z ).

Proof.

As the neural mapping metric is given in definition 4, it suffices to calculate the formula for the Euclidean gradient θjF(θ)subscriptsubscript𝜃𝑗𝐹𝜃\partial_{\theta_{j}}F(\theta)∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_F ( italic_θ ) as follows:

θjF(θ)=subscriptsubscript𝜃𝑗𝐹𝜃absent\displaystyle\partial_{\theta_{j}}F(\theta)=∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_F ( italic_θ ) = Ωθjρθ(x)δδp(ρθ)(x)dxsubscriptΩsubscriptsubscript𝜃𝑗subscript𝜌𝜃𝑥𝛿𝛿𝑝subscript𝜌𝜃𝑥𝑑𝑥\displaystyle\ \int_{\Omega}\partial_{\theta_{j}}\rho_{\theta}(x)\frac{\delta}% {\delta p}\mathcal{F}(\rho_{\theta})(x)dx∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_ρ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_p end_ARG caligraphic_F ( italic_ρ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) ( italic_x ) italic_d italic_x
=\displaystyle== Ωx[ρθ(x)θjf(θ,f(θ,)1(x)))]δδp(ρθ)(x)dx\displaystyle\ \int_{\Omega}-\nabla_{x}\cdot\left[\rho_{\theta}(x)\partial_{% \theta_{j}}f(\theta,f(\theta,\cdot)^{-1}(x)))\right]\frac{\delta}{\delta p}% \mathcal{F}(\rho_{\theta})(x)dx∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT - ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ [ italic_ρ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_θ , italic_f ( italic_θ , ⋅ ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_x ) ) ) ] divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_p end_ARG caligraphic_F ( italic_ρ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) ( italic_x ) italic_d italic_x
=\displaystyle== Ωθjf(θ,f(θ,)1(x))x(δδp(ρθ))(x)ρθ(x)dxsubscriptΩsubscriptsubscript𝜃𝑗𝑓𝜃𝑓superscript𝜃1𝑥subscript𝑥𝛿𝛿𝑝subscript𝜌𝜃𝑥subscript𝜌𝜃𝑥𝑑𝑥\displaystyle\ \int_{\Omega}\partial_{\theta_{j}}f(\theta,f(\theta,\cdot)^{-1}% (x))\cdot\nabla_{x}\left(\frac{\delta}{\delta p}\mathcal{F}(\rho_{\theta})% \right)(x)\rho_{\theta}(x)dx∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_θ , italic_f ( italic_θ , ⋅ ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_x ) ) ⋅ ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_p end_ARG caligraphic_F ( italic_ρ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) ) ( italic_x ) italic_ρ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ) italic_d italic_x
=\displaystyle== 𝔼zpr[θjf(θ,z)x(δδp(p))(f(θ,z))].subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]subscriptsubscript𝜃𝑗𝑓𝜃𝑧subscript𝑥𝛿𝛿𝑝𝑝𝑓𝜃𝑧\displaystyle\ \mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}\partial_{\theta_{j}}f(% \theta,z)\cdot\nabla_{x}\left(\frac{\delta}{\delta p}\mathcal{F}(p)\right)(f(% \theta,z))\Big{]}\,.blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ⋅ ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_p end_ARG caligraphic_F ( italic_p ) ) ( italic_f ( italic_θ , italic_z ) ) ] .

Here we denote ρθ=fθ#prsubscript𝜌𝜃subscript𝑓𝜃#subscript𝑝r\rho_{\theta}=f_{\theta\#}p_{\mathrm{r}}italic_ρ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT italic_θ # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT. ∎

3.5. Neural projected Wasserstein flows

The dynamics in parameter space can be formulated in terms of mappings and probability densities. For simplicity of discussion, we demonstrate that the neural mapping gradient flow is a projected Wasserstein gradient flow. Here the projection is from the full mapping space into a neural parameterized mapping space. Concretely, we present the following reformulations of equation (10), which are in terms of mapping functions and probability density functions. The proof is based on the gradient flow equation in proposition 2 and the application of the chain rule.

Proposition 3 (Neural projected Wasserstein gradient flows).

Dynamic (10) in term of mapping functions f(θ,z)=(fm(θ,z))m=1d𝑓𝜃𝑧superscriptsubscriptsubscript𝑓𝑚𝜃𝑧𝑚1𝑑f(\theta,z)=(f_{m}(\theta,z))_{m=1}^{d}italic_f ( italic_θ , italic_z ) = ( italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_θ , italic_z ) ) start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT leads to

tfm(θ(t),z)=i=1Dj=1Dn=1dθifm(θ,z)(𝔼z~pr[θf(θ,z~)θf(θ,z~)𝖳])ij1𝔼z~pr[xnδδp(x)(p)(f(θ,z~))θjfn(θ,z~)].𝑡subscript𝑓𝑚𝜃𝑡𝑧superscriptsubscript𝑖1𝐷superscriptsubscript𝑗1𝐷superscriptsubscript𝑛1𝑑subscriptsubscript𝜃𝑖subscript𝑓𝑚𝜃𝑧superscriptsubscriptsubscript𝔼similar-to~𝑧subscript𝑝rdelimited-[]subscript𝜃𝑓𝜃~𝑧subscript𝜃𝑓superscript𝜃~𝑧𝖳𝑖𝑗1subscript𝔼similar-to~𝑧subscript𝑝rdelimited-[]subscriptsubscript𝑥𝑛𝛿𝛿𝑝𝑥𝑝𝑓𝜃~𝑧subscriptsubscript𝜃𝑗subscript𝑓𝑛𝜃~𝑧\begin{split}\frac{\partial}{\partial t}f_{m}(\theta(t),z)=&-\sum_{i=1}^{D}% \sum_{j=1}^{D}\sum_{n=1}^{d}\partial_{\theta_{i}}f_{m}(\theta,z)\Big{(}\mathbb% {E}_{\tilde{z}\sim p_{\mathrm{r}}}\Big{[}\nabla_{\theta}f(\theta,\tilde{z})% \nabla_{\theta}f(\theta,\tilde{z})^{\mathsf{T}}\Big{]}\Big{)}_{ij}^{-1}\cdot\\ &\hskip 71.13188pt\mathbb{E}_{\tilde{z}\sim p_{\mathrm{r}}}\Big{[}\nabla_{x_{n% }}\frac{\delta}{\delta p(x)}\mathcal{F}(p)(f(\theta,\tilde{z}))\cdot\partial_{% \theta_{j}}f_{n}(\theta,\tilde{z})\Big{]}.\end{split}start_ROW start_CELL divide start_ARG ∂ end_ARG start_ARG ∂ italic_t end_ARG italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_θ ( italic_t ) , italic_z ) = end_CELL start_CELL - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_θ , italic_z ) ( blackboard_E start_POSTSUBSCRIPT over~ start_ARG italic_z end_ARG ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f ( italic_θ , over~ start_ARG italic_z end_ARG ) ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f ( italic_θ , over~ start_ARG italic_z end_ARG ) start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT ] ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⋅ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL blackboard_E start_POSTSUBSCRIPT over~ start_ARG italic_z end_ARG ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT divide start_ARG italic_δ end_ARG start_ARG italic_δ italic_p ( italic_x ) end_ARG caligraphic_F ( italic_p ) ( italic_f ( italic_θ , over~ start_ARG italic_z end_ARG ) ) ⋅ ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_θ , over~ start_ARG italic_z end_ARG ) ] . end_CELL end_ROW

We present several examples of neural mapping Wasserstein gradient flows from proposition 2.

Example 4 (Neural projected linear transport equation).

Consider a linear energy given by

(p)=ΩV(x)p(x)𝑑x.𝑝subscriptΩ𝑉𝑥𝑝𝑥differential-d𝑥\mathcal{F}(p)=\int_{\Omega}V(x)p(x)dx.caligraphic_F ( italic_p ) = ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT italic_V ( italic_x ) italic_p ( italic_x ) italic_d italic_x .

In this case, the neural projected gradient flow satisfies

dθdt=GW1(θ)𝔼z~pr[θV(f(θ,z~))].𝑑𝜃𝑑𝑡superscriptsubscript𝐺W1𝜃subscript𝔼similar-to~𝑧subscript𝑝rdelimited-[]subscript𝜃𝑉𝑓𝜃~𝑧\frac{d\theta}{dt}=-G_{\mathrm{W}}^{-1}(\theta)\cdot\mathbb{E}_{\tilde{z}\sim p% _{\mathrm{r}}}\Big{[}\nabla_{\theta}V(f(\theta,\tilde{z}))\Big{]}.divide start_ARG italic_d italic_θ end_ARG start_ARG italic_d italic_t end_ARG = - italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_θ ) ⋅ blackboard_E start_POSTSUBSCRIPT over~ start_ARG italic_z end_ARG ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_V ( italic_f ( italic_θ , over~ start_ARG italic_z end_ARG ) ) ] . (11)

In details,

dθidt=j=1D(𝔼zpr[θf(θ,z)θf(θ,z)𝖳])ij1𝔼z~pr[xV(f(θ,z~))θjf(θ,z~)].𝑑subscript𝜃𝑖𝑑𝑡superscriptsubscript𝑗1𝐷superscriptsubscriptsubscript𝔼similar-to𝑧subscript𝑝rdelimited-[]subscript𝜃𝑓𝜃𝑧subscript𝜃𝑓superscript𝜃𝑧𝖳𝑖𝑗1subscript𝔼similar-to~𝑧subscript𝑝rdelimited-[]subscript𝑥𝑉𝑓𝜃~𝑧subscriptsubscript𝜃𝑗𝑓𝜃~𝑧\frac{d\theta_{i}}{dt}=-\sum_{j=1}^{D}\Big{(}\mathbb{E}_{z\sim p_{\mathrm{r}}}% \Big{[}\nabla_{\theta}f(\theta,z)\nabla_{\theta}f(\theta,z)^{\mathsf{T}}\Big{]% }\Big{)}_{ij}^{-1}\cdot\mathbb{E}_{\tilde{z}\sim p_{\mathrm{r}}}\Big{[}\nabla_% {x}V(f(\theta,\tilde{z}))\cdot\partial_{\theta_{j}}f(\theta,\tilde{z})\Big{]}.divide start_ARG italic_d italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_d italic_t end_ARG = - ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ( blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT ] ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⋅ blackboard_E start_POSTSUBSCRIPT over~ start_ARG italic_z end_ARG ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_V ( italic_f ( italic_θ , over~ start_ARG italic_z end_ARG ) ) ⋅ ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_θ , over~ start_ARG italic_z end_ARG ) ] .
Example 5 (Neural projected interaction transport equation).

Consider an interaction energy given by

(p)=12ΩΩW(x1,x2)p(x1)p(x2)𝑑x1𝑑x2.𝑝12subscriptΩsubscriptΩ𝑊subscript𝑥1subscript𝑥2𝑝subscript𝑥1𝑝subscript𝑥2differential-dsubscript𝑥1differential-dsubscript𝑥2\mathcal{F}(p)=\frac{1}{2}\int_{\Omega}\int_{\Omega}W(x_{1},x_{2})p(x_{1})p(x_% {2})dx_{1}dx_{2}.caligraphic_F ( italic_p ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT italic_W ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) italic_p ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_p ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) italic_d italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_d italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT .

In this case, the neural mapping gradient flow satisfies

dθdt=12GW1(θ)𝔼(z1,z2)pr×pr[θW(f(θ,z1),f(θ,z2))].𝑑𝜃𝑑𝑡12superscriptsubscript𝐺W1𝜃subscript𝔼similar-tosubscript𝑧1subscript𝑧2subscript𝑝rsubscript𝑝rdelimited-[]subscript𝜃𝑊𝑓𝜃subscript𝑧1𝑓𝜃subscript𝑧2\frac{d\theta}{dt}=-\frac{1}{2}G_{\mathrm{W}}^{-1}(\theta)\cdot\mathbb{E}_{(z_% {1},z_{2})\sim p_{\mathrm{r}}\times p_{\mathrm{r}}}\Big{[}\nabla_{\theta}W(f(% \theta,z_{1}),f(\theta,z_{2}))\Big{]}.divide start_ARG italic_d italic_θ end_ARG start_ARG italic_d italic_t end_ARG = - divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_θ ) ⋅ blackboard_E start_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT × italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_W ( italic_f ( italic_θ , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_f ( italic_θ , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ) ] . (12)

In details,

dθidt=j=1D(𝔼zpr[θf(θ,z)θf(θ,z)𝖳])ij1𝔼(z1,z2)pr×pr[x1W(f(θ,z1),f(θ,z2))θjf(θ,z1)].𝑑subscript𝜃𝑖𝑑𝑡superscriptsubscript𝑗1𝐷superscriptsubscriptsubscript𝔼similar-to𝑧subscript𝑝rdelimited-[]subscript𝜃𝑓𝜃𝑧subscript𝜃𝑓superscript𝜃𝑧𝖳𝑖𝑗1subscript𝔼similar-tosubscript𝑧1subscript𝑧2subscript𝑝rsubscript𝑝rdelimited-[]subscriptsubscript𝑥1𝑊𝑓𝜃subscript𝑧1𝑓𝜃subscript𝑧2subscriptsubscript𝜃𝑗𝑓𝜃subscript𝑧1\begin{split}\frac{d\theta_{i}}{dt}=&-\sum_{j=1}^{D}\Big{(}\mathbb{E}_{z\sim p% _{\mathrm{r}}}\Big{[}\nabla_{\theta}f(\theta,z)\nabla_{\theta}f(\theta,z)^{% \mathsf{T}}\Big{]}\Big{)}_{ij}^{-1}\cdot\\ &\hskip 34.14322pt\mathbb{E}_{(z_{1},z_{2})\sim p_{\mathrm{r}}\times p_{% \mathrm{r}}}\Big{[}\nabla_{x_{1}}W(f(\theta,z_{1}),f(\theta,z_{2}))\cdot% \partial_{\theta_{j}}f(\theta,z_{1})\Big{]}.\end{split}start_ROW start_CELL divide start_ARG italic_d italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_d italic_t end_ARG = end_CELL start_CELL - ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ( blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT ] ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⋅ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL blackboard_E start_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT × italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_W ( italic_f ( italic_θ , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_f ( italic_θ , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ) ⋅ ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ] . end_CELL end_ROW
Example 6 (Neural projected negative entropy).

Consider a negative entropy functional given by

(p)=ΩU(p(x))𝑑x.𝑝subscriptΩ𝑈𝑝𝑥differential-d𝑥\mathcal{F}(p)=\int_{\Omega}U(p(x))dx.caligraphic_F ( italic_p ) = ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT italic_U ( italic_p ( italic_x ) ) italic_d italic_x .

In this case, the neural mapping gradient flow satisfies

dθdt=GW1(θ)𝔼zpr[θU^(pr(z)det(Dzf(θ,z)))],𝑑𝜃𝑑𝑡superscriptsubscript𝐺W1𝜃subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]subscript𝜃^𝑈subscript𝑝r𝑧detsubscript𝐷𝑧𝑓𝜃𝑧\frac{d\theta}{dt}=-G_{\mathrm{W}}^{-1}(\theta)\cdot\mathbb{E}_{z\sim p_{% \mathrm{r}}}\Big{[}\nabla_{\theta}\hat{U}(\frac{p_{\mathrm{r}}(z)}{\mathrm{det% }(D_{z}f(\theta,z))})\Big{]}\,,divide start_ARG italic_d italic_θ end_ARG start_ARG italic_d italic_t end_ARG = - italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_θ ) ⋅ blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over^ start_ARG italic_U end_ARG ( divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG start_ARG roman_det ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ) end_ARG ) ] , (13)

where U^(p)=U(p)/pnormal-^𝑈𝑝𝑈𝑝𝑝\hat{U}(p)=U(p)/pover^ start_ARG italic_U end_ARG ( italic_p ) = italic_U ( italic_p ) / italic_p. This is because:

(fθ#pr)=ΩU(p(f(θ,z)))𝑑f(θ,z)=ZU(pr(z)det(Dzf(θ,z)))det(Dzf(θ,z))pr(z)pr(z)𝑑z=𝔼zpr[U^(pr(z)det(Dzf(θ,z)))].subscriptsubscript𝑓𝜃#subscript𝑝rsubscriptΩ𝑈𝑝𝑓𝜃𝑧differential-d𝑓𝜃𝑧subscript𝑍𝑈subscript𝑝r𝑧detsubscript𝐷𝑧𝑓𝜃𝑧detsubscript𝐷𝑧𝑓𝜃𝑧subscript𝑝r𝑧subscript𝑝r𝑧differential-d𝑧subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]^𝑈subscript𝑝r𝑧detsubscript𝐷𝑧𝑓𝜃𝑧\begin{split}\mathcal{F}({f_{\theta}}_{\#}p_{\mathrm{r}})=&\int_{\Omega}U(p(f(% \theta,z)))df(\theta,z)\\ =&\int_{Z}U(\frac{p_{\mathrm{r}}(z)}{\mathrm{det}(D_{z}f(\theta,z))})\frac{% \mathrm{det}(D_{z}f(\theta,z))}{p_{\mathrm{r}}(z)}p_{\mathrm{r}}(z)dz\\ =&\mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}\hat{U}(\frac{p_{\mathrm{r}}(z)}{% \mathrm{det}(D_{z}f(\theta,z))})\Big{]}\,.\end{split}start_ROW start_CELL caligraphic_F ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) = end_CELL start_CELL ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT italic_U ( italic_p ( italic_f ( italic_θ , italic_z ) ) ) italic_d italic_f ( italic_θ , italic_z ) end_CELL end_ROW start_ROW start_CELL = end_CELL start_CELL ∫ start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT italic_U ( divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG start_ARG roman_det ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ) end_ARG ) divide start_ARG roman_det ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_CELL end_ROW start_ROW start_CELL = end_CELL start_CELL blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over^ start_ARG italic_U end_ARG ( divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG start_ARG roman_det ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ) end_ARG ) ] . end_CELL end_ROW

The choice U(p)=plog(p)𝑈𝑝𝑝𝑝U(p)=p\log(p)italic_U ( italic_p ) = italic_p roman_log ( italic_p ) and U^(p)=log(p)normal-^𝑈𝑝𝑝\hat{U}(p)=\log(p)over^ start_ARG italic_U end_ARG ( italic_p ) = roman_log ( italic_p ) corresponds to the negative entropy. This belongs to the family of internal energy. In details,

dθidt=j=1D(𝔼zpr[θf(θ,z)θf(θ,z)𝖳])ij1𝔼zpr[tr(Dzf(θ,z)1:θjDzf(θ,z))U^(pr(z)det(Dzf(θ,z)))pr(z)det(Dzf(θ,z))].\begin{split}\frac{d\theta_{i}}{dt}=&-\sum_{j=1}^{D}\Big{(}\mathbb{E}_{z\sim p% _{\mathrm{r}}}\Big{[}\nabla_{\theta}f(\theta,z)\nabla_{\theta}f(\theta,z)^{% \mathsf{T}}\Big{]}\Big{)}_{ij}^{-1}\cdot\\ &\hskip 34.14322pt\mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}-\mathrm{tr}\Big{(}D% _{z}f(\theta,z)^{-1}\colon\partial_{\theta_{j}}D_{z}f(\theta,z)\Big{)}\hat{U}^% {\prime}(\frac{p_{\mathrm{r}}(z)}{\mathrm{det}(D_{z}f(\theta,z))})\frac{p_{% \mathrm{r}}(z)}{\mathrm{det}(D_{z}f(\theta,z))}\Big{]}.\end{split}start_ROW start_CELL divide start_ARG italic_d italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_d italic_t end_ARG = end_CELL start_CELL - ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ( blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) start_POSTSUPERSCRIPT sansserif_T end_POSTSUPERSCRIPT ] ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⋅ end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ - roman_tr ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT : ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ) over^ start_ARG italic_U end_ARG start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG start_ARG roman_det ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ) end_ARG ) divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG start_ARG roman_det ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ) end_ARG ] . end_CELL end_ROW

Here we denote tr(A:B)=tr(AB)\mathrm{tr}(A\colon B)=\mathrm{tr}(AB)roman_tr ( italic_A : italic_B ) = roman_tr ( italic_A italic_B ), for matrices A𝐴Aitalic_A, Bd×d𝐵superscript𝑑𝑑B\in\mathbb{R}^{d\times d}italic_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT.

The above examples are projected Wasserstein gradient flows in neural mapping metric space. In particular, Examples 4, 5, 6 correspond to the following classical PDEs, respectively.

tp(t,x)=subscript𝑡𝑝𝑡𝑥absent\displaystyle\partial_{t}p(t,x)=∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) = x(p(t,x)xV(x)),subscript𝑥𝑝𝑡𝑥subscript𝑥𝑉𝑥\displaystyle\nabla_{x}\cdot\Big{(}p(t,x)\nabla_{x}V(x)\Big{)},∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( italic_p ( italic_t , italic_x ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_V ( italic_x ) ) , (14)
tp(t,x)=subscript𝑡𝑝𝑡𝑥absent\displaystyle\partial_{t}p(t,x)=∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) = x(p(t,x)ΩxW(x,y)p(t,y)𝑑y),subscript𝑥𝑝𝑡𝑥subscriptΩsubscript𝑥𝑊𝑥𝑦𝑝𝑡𝑦differential-d𝑦\displaystyle\nabla_{x}\cdot\Big{(}p(t,x)\int_{\Omega}\nabla_{x}W(x,y)p(t,y)dy% \Big{)},∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( italic_p ( italic_t , italic_x ) ∫ start_POSTSUBSCRIPT roman_Ω end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_W ( italic_x , italic_y ) italic_p ( italic_t , italic_y ) italic_d italic_y ) , (15)
tp(t,x)=subscript𝑡𝑝𝑡𝑥absent\displaystyle\partial_{t}p(t,x)=∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) = x(p(t,x)xU(p(t,x))).subscript𝑥𝑝𝑡𝑥subscript𝑥superscript𝑈𝑝𝑡𝑥\displaystyle\nabla_{x}\cdot\Big{(}p(t,x)\nabla_{x}U^{\prime}(p(t,x))\Big{)}.\,∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ⋅ ( italic_p ( italic_t , italic_x ) ∇ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_U start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_p ( italic_t , italic_x ) ) ) . (16)

The above dynamics include potential transport, interaction transport, and porous medium equations. The Fokker-Planck equation is a combination of the above first and third equations.

3.6. Algorithm

In this section, we discuss the implementations of gradient flows projected onto the parameter space. We apply the forward Euler discretization of the natural gradient flow (10). Let h>00h>0italic_h > 0 be the step size. Then the update is given by

θk+1=θkh(G~W(θk))1θF~(θk),superscript𝜃𝑘1superscript𝜃𝑘superscriptsubscript~𝐺Wsuperscript𝜃𝑘1subscript𝜃~𝐹superscript𝜃𝑘\theta^{k+1}=\theta^{k}-h\Big{(}\tilde{G}_{\mathrm{W}}(\theta^{k})\Big{)}^{-1}% \nabla_{\theta}\tilde{F}(\theta^{k})\,,italic_θ start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT = italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - italic_h ( over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over~ start_ARG italic_F end_ARG ( italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) , (17)

where G~W(θ)=(G~W(θ)ij)1i,jDD×Dsubscript~𝐺W𝜃subscriptsubscript~𝐺Wsubscript𝜃𝑖𝑗formulae-sequence1𝑖𝑗𝐷superscript𝐷𝐷\tilde{G}_{\mathrm{W}}(\theta)=(\tilde{G}_{\mathrm{W}}(\theta)_{ij})_{1\leq i,% j\leq D}\in\mathbb{R}^{D\times D}over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) = ( over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT 1 ≤ italic_i , italic_j ≤ italic_D end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_D end_POSTSUPERSCRIPT, θF~(θ)subscript𝜃~𝐹𝜃\nabla_{\theta}\tilde{F}(\theta)∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over~ start_ARG italic_F end_ARG ( italic_θ ) are empirical estimates of the matrix GWsubscript𝐺WG_{\mathrm{W}}italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT and the gradient F(θ)={θjF(θ)}j=1D𝐹𝜃superscriptsubscriptsubscriptsubscript𝜃𝑗𝐹𝜃𝑗1𝐷\nabla F(\theta)=\{\partial_{\theta_{j}}F(\theta)\}_{j=1}^{D}∇ italic_F ( italic_θ ) = { ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_F ( italic_θ ) } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, respectively. In details, if (zi)l=1Mprsimilar-tosuperscriptsubscriptsubscript𝑧𝑖𝑙1𝑀subscript𝑝r(z_{i})_{l=1}^{M}\sim p_{\mathrm{r}}( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT, where M𝑀Mitalic_M is the number of empirical samples, then

G~W(θ)ij=1Ml=1Mm=1dθifm(zl,θ)θjfm(zl,θ).subscript~𝐺Wsubscript𝜃𝑖𝑗1𝑀superscriptsubscript𝑙1𝑀superscriptsubscript𝑚1𝑑subscriptsubscript𝜃𝑖subscript𝑓𝑚subscript𝑧𝑙𝜃subscriptsubscript𝜃𝑗subscript𝑓𝑚subscript𝑧𝑙𝜃\tilde{G}_{\mathrm{W}}(\theta)_{ij}=\frac{1}{M}\sum_{l=1}^{M}\sum_{m=1}^{d}% \partial_{\theta_{i}}f_{m}(z_{l},\theta)\partial_{\theta_{j}}f_{m}(z_{l},% \theta)\,.over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_θ ) ∂ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_θ ) .

In practice, the condition number of G~W(θ)subscript~𝐺W𝜃\tilde{G}_{\mathrm{W}}(\theta)over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) could be very large and it is more stable to use instead the pseudoinverse of G~W(θ)subscript~𝐺W𝜃\tilde{G}_{\mathrm{W}}(\theta)over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) in (17). Therefore, the update is

θk+1=θkhG~W(θ)θF~(θk).superscript𝜃𝑘1superscript𝜃𝑘subscript~𝐺Wsuperscript𝜃subscript𝜃~𝐹superscript𝜃𝑘\theta^{k+1}=\theta^{k}-h\tilde{G}_{\mathrm{W}}(\theta)^{\dagger}\nabla_{% \theta}\tilde{F}(\theta^{k})\,.italic_θ start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT = italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - italic_h over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over~ start_ARG italic_F end_ARG ( italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) .

When the reference measure is a one-dimensional standard Gaussian distribution, GW(θ)subscript𝐺W𝜃G_{\mathrm{W}}(\theta)italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) can be explicitly computed for our choice of neural network. In this case, we have

θk+1=θkhGW(θ)θF~(θk).superscript𝜃𝑘1superscript𝜃𝑘subscript𝐺Wsuperscript𝜃subscript𝜃~𝐹superscript𝜃𝑘\theta^{k+1}=\theta^{k}-hG_{\mathrm{W}}(\theta)^{\dagger}\nabla_{\theta}\tilde% {F}(\theta^{k})\,.italic_θ start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT = italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - italic_h italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over~ start_ARG italic_F end_ARG ( italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) .

We summarize the above explicitly update formulas below.

  Input: Initial parameters θD𝜃superscript𝐷\theta\in\mathbb{R}^{D}italic_θ ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT; stepsize h>00h>0italic_h > 0, total number of steps L𝐿Litalic_L, samples {zi}i=1Mprsimilar-tosuperscriptsubscriptsubscript𝑧𝑖𝑖1𝑀subscript𝑝r\{z_{i}\}_{i=1}^{M}\sim p_{\mathrm{r}}{ italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT for estimating G~W(θ)subscript~𝐺W𝜃\tilde{G}_{\mathrm{W}}(\theta)over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) and θF~(θ)subscript𝜃~𝐹𝜃\nabla_{\theta}\tilde{F}(\theta)∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over~ start_ARG italic_F end_ARG ( italic_θ ).
  
  for k=1,2,,L𝑘12𝐿k=1,2,\ldots,Litalic_k = 1 , 2 , … , italic_L  do
     
θk+1=θkhG~W(θ)θF~(θk); (when GW(θ) is unknown)superscript𝜃𝑘1superscript𝜃𝑘subscript~𝐺Wsuperscript𝜃subscript𝜃~𝐹superscript𝜃𝑘 (when GW(θ) is unknown)\theta^{k+1}=\theta^{k}-h\tilde{G}_{\mathrm{W}}(\theta)^{\dagger}\nabla_{% \theta}\tilde{F}(\theta^{k});\textrm{\quad(when $G_{\mathrm{W}}(\theta)$ is % unknown)}italic_θ start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT = italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - italic_h over~ start_ARG italic_G end_ARG start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over~ start_ARG italic_F end_ARG ( italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) ; (when italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) is unknown)
or
     
θk+1=θkhGW(θ)θF~(θk); (when GW(θ) is known)superscript𝜃𝑘1superscript𝜃𝑘subscript𝐺Wsuperscript𝜃subscript𝜃~𝐹superscript𝜃𝑘 (when GW(θ) is known)\theta^{k+1}=\theta^{k}-hG_{\mathrm{W}}(\theta)^{\dagger}\nabla_{\theta}\tilde% {F}(\theta^{k});\textrm{\quad(when $G_{\mathrm{W}}(\theta)$ is known)}italic_θ start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT = italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - italic_h italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over~ start_ARG italic_F end_ARG ( italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) ; (when italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) is known)
  end for
Algorithm 1 Projected Wasserstein gradient flows

4. Numerical analysis on neural network projected gradient flows

In this section, we establish theoretical guarantees for the performance of the neural projected dynamics. We start by deriving an analytic formula for the inverse of the neural mapping metric of a special ReLU family in section 4.1. Based on the closed-form projected dynamics equations, we can establish the truncated error analysis for the projected dynamics in section 4.2. The analysis of truncated error for general dynamics is presented in section 4.3.

4.1. Analytic formula for the inverse of neural mapping metric

In this section, we consider the following special case of the ReLU model in 1D. We first rewrite the neural network mapping function into the following form:

f(θ,z)=1Ni=1Naiσ(zbi),σ(z)={0,z<0,z,z0.f(\theta,z)=\frac{1}{N}\sum_{i=1}^{N}a_{i}\sigma(z-b_{i}),\quad\sigma(z)=\left% \{\begin{aligned} &0,\quad z<0,\\ &z,\quad z\geq 0.\end{aligned}\right.italic_f ( italic_θ , italic_z ) = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_σ ( italic_z - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_σ ( italic_z ) = { start_ROW start_CELL end_CELL start_CELL 0 , italic_z < 0 , end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL italic_z , italic_z ≥ 0 . end_CELL end_ROW (18)

In particular, we combine ai,bisubscript𝑎𝑖subscript𝑏𝑖a_{i},b_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT into one parameter in the 1D case. Under this reparameterization, aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTs represent the slopes of each ReLU component and bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTs are the intercepts. To make the last assumption on this ReLU network mapping function which facilitates the analytic formula of the neural mapping metric, we require all the slope parameters to stay non-negative, i.e. ai0subscript𝑎𝑖0a_{i}\geq 0italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ 0. Although this is an artificial assumption to enforce analyticity, it is natural in the sense that positive slope parameters induce monotone mapping function. Meanwhile, solutions of the Monge problems in 1D are known to be monotone. In fig. 1, we plot a typical ReLU mapping function.

Refer to caption
Figure 1. ReLU network mapping function considered in this section. The figure plots a typical monotone map parameterized by the ReLU network where the parameter aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is required to be positive.

We start with the analytic formula for the neural mapping metric, assuming the reference measure is given by pr()subscript𝑝rp_{\mathrm{r}}(\cdot)italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( ⋅ ) with associated cumulative distribution function 𝔉0()subscript𝔉0\mathfrak{F}_{0}(\cdot)fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( ⋅ ).

Proposition 4 (Neural mapping metric of two-layer ReLU network).

The neural mapping metric of the two-layer ReLU network with reference measure pr()subscript𝑝normal-rnormal-⋅p_{\mathrm{r}}(\cdot)italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( ⋅ ) is given as

GW=1N2(GWbbGWbw(GWbw)TGWww),subscript𝐺W1superscript𝑁2matrixsuperscriptsubscript𝐺W𝑏𝑏superscriptsubscript𝐺W𝑏𝑤superscriptsuperscriptsubscript𝐺W𝑏𝑤𝑇superscriptsubscript𝐺W𝑤𝑤\displaystyle G_{\mathrm{W}}=\frac{1}{N^{2}}\begin{pmatrix}G_{\mathrm{W}}^{bb}% &G_{\mathrm{W}}^{bw}\\ \left(G_{\mathrm{W}}^{bw}\right)^{T}&G_{\mathrm{W}}^{ww}\end{pmatrix},italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( start_ARG start_ROW start_CELL italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b italic_b end_POSTSUPERSCRIPT end_CELL start_CELL italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b italic_w end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL ( italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b italic_w end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_CELL start_CELL italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_w italic_w end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ) , (19)
GWbb=(a12(1𝔉0(b1))a1a2(1𝔉0(b2))a1aN(1𝔉0(bN))a1a2(1𝔉0(b2))a22(1𝔉0(b2))a2aN(1𝔉0(bN))a1aN1(1𝔉0(bN1))a2aN1(1𝔉0(bN1))aNaN1(1𝔉0(bN))a1aN(1𝔉0(bN))a2aN(1𝔉0(bN))aN2(1𝔉0(bN))),superscriptsubscript𝐺W𝑏𝑏matrixsuperscriptsubscript𝑎121subscript𝔉0subscript𝑏1subscript𝑎1subscript𝑎21subscript𝔉0subscript𝑏2subscript𝑎1subscript𝑎𝑁1subscript𝔉0subscript𝑏𝑁subscript𝑎1subscript𝑎21subscript𝔉0subscript𝑏2superscriptsubscript𝑎221subscript𝔉0subscript𝑏2subscript𝑎2subscript𝑎𝑁1subscript𝔉0subscript𝑏𝑁subscript𝑎1subscript𝑎𝑁11subscript𝔉0subscript𝑏𝑁1subscript𝑎2subscript𝑎𝑁11subscript𝔉0subscript𝑏𝑁1subscript𝑎𝑁subscript𝑎𝑁11subscript𝔉0subscript𝑏𝑁missing-subexpressionsubscript𝑎1subscript𝑎𝑁1subscript𝔉0subscript𝑏𝑁subscript𝑎2subscript𝑎𝑁1subscript𝔉0subscript𝑏𝑁superscriptsubscript𝑎𝑁21subscript𝔉0subscript𝑏𝑁\displaystyle G_{\mathrm{W}}^{bb}=\begin{pmatrix}a_{1}^{2}\left(1-\mathfrak{F}% _{0}\left(b_{1}\right)\right)&a_{1}a_{2}\left(1-\mathfrak{F}_{0}\left(b_{2}% \right)\right)&\cdots&a_{1}a_{N}\left(1-\mathfrak{F}_{0}\left(b_{N}\right)% \right)\\ a_{1}a_{2}\left(1-\mathfrak{F}_{0}\left(b_{2}\right)\right)&a_{2}^{2}\left(1-% \mathfrak{F}_{0}\left(b_{2}\right)\right)&\cdots&a_{2}a_{N}\left(1-\mathfrak{F% }_{0}\left(b_{N}\right)\right)\\ \vdots&\vdots&\ddots&\vdots\\ a_{1}a_{N-1}\left(1-\mathfrak{F}_{0}\left(b_{N-1}\right)\right)&a_{2}a_{N-1}% \left(1-\mathfrak{F}_{0}\left(b_{N-1}\right)\right)&\cdots&a_{N}a_{N-1}\left(1% -\mathfrak{F}_{0}\left(b_{N}\right)\right)\\ \\ a_{1}a_{N}\left(1-\mathfrak{F}_{0}\left(b_{N}\right)\right)&a_{2}a_{N}\left(1-% \mathfrak{F}_{0}\left(b_{N}\right)\right)&\cdots&a_{N}^{2}\left(1-\mathfrak{F}% _{0}\left(b_{N}\right)\right)\end{pmatrix},italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b italic_b end_POSTSUPERSCRIPT = ( start_ARG start_ROW start_CELL italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ) end_CELL start_CELL italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ) end_CELL start_CELL ⋯ end_CELL start_CELL italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ) end_CELL start_CELL italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ) end_CELL start_CELL ⋯ end_CELL start_CELL italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋱ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ) ) end_CELL start_CELL italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ) ) end_CELL start_CELL ⋯ end_CELL start_CELL italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL end_CELL end_ROW start_ROW start_CELL italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) ) end_CELL start_CELL italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) ) end_CELL start_CELL ⋯ end_CELL start_CELL italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) ) end_CELL end_ROW end_ARG ) ,
GWba=(a1b1(zb1)pr(z)𝑑za1b2(zb2)pr(z)𝑑za1bN(zbN)pr(z)𝑑za2b2(zb1)pr(z)𝑑za2b2(zb2)pr(z)𝑑za2bN(zbN)pr(z)𝑑zaNbN(zb1)pr(z)𝑑zaNbN(zb2)pr(z)𝑑zaNbN(zbN)pr(z)𝑑z),superscriptsubscript𝐺W𝑏𝑎matrixsubscript𝑎1superscriptsubscriptsubscript𝑏1𝑧subscript𝑏1subscript𝑝r𝑧differential-d𝑧subscript𝑎1superscriptsubscriptsubscript𝑏2𝑧subscript𝑏2subscript𝑝r𝑧differential-d𝑧subscript𝑎1superscriptsubscriptsubscript𝑏𝑁𝑧subscript𝑏𝑁subscript𝑝r𝑧differential-d𝑧missing-subexpressionsubscript𝑎2superscriptsubscriptsubscript𝑏2𝑧subscript𝑏1subscript𝑝r𝑧differential-d𝑧subscript𝑎2superscriptsubscriptsubscript𝑏2𝑧subscript𝑏2subscript𝑝r𝑧differential-d𝑧subscript𝑎2superscriptsubscriptsubscript𝑏𝑁𝑧subscript𝑏𝑁subscript𝑝r𝑧differential-d𝑧subscript𝑎𝑁superscriptsubscriptsubscript𝑏𝑁𝑧subscript𝑏1subscript𝑝r𝑧differential-d𝑧subscript𝑎𝑁superscriptsubscriptsubscript𝑏𝑁𝑧subscript𝑏2subscript𝑝r𝑧differential-d𝑧subscript𝑎𝑁superscriptsubscriptsubscript𝑏𝑁𝑧subscript𝑏𝑁subscript𝑝r𝑧differential-d𝑧\displaystyle G_{\mathrm{W}}^{ba}=-\begin{pmatrix}a_{1}\int_{b_{1}}^{\infty}% \left(z-b_{1}\right)p_{\mathrm{r}}(z)dz&a_{1}\int_{b_{2}}^{\infty}\left(z-b_{2% }\right)p_{\mathrm{r}}(z)dz&\cdots&a_{1}\int_{b_{N}}^{\infty}\left(z-b_{N}% \right)p_{\mathrm{r}}(z)dz\\ \\ a_{2}\int_{b_{2}}^{\infty}\left(z-b_{1}\right)p_{\mathrm{r}}(z)dz&a_{2}\int_{b% _{2}}^{\infty}\left(z-b_{2}\right)p_{\mathrm{r}}(z)dz&\cdots&a_{2}\int_{b_{N}}% ^{\infty}\left(z-b_{N}\right)p_{\mathrm{r}}(z)dz\\ \vdots&\vdots&\ddots&\vdots\\ a_{N}\int_{b_{N}}^{\infty}\left(z-b_{1}\right)p_{\mathrm{r}}(z)dz&a_{N}\int_{b% _{N}}^{\infty}\left(z-b_{2}\right)p_{\mathrm{r}}(z)dz&\cdots&a_{N}\int_{b_{N}}% ^{\infty}\left(z-b_{N}\right)p_{\mathrm{r}}(z)dz\end{pmatrix},italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b italic_a end_POSTSUPERSCRIPT = - ( start_ARG start_ROW start_CELL italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_CELL start_CELL italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_CELL start_CELL ⋯ end_CELL start_CELL italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_CELL end_ROW start_ROW start_CELL end_CELL end_ROW start_ROW start_CELL italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_CELL start_CELL italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_CELL start_CELL ⋯ end_CELL start_CELL italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋱ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_CELL start_CELL italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_CELL start_CELL ⋯ end_CELL start_CELL italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_CELL end_ROW end_ARG ) ,
(GWaa)ij=max{bi,bj}(zbj)(zbi)pr(z)𝑑z.subscriptsuperscriptsubscript𝐺W𝑎𝑎𝑖𝑗superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑗𝑧subscript𝑏𝑗𝑧subscript𝑏𝑖subscript𝑝r𝑧differential-d𝑧\displaystyle\left(G_{\mathrm{W}}^{aa}\right)_{ij}=\int_{\max\{b_{i},b_{j}\}}^% {\infty}\left(z-b_{j}\right)\left(z-b_{i}\right)p_{\mathrm{r}}(z)dz.( italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a italic_a end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = ∫ start_POSTSUBSCRIPT roman_max { italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( italic_z - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z .
Proof.

We first calculate the derivative of the neural network map f(θ,z)𝑓𝜃𝑧f(\theta,z)italic_f ( italic_θ , italic_z ) w.r.t. network parameters θ𝜃\thetaitalic_θ

bif(θ,z)={ 0,z<bi,aiN,z>bi,aif(θ,z)=1Nσ(zbi),\displaystyle\partial_{b_{i}}f(\theta,z)=\left\{\begin{aligned} &\ 0,\quad z<b% _{i},\\ &-\frac{a_{i}}{N},\quad z>b_{i},\end{aligned}\right.\quad\partial_{a_{i}}f(% \theta,z)=\frac{1}{N}\sigma\left(z-b_{i}\right),∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) = { start_ROW start_CELL end_CELL start_CELL 0 , italic_z < italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL - divide start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG , italic_z > italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , end_CELL end_ROW ∂ start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG italic_σ ( italic_z - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , (20)

while the value at the singular point bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT does not exist and can be omitted from the measure-theoretical perspective. According to definition 4, one can evaluate different blocks of the metric tensor as the following integral

(GWbb)ijsubscriptsuperscriptsubscript𝐺W𝑏𝑏𝑖𝑗\displaystyle\left(G_{\mathrm{W}}^{bb}\right)_{ij}( italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b italic_b end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT =bif(θ,z)bjf(θ,z)pr(z)dz=aiajN2(1𝔉0(max{bi,bj})),absentsubscriptsubscriptsubscript𝑏𝑖𝑓𝜃𝑧subscriptsubscript𝑏𝑗𝑓𝜃𝑧subscript𝑝r𝑧𝑑𝑧subscript𝑎𝑖subscript𝑎𝑗superscript𝑁21subscript𝔉0subscript𝑏𝑖subscript𝑏𝑗\displaystyle=\int_{\mathbb{R}}\partial_{b_{i}}f(\theta,z)\partial_{b_{j}}f(% \theta,z)p_{\mathrm{r}}(z)dz=\frac{a_{i}a_{j}}{N^{2}}\left(1-\mathfrak{F}_{0}% \left(\max\{b_{i},b_{j}\}\right)\right),= ∫ start_POSTSUBSCRIPT blackboard_R end_POSTSUBSCRIPT ∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z = divide start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( roman_max { italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } ) ) ,
(GWba)ijsubscriptsuperscriptsubscript𝐺W𝑏𝑎𝑖𝑗\displaystyle\left(G_{\mathrm{W}}^{ba}\right)_{ij}( italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b italic_a end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT =bif(θ,z)ajf(θ,z)pr(z)dz=aiN2max{bi,bj}(zbj)pr(z)𝑑z,absentsubscriptsubscriptsubscript𝑏𝑖𝑓𝜃𝑧subscriptsubscript𝑎𝑗𝑓𝜃𝑧subscript𝑝r𝑧𝑑𝑧subscript𝑎𝑖superscript𝑁2superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑗𝑧subscript𝑏𝑗subscript𝑝r𝑧differential-d𝑧\displaystyle=\int_{\mathbb{R}}\partial_{b_{i}}f(\theta,z)\partial_{a_{j}}f(% \theta,z)p_{\mathrm{r}}(z)dz=-\frac{a_{i}}{N^{2}}\int_{\max\{b_{i},b_{j}\}}^{% \infty}\left(z-b_{j}\right)p_{\mathrm{r}}(z)dz,= ∫ start_POSTSUBSCRIPT blackboard_R end_POSTSUBSCRIPT ∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ∂ start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z = - divide start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∫ start_POSTSUBSCRIPT roman_max { italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z ,
(GWaa)ijsubscriptsuperscriptsubscript𝐺W𝑎𝑎𝑖𝑗\displaystyle\left(G_{\mathrm{W}}^{aa}\right)_{ij}( italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a italic_a end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT =aif(θ,z)ajf(θ,z)pr(z)dz=1N2max{bi,bj}(zbj)(zbi)pr(z)𝑑z.absentsubscriptsubscriptsubscript𝑎𝑖𝑓𝜃𝑧subscriptsubscript𝑎𝑗𝑓𝜃𝑧subscript𝑝r𝑧𝑑𝑧1superscript𝑁2superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑗𝑧subscript𝑏𝑗𝑧subscript𝑏𝑖subscript𝑝r𝑧differential-d𝑧\displaystyle=\int_{\mathbb{R}}\partial_{a_{i}}f(\theta,z)\partial_{a_{j}}f(% \theta,z)p_{\mathrm{r}}(z)dz=\frac{1}{N^{2}}\int_{\max\{b_{i},b_{j}\}}^{\infty% }\left(z-b_{j}\right)\left(z-b_{i}\right)p_{\mathrm{r}}(z)dz.= ∫ start_POSTSUBSCRIPT blackboard_R end_POSTSUBSCRIPT ∂ start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ∂ start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z = divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∫ start_POSTSUBSCRIPT roman_max { italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( italic_z - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z .

For general reference measure pr()subscript𝑝rp_{\mathrm{r}}(\cdot)italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( ⋅ ), the matrix elements of the GWba,GWaasuperscriptsubscript𝐺W𝑏𝑎superscriptsubscript𝐺W𝑎𝑎G_{\mathrm{W}}^{ba},G_{\mathrm{W}}^{aa}italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b italic_a end_POSTSUPERSCRIPT , italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a italic_a end_POSTSUPERSCRIPT relate to the first and second moments of the measure which may not have an analytic formula. Here, we consider a special neural mapping model with a Gaussian reference measure, thus rendering the metric with analytic elements.

Corollary 1.

With the same setting as proposition 4 and Gaussian reference measure, the matrix element of the neural mapping metric can be written analytically as

(GWba)ij=pr(bi)bj(1𝔉0(bi)),bi>bj.formulae-sequencesubscriptsuperscriptsubscript𝐺W𝑏𝑎𝑖𝑗subscript𝑝rsubscript𝑏𝑖subscript𝑏𝑗1subscript𝔉0subscript𝑏𝑖subscript𝑏𝑖subscript𝑏𝑗\displaystyle\left(G_{\mathrm{W}}^{ba}\right)_{ij}=p_{\mathrm{r}}(b_{i})-b_{j}% (1-\mathfrak{F}_{0}(b_{i})),\quad b_{i}>b_{j}.( italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b italic_a end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT > italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT . (21)
(GWaa)ij=bibj(1𝔉0(bi))bjpr(bi)+(1𝔉0(bi)),bi>bj.formulae-sequencesubscriptsuperscriptsubscript𝐺W𝑎𝑎𝑖𝑗subscript𝑏𝑖subscript𝑏𝑗1subscript𝔉0subscript𝑏𝑖subscript𝑏𝑗subscript𝑝rsubscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖subscript𝑏𝑖subscript𝑏𝑗\displaystyle\left(G_{\mathrm{W}}^{aa}\right)_{ij}=b_{i}b_{j}(1-\mathfrak{F}_{% 0}(b_{i}))-b_{j}p_{\mathrm{r}}(b_{i})+(1-\mathfrak{F}_{0}(b_{i})),\quad b_{i}>% b_{j}.( italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_a italic_a end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) - italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT > italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT .

The other half of the elements can be obtained via switching bi,bjsubscript𝑏𝑖subscript𝑏𝑗b_{i},b_{j}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT.

Proof.

The proof is obtained by elementary integration calculation

bi(zbj)pr(z)𝑑zsuperscriptsubscriptsubscript𝑏𝑖𝑧subscript𝑏𝑗subscript𝑝r𝑧differential-d𝑧\displaystyle\ \int_{b_{i}}^{\infty}\left(z-b_{j}\right)p_{\mathrm{r}}(z)dz∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z (22)
=\displaystyle== pr(z)|bibj(1𝔉0(bi))=pr(bi)bj(1𝔉0(bi)),evaluated-atsubscript𝑝r𝑧subscript𝑏𝑖subscript𝑏𝑗1subscript𝔉0subscript𝑏𝑖subscript𝑝rsubscript𝑏𝑖subscript𝑏𝑗1subscript𝔉0subscript𝑏𝑖\displaystyle\ p_{\mathrm{r}}(z)\Big{|}_{\infty}^{b_{i}}-b_{j}(1-\mathfrak{F}_% {0}(b_{i}))=p_{\mathrm{r}}(b_{i})-b_{j}(1-\mathfrak{F}_{0}(b_{i})),italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) | start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) = italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ,
bi(zbj)(zbi)pr(z)𝑑zsuperscriptsubscriptsubscript𝑏𝑖𝑧subscript𝑏𝑗𝑧subscript𝑏𝑖subscript𝑝r𝑧differential-d𝑧\displaystyle\ \int_{b_{i}}^{\infty}\left(z-b_{j}\right)\left(z-b_{i}\right)p_% {\mathrm{r}}(z)dz∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( italic_z - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z
=\displaystyle== bibj(1𝔉0(bi))(bi+bj)pr(bi)+bipr(bi)+(1𝔉0(bi))subscript𝑏𝑖subscript𝑏𝑗1subscript𝔉0subscript𝑏𝑖subscript𝑏𝑖subscript𝑏𝑗subscript𝑝rsubscript𝑏𝑖subscript𝑏𝑖subscript𝑝rsubscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖\displaystyle\ b_{i}b_{j}(1-\mathfrak{F}_{0}(b_{i}))-(b_{i}+b_{j})p_{\mathrm{r% }}(b_{i})+b_{i}p_{\mathrm{r}}(b_{i})+(1-\mathfrak{F}_{0}(b_{i}))italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) - ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )
=\displaystyle== bibj(1𝔉0(bi))bjpr(bi)+(1𝔉0(bi)).subscript𝑏𝑖subscript𝑏𝑗1subscript𝔉0subscript𝑏𝑖subscript𝑏𝑗subscript𝑝rsubscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖\displaystyle\ b_{i}b_{j}(1-\mathfrak{F}_{0}(b_{i}))-b_{j}p_{\mathrm{r}}(b_{i}% )+(1-\mathfrak{F}_{0}(b_{i})).italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) - italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) .

Now, we focus on the upper right corner GWbbsuperscriptsubscript𝐺W𝑏𝑏G_{\mathrm{W}}^{bb}italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b italic_b end_POSTSUPERSCRIPT of the neural mapping metric. We will establish an analytical formula for the inverse of this matrix.

Theorem 2 (Analytic inverse of the neural mapping metric).

The inverse matrix of the GWbbsuperscriptsubscript𝐺normal-W𝑏𝑏G_{\mathrm{W}}^{bb}italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b italic_b end_POSTSUPERSCRIPT block in proposition 4 can be written analytically as

1N2(GW1(𝐛))ij={1ai2(1𝔉0(bi)𝔉0(bi1)+1𝔉0(bi+1)𝔉0(bi)),i=j1,N,1ai2(1𝔉0(bN)𝔉0(bN1)+11𝔉0(bN)),i=j=N,1ai21𝔉0(b2)𝔉0(b1),i=j=1,1aiai11𝔉0(bi)𝔉0(bi1),j=i1,1aiai+11𝔉0(bi+1)𝔉0(bi),j=i+1,0,o.w.\displaystyle\frac{1}{N^{2}}\left(G_{\mathrm{W}}^{-1}(\mathbf{b})\right)_{ij}=% \left\{\begin{aligned} \frac{1}{a_{i}^{2}}\left(\frac{1}{\mathfrak{F}_{0}(b_{i% })-\mathfrak{F}_{0}(b_{i-1})}+\frac{1}{\mathfrak{F}_{0}(b_{i+1})-\mathfrak{F}_% {0}(b_{i})}\right),\quad i=j\neq 1,N,\\ \frac{1}{a_{i}^{2}}\left(\frac{1}{\mathfrak{F}_{0}(b_{N})-\mathfrak{F}_{0}(b_{% N-1})}+\frac{1}{1-\mathfrak{F}_{0}(b_{N})}\right),\quad i=j=N,\\ \frac{1}{a_{i}^{2}}\frac{1}{\mathfrak{F}_{0}(b_{2})-\mathfrak{F}_{0}(b_{1})},% \quad i=j=1,\\ -\frac{1}{a_{i}a_{i-1}}\frac{1}{\mathfrak{F}_{0}(b_{i})-\mathfrak{F}_{0}(b_{i-% 1})},\quad j=i-1,\\ -\frac{1}{a_{i}a_{i+1}}\frac{1}{\mathfrak{F}_{0}(b_{i+1})-\mathfrak{F}_{0}(b_{% i})},\quad j=i+1,\\ 0,\qquad\qquad o.w.\end{aligned}\right.divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_b ) ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = { start_ROW start_CELL divide start_ARG 1 end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) end_ARG + divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG ) , italic_i = italic_j ≠ 1 , italic_N , end_CELL end_ROW start_ROW start_CELL divide start_ARG 1 end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ) end_ARG + divide start_ARG 1 end_ARG start_ARG 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) end_ARG ) , italic_i = italic_j = italic_N , end_CELL end_ROW start_ROW start_CELL divide start_ARG 1 end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG , italic_i = italic_j = 1 , end_CELL end_ROW start_ROW start_CELL - divide start_ARG 1 end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT end_ARG divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) end_ARG , italic_j = italic_i - 1 , end_CELL end_ROW start_ROW start_CELL - divide start_ARG 1 end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_ARG divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG , italic_j = italic_i + 1 , end_CELL end_ROW start_ROW start_CELL 0 , italic_o . italic_w . end_CELL end_ROW (23)
Proof.

First, we decompose the neural mapping metric into the following matrix product

GW=1N2D(1𝔉0(b1)1𝔉0(b2)1𝔉0(bN)1𝔉0(b2)1𝔉0(b2)1𝔉0(bN)1𝔉0(bN)1𝔉0(bN)1𝔉0(bN))D,subscript𝐺W1superscript𝑁2𝐷matrix1subscript𝔉0subscript𝑏11subscript𝔉0subscript𝑏21subscript𝔉0subscript𝑏𝑁1subscript𝔉0subscript𝑏21subscript𝔉0subscript𝑏21subscript𝔉0subscript𝑏𝑁1subscript𝔉0subscript𝑏𝑁1subscript𝔉0subscript𝑏𝑁1subscript𝔉0subscript𝑏𝑁𝐷G_{\mathrm{W}}=\frac{1}{N^{2}}D\begin{pmatrix}1-\mathfrak{F}_{0}(b_{1})&1-% \mathfrak{F}_{0}(b_{2})&\cdots&1-\mathfrak{F}_{0}(b_{N})\\ 1-\mathfrak{F}_{0}(b_{2})&1-\mathfrak{F}_{0}(b_{2})&\cdots&1-\mathfrak{F}_{0}(% b_{N})\\ \vdots&\vdots&\ddots&\vdots\\ 1-\mathfrak{F}_{0}(b_{N})&1-\mathfrak{F}_{0}(b_{N})&\cdots&1-\mathfrak{F}_{0}(% b_{N})\\ \end{pmatrix}D,italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG italic_D ( start_ARG start_ROW start_CELL 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_CELL start_CELL 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_CELL start_CELL ⋯ end_CELL start_CELL 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_CELL start_CELL 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_CELL start_CELL ⋯ end_CELL start_CELL 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋱ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) end_CELL start_CELL 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) end_CELL start_CELL ⋯ end_CELL start_CELL 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG ) italic_D , (24)

where D=diag(a1,a2,,aN)𝐷diagsubscript𝑎1subscript𝑎2subscript𝑎𝑁D=\operatorname{diag}(a_{1},a_{2},\cdots,a_{N})italic_D = roman_diag ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ⋯ , italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) is a diagonal matrix. Then, it is direct to check that the middle matrix has the following tri-diagonal analytic inverse below:

(1𝔉0(b2)𝔉0(b1)1𝔉0(b2)𝔉0(b1)001𝔉0(b2)𝔉0(b1)1𝔉0(b2)𝔉0(b1)+1𝔉0(b3)𝔉0(b2)1𝔉0(b3)F(b2)00001F(bN)F(bN1)+11F(bN)).matrix1subscript𝔉0subscript𝑏2subscript𝔉0subscript𝑏11subscript𝔉0subscript𝑏2subscript𝔉0subscript𝑏1001subscript𝔉0subscript𝑏2subscript𝔉0subscript𝑏11subscript𝔉0subscript𝑏2subscript𝔉0subscript𝑏11subscript𝔉0subscript𝑏3subscript𝔉0subscript𝑏21subscript𝔉0subscript𝑏3𝐹subscript𝑏200001𝐹subscript𝑏𝑁𝐹subscript𝑏𝑁111𝐹subscript𝑏𝑁\begin{pmatrix}\frac{1}{\mathfrak{F}_{0}(b_{2})-\mathfrak{F}_{0}(b_{1})}&-% \frac{1}{\mathfrak{F}_{0}(b_{2})-\mathfrak{F}_{0}(b_{1})}&0&\cdots&0\\ -\frac{1}{\mathfrak{F}_{0}(b_{2})-\mathfrak{F}_{0}(b_{1})}&\frac{1}{\mathfrak{% F}_{0}(b_{2})-\mathfrak{F}_{0}(b_{1})}+\frac{1}{\mathfrak{F}_{0}(b_{3})-% \mathfrak{F}_{0}(b_{2})}&-\frac{1}{\mathfrak{F}_{0}(b_{3})-F(b_{2})}&\cdots&0% \\ \vdots&\vdots&\vdots&\ddots&\vdots\\ 0&0&0&\cdots&\frac{1}{F(b_{N})-F(b_{N-1})}+\frac{1}{1-F(b_{N})}\end{pmatrix}.( start_ARG start_ROW start_CELL divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG end_CELL start_CELL - divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG end_CELL start_CELL 0 end_CELL start_CELL ⋯ end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL - divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG end_CELL start_CELL divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG + divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_ARG end_CELL start_CELL - divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ) - italic_F ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_ARG end_CELL start_CELL ⋯ end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋱ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL ⋯ end_CELL start_CELL divide start_ARG 1 end_ARG start_ARG italic_F ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) - italic_F ( italic_b start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ) end_ARG + divide start_ARG 1 end_ARG start_ARG 1 - italic_F ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) end_ARG end_CELL end_ROW end_ARG ) . (25)

Multiplying this matrix with the inverse of the diagonal matrix D𝐷Ditalic_D on both sides concludes this proof. ∎

This analytic form of the inverse metric will be used intensively in the next subsection to prove the consistency of the numerical scheme based on the ReLU neural network.

4.2. Truncated error analysis of the neural projected Wasserstein gradient flows based on analytic formula

In this section, we perform the numerical analysis of the neural mapping projected Wasserstein flows introduced in section 3.5 based on the analytic formula in section 4.1. Because of the analytic inverse of the neural mapping metric, the right-hand side of the Wasserstein projected gradient flow can be calculated explicitly, and one can thus talk about its consistency and order of accuracy following the same spirit as classical numerical analysis. We perform this derivation for the Wasserstein projected gradient flows of the potential functional explicitly.

Let us first recall that the formula for neural projected Wasserstein gradient flow is given by

dθdt=GW1(θ)𝔼z~pr[θV(f(θ,z~))].𝑑𝜃𝑑𝑡superscriptsubscript𝐺W1𝜃subscript𝔼similar-to~𝑧subscript𝑝rdelimited-[]subscript𝜃𝑉𝑓𝜃~𝑧\frac{d\theta}{dt}=-G_{\mathrm{W}}^{-1}(\theta)\cdot\mathbb{E}_{\tilde{z}\sim p% _{\mathrm{r}}}\Big{[}\nabla_{\theta}V(f(\theta,\tilde{z}))\Big{]}.divide start_ARG italic_d italic_θ end_ARG start_ARG italic_d italic_t end_ARG = - italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_θ ) ⋅ blackboard_E start_POSTSUBSCRIPT over~ start_ARG italic_z end_ARG ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_V ( italic_f ( italic_θ , over~ start_ARG italic_z end_ARG ) ) ] . (26)

We have the following analytic formula for the projected gradient flow in the ReLU network model that we introduced in section 4.1.

Proposition 5 (Wasserstein gradient flow of potential functionals in ReLU network).

The projected potential flow in the ReLU network model eq. 18 has the following form:

b˙isubscript˙𝑏𝑖\displaystyle\dot{b}_{i}over˙ start_ARG italic_b end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT =Nai[𝔼zpr[V(f(b,z))𝟏[bi,bi+1]]𝔉0(bi+1)𝔉0(bi)𝔼zpr[V(f(b,z))𝟏[bi1,bi]]𝔉0(bi)𝔉0(bi1)],i1,N,formulae-sequenceabsent𝑁subscript𝑎𝑖delimited-[]subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscript𝑉𝑓𝑏𝑧subscript1subscript𝑏𝑖subscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscript𝑉𝑓𝑏𝑧subscript1subscript𝑏𝑖1subscript𝑏𝑖subscript𝔉0subscript𝑏𝑖subscript𝔉0subscript𝑏𝑖1𝑖1𝑁\displaystyle=\frac{N}{a_{i}}\left[\frac{\mathbb{E}_{z\sim p_{\mathrm{r}}}[V^{% \prime}(f(b,z))\mathbf{1}_{[b_{i},b_{i+1}]}]}{\mathfrak{F}_{0}(b_{i+1})-% \mathfrak{F}_{0}(b_{i})}-\frac{\mathbb{E}_{z\sim p_{\mathrm{r}}}[V^{\prime}(f(% b,z))\mathbf{1}_{[b_{i-1},b_{i}]}]}{\mathfrak{F}_{0}(b_{i})-\mathfrak{F}_{0}(b% _{i-1})}\right],\quad i\neq 1,N,= divide start_ARG italic_N end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG [ divide start_ARG blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_z ) ) bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ] end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG - divide start_ARG blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_z ) ) bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ] end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) end_ARG ] , italic_i ≠ 1 , italic_N , (27)
b˙Nsubscript˙𝑏𝑁\displaystyle\dot{b}_{N}over˙ start_ARG italic_b end_ARG start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT =NaN[𝔼zpr[V(f(b,z))𝟏[bN,)]1𝔉0(bN)𝔼zpr[V(f(b,z))𝟏[bN1,bN]]𝔉0(bN)𝔉0(bN1)],absent𝑁subscript𝑎𝑁delimited-[]subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscript𝑉𝑓𝑏𝑧subscript1subscript𝑏𝑁1subscript𝔉0subscript𝑏𝑁subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscript𝑉𝑓𝑏𝑧subscript1subscript𝑏𝑁1subscript𝑏𝑁subscript𝔉0subscript𝑏𝑁subscript𝔉0subscript𝑏𝑁1\displaystyle=\frac{N}{a_{N}}\left[\frac{\mathbb{E}_{z\sim p_{\mathrm{r}}}[V^{% \prime}(f(b,z))\mathbf{1}_{[b_{N},\infty)}]}{1-\mathfrak{F}_{0}(b_{N})}-\frac{% \mathbb{E}_{z\sim p_{\mathrm{r}}}[V^{\prime}(f(b,z))\mathbf{1}_{[b_{N-1},b_{N}% ]}]}{\mathfrak{F}_{0}(b_{N})-\mathfrak{F}_{0}(b_{N-1})}\right],= divide start_ARG italic_N end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT end_ARG [ divide start_ARG blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_z ) ) bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT , ∞ ) end_POSTSUBSCRIPT ] end_ARG start_ARG 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) end_ARG - divide start_ARG blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_z ) ) bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ] end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ) end_ARG ] ,
b˙1subscript˙𝑏1\displaystyle\dot{b}_{1}over˙ start_ARG italic_b end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT =Na1𝔼zpr[V(f(b,z))𝟏[b1,b2]]𝔉0(b2)𝔉0(b1).absent𝑁subscript𝑎1subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscript𝑉𝑓𝑏𝑧subscript1subscript𝑏1subscript𝑏2subscript𝔉0subscript𝑏2subscript𝔉0subscript𝑏1\displaystyle=\frac{N}{a_{1}}\frac{\mathbb{E}_{z\sim p_{\mathrm{r}}}[V^{\prime% }(f(b,z))\mathbf{1}_{[b_{1},b_{2}]}]}{\mathfrak{F}_{0}(b_{2})-\mathfrak{F}_{0}% (b_{1})}.= divide start_ARG italic_N end_ARG start_ARG italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG divide start_ARG blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_z ) ) bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ] end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG .

Using the trapezoid rule to calculate the integration gives the following spatial discretization, which can be used to simulate the projected gradient flow:

b˙i=N2ai(V(f(b,bi+1))V(f(b,bi1))).subscript˙𝑏𝑖𝑁2subscript𝑎𝑖superscript𝑉𝑓𝑏subscript𝑏𝑖1superscript𝑉𝑓𝑏subscript𝑏𝑖1\dot{b}_{i}=\frac{N}{2a_{i}}\left(V^{\prime}(f(b,b_{i+1}))-V^{\prime}(f(b,b_{i% -1}))\right).over˙ start_ARG italic_b end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = divide start_ARG italic_N end_ARG start_ARG 2 italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ( italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ) - italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) ) ) . (28)
Proof.

It suffices to calculate the gradient of the linear potential functional in this model. Let us start with the calculation of the functional form of the potential energy in the ReLU network mapping model as follows

𝔼xfb#pr[V(x)]=subscript𝔼similar-to𝑥subscript𝑓𝑏#subscript𝑝rdelimited-[]𝑉𝑥absent\displaystyle\mathbb{E}_{x\sim f_{b\#}p_{\mathrm{r}}}[V(x)]=blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_f start_POSTSUBSCRIPT italic_b # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V ( italic_x ) ] = 𝔼zpr[V(f(b,z))],subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]𝑉𝑓𝑏𝑧\displaystyle\ \mathbb{E}_{z\sim p_{\mathrm{r}}}[V(f(b,z))],blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V ( italic_f ( italic_b , italic_z ) ) ] , (29)

where we use the change of the integration variable above. Therefore, the gradient of this functional w.r.t. b𝑏bitalic_b can be simplified to

bi𝔼zpr[V(f(b,z))]=subscriptsubscript𝑏𝑖subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]𝑉𝑓𝑏𝑧absent\displaystyle\partial_{b_{i}}\mathbb{E}_{z\sim p_{\mathrm{r}}}[V(f(b,z))]=∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V ( italic_f ( italic_b , italic_z ) ) ] = 𝔼zpr[biV(f(b,z))]=aiN𝔼zpr[V(f(b,z))𝟏[bi,)(z)],subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]subscriptsubscript𝑏𝑖𝑉𝑓𝑏𝑧subscript𝑎𝑖𝑁subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscript𝑉𝑓𝑏𝑧subscript1subscript𝑏𝑖𝑧\displaystyle\ \mathbb{E}_{z\sim p_{\mathrm{r}}}[\partial_{b_{i}}V(f(b,z))]=-% \frac{a_{i}}{N}\mathbb{E}_{z\sim p_{\mathrm{r}}}[V^{\prime}(f(b,z))\mathbf{1}_% {[b_{i},\infty)}(z)],blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_V ( italic_f ( italic_b , italic_z ) ) ] = - divide start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_z ) ) bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ∞ ) end_POSTSUBSCRIPT ( italic_z ) ] , (30)

where we use 𝟏Asubscript1𝐴\mathbf{1}_{A}bold_1 start_POSTSUBSCRIPT italic_A end_POSTSUBSCRIPT to denote the characteristic function on the interval A𝐴Aitalic_A. Now, plugging this result into the projected gradient flow eq. 26 with the analytical formula for the inverse matrix GW1superscriptsubscript𝐺W1G_{\mathrm{W}}^{-1}italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT in theorem 2, we obtain

bi˙=˙subscript𝑏𝑖absent\displaystyle\dot{b_{i}}=over˙ start_ARG italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = N2ai2(1𝔉0(bi)𝔉0(bi1)+1𝔉0(bi+1)𝔉0(bi))aiN𝔼zpr[V(f(b,z))𝟏[bi,)(z)]superscript𝑁2superscriptsubscript𝑎𝑖21subscript𝔉0subscript𝑏𝑖subscript𝔉0subscript𝑏𝑖11subscript𝔉0subscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖subscript𝑎𝑖𝑁subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscript𝑉𝑓𝑏𝑧subscript1subscript𝑏𝑖𝑧\displaystyle\ \frac{N^{2}}{a_{i}^{2}}\left(\frac{1}{\mathfrak{F}_{0}(b_{i})-% \mathfrak{F}_{0}(b_{i-1})}+\frac{1}{\mathfrak{F}_{0}(b_{i+1})-\mathfrak{F}_{0}% (b_{i})}\right)\frac{a_{i}}{N}\mathbb{E}_{z\sim p_{\mathrm{r}}}[V^{\prime}(f(b% ,z))\mathbf{1}_{[b_{i},\infty)}(z)]divide start_ARG italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) end_ARG + divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG ) divide start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_z ) ) bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ∞ ) end_POSTSUBSCRIPT ( italic_z ) ] (31)
N2aiai11𝔉0(bi)𝔉0(bi1)ai1N𝔼zpr[V(f(b,z))𝟏[bi1,)(z)]superscript𝑁2subscript𝑎𝑖subscript𝑎𝑖11subscript𝔉0subscript𝑏𝑖subscript𝔉0subscript𝑏𝑖1subscript𝑎𝑖1𝑁subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscript𝑉𝑓𝑏𝑧subscript1subscript𝑏𝑖1𝑧\displaystyle\ -\frac{N^{2}}{a_{i}a_{i-1}}\frac{1}{\mathfrak{F}_{0}(b_{i})-% \mathfrak{F}_{0}(b_{i-1})}\frac{a_{i-1}}{N}\mathbb{E}_{z\sim p_{\mathrm{r}}}[V% ^{\prime}(f(b,z))\mathbf{1}_{[b_{i-1},\infty)}(z)]- divide start_ARG italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT end_ARG divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) end_ARG divide start_ARG italic_a start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_z ) ) bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT , ∞ ) end_POSTSUBSCRIPT ( italic_z ) ]
N2aiai+11𝔉0(bi+1)𝔉0(bi)ai+1N𝔼zpr[V(f(b,z))𝟏[bi+1,)(z)]superscript𝑁2subscript𝑎𝑖subscript𝑎𝑖11subscript𝔉0subscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖subscript𝑎𝑖1𝑁subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscript𝑉𝑓𝑏𝑧subscript1subscript𝑏𝑖1𝑧\displaystyle\ -\frac{N^{2}}{a_{i}a_{i+1}}\frac{1}{\mathfrak{F}_{0}(b_{i+1})-% \mathfrak{F}_{0}(b_{i})}\frac{a_{i+1}}{N}\mathbb{E}_{z\sim p_{\mathrm{r}}}[V^{% \prime}(f(b,z))\mathbf{1}_{[b_{i+1},\infty)}(z)]- divide start_ARG italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_ARG divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG divide start_ARG italic_a start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_z ) ) bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT , ∞ ) end_POSTSUBSCRIPT ( italic_z ) ]
=\displaystyle== Nai[𝔼zpr[V(f(b,z))𝟏[bi,bi+1](z)]𝔉0(bi+1)𝔉0(bi)𝔼zpr[V(f(b,z))𝟏[bi,bi+1](z)]𝔉0(bi)𝔉0(bi1)].𝑁subscript𝑎𝑖delimited-[]subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscript𝑉𝑓𝑏𝑧subscript1subscript𝑏𝑖subscript𝑏𝑖1𝑧subscript𝔉0subscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscript𝑉𝑓𝑏𝑧subscript1subscript𝑏𝑖subscript𝑏𝑖1𝑧subscript𝔉0subscript𝑏𝑖subscript𝔉0subscript𝑏𝑖1\displaystyle\ \frac{N}{a_{i}}\left[\frac{\mathbb{E}_{z\sim p_{\mathrm{r}}}[V^% {\prime}(f(b,z))\mathbf{1}_{[b_{i},b_{i+1}]}(z)]}{\mathfrak{F}_{0}(b_{i+1})-% \mathfrak{F}_{0}(b_{i})}-\frac{\mathbb{E}_{z\sim p_{\mathrm{r}}}[V^{\prime}(f(% b,z))\mathbf{1}_{[b_{i},b_{i+1}]}(z)]}{\mathfrak{F}_{0}(b_{i})-\mathfrak{F}_{0% }(b_{i-1})}\right].divide start_ARG italic_N end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG [ divide start_ARG blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_z ) ) bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ( italic_z ) ] end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG - divide start_ARG blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_z ) ) bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ( italic_z ) ] end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) end_ARG ] .

Taking a close look at the terms inside the brackets, one finds that they are calculating the average value of Vsuperscript𝑉V^{\prime}italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT inside the intervals [bi1,bi],[bi,bi+1]subscript𝑏𝑖1subscript𝑏𝑖subscript𝑏𝑖subscript𝑏𝑖1[b_{i-1},b_{i}],[b_{i},b_{i+1}][ italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] , [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ] weighted by the base distribution pr()subscript𝑝rp_{\mathrm{r}}(\cdot)italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( ⋅ ). Lastly, in order to complete the spatial discretization, one needs to choose a quadrature rule to calculate the integration in the above formula. One example is the trapezoid rule:

𝔼zpr[V(f(b,z))𝟏[bi,bi+1](z)](𝔉0(bi+1)𝔉0(bi))V(f(b,bi))+V(f(b,bi+1))2,subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]superscript𝑉𝑓𝑏𝑧subscript1subscript𝑏𝑖subscript𝑏𝑖1𝑧subscript𝔉0subscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖superscript𝑉𝑓𝑏subscript𝑏𝑖superscript𝑉𝑓𝑏subscript𝑏𝑖12\mathbb{E}_{z\sim p_{\mathrm{r}}}[V^{\prime}(f(b,z))\mathbf{1}_{[b_{i},b_{i+1}% ]}(z)]\approx(\mathfrak{F}_{0}(b_{i+1})-\mathfrak{F}_{0}(b_{i}))\frac{V^{% \prime}(f(b,b_{i}))+V^{\prime}(f(b,b_{i+1}))}{2},blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_z ) ) bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ( italic_z ) ] ≈ ( fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) divide start_ARG italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) + italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ) end_ARG start_ARG 2 end_ARG ,

which provides the desired discretization. Special attention should be paid to the boundary node b1,bNsubscript𝑏1subscript𝑏𝑁b_{1},b_{N}italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT to obtain their corresponding evolution equation and discretization. ∎

Given this spatial discretization, we can analyze the order of consistency of it, which is treated in the following proposition.

Proposition 6 (Consistency of the projected gradient flow).

Assume potential functional satisfies V′′<subscriptnormsuperscript𝑉normal-′′\left\|V^{\prime\prime}\right\|_{\infty}<\infty∥ italic_V start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT < ∞. The spatial discretization eq. 28 is of first-order accuracy both in the mapping and the density coordinates.

Proof.

We prove this statement from two directions, i.e. consistency in the space of mapping distribution and consistency in the space of mapping function. We have

tf(b(t),z)=subscript𝑡𝑓𝑏𝑡𝑧absent\displaystyle\partial_{t}f(b(t),z)=∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_f ( italic_b ( italic_t ) , italic_z ) = b˙Tbf(b,z)superscript˙𝑏𝑇subscript𝑏𝑓𝑏𝑧\displaystyle\ \dot{b}^{T}\partial_{b}f(b,z)over˙ start_ARG italic_b end_ARG start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∂ start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT italic_f ( italic_b , italic_z ) (32)
=\displaystyle== i=1NN2ai(V(f(b,bi+1))V(f(b,bi1)))aiN𝟏[bi,)(z)superscriptsubscript𝑖1𝑁𝑁2subscript𝑎𝑖superscript𝑉𝑓𝑏subscript𝑏𝑖1superscript𝑉𝑓𝑏subscript𝑏𝑖1subscript𝑎𝑖𝑁subscript1subscript𝑏𝑖𝑧\displaystyle\ -\sum_{i=1}^{N}\frac{N}{2a_{i}}\left(V^{\prime}(f(b,b_{i+1}))-V% ^{\prime}(f(b,b_{i-1}))\right)\frac{a_{i}}{N}\mathbf{1}_{[b_{i},\infty)}(z)- ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG italic_N end_ARG start_ARG 2 italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ( italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ) - italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) ) ) divide start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ∞ ) end_POSTSUBSCRIPT ( italic_z )
=\displaystyle== i=1NV(f(b,bi+1))V(f(b,bi1))2𝟏[bi,)(z)superscriptsubscript𝑖1𝑁superscript𝑉𝑓𝑏subscript𝑏𝑖1superscript𝑉𝑓𝑏subscript𝑏𝑖12subscript1subscript𝑏𝑖𝑧\displaystyle\ -\sum_{i=1}^{N}\frac{V^{\prime}(f(b,b_{i+1}))-V^{\prime}(f(b,b_% {i-1}))}{2}\mathbf{1}_{[b_{i},\infty)}(z)- ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ) - italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) ) end_ARG start_ARG 2 end_ARG bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ∞ ) end_POSTSUBSCRIPT ( italic_z )
=\displaystyle== V(f(b,bi+1))+V(f(b,bi))2,z[bi,bi+1].superscript𝑉𝑓𝑏subscript𝑏𝑖1superscript𝑉𝑓𝑏subscript𝑏𝑖2𝑧subscript𝑏𝑖subscript𝑏𝑖1\displaystyle\ -\frac{V^{\prime}(f(b,b_{i+1}))+V^{\prime}(f(b,b_{i}))}{2},% \quad z\in[b_{i},b_{i+1}].- divide start_ARG italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ) + italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) end_ARG start_ARG 2 end_ARG , italic_z ∈ [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ] .

In the above derivation, we slightly cheat in the derivation so we can use the consistent formula for the evolution equations for all the nodes bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. It is easy to conclude that our discretization corresponds to the evolution of the mapping function f𝑓fitalic_f of constant speed V(f(b,bi+1))+V(f(b,bi))2superscript𝑉𝑓𝑏subscript𝑏𝑖1superscript𝑉𝑓𝑏subscript𝑏𝑖2-\frac{V^{\prime}(f(b,b_{i+1}))+V^{\prime}(f(b,b_{i}))}{2}- divide start_ARG italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ) + italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) end_ARG start_ARG 2 end_ARG on each interval [bi,bi+1]subscript𝑏𝑖subscript𝑏𝑖1[b_{i},b_{i+1}][ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ]. Now, recall that in mapping space, the Wasserstein gradient flow of the potential function V(x)𝑉𝑥V(x)italic_V ( italic_x ) corresponds to the velocity field V(x)superscript𝑉𝑥-V^{\prime}(x)- italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ). Therefore, given that the length of each interval is of order ΔbΔ𝑏\Delta broman_Δ italic_b, we conclude that our spatial discretization is first order consistent on the mapping space.

Next, we prove the statement for the mapping distribution. To do this, we need to derive the evolution equation for the mapping distribution according to eq. 28. We have for x[f(b,bi),f(b,bi+1)]𝑥𝑓𝑏subscript𝑏𝑖𝑓𝑏subscript𝑏𝑖1x\in[f(b,b_{i}),f(b,b_{i+1})]italic_x ∈ [ italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ]

tp(t,x)=b˙Tbp(t,x)subscript𝑡𝑝𝑡𝑥superscript˙𝑏𝑇subscript𝑏𝑝𝑡𝑥\displaystyle\ \partial_{t}p(t,x)=\dot{b}^{T}\partial_{b}p(t,x)∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) = over˙ start_ARG italic_b end_ARG start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∂ start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) (33)
=\displaystyle== i=1NN2ai(V(f(b,bi+1))V(f(b,bi1)))aiN(xp(t,x)𝟏[f(b,bi),)(x)+p(t,x)δf(b,bi)(x))superscriptsubscript𝑖1𝑁𝑁2subscript𝑎𝑖superscript𝑉𝑓𝑏subscript𝑏𝑖1superscript𝑉𝑓𝑏subscript𝑏𝑖1subscript𝑎𝑖𝑁subscript𝑥𝑝𝑡𝑥subscript1𝑓𝑏subscript𝑏𝑖𝑥𝑝𝑡𝑥subscript𝛿𝑓𝑏subscript𝑏𝑖𝑥\displaystyle\ \sum_{i=1}^{N}\frac{N}{2a_{i}}\left(V^{\prime}(f(b,b_{i+1}))-V^% {\prime}(f(b,b_{i-1}))\right)\frac{a_{i}}{N}\left(\partial_{x}p(t,x)\mathbf{1}% _{[f(b,b_{i}),\infty)}(x)+p(t,x)\delta_{f(b,b_{i})}(x)\right)∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG italic_N end_ARG start_ARG 2 italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ( italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ) - italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) ) ) divide start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG ( ∂ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) bold_1 start_POSTSUBSCRIPT [ italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , ∞ ) end_POSTSUBSCRIPT ( italic_x ) + italic_p ( italic_t , italic_x ) italic_δ start_POSTSUBSCRIPT italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ( italic_x ) )
=\displaystyle== xp(t,x)i=1NV(f(b,bi+1))V(f(b,bi1))2𝟏[f(b,bi),)(x)subscript𝑥𝑝𝑡𝑥superscriptsubscript𝑖1𝑁superscript𝑉𝑓𝑏subscript𝑏𝑖1superscript𝑉𝑓𝑏subscript𝑏𝑖12subscript1𝑓𝑏subscript𝑏𝑖𝑥\displaystyle\ \partial_{x}p(t,x)\sum_{i=1}^{N}\frac{V^{\prime}(f(b,b_{i+1}))-% V^{\prime}(f(b,b_{i-1}))}{2}\mathbf{1}_{[f(b,b_{i}),\infty)}(x)∂ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ) - italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) ) end_ARG start_ARG 2 end_ARG bold_1 start_POSTSUBSCRIPT [ italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , ∞ ) end_POSTSUBSCRIPT ( italic_x )
+p(t,x)i=1NV(f(b,bi+1))V(f(b,bi1))2δf(b,bi)(x)𝑝𝑡𝑥superscriptsubscript𝑖1𝑁superscript𝑉𝑓𝑏subscript𝑏𝑖1superscript𝑉𝑓𝑏subscript𝑏𝑖12subscript𝛿𝑓𝑏subscript𝑏𝑖𝑥\displaystyle\ +p(t,x)\sum_{i=1}^{N}\frac{V^{\prime}(f(b,b_{i+1}))-V^{\prime}(% f(b,b_{i-1}))}{2}\delta_{f(b,b_{i})}(x)+ italic_p ( italic_t , italic_x ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ) - italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) ) end_ARG start_ARG 2 end_ARG italic_δ start_POSTSUBSCRIPT italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ( italic_x )
=\displaystyle== xp(t,x)V(f(b,bi+1))+V(f(b,bi))2p(t,x)V(f(b,bi+1))+V(f(b,bi1))2δf(b,bi)(x).subscript𝑥𝑝𝑡𝑥superscript𝑉𝑓𝑏subscript𝑏𝑖1superscript𝑉𝑓𝑏subscript𝑏𝑖2𝑝𝑡𝑥superscript𝑉𝑓𝑏subscript𝑏𝑖1superscript𝑉𝑓𝑏subscript𝑏𝑖12subscript𝛿𝑓𝑏subscript𝑏𝑖𝑥\displaystyle\ \partial_{x}p(t,x)\frac{V^{\prime}(f(b,b_{i+1}))+V^{\prime}(f(b% ,b_{i}))}{2}-p(t,x)\frac{V^{\prime}(f(b,b_{i+1}))+V^{\prime}(f(b,b_{i-1}))}{2}% \delta_{f(b,b_{i})}(x).∂ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) divide start_ARG italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ) + italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) end_ARG start_ARG 2 end_ARG - italic_p ( italic_t , italic_x ) divide start_ARG italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ) + italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) ) end_ARG start_ARG 2 end_ARG italic_δ start_POSTSUBSCRIPT italic_f ( italic_b , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ( italic_x ) .

A quick method to derive the formula of bp(t,x)subscript𝑏𝑝𝑡𝑥\partial_{b}p(t,x)∂ start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) is to view it as a probability flow corresponds to the cotangent vector bfsubscript𝑏𝑓\partial_{b}f∂ start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT italic_f and then use the Wasserstein metric to calculate via a continuity equation as in proposition 2. Recall that the potential gradient flow in the density manifold is given by

tp(t,x)=(p(t,x)V(x))=xp(t,x)xV(x)+xxp(t,x)V(x).subscript𝑡𝑝𝑡𝑥𝑝𝑡𝑥𝑉𝑥subscript𝑥𝑝𝑡𝑥subscript𝑥𝑉𝑥subscript𝑥𝑥𝑝𝑡𝑥𝑉𝑥\partial_{t}p(t,x)=\nabla\cdot(p(t,x)\nabla V(x))=\partial_{x}p(t,x)\partial_{% x}V(x)+\partial_{xx}p(t,x)V(x).∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) = ∇ ⋅ ( italic_p ( italic_t , italic_x ) ∇ italic_V ( italic_x ) ) = ∂ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) ∂ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_V ( italic_x ) + ∂ start_POSTSUBSCRIPT italic_x italic_x end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) italic_V ( italic_x ) . (34)

Comparing eq. 33 and eq. 34, it is not difficult to recognize that the first term in eq. 33 approximates the continuous counterpart in eq. 34 in the first order. The remaining parts correspond to each other: the approximation is first order not in the strong sense, but rather in the weak sense as there is Dirac measure in eq. 33. Combining the above two parts, we finish the proof. ∎

4.2.1. Projected dynamics of Negative entropy gradient flow

The potential functional can be viewed as a linear functional whose projected gradient flow has a rather simple expression. The corresponding formula has a more complex expression for general nonlinear internal energy, such as entropy. We begin with calculating the negative entropy functional of a neural mapping measure f#prsubscript𝑓#subscript𝑝rf_{\#}p_{\mathrm{r}}italic_f start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT:

H(f#pr)=𝐻subscript𝑓#subscript𝑝rabsent\displaystyle H\left(f_{\#}p_{\mathrm{r}}\right)=italic_H ( italic_f start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) = 𝔼xcont(f#pr)[logf#pr(x)]+𝔉0(b1)log𝔉0(b1)subscript𝔼similar-to𝑥𝑐𝑜𝑛𝑡subscript𝑓#subscript𝑝rdelimited-[]subscript𝑓#subscript𝑝r𝑥subscript𝔉0subscript𝑏1subscript𝔉0subscript𝑏1\displaystyle\ \mathbb{E}_{x\sim cont\left(f_{\#}p_{\mathrm{r}}\right)}\left[% \log f_{\#}p_{\mathrm{r}}\left(x\right)\right]+\mathfrak{F}_{0}\left(b_{1}% \right)\log\mathfrak{F}_{0}\left(b_{1}\right)blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_c italic_o italic_n italic_t ( italic_f start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_f start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_x ) ] + fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) roman_log fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) (35)
=\displaystyle== 𝔼zcont(pr)[logf#pr(f(z))]+𝔉0(b1)log𝔉0(b1)subscript𝔼similar-to𝑧𝑐𝑜𝑛𝑡subscript𝑝rdelimited-[]subscript𝑓#subscript𝑝r𝑓𝑧subscript𝔉0subscript𝑏1subscript𝔉0subscript𝑏1\displaystyle\ \mathbb{E}_{z\sim cont\left(p_{\mathrm{r}}\right)}\left[\log f_% {\#}p_{\mathrm{r}}\left(f\left(z\right)\right)\right]+\mathfrak{F}_{0}\left(b_% {1}\right)\log\mathfrak{F}_{0}\left(b_{1}\right)blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_c italic_o italic_n italic_t ( italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_f start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_f ( italic_z ) ) ] + fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) roman_log fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )
=\displaystyle== 𝔼zcont(pr)[logpr(z)f(z)]+𝔉0(b1)log𝔉0(b1)subscript𝔼similar-to𝑧𝑐𝑜𝑛𝑡subscript𝑝rdelimited-[]subscript𝑝r𝑧superscript𝑓𝑧subscript𝔉0subscript𝑏1subscript𝔉0subscript𝑏1\displaystyle\ \mathbb{E}_{z\sim cont\left(p_{\mathrm{r}}\right)}\left[\log% \frac{p_{\mathrm{r}}\left(z\right)}{f^{\prime}\left(z\right)}\right]+\mathfrak% {F}_{0}\left(b_{1}\right)\log\mathfrak{F}_{0}\left(b_{1}\right)blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_c italic_o italic_n italic_t ( italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG start_ARG italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_z ) end_ARG ] + fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) roman_log fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )
=\displaystyle== 𝔼zcont(pr)[logpr(z)]𝔼zcont(pr)[logf(z)]+𝔉0(b1)log𝔉0(b1),subscript𝔼similar-to𝑧𝑐𝑜𝑛𝑡subscript𝑝rdelimited-[]subscript𝑝r𝑧subscript𝔼similar-to𝑧𝑐𝑜𝑛𝑡subscript𝑝rdelimited-[]superscript𝑓𝑧subscript𝔉0subscript𝑏1subscript𝔉0subscript𝑏1\displaystyle\ \mathbb{E}_{z\sim cont\left(p_{\mathrm{r}}\right)}\left[\log p_% {\mathrm{r}}\left(z\right)\right]-\mathbb{E}_{z\sim cont\left(p_{\mathrm{r}}% \right)}\left[{\log f^{\prime}\left(z\right)}\right]+\mathfrak{F}_{0}\left(b_{% 1}\right)\log\mathfrak{F}_{0}\left(b_{1}\right),blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_c italic_o italic_n italic_t ( italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) ] - blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_c italic_o italic_n italic_t ( italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_z ) ] + fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) roman_log fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ,

where we use the Monge-Ampère equation f#pr(f(z))=pr(z)f(z)subscript𝑓#subscript𝑝r𝑓𝑧subscript𝑝r𝑧superscript𝑓𝑧f_{\#}p_{\mathrm{r}}\left(f\left(z\right)\right)=\frac{p_{\mathrm{r}}(z)}{f^{% \prime}\left(z\right)}italic_f start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_f ( italic_z ) ) = divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG start_ARG italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_z ) end_ARG in one dimension. Moreover, notice that the last term corresponds to the entropy of the discrete part of distribution f#prsubscript𝑓#subscript𝑝rf_{\#}p_{\mathrm{r}}italic_f start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT as the ReLU mapping function maps (,b1]subscript𝑏1(-\infty,b_{1}]( - ∞ , italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] to 00 and cont()𝑐𝑜𝑛𝑡cont\left(\cdot\right)italic_c italic_o italic_n italic_t ( ⋅ ) refers to the continuous part of a distribution. Similarly, the relative entropy functional is given by

DKL(f#prν)=subscriptDKLconditionalsubscript𝑓#subscript𝑝r𝜈absent\displaystyle\mathrm{D}_{\mathrm{KL}}\left(f_{\#}p_{\mathrm{r}}\big{\|}\nu% \right)=roman_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ∥ italic_ν ) = 𝔼zcont(pr)[logpr(z)]𝔼zcont(pr)[logf(z)]subscript𝔼similar-to𝑧𝑐𝑜𝑛𝑡subscript𝑝rdelimited-[]subscript𝑝r𝑧subscript𝔼similar-to𝑧𝑐𝑜𝑛𝑡subscript𝑝rdelimited-[]superscript𝑓𝑧\displaystyle\ \quad\mathbb{E}_{z\sim cont\left(p_{\mathrm{r}}\right)}\left[% \log p_{\mathrm{r}}\left(z\right)\right]-\mathbb{E}_{z\sim cont\left(p_{% \mathrm{r}}\right)}\left[{\log f^{\prime}\left(z\right)}\right]blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_c italic_o italic_n italic_t ( italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) ] - blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_c italic_o italic_n italic_t ( italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_z ) ]
𝔼zcont(pr)[logν(f(z))]+𝔉0(b1)(log𝔉0(b1)).subscript𝔼similar-to𝑧𝑐𝑜𝑛𝑡subscript𝑝rdelimited-[]𝜈𝑓𝑧subscript𝔉0subscript𝑏1subscript𝔉0subscript𝑏1\displaystyle-\ \mathbb{E}_{z\sim cont\left(p_{\mathrm{r}}\right)}\left[{\log% \nu\left(f\left(z\right)\right)}\right]+\mathfrak{F}_{0}\left(b_{1}\right)% \left(\log\mathfrak{F}_{0}\left(b_{1}\right)\right).- blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_c italic_o italic_n italic_t ( italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_ν ( italic_f ( italic_z ) ) ] + fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ( roman_log fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ) .

Moreover, the gradient flow of the KL-divergence differs from that of negative entropy only by a term that appears in the derivation in the potential functional gradient flow. This This also manifests in calculus on the density manifold between the heat and Fokker-Planck equations. Now, one calculates the derivative of continuous parts w.r.t. parameter bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT

bi𝔼xpr[logf(x)]subscriptsubscript𝑏𝑖subscript𝔼similar-to𝑥subscript𝑝rdelimited-[]superscript𝑓𝑥\displaystyle\partial_{b_{i}}\mathbb{E}_{x\sim p_{\mathrm{r}}}\left[{\log f^{% \prime}\left(x\right)}\right]∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) ] ={logj=1i1ajj=1iajpr(bi),i1,pr(b1)loga1N,i=1.\displaystyle\ =\left\{\begin{aligned} &\log\frac{\sum_{j=1}^{i-1}a_{j}}{\sum_% {j=1}^{i}a_{j}}p_{\mathrm{r}}\left(b_{i}\right),\quad i\neq 1,\\ &-p_{\mathrm{r}}\left(b_{1}\right)\frac{\log a_{1}}{N},\hskip 28.45274pti=1.% \end{aligned}\right.= { start_ROW start_CELL end_CELL start_CELL roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_i ≠ 1 , end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL - italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) divide start_ARG roman_log italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG , italic_i = 1 . end_CELL end_ROW
bi𝔼xpr[logν(f(x))]subscriptsubscript𝑏𝑖subscript𝔼similar-to𝑥subscript𝑝rdelimited-[]𝜈𝑓𝑥\displaystyle\partial_{b_{i}}\mathbb{E}_{x\sim p_{\mathrm{r}}}\left[{\log\nu% \left(f\left(x\right)\right)}\right]∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_log italic_ν ( italic_f ( italic_x ) ) ] =𝔼xpr[ν(y(x))biy(x)ν(y(x))].absentsubscript𝔼similar-to𝑥subscript𝑝rdelimited-[]superscript𝜈𝑦𝑥subscriptsubscript𝑏𝑖𝑦𝑥𝜈𝑦𝑥\displaystyle\ =\mathbb{E}_{x\sim p_{\mathrm{r}}}\left[\frac{\nu^{\prime}\left% (y\left(x\right)\right)\partial_{b_{i}}y\left(x\right)}{\nu\left(y\left(x% \right)\right)}\right].= blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ divide start_ARG italic_ν start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ( italic_x ) ) ∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_y ( italic_x ) end_ARG start_ARG italic_ν ( italic_y ( italic_x ) ) end_ARG ] .

The first derivation is based on the observation that the function logf(x)superscript𝑓𝑥\log f^{\prime}\left(x\right)roman_log italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) is a step function which changes its value at bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. It takes value logj=1i1ajsuperscriptsubscript𝑗1𝑖1subscript𝑎𝑗\log\sum_{j=1}^{i-1}a_{j}roman_log ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT at interval [bi1,bi]subscript𝑏𝑖1subscript𝑏𝑖\left[b_{i-1},b_{i}\right][ italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ]. Hence the desired conclusion follows, where pr(bi)subscript𝑝rsubscript𝑏𝑖p_{\mathrm{r}}\left(b_{i}\right)italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) comes in since this is the expectation w.r.t. distribution pr(x)subscript𝑝r𝑥p_{\mathrm{r}}\left(x\right)italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_x ). Therefore, the derivative of the entropy and relative entropy functional reads as follows

biH(f#pr)=subscriptsubscript𝑏𝑖𝐻subscript𝑓#subscript𝑝rabsent\displaystyle\partial_{b_{i}}H\left(f_{\#}p_{\mathrm{r}}\right)=∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_H ( italic_f start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) = {logj=1i1ajj=1iajpr(bi),i1,pr(b1)(log𝔉0(b1)+1+loga1N),i=1.\displaystyle\ \left\{\begin{aligned} &-\log\frac{\sum_{j=1}^{i-1}a_{j}}{\sum_% {j=1}^{i}a_{j}}p_{\mathrm{r}}\left(b_{i}\right),\hskip 59.75095pti\neq 1,\\ &p_{\mathrm{r}}\left(b_{1}\right)\left(\log\mathfrak{F}_{0}\left(b_{1}\right)+% 1+\log\frac{a_{1}}{N}\right),\quad i=1.\end{aligned}\right.{ start_ROW start_CELL end_CELL start_CELL - roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_i ≠ 1 , end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ( roman_log fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + 1 + roman_log divide start_ARG italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG ) , italic_i = 1 . end_CELL end_ROW (36)
biDKL(f#prν)=subscriptsubscript𝑏𝑖subscriptDKLconditionalsubscript𝑓#subscript𝑝r𝜈absent\displaystyle\partial_{b_{i}}\mathrm{D}_{\mathrm{KL}}\left(f_{\#}p_{\mathrm{r}% }\big{\|}\nu\right)=∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ∥ italic_ν ) = 𝔼xp[ν(f(x))bif(x)ν(f(x))]logj=1iajj=1i1ajpr(bi).subscript𝔼similar-to𝑥𝑝delimited-[]superscript𝜈𝑓𝑥subscriptsubscript𝑏𝑖𝑓𝑥𝜈𝑓𝑥superscriptsubscript𝑗1𝑖subscript𝑎𝑗superscriptsubscript𝑗1𝑖1subscript𝑎𝑗subscript𝑝rsubscript𝑏𝑖\displaystyle\ \mathbb{E}_{x\sim p}\left[\frac{\nu^{\prime}\left(f\left(x% \right)\right)\partial_{b_{i}}f\left(x\right)}{\nu\left(f\left(x\right)\right)% }\right]-\log\frac{\sum_{j=1}^{i}a_{j}}{\sum_{j=1}^{i-1}a_{j}}p_{\mathrm{r}}% \left(b_{i}\right).blackboard_E start_POSTSUBSCRIPT italic_x ∼ italic_p end_POSTSUBSCRIPT [ divide start_ARG italic_ν start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f ( italic_x ) ) ∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f ( italic_x ) end_ARG start_ARG italic_ν ( italic_f ( italic_x ) ) end_ARG ] - roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) .

With all these preparations, we can write out the gradient flow equation of the entropy functional:

bi˙=˙subscript𝑏𝑖absent\displaystyle\dot{b_{i}}=over˙ start_ARG italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = 1𝔉0(bi)𝔉0(bi1)(logj=1i1ajj=1iajpr(bi)ai2logj=1i2ajj=1i1ajpr(bi1)aiai1)1subscript𝔉0subscript𝑏𝑖subscript𝔉0subscript𝑏𝑖1superscriptsubscript𝑗1𝑖1subscript𝑎𝑗superscriptsubscript𝑗1𝑖subscript𝑎𝑗subscript𝑝rsubscript𝑏𝑖superscriptsubscript𝑎𝑖2superscriptsubscript𝑗1𝑖2subscript𝑎𝑗superscriptsubscript𝑗1𝑖1subscript𝑎𝑗subscript𝑝rsubscript𝑏𝑖1subscript𝑎𝑖subscript𝑎𝑖1\displaystyle\ \frac{1}{\mathfrak{F}_{0}\left(b_{i}\right)-\mathfrak{F}_{0}% \left(b_{i-1}\right)}\left(\frac{\log\frac{\sum_{j=1}^{i-1}a_{j}}{\sum_{j=1}^{% i}a_{j}}p_{\mathrm{r}}\left(b_{i}\right)}{a_{i}^{2}}-\frac{\log\frac{\sum_{j=1% }^{i-2}a_{j}}{\sum_{j=1}^{i-1}a_{j}}p_{\mathrm{r}}\left(b_{i-1}\right)}{a_{i}a% _{i-1}}\right)divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) end_ARG ( divide start_ARG roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG - divide start_ARG roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 2 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT end_ARG ) (37)
+1𝔉0(bi+1)𝔉0(bi)(logj=1i1ajj=1iajpr(bi)ai2logj=1iajj=1i+1ajpr(bi+1)aiai+1),i=2,,N1,formulae-sequence1subscript𝔉0subscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖superscriptsubscript𝑗1𝑖1subscript𝑎𝑗superscriptsubscript𝑗1𝑖subscript𝑎𝑗subscript𝑝rsubscript𝑏𝑖superscriptsubscript𝑎𝑖2superscriptsubscript𝑗1𝑖subscript𝑎𝑗superscriptsubscript𝑗1𝑖1subscript𝑎𝑗subscript𝑝rsubscript𝑏𝑖1subscript𝑎𝑖subscript𝑎𝑖1𝑖2𝑁1\displaystyle\ +\frac{1}{\mathfrak{F}_{0}\left(b_{i+1}\right)-\mathfrak{F}_{0}% \left(b_{i}\right)}\left(\frac{\log\frac{\sum_{j=1}^{i-1}a_{j}}{\sum_{j=1}^{i}% a_{j}}p_{\mathrm{r}}\left(b_{i}\right)}{a_{i}^{2}}-\frac{\log\frac{\sum_{j=1}^% {i}a_{j}}{\sum_{j=1}^{i+1}a_{j}}p_{\mathrm{r}}\left(b_{i+1}\right)}{a_{i}a_{i+% 1}}\right),\quad i=2,\cdots,N-1,+ divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG ( divide start_ARG roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG - divide start_ARG roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i + 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_ARG ) , italic_i = 2 , ⋯ , italic_N - 1 ,
b1˙=˙subscript𝑏1absent\displaystyle\dot{b_{1}}=over˙ start_ARG italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG = 1a1(𝔉0(b2)𝔉0(b1))(pr(b1)(log𝔉0(b1)+1+loga1N)a1+logj=11ajj=12ajpr(b2)a2),1subscript𝑎1subscript𝔉0subscript𝑏2subscript𝔉0subscript𝑏1subscript𝑝rsubscript𝑏1subscript𝔉0subscript𝑏11subscript𝑎1𝑁subscript𝑎1superscriptsubscript𝑗11subscript𝑎𝑗superscriptsubscript𝑗12subscript𝑎𝑗subscript𝑝rsubscript𝑏2subscript𝑎2\displaystyle\ -\frac{1}{a_{1}(\mathfrak{F}_{0}(b_{2})-\mathfrak{F}_{0}(b_{1})% )}\left(\frac{p_{\mathrm{r}}\left(b_{1}\right)\left(\log\mathfrak{F}_{0}\left(% b_{1}\right)+1+\log\frac{a_{1}}{N}\right)}{a_{1}}+\frac{\log\frac{\sum_{j=1}^{% 1}a_{j}}{\sum_{j=1}^{2}a_{j}}p_{\mathrm{r}}\left(b_{2}\right)}{a_{2}}\right),- divide start_ARG 1 end_ARG start_ARG italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ) end_ARG ( divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ( roman_log fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + 1 + roman_log divide start_ARG italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG ) end_ARG start_ARG italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG + divide start_ARG roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ) ,
bN˙=˙subscript𝑏𝑁absent\displaystyle\dot{b_{N}}=over˙ start_ARG italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT end_ARG = logj=1N1ajj=1Najpr(bN)aN2(1𝔉0(bN))1𝔉0(bN)𝔉0(bN1)(logj=1N2ajj=1N1ajpr(bN1)aNaN1logj=1N1ajj=1Najpr(bN)aN2).superscriptsubscript𝑗1𝑁1subscript𝑎𝑗superscriptsubscript𝑗1𝑁subscript𝑎𝑗subscript𝑝rsubscript𝑏𝑁superscriptsubscript𝑎𝑁21subscript𝔉0subscript𝑏𝑁1subscript𝔉0subscript𝑏𝑁subscript𝔉0subscript𝑏𝑁1superscriptsubscript𝑗1𝑁2subscript𝑎𝑗superscriptsubscript𝑗1𝑁1subscript𝑎𝑗subscript𝑝rsubscript𝑏𝑁1subscript𝑎𝑁subscript𝑎𝑁1superscriptsubscript𝑗1𝑁1subscript𝑎𝑗superscriptsubscript𝑗1𝑁subscript𝑎𝑗subscript𝑝rsubscript𝑏𝑁superscriptsubscript𝑎𝑁2\displaystyle\ \frac{\log\frac{\sum_{j=1}^{N-1}a_{j}}{\sum_{j=1}^{N}a_{j}}p_{% \mathrm{r}}\left(b_{N}\right)}{a_{N}^{2}(1-\mathfrak{F}_{0}\left(b_{N}\right))% }-\frac{1}{\mathfrak{F}_{0}\left(b_{N}\right)-\mathfrak{F}_{0}\left(b_{N-1}% \right)}\left(\frac{\log\frac{\sum_{j=1}^{N-2}a_{j}}{\sum_{j=1}^{N-1}a_{j}}p_{% \mathrm{r}}\left(b_{N-1}\right)}{a_{N}a_{N-1}}-\frac{\log\frac{\sum_{j=1}^{N-1% }a_{j}}{\sum_{j=1}^{N}a_{j}}p_{\mathrm{r}}\left(b_{N}\right)}{a_{N}^{2}}\right).divide start_ARG roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) ) end_ARG - divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ) end_ARG ( divide start_ARG roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N - 2 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_N - 1 end_POSTSUBSCRIPT end_ARG - divide start_ARG roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) .

Similar to the proof in proposition 6, one can carefully expand the neural projected dynamics of the entropy functional and prove that it converges to the heat equation in the limit that number of neurons tends to infinity and the gap between neurons nodes tends to zero.

4.2.2. Analysis of the long-time existence of the neural-projected heat flow

In general, the projected Wasserstein gradient flow does not necessarily need to be a linear dynamics even though the original gradient flow is linear, e.g., the projected gradient flow corresponding to the heat equation is highly nonlinear. This poses great difficulties in analyzing and establishing the long-time existence of the projected dynamics, as mentioned in [25]. Specifically, we focus on the nonlinear projected gradient flow of the entropy, which corresponds to the Heat equation in the full space. If we view all nodes bi,i[N]subscript𝑏𝑖𝑖delimited-[]𝑁b_{i},i\in[N]italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_i ∈ [ italic_N ] as grid points and view the scheme as an example of the moving mesh method [17], then the mesh quality is an important quantity to observe during simulation. One does not want the mesh quality to decrease too much and even become degenerate during the simulations. Therefore, we consider the well-posedness of the non-linear ODE eq. 37.

Proposition 7.

The neural projected dynamics eq. 37 of the heat flow is well-posed, e.g. the solution extends to arbitrary time.

Proof.

We consider a special scenario when two adjacent nodes bi,bi+1subscript𝑏𝑖subscript𝑏𝑖1b_{i},b_{i+1}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT become close to each other while maintaining a relatively large gap with all other nodes, i.e.

o(1)=bi+1bi=o(bpbq),q[N]\{i,i+1},p=i,i+1.formulae-sequence𝑜1subscript𝑏𝑖1subscript𝑏𝑖𝑜subscript𝑏𝑝subscript𝑏𝑞formulae-sequencefor-all𝑞\delimited-[]𝑁𝑖𝑖1𝑝𝑖𝑖1o(1)=b_{i+1}-b_{i}=o(b_{p}-b_{q}),\quad\forall q\in[N]\backslash\{i,i+1\},% \quad p=i,i+1.italic_o ( 1 ) = italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_o ( italic_b start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT - italic_b start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ) , ∀ italic_q ∈ [ italic_N ] \ { italic_i , italic_i + 1 } , italic_p = italic_i , italic_i + 1 . (38)

WLOG, we assume bi+1=bi+Δb>bisubscript𝑏𝑖1subscript𝑏𝑖Δ𝑏subscript𝑏𝑖b_{i+1}=b_{i}+\Delta b>b_{i}italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT = italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + roman_Δ italic_b > italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and reduce the following term which appears both in their time derivative expression in eq. 37

1𝔉0(bi)𝔉0(bi1)(logj=1i1ajj=1iajpr(bi)ai2logj=1i2ajj=1i1ajpr(bi1)aiai1)1subscript𝔉0subscript𝑏𝑖subscript𝔉0subscript𝑏𝑖1superscriptsubscript𝑗1𝑖1subscript𝑎𝑗superscriptsubscript𝑗1𝑖subscript𝑎𝑗subscript𝑝rsubscript𝑏𝑖superscriptsubscript𝑎𝑖2superscriptsubscript𝑗1𝑖2subscript𝑎𝑗superscriptsubscript𝑗1𝑖1subscript𝑎𝑗subscript𝑝rsubscript𝑏𝑖1subscript𝑎𝑖subscript𝑎𝑖1\displaystyle\ \frac{1}{\mathfrak{F}_{0}\left(b_{i}\right)-\mathfrak{F}_{0}% \left(b_{i-1}\right)}\left(\frac{\log\frac{\sum_{j=1}^{i-1}a_{j}}{\sum_{j=1}^{% i}a_{j}}p_{\mathrm{r}}\left(b_{i}\right)}{a_{i}^{2}}-\frac{\log\frac{\sum_{j=1% }^{i-2}a_{j}}{\sum_{j=1}^{i-1}a_{j}}p_{\mathrm{r}}\left(b_{i-1}\right)}{a_{i}a% _{i-1}}\right)divide start_ARG 1 end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) end_ARG ( divide start_ARG roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG - divide start_ARG roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 2 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT end_ARG ) (39)
=\displaystyle== (1pr(bi)Δb+O(1))(logi1ipr(bi)logi2i1pr(bi1))1subscript𝑝rsubscript𝑏𝑖Δ𝑏𝑂1𝑖1𝑖subscript𝑝rsubscript𝑏𝑖𝑖2𝑖1subscript𝑝rsubscript𝑏𝑖1\displaystyle\ \left(\frac{1}{p_{\mathrm{r}}(b_{i})\Delta b}+O(1)\right)\left(% \log\frac{i-1}{i}p_{\mathrm{r}}(b_{i})-\log\frac{i-2}{i-1}p_{\mathrm{r}}(b_{i-% 1})\right)( divide start_ARG 1 end_ARG start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_Δ italic_b end_ARG + italic_O ( 1 ) ) ( roman_log divide start_ARG italic_i - 1 end_ARG start_ARG italic_i end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - roman_log divide start_ARG italic_i - 2 end_ARG start_ARG italic_i - 1 end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) )
=\displaystyle== (1pr(bi)Δb+O(1))(logi1ipr(bi)logi2i1(pr(bi)+O(Δb)))1subscript𝑝rsubscript𝑏𝑖Δ𝑏𝑂1𝑖1𝑖subscript𝑝rsubscript𝑏𝑖𝑖2𝑖1subscript𝑝rsubscript𝑏𝑖𝑂Δ𝑏\displaystyle\ \left(\frac{1}{p_{\mathrm{r}}(b_{i})\Delta b}+O(1)\right)\left(% \log\frac{i-1}{i}p_{\mathrm{r}}(b_{i})-\log\frac{i-2}{i-1}\left(p_{\mathrm{r}}% (b_{i})+O(\Delta b)\right)\right)( divide start_ARG 1 end_ARG start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_Δ italic_b end_ARG + italic_O ( 1 ) ) ( roman_log divide start_ARG italic_i - 1 end_ARG start_ARG italic_i end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - roman_log divide start_ARG italic_i - 2 end_ARG start_ARG italic_i - 1 end_ARG ( italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_O ( roman_Δ italic_b ) ) )
=\displaystyle== (1pr(bi)Δb+O(1))(logi22i+1i22ipr(bi)+O(Δb))1subscript𝑝rsubscript𝑏𝑖Δ𝑏𝑂1superscript𝑖22𝑖1superscript𝑖22𝑖subscript𝑝rsubscript𝑏𝑖𝑂Δ𝑏\displaystyle\ \left(\frac{1}{p_{\mathrm{r}}(b_{i})\Delta b}+O(1)\right)\left(% \log\frac{i^{2}-2i+1}{i^{2}-2i}p_{\mathrm{r}}(b_{i})+O(\Delta b)\right)( divide start_ARG 1 end_ARG start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_Δ italic_b end_ARG + italic_O ( 1 ) ) ( roman_log divide start_ARG italic_i start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 italic_i + 1 end_ARG start_ARG italic_i start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 italic_i end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_O ( roman_Δ italic_b ) )
=\displaystyle== 1Δblogi22i+1i22i+O(1)+,Δb0+,formulae-sequence1Δ𝑏superscript𝑖22𝑖1superscript𝑖22𝑖𝑂1Δ𝑏superscript0\displaystyle\ \frac{1}{\Delta b}\log\frac{i^{2}-2i+1}{i^{2}-2i}+O(1)% \rightarrow+\infty,\quad\Delta b\rightarrow 0^{+},divide start_ARG 1 end_ARG start_ARG roman_Δ italic_b end_ARG roman_log divide start_ARG italic_i start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 italic_i + 1 end_ARG start_ARG italic_i start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 italic_i end_ARG + italic_O ( 1 ) → + ∞ , roman_Δ italic_b → 0 start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT ,

where we use the simplified model where all the weights aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are set to 1111 and Taylor expansion to conclude that 𝔉0(bi)𝔉0(bi1)=pr(bi)Δb+O(Δb2)subscript𝔉0subscript𝑏𝑖subscript𝔉0subscript𝑏𝑖1subscript𝑝rsubscript𝑏𝑖Δ𝑏𝑂Δsuperscript𝑏2\mathfrak{F}_{0}\left(b_{i}\right)-\mathfrak{F}_{0}\left(b_{i-1}\right)=p_{% \mathrm{r}}(b_{i})\Delta b+O(\Delta b^{2})fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) = italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_Δ italic_b + italic_O ( roman_Δ italic_b start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) and p(bi1)𝑝subscript𝑏𝑖1p(b_{i-1})italic_p ( italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) follows the same spirit. This term appears with positive sign in the RHS of b˙isubscript˙𝑏𝑖\dot{b}_{i}over˙ start_ARG italic_b end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and negative sign in the RHS of b˙i1subscript˙𝑏𝑖1\dot{b}_{i-1}over˙ start_ARG italic_b end_ARG start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT, indicating that the left (right) node bi1subscript𝑏𝑖1b_{i-1}italic_b start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT(bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT) will move fast towards left (right) respectively. This repulsion behavior guarantees that the Lagrangian coordinates will never collide with each other and the mesh degeneracy will not appear.

Next, we analyze our scheme using the time derivative of the Lagrangian coordinate. It is a well-known result that under the heat flow the mean of the distribution is fixed. Therefore, due to the diffusive nature of the heat equation, one can imagine that the position of the quantile greater than the mean should move right in the heat equation and vice versa. Suppose x[bi,bi+1]𝑥subscript𝑏𝑖subscript𝑏𝑖1x\in[b_{i},b_{i+1}]italic_x ∈ [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ] is a quantile with bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT greater than the mean 00. As the base measure is a standard Gaussian distribution whose probability density function decreases over [0,)0[0,\infty)[ 0 , ∞ ), we conclude that

0<bi<bi+1pr(bi)>pr(bi+1)logj=1i1ajj=1iajpr(bi)<logj=1iajj=1i+1ajpr(bi+1)<0.0subscript𝑏𝑖subscript𝑏𝑖1subscript𝑝rsubscript𝑏𝑖subscript𝑝rsubscript𝑏𝑖1superscriptsubscript𝑗1𝑖1subscript𝑎𝑗superscriptsubscript𝑗1𝑖subscript𝑎𝑗subscript𝑝rsubscript𝑏𝑖superscriptsubscript𝑗1𝑖subscript𝑎𝑗superscriptsubscript𝑗1𝑖1subscript𝑎𝑗subscript𝑝rsubscript𝑏𝑖100<b_{i}<b_{i+1}\Longrightarrow p_{\mathrm{r}}(b_{i})>p_{\mathrm{r}}(b_{i+1})% \Longrightarrow\log\frac{\sum_{j=1}^{i-1}a_{j}}{\sum_{j=1}^{i}a_{j}}p_{\mathrm% {r}}\left(b_{i}\right)<\log\frac{\sum_{j=1}^{i}a_{j}}{\sum_{j=1}^{i+1}a_{j}}p_% {\mathrm{r}}\left(b_{i+1}\right)<0.0 < italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT < italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ⟹ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) > italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ⟹ roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) < roman_log divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i + 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) < 0 . (40)

Consequently, the Lagrangian coordinate fb(z)subscript𝑓𝑏𝑧f_{b}(z)italic_f start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ( italic_z ) is indeed moving towards right, which matchs the intuition from the heat equation.

The neural projected dynamics can be understood as a Lagrangian scheme [8, 24, 26] with neural network basis. Specifically, fixing basis as ReLU components in eq. 18, one can view aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s and bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s as the shape and location coefficients of the basis functions respectively. Updating aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s is similar to classical finite-element method with fixed basis functions, while adding the degree of freedom of bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s is similar to the moving mesh method. The Lagrangian schemes can handle the problem of the free boundary such as porous medium, e.g. in [24], they use finite element method to solve the mapping function of the porous medium equation with high accuracy. While most Lagrangian schemes are based on updating the aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s parameters, our methods have more flexibility and expressivity as it takes more degree of freedom into account. The primal-dual structure of the Wasserstein gradient flow also leverages a lot of usage of Lagrangian schemes [8].

On the other hand, our numerical algorithm and the moving mesh method. The principal ingredients of the moving mesh method include the equidistribution principle, the moving mesh equation, and the method of lines approach [36]. The moving mesh equation is solved during the simulation to ensure the adaptivity such that the mesh can resolve to the detailed structure. In many classical moving mesh methods, the mesh equations are solved separately from the governing PDE itself to guarantee the adaptivity of the numerical methods. This implies that how the mesh change will not depend explicitly on the underlying PDE. There also exist moving mesh methods such that the mesh updates take into account of the governing PDE (e.g., the arbitrary Lagrangian-Eularian methods [5]). From this perspective, the projected dynamics provide a PDE-specific moving mesh equation, i.e. the mesh moved according to the PDE dynamics to simulate which is more adaptive and efficient. Moreover, through a detailed study of the simple case, we can establish a theoretical guarantee on the quality of our moving mesh method in proposition 7.

4.3. Truncated error analysis for general neural projected Wasserstein gradient flow

The proof of the consistency of the numerical scheme relies on the analytic formula derived before which is restrictive. In this section, we provide another methodology to prove the consistency of the numerical scheme we derived in this paper. Instead of calculating the evolution of the mapping explicitly, we calculate the deviation of the projected gradient w.r.t. the original gradient direction. Let us first state a geometric proposition where we attempt to be as general as possible. This result is also proved in [25] and we prove it here for completeness.

Let 𝒳𝒳\mathcal{X}caligraphic_X be a manifold (possibly infinite-dimensional) with a Riemannian metric g𝒳subscript𝑔𝒳g_{\mathcal{X}}italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT, which provides an inner product on the tangent space Tx𝒳subscript𝑇𝑥𝒳T_{x}\mathcal{X}italic_T start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT caligraphic_X (possibly infinite-dimensional Hilbert space) for each x𝒳𝑥𝒳x\in\mathcal{X}italic_x ∈ caligraphic_X. Let 𝒴𝒳𝒴𝒳\mathcal{Y}\subset\mathcal{X}caligraphic_Y ⊂ caligraphic_X be its submanifold with induced metric denoted by g𝒴subscript𝑔𝒴g_{\mathcal{Y}}italic_g start_POSTSUBSCRIPT caligraphic_Y end_POSTSUBSCRIPT, i.e. y𝒴for-all𝑦𝒴\forall y\in\mathcal{Y}∀ italic_y ∈ caligraphic_Y:

g𝒴(y):Ty𝒴×Ty𝒴,g𝒴(y)(v,w)=g𝒳(y)(v,w),v,wTy𝒴.:subscript𝑔𝒴𝑦formulae-sequencesubscript𝑇𝑦𝒴subscript𝑇𝑦𝒴formulae-sequencesubscript𝑔𝒴𝑦𝑣𝑤subscript𝑔𝒳𝑦𝑣𝑤for-all𝑣𝑤subscript𝑇𝑦𝒴g_{\mathcal{Y}}(y):T_{y}\mathcal{Y}\times T_{y}\mathcal{Y}\rightarrow\mathbb{R% },\quad g_{\mathcal{Y}}(y)(v,w)=g_{\mathcal{X}}(y)(v,w),\quad\forall v,w\in T_% {y}\mathcal{Y}.italic_g start_POSTSUBSCRIPT caligraphic_Y end_POSTSUBSCRIPT ( italic_y ) : italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_Y × italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_Y → blackboard_R , italic_g start_POSTSUBSCRIPT caligraphic_Y end_POSTSUBSCRIPT ( italic_y ) ( italic_v , italic_w ) = italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT ( italic_y ) ( italic_v , italic_w ) , ∀ italic_v , italic_w ∈ italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_Y .

Furthermore, let H:𝒳:𝐻𝒳H:\mathcal{X}\rightarrow\mathbb{R}italic_H : caligraphic_X → blackboard_R be a functional defined over 𝒳𝒳\mathcal{X}caligraphic_X and we use H~:𝒴:~𝐻𝒴\widetilde{H}:\mathcal{Y}\rightarrow\mathbb{R}over~ start_ARG italic_H end_ARG : caligraphic_Y → blackboard_R for its restriction on 𝒴𝒴\mathcal{Y}caligraphic_Y. We have the following proposition.

Proposition 8.

Let g𝒳H(y)Ty𝒳subscriptnormal-∇subscript𝑔𝒳𝐻𝑦subscript𝑇𝑦𝒳\nabla_{g_{\mathcal{X}}}H(y)\in T_{y}\mathcal{X}∇ start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_H ( italic_y ) ∈ italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_X (g𝒴H~(y)Ty𝒴)subscriptnormal-∇subscript𝑔𝒴normal-~𝐻𝑦subscript𝑇𝑦𝒴(\nabla_{g_{\mathcal{Y}}}\widetilde{H}(y)\in T_{y}\mathcal{Y})( ∇ start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT caligraphic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG italic_H end_ARG ( italic_y ) ∈ italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_Y ) denote the gradient of the functional H𝐻Hitalic_H w.r.t. the metric g𝒳subscript𝑔𝒳g_{\mathcal{X}}italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT (g𝒴)subscript𝑔𝒴(g_{\mathcal{Y}})( italic_g start_POSTSUBSCRIPT caligraphic_Y end_POSTSUBSCRIPT ) at y𝒳𝑦𝒳y\in\mathcal{X}italic_y ∈ caligraphic_X (y𝒴)𝑦𝒴(y\in\mathcal{Y})( italic_y ∈ caligraphic_Y ). Then, we have

g𝒴H~(y)=Π(y)g𝒳H(y),subscriptsubscript𝑔𝒴~𝐻𝑦Π𝑦subscriptsubscript𝑔𝒳𝐻𝑦\nabla_{g_{\mathcal{Y}}}\widetilde{H}(y)=\Pi(y)\nabla_{g_{\mathcal{X}}}H(y),∇ start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT caligraphic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG italic_H end_ARG ( italic_y ) = roman_Π ( italic_y ) ∇ start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_H ( italic_y ) , (41)

where Π(y)normal-Π𝑦\Pi(y)roman_Π ( italic_y ) is the orthogonal projection operator from Ty𝒳subscript𝑇𝑦𝒳T_{y}\mathcal{X}italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_X to Ty𝒴subscript𝑇𝑦𝒴T_{y}\mathcal{Y}italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_Y.

Proof.

As 𝒴𝒴\mathcal{Y}caligraphic_Y is a submanfold of 𝒳𝒳\mathcal{X}caligraphic_X, we have inclusion map I(y):Ty𝒴Ty𝒳:I𝑦subscript𝑇𝑦𝒴subscript𝑇𝑦𝒳\mathrm{I}(y):T_{y}\mathcal{Y}\rightarrow T_{y}\mathcal{X}roman_I ( italic_y ) : italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_Y → italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_X and restriction map I*(y):Ty*𝒳Ty*𝒴:superscriptI𝑦superscriptsubscript𝑇𝑦𝒳superscriptsubscript𝑇𝑦𝒴\mathrm{I}^{*}(y):T_{y}^{*}\mathcal{X}\rightarrow T_{y}^{*}\mathcal{Y}roman_I start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( italic_y ) : italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT caligraphic_X → italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT caligraphic_Y for each y𝒴𝑦𝒴y\in\mathcal{Y}italic_y ∈ caligraphic_Y. Both mappings are linear and are adjoint to each other. Therefore, viewing the metric tensor g𝒴(y)subscript𝑔𝒴𝑦g_{\mathcal{Y}}(y)italic_g start_POSTSUBSCRIPT caligraphic_Y end_POSTSUBSCRIPT ( italic_y ) as a linear mapping between Ty𝒴Ty*𝒴subscript𝑇𝑦𝒴superscriptsubscript𝑇𝑦𝒴T_{y}\mathcal{Y}\rightarrow T_{y}^{*}\mathcal{Y}italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_Y → italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT caligraphic_Y, we have

g𝒴(y)=I*(y)g𝒳(y)I(y),y𝒴.formulae-sequencesubscript𝑔𝒴𝑦superscriptI𝑦subscript𝑔𝒳𝑦I𝑦for-all𝑦𝒴g_{\mathcal{Y}}(y)=\mathrm{I}^{*}(y)\circ g_{\mathcal{X}}(y)\circ\mathrm{I}(y)% ,\quad\forall y\in\mathcal{Y}.italic_g start_POSTSUBSCRIPT caligraphic_Y end_POSTSUBSCRIPT ( italic_y ) = roman_I start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( italic_y ) ∘ italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT ( italic_y ) ∘ roman_I ( italic_y ) , ∀ italic_y ∈ caligraphic_Y .

Moreover, the inner product g𝒳(y)subscript𝑔𝒳𝑦g_{\mathcal{X}}(y)italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT ( italic_y ) on the Hilbert space Ty𝒳subscript𝑇𝑦𝒳T_{y}\mathcal{X}italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_X induces an orthogonal decomposition:

Ty𝒳=Ty𝒴Ty𝒴,y𝒴,formulae-sequencesubscript𝑇𝑦𝒳direct-sumsubscript𝑇𝑦𝒴subscript𝑇𝑦superscript𝒴perpendicular-tofor-all𝑦𝒴T_{y}\mathcal{X}=T_{y}\mathcal{Y}\oplus T_{y}\mathcal{Y}^{\perp},\quad\forall y% \in\mathcal{Y},italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_X = italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_Y ⊕ italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_Y start_POSTSUPERSCRIPT ⟂ end_POSTSUPERSCRIPT , ∀ italic_y ∈ caligraphic_Y ,

along with an orthogonal projection operator Π(y)Π𝑦\Pi(y)roman_Π ( italic_y ). Now, recall that the Riemannian gradient g𝒳H(y)subscriptsubscript𝑔𝒳𝐻𝑦\nabla_{g_{\mathcal{X}}}H(y)∇ start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_H ( italic_y ) is defined as

g𝒳(y)g𝒳H(y)=dH(y).subscript𝑔𝒳𝑦subscriptsubscript𝑔𝒳𝐻𝑦𝑑𝐻𝑦g_{\mathcal{X}}(y)\nabla_{g_{\mathcal{X}}}H(y)=dH(y).italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT ( italic_y ) ∇ start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_H ( italic_y ) = italic_d italic_H ( italic_y ) .

The differential of H()𝐻H(\cdot)italic_H ( ⋅ ) and H~()~𝐻\widetilde{H}(\cdot)over~ start_ARG italic_H end_ARG ( ⋅ ) is related by

dH~(y)=I*(y)dH(y),y𝒴.formulae-sequence𝑑~𝐻𝑦superscriptI𝑦𝑑𝐻𝑦for-all𝑦𝒴d\widetilde{H}(y)=\mathrm{I}^{*}(y)dH(y),\quad\forall y\in\mathcal{Y}.italic_d over~ start_ARG italic_H end_ARG ( italic_y ) = roman_I start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( italic_y ) italic_d italic_H ( italic_y ) , ∀ italic_y ∈ caligraphic_Y .

Therefore, gathering all the ingredients, we have the following commutative diagram

{tikzcd}{tikzcd}\begin{tikzcd}

As Π(y)Π𝑦\Pi(y)roman_Π ( italic_y ) is the orthogonal projection, we conclude that

g𝒴H~(y)=(I*(y)g𝒳(y)I(y))1I*(y)dH(y)=Π(y)g𝒳H(y).subscriptsubscript𝑔𝒴~𝐻𝑦superscriptsuperscriptI𝑦subscript𝑔𝒳𝑦I𝑦1superscriptI𝑦𝑑𝐻𝑦Π𝑦subscriptsubscript𝑔𝒳𝐻𝑦\nabla_{g_{\mathcal{Y}}}\widetilde{H}(y)=(\mathrm{I}^{*}(y)g_{\mathcal{X}}(y)% \mathrm{I}(y))^{-1}\mathrm{I}^{*}(y)dH(y)=\Pi(y)\nabla_{g_{\mathcal{X}}}H(y).∇ start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT caligraphic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG italic_H end_ARG ( italic_y ) = ( roman_I start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( italic_y ) italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT ( italic_y ) roman_I ( italic_y ) ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_I start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( italic_y ) italic_d italic_H ( italic_y ) = roman_Π ( italic_y ) ∇ start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_H ( italic_y ) .

We can prove the consistency of our numerical schemes over different PDEs with the Wasserstein gradient flow structures by leveraging this proposition in the case 𝒳=P2()𝒳superscriptsubscript𝑃2\mathcal{X}=P_{2}^{\infty}(\mathbb{R})caligraphic_X = italic_P start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( blackboard_R ) is the density manifold and g𝒳subscript𝑔𝒳g_{\mathcal{X}}italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT is chosen to be the W2subscript𝑊2W_{2}italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT metric. To achieve this, we can rewrite eq. 41 as

g𝒴H~(y)=argminvTy𝒴g𝒳H(y)vg𝒳.subscriptsubscript𝑔𝒴~𝐻𝑦subscriptargmin𝑣subscript𝑇𝑦𝒴subscriptnormsubscriptsubscript𝑔𝒳𝐻𝑦𝑣subscript𝑔𝒳\nabla_{g_{\mathcal{Y}}}\widetilde{H}(y)=\operatorname{argmin}_{v\in T_{y}% \mathcal{Y}}\left\|\nabla_{g_{\mathcal{X}}}H(y)-v\right\|_{g_{\mathcal{X}}}.∇ start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT caligraphic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT over~ start_ARG italic_H end_ARG ( italic_y ) = roman_argmin start_POSTSUBSCRIPT italic_v ∈ italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_Y end_POSTSUBSCRIPT ∥ ∇ start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_H ( italic_y ) - italic_v ∥ start_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT caligraphic_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT . (42)

Therefore, vTy𝒴for-all𝑣subscript𝑇𝑦𝒴\forall v\in T_{y}\mathcal{Y}∀ italic_v ∈ italic_T start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT caligraphic_Y will provide an upper bound for the truncated error of our approximation scheme. Moreover, if we assume that the submanifold 𝒴P2()𝒴superscriptsubscript𝑃2\mathcal{Y}\subset P_{2}^{\infty}(\mathbb{R})caligraphic_Y ⊂ italic_P start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ( blackboard_R ) is identical to a generative model via mapping function fθ#subscript𝑓𝜃#f_{\theta\#}italic_f start_POSTSUBSCRIPT italic_θ # end_POSTSUBSCRIPT, i.e. 𝒴=fθ#pr𝒴subscript𝑓𝜃#subscript𝑝r\mathcal{Y}=f_{\theta\#}p_{\mathrm{r}}caligraphic_Y = italic_f start_POSTSUBSCRIPT italic_θ # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT with θΘ𝜃Θ\theta\in\Thetaitalic_θ ∈ roman_Θ and prsubscript𝑝rp_{\mathrm{r}}italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT the base measure. Then, the projected gradient direction can also be characterized using the metric over the mapping space, i.e.

ΘH~(θ)=argminvTθΘ(θH(θ)(x)v(x))2fθ#pr(x)𝑑x,subscriptΘ~𝐻𝜃subscriptargmin𝑣subscript𝑇𝜃Θsuperscriptsubscript𝜃𝐻𝜃𝑥𝑣𝑥2subscript𝑓𝜃#subscript𝑝r𝑥differential-d𝑥\nabla_{\Theta}\widetilde{H}(\theta)=\operatorname{argmin}_{v\in T_{\theta}% \Theta}\int(\nabla_{\theta}H(\theta)(x)-v(x))^{2}f_{\theta\#}p_{\mathrm{r}}(x)dx,∇ start_POSTSUBSCRIPT roman_Θ end_POSTSUBSCRIPT over~ start_ARG italic_H end_ARG ( italic_θ ) = roman_argmin start_POSTSUBSCRIPT italic_v ∈ italic_T start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT roman_Θ end_POSTSUBSCRIPT ∫ ( ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_H ( italic_θ ) ( italic_x ) - italic_v ( italic_x ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT italic_θ # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_x ) italic_d italic_x , (43)

where θ𝜃\thetaitalic_θ is mapped to point y𝑦yitalic_y and we abuse the notion of ΘH(θ)subscriptΘ𝐻𝜃\nabla_{\Theta}H(\theta)∇ start_POSTSUBSCRIPT roman_Θ end_POSTSUBSCRIPT italic_H ( italic_θ ) to denote the gradient vector in the mapping coordinate. Moreover, we can perform truncated error analysis directly over the mapping space, which is more convenient and clear. Let us focus on the ReLU network mapping eq. 18. The tangent space in the mapping coordinate is spanned by the vectors in eq. 20. Meanwhile, the tangent space in the density coordinate is spanned by

bifθ#pr(x)=aiNpr(x)𝟏[f(θ,bi),),aifθ#pr(x)=fθ1(x)biNpr(x)𝟏[f(θ,bi),),formulae-sequencesubscriptsubscript𝑏𝑖subscript𝑓𝜃#subscript𝑝r𝑥subscript𝑎𝑖𝑁superscriptsubscript𝑝r𝑥subscript1𝑓𝜃subscript𝑏𝑖subscriptsubscript𝑎𝑖subscript𝑓𝜃#subscript𝑝r𝑥superscriptsubscript𝑓𝜃1𝑥subscript𝑏𝑖𝑁superscriptsubscript𝑝r𝑥subscript1𝑓𝜃subscript𝑏𝑖\partial_{b_{i}}f_{\theta\#}p_{\mathrm{r}}(x)=\frac{a_{i}}{N}p_{\mathrm{r}}^{% \prime}(x)\mathbf{1}_{[f(\theta,b_{i}),\infty)},\quad\partial_{a_{i}}f_{\theta% \#}p_{\mathrm{r}}(x)=\frac{f_{\theta}^{-1}(x)-b_{i}}{N}p_{\mathrm{r}}^{\prime}% (x)\mathbf{1}_{[f(\theta,b_{i}),\infty)},∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_x ) = divide start_ARG italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) bold_1 start_POSTSUBSCRIPT [ italic_f ( italic_θ , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , ∞ ) end_POSTSUBSCRIPT , ∂ start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_x ) = divide start_ARG italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_x ) - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) bold_1 start_POSTSUBSCRIPT [ italic_f ( italic_θ , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , ∞ ) end_POSTSUBSCRIPT , (44)

where the notation fθ()=f(θ,)subscript𝑓𝜃𝑓𝜃f_{\theta}(\cdot)=f(\theta,\cdot)italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ ) = italic_f ( italic_θ , ⋅ ). Here we use the fact that the mapping fθsubscript𝑓𝜃f_{\theta}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT is linear with slope j=1iajNsuperscriptsubscript𝑗1𝑖subscript𝑎𝑗𝑁\frac{\sum_{j=1}^{i}a_{j}}{N}divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG over the interval [bi,bi+1]subscript𝑏𝑖subscript𝑏𝑖1[b_{i},b_{i+1}][ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ]. If bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTs are fixed, the projected dynamics belongs to projection-based model reduction [6] where the basis is fixed to be neurons. While changing bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPTs correspond to model reduction with adaptive basis.

Proposition 9.

The numerical scheme based on ReLU network mapping is consistent with order 2222 using both a,b𝑎𝑏a,bitalic_a , italic_b parameters and of order 1111 with either a𝑎aitalic_a or b𝑏bitalic_b parameters.

Proof.

In view of eq. 20, we have that the approximation using only bifθsubscriptsubscript𝑏𝑖subscript𝑓𝜃\partial_{b_{i}}f_{\theta}∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT is simply piece-wise constant approximation. As each ingredient has the shape of a Heaviside function, it is consistent with order 1111. While the approximation using both bifθsubscriptsubscript𝑏𝑖subscript𝑓𝜃\partial_{b_{i}}f_{\theta}∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and aifθsubscriptsubscript𝑎𝑖subscript𝑓𝜃\partial_{a_{i}}f_{\theta}∂ start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT is a piece-wise linear approximation, thereby consistent of order 2222. This is because another set of ReLU-shape functions is added to the basis. ∎

The connection between the ReLU neural network and the linear finite element space is systematically studied in [15]. They theoretically establish that at least two hidden layers are needed in a ReLU neural network to represent any linear finite element functions in ΩdΩsuperscript𝑑\Omega\subset\mathbb{R}^{d}roman_Ω ⊂ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT when d2.𝑑2d\geq 2.italic_d ≥ 2 .

Based on this concrete understanding of the structure of the tangent space, we can calculate the local truncation error of the projected gradient flow.

Theorem 3.

Given a tangent vector v(x)Tfθ#pr𝒫()𝑣𝑥subscript𝑇subscript𝑓𝜃normal-#subscript𝑝normal-r𝒫v(x)\in T_{f_{\theta\#}p_{\mathrm{r}}}\mathcal{P}(\mathbb{R})italic_v ( italic_x ) ∈ italic_T start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_P ( blackboard_R ) whose approximated tangent vector in projected dynamics is given by θH(θ)subscriptnormal-∇𝜃𝐻𝜃\nabla_{\theta}H(\theta)∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_H ( italic_θ ), the local truncation error in the ReLU network mapping is given by

i=1Nbibi+1v2(fθ(z))pr(z)𝑑x(bibi+1v(fθ(z))(zmi)pr(z)𝑑z)2bibi+1(zmi)2pr(z)𝑑z(bibi+1v(fθ(z))pr(z)𝑑z)2𝔉0(bi+1)𝔉0(bi)superscriptsubscript𝑖1𝑁superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscript𝑣2subscript𝑓𝜃𝑧subscript𝑝r𝑧differential-d𝑥superscriptsuperscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1𝑣subscript𝑓𝜃𝑧𝑧subscript𝑚𝑖subscript𝑝r𝑧differential-d𝑧2superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscript𝑧subscript𝑚𝑖2subscript𝑝r𝑧differential-d𝑧superscriptsuperscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1𝑣subscript𝑓𝜃𝑧subscript𝑝r𝑧differential-d𝑧2subscript𝔉0subscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖\sum_{i=1}^{N}\int_{b_{i}}^{b_{i+1}}v^{2}(f_{\theta}(z))p_{\mathrm{r}}(z)dx-% \frac{\left(\int_{b_{i}}^{b_{i+1}}v(f_{\theta}(z))(z-m_{i})p_{\mathrm{r}}(z)dz% \right)^{2}}{\int_{b_{i}}^{b_{i+1}}(z-m_{i})^{2}p_{\mathrm{r}}(z)dz}-\frac{% \left(\int_{b_{i}}^{b_{i+1}}v(f_{\theta}(z))p_{\mathrm{r}}(z)dz\right)^{2}}{% \mathfrak{F}_{0}(b_{i+1})-\mathfrak{F}_{0}(b_{i})}∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_v start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_x - divide start_ARG ( ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_ARG - divide start_ARG ( ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG (45)

where misubscript𝑚𝑖m_{i}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the center of mass of pr(z)subscript𝑝normal-r𝑧p_{\mathrm{r}}(z)italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) in [bi,bi+1]subscript𝑏𝑖subscript𝑏𝑖1[b_{i},b_{i+1}][ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ] and bN+1subscript𝑏𝑁1b_{N+1}italic_b start_POSTSUBSCRIPT italic_N + 1 end_POSTSUBSCRIPT is understood as ++\infty+ ∞. Under the assumption that v𝑣vitalic_v has bounded second order derivative and bi+1bi<Δb,i.subscript𝑏𝑖1subscript𝑏𝑖normal-Δ𝑏for-all𝑖b_{i+1}-b_{i}<\Delta b,\forall i.italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT < roman_Δ italic_b , ∀ italic_i .

v(x)θH(θ)L2(fθ#pr)2=14(j=1NajN)2v′′O(Δb4).superscriptsubscriptnorm𝑣𝑥subscript𝜃𝐻𝜃superscript𝐿2subscript𝑓𝜃#subscript𝑝r214superscriptsuperscriptsubscript𝑗1𝑁subscript𝑎𝑗𝑁2subscriptnormsuperscript𝑣′′𝑂Δsuperscript𝑏4\left\|v(x)-\nabla_{\theta}H(\theta)\right\|_{L^{2}(f_{\theta\#}p_{\mathrm{r}}% )}^{2}=\frac{1}{4}\left(\frac{\sum_{j=1}^{N}a_{j}}{N}\right)^{2}\left\|v^{% \prime\prime}\right\|_{\infty}O(\Delta b^{4}).∥ italic_v ( italic_x ) - ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_H ( italic_θ ) ∥ start_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG 4 end_ARG ( divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ italic_v start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT italic_O ( roman_Δ italic_b start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT ) . (46)
Proof.

As mentioned in the above theorem, the approximation using ReLU network mapping is equivalent to piecewise linear approximation in the mapping coordinate. Moreover, at each node bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, the slope and value of the The function does not need to be continuous, which is exactly the same as the linear spline interpolation. The main difference is that the grid points bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are not fixed since they can evolve over time. Therefore, we rewrite the optimization problem eq. 43 as

argminci,dii=1Nfθ(bi)fθ(bi+1)(v(x)cixdi)2fθ#pr(x)𝑑x,subscriptsubscript𝑐𝑖subscript𝑑𝑖superscriptsubscript𝑖1𝑁superscriptsubscriptsubscript𝑓𝜃subscript𝑏𝑖subscript𝑓𝜃subscript𝑏𝑖1superscript𝑣𝑥subscript𝑐𝑖𝑥subscript𝑑𝑖2subscript𝑓𝜃#subscript𝑝r𝑥differential-d𝑥\arg\min_{c_{i},d_{i}}\quad\sum_{i=1}^{N}\int_{f_{\theta}(b_{i})}^{f_{\theta}(% b_{i+1})}\left(v(x)-c_{i}x-d_{i}\right)^{2}f_{\theta\#}p_{\mathrm{r}}(x)dx,roman_arg roman_min start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT ( italic_v ( italic_x ) - italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x - italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT italic_θ # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_x ) italic_d italic_x , (47)

which can be further reduced to N1𝑁1N-1italic_N - 1 separated optimization problem of ci,disubscript𝑐𝑖subscript𝑑𝑖c_{i},d_{i}italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over small interval [fθ(bi),fθ(bi+1)]subscript𝑓𝜃subscript𝑏𝑖subscript𝑓𝜃subscript𝑏𝑖1[f_{\theta}(b_{i}),f_{\theta}(b_{i+1})][ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) ]. For each subproblem, we have

fθ(bi)fθ(bi+1)(v(x)cixdi)2fθ#pr(x)𝑑x=bibi+1(v(fθ(z))cifθ(z)di)2pr(z)𝑑z.superscriptsubscriptsubscript𝑓𝜃subscript𝑏𝑖subscript𝑓𝜃subscript𝑏𝑖1superscript𝑣𝑥subscript𝑐𝑖𝑥subscript𝑑𝑖2subscript𝑓𝜃#subscript𝑝r𝑥differential-d𝑥superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscript𝑣subscript𝑓𝜃𝑧subscript𝑐𝑖subscript𝑓𝜃𝑧subscript𝑑𝑖2subscript𝑝r𝑧differential-d𝑧\displaystyle\int_{f_{\theta}(b_{i})}^{f_{\theta}(b_{i+1})}\left(v(x)-c_{i}x-d% _{i}\right)^{2}f_{\theta\#}p_{\mathrm{r}}(x)dx=\int_{b_{i}}^{b_{i+1}}\left(v(f% _{\theta}(z))-c_{i}f_{\theta}(z)-d_{i}\right)^{2}p_{\mathrm{r}}(z)dz.∫ start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT ( italic_v ( italic_x ) - italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x - italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT italic_θ # end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_x ) italic_d italic_x = ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) - italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) - italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z .

This is a quadratic optimization problem of ci,disubscript𝑐𝑖subscript𝑑𝑖c_{i},d_{i}italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT with positive definite Hessian matrix. Taking derivative w.r.t. ci,disubscript𝑐𝑖subscript𝑑𝑖c_{i},d_{i}italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, we obtain

bibi+1(v(fθ(z))cifθ(z)di)pr(z)𝑑zsuperscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1𝑣subscript𝑓𝜃𝑧subscript𝑐𝑖subscript𝑓𝜃𝑧subscript𝑑𝑖subscript𝑝r𝑧differential-d𝑧\displaystyle\int_{b_{i}}^{b_{i+1}}\left(v(f_{\theta}(z))-c_{i}f_{\theta}(z)-d% _{i}\right)p_{\mathrm{r}}(z)dz∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) - italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) - italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z =0,absent0\displaystyle=0,= 0 ,
bibi+1fθ(z)(v(fθ(z))cifθ(z)di)pr(z)𝑑zsuperscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1subscript𝑓𝜃𝑧𝑣subscript𝑓𝜃𝑧subscript𝑐𝑖subscript𝑓𝜃𝑧subscript𝑑𝑖subscript𝑝r𝑧differential-d𝑧\displaystyle\int_{b_{i}}^{b_{i+1}}f_{\theta}(z)\left(v(f_{\theta}(z))-c_{i}f_% {\theta}(z)-d_{i}\right)p_{\mathrm{r}}(z)dz∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ( italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) - italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) - italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z =0.absent0\displaystyle=0.= 0 .

Now, using the fact that fθ(z)subscript𝑓𝜃𝑧f_{\theta}(z)italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) is a linear function over the interval [bi,bi+1]subscript𝑏𝑖subscript𝑏𝑖1[b_{i},b_{i+1}][ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ], we have

cifθ(z)+di=bibi+1v(fθ(z))(zmi)pr(z)𝑑xbibi+1(zmi)2pr(z)𝑑z(zmi)+bibi+1v(fθ(z))pr(z)𝑑x𝔉0(bi+1)𝔉0(bi).subscript𝑐𝑖subscript𝑓𝜃𝑧subscript𝑑𝑖superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1𝑣subscript𝑓𝜃𝑧𝑧subscript𝑚𝑖subscript𝑝r𝑧differential-d𝑥superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscript𝑧subscript𝑚𝑖2subscript𝑝r𝑧differential-d𝑧𝑧subscript𝑚𝑖superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1𝑣subscript𝑓𝜃𝑧subscript𝑝r𝑧differential-d𝑥subscript𝔉0subscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖\displaystyle c_{i}f_{\theta}(z)+d_{i}=\frac{\int_{b_{i}}^{b_{i+1}}v(f_{\theta% }(z))(z-m_{i})p_{\mathrm{r}}(z)dx}{\int_{b_{i}}^{b_{i+1}}(z-m_{i})^{2}p_{% \mathrm{r}}(z)dz}(z-m_{i})+\frac{\int_{b_{i}}^{b_{i+1}}v(f_{\theta}(z))p_{% \mathrm{r}}(z)dx}{\mathfrak{F}_{0}(b_{i+1})-\mathfrak{F}_{0}(b_{i})}.italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) + italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = divide start_ARG ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_x end_ARG start_ARG ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_ARG ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + divide start_ARG ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_x end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG . (48)

Plugging back, we obtain the approximation error as

bibi+1(v(fθ(z))cifθ(z)di)2pr(z)𝑑zsuperscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscript𝑣subscript𝑓𝜃𝑧subscript𝑐𝑖subscript𝑓𝜃𝑧subscript𝑑𝑖2subscript𝑝r𝑧differential-d𝑧\displaystyle\int_{b_{i}}^{b_{i+1}}\left(v(f_{\theta}(z))-c_{i}f_{\theta}(z)-d% _{i}\right)^{2}p_{\mathrm{r}}(z)dz∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) - italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) - italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z (49)
=\displaystyle== bibi+1v(fθ(z))(v(fθ(z))cifθ(z)di)pr(z)𝑑zsuperscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1𝑣subscript𝑓𝜃𝑧𝑣subscript𝑓𝜃𝑧subscript𝑐𝑖subscript𝑓𝜃𝑧subscript𝑑𝑖subscript𝑝r𝑧differential-d𝑧\displaystyle\ \int_{b_{i}}^{b_{i+1}}v(f_{\theta}(z))\left(v(f_{\theta}(z))-c_% {i}f_{\theta}(z)-d_{i}\right)p_{\mathrm{r}}(z)dz∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) ( italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) - italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) - italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z
=\displaystyle== bibi+1v2(fθ(z))pr(z)𝑑z(bibi+1v(fθ(z))(zmi)pr(z)𝑑z)2bibi+1(zmi)2pr(z)𝑑z(bibi+1v(fθ(z))pr(z)𝑑z)2𝔉0(bi+1)𝔉0(bi).superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscript𝑣2subscript𝑓𝜃𝑧subscript𝑝r𝑧differential-d𝑧superscriptsuperscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1𝑣subscript𝑓𝜃𝑧𝑧subscript𝑚𝑖subscript𝑝r𝑧differential-d𝑧2superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscript𝑧subscript𝑚𝑖2subscript𝑝r𝑧differential-d𝑧superscriptsuperscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1𝑣subscript𝑓𝜃𝑧subscript𝑝r𝑧differential-d𝑧2subscript𝔉0subscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖\displaystyle\ \int_{b_{i}}^{b_{i+1}}v^{2}(f_{\theta}(z))p_{\mathrm{r}}(z)dz-% \frac{\left(\int_{b_{i}}^{b_{i+1}}v(f_{\theta}(z))(z-m_{i})p_{\mathrm{r}}(z)dz% \right)^{2}}{\int_{b_{i}}^{b_{i+1}}(z-m_{i})^{2}p_{\mathrm{r}}(z)dz}-\frac{% \left(\int_{b_{i}}^{b_{i+1}}v(f_{\theta}(z))p_{\mathrm{r}}(z)dz\right)^{2}}{% \mathfrak{F}_{0}(b_{i+1})-\mathfrak{F}_{0}(b_{i})}.∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_v start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z - divide start_ARG ( ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_ARG - divide start_ARG ( ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG .

Next, we assume all the intervals [bi,bi+1]subscript𝑏𝑖subscript𝑏𝑖1[b_{i},b_{i+1}][ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ] are short (of scale O(Δ)𝑂ΔO(\Delta)italic_O ( roman_Δ )) and consider expanding the v𝑣vitalic_v as Taylor series around misubscript𝑚𝑖m_{i}italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, i.e.

v(fθ(z))=𝑣subscript𝑓𝜃𝑧absent\displaystyle v(f_{\theta}(z))=italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) = v(fθ(mi))+j=1iajNv(fθ(mi))(zmi)𝑣subscript𝑓𝜃subscript𝑚𝑖superscriptsubscript𝑗1𝑖subscript𝑎𝑗𝑁superscript𝑣subscript𝑓𝜃subscript𝑚𝑖𝑧subscript𝑚𝑖\displaystyle\ v(f_{\theta}(m_{i}))+\frac{\sum_{j=1}^{i}a_{j}}{N}v^{\prime}(f_% {\theta}(m_{i}))(z-m_{i})italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) + divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) (50)
+12(j=1iajN)2v′′(fθ(mi))(zmi)2+O(Δ3).12superscriptsuperscriptsubscript𝑗1𝑖subscript𝑎𝑗𝑁2superscript𝑣′′subscript𝑓𝜃subscript𝑚𝑖superscript𝑧subscript𝑚𝑖2𝑂superscriptΔ3\displaystyle\ +\frac{1}{2}\left(\frac{\sum_{j=1}^{i}a_{j}}{N}\right)^{2}v^{% \prime\prime}(f_{\theta}(m_{i}))(z-m_{i})^{2}+O(\Delta^{3}).+ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_v start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_O ( roman_Δ start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ) .

Here, we use the fact that fθ(z)subscript𝑓𝜃𝑧f_{\theta}(z)italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) is a linear function with slope j=1iajNsuperscriptsubscript𝑗1𝑖subscript𝑎𝑗𝑁\frac{\sum_{j=1}^{i}a_{j}}{N}divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG over the interval [bi,bi+1]subscript𝑏𝑖subscript𝑏𝑖1[b_{i},b_{i+1}][ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ]. Plugging into eq. 48, we obtain

cifθ(z)di=subscript𝑐𝑖subscript𝑓𝜃𝑧subscript𝑑𝑖absent\displaystyle c_{i}f_{\theta}(z)-d_{i}=italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) - italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = j=1iajNv(fθ(mi))(zmi)+v(fθ(mi))superscriptsubscript𝑗1𝑖subscript𝑎𝑗𝑁superscript𝑣subscript𝑓𝜃subscript𝑚𝑖𝑧subscript𝑚𝑖𝑣subscript𝑓𝜃subscript𝑚𝑖\displaystyle\ \frac{\sum_{j=1}^{i}a_{j}}{N}v^{\prime}(f_{\theta}(m_{i}))(z-m_% {i})+v(f_{\theta}(m_{i}))divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) (51)
+12(j=1iajN)2v′′(fθ(mi))bibi+1(zmi)2pr(z)𝑑z𝔉0(bi+1)𝔉0(bi)12superscriptsuperscriptsubscript𝑗1𝑖subscript𝑎𝑗𝑁2superscript𝑣′′subscript𝑓𝜃subscript𝑚𝑖superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscript𝑧subscript𝑚𝑖2subscript𝑝r𝑧differential-d𝑧subscript𝔉0subscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖\displaystyle\ +\frac{1}{2}\left(\frac{\sum_{j=1}^{i}a_{j}}{N}\right)^{2}v^{% \prime\prime}(f_{\theta}(m_{i}))\frac{\int_{b_{i}}^{b_{i+1}}(z-m_{i})^{2}p_{% \mathrm{r}}(z)dz}{\mathfrak{F}_{0}(b_{i+1})-\mathfrak{F}_{0}(b_{i})}+ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_v start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) divide start_ARG ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG
+12(j=1iajN)2v′′(fθ(mi))bibi+1(zmi)3pr(z)𝑑zbibi+1(zmi)2pr(z)𝑑z(zmi)+O(Δ3).12superscriptsuperscriptsubscript𝑗1𝑖subscript𝑎𝑗𝑁2superscript𝑣′′subscript𝑓𝜃subscript𝑚𝑖superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscript𝑧subscript𝑚𝑖3subscript𝑝r𝑧differential-d𝑧superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscript𝑧subscript𝑚𝑖2subscript𝑝r𝑧differential-d𝑧𝑧subscript𝑚𝑖𝑂superscriptΔ3\displaystyle\ +\frac{1}{2}\left(\frac{\sum_{j=1}^{i}a_{j}}{N}\right)^{2}v^{% \prime\prime}(f_{\theta}(m_{i}))\frac{\int_{b_{i}}^{b_{i+1}}(z-m_{i})^{3}p_{% \mathrm{r}}(z)dz}{\int_{b_{i}}^{b_{i+1}}(z-m_{i})^{2}p_{\mathrm{r}}(z)dz}(z-m_% {i})+O(\Delta^{3}).+ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_v start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) divide start_ARG ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_ARG start_ARG ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_ARG ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_O ( roman_Δ start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ) .

Notice that the first two terms are exactly the zero-th and first order term of the v(fθ(z))𝑣subscript𝑓𝜃𝑧v(f_{\theta}(z))italic_v ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z ) ) function which is similar to classical linear function approximation by discarding all the higher order term. The appearance of residue terms is due to approximation in L2(p)superscript𝐿2𝑝L^{2}(p)italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_p ) sense. To calculate the L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-approximation error, we have

[12(j=1iajN)2v′′(fθ(mi))]2superscriptdelimited-[]12superscriptsuperscriptsubscript𝑗1𝑖subscript𝑎𝑗𝑁2superscript𝑣′′subscript𝑓𝜃subscript𝑚𝑖2\displaystyle\ \left[\frac{1}{2}\left(\frac{\sum_{j=1}^{i}a_{j}}{N}\right)^{2}% v^{\prime\prime}(f_{\theta}(m_{i}))\right]^{2}[ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_v start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (52)
bibi+1((zmi)2bibi+1(zmi)2pr(z)𝑑z𝔉0(bi+1)𝔉0(bi)bibi+1(zmi)3pr(z)𝑑zbibi+1(zmi)2pr(z)𝑑z(zmi)+O(Δ3))2pr(z)𝑑zsuperscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscriptsuperscript𝑧subscript𝑚𝑖2superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscript𝑧subscript𝑚𝑖2subscript𝑝r𝑧differential-d𝑧subscript𝔉0subscript𝑏𝑖1subscript𝔉0subscript𝑏𝑖superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscript𝑧subscript𝑚𝑖3subscript𝑝r𝑧differential-d𝑧superscriptsubscriptsubscript𝑏𝑖subscript𝑏𝑖1superscript𝑧subscript𝑚𝑖2subscript𝑝r𝑧differential-d𝑧𝑧subscript𝑚𝑖𝑂superscriptΔ32subscript𝑝r𝑧differential-d𝑧\displaystyle\ \int_{b_{i}}^{b_{i+1}}\left((z-m_{i})^{2}-\frac{\int_{b_{i}}^{b% _{i+1}}(z-m_{i})^{2}p_{\mathrm{r}}(z)dz}{\mathfrak{F}_{0}(b_{i+1})-\mathfrak{F% }_{0}(b_{i})}-\frac{\int_{b_{i}}^{b_{i+1}}(z-m_{i})^{3}p_{\mathrm{r}}(z)dz}{% \int_{b_{i}}^{b_{i+1}}(z-m_{i})^{2}p_{\mathrm{r}}(z)dz}(z-m_{i})+O(\Delta^{3})% \right)^{2}p_{\mathrm{r}}(z)dz∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - divide start_ARG ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_ARG start_ARG fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) - fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG - divide start_ARG ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_ARG start_ARG ∫ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z end_ARG ( italic_z - italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_O ( roman_Δ start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) italic_d italic_z
=\displaystyle== [12(j=1iajN)2v′′(fθ(mi))]2O((bi+1bi)5)pr(bi)+O((bi+1bi)6)pr(bi).superscriptdelimited-[]12superscriptsuperscriptsubscript𝑗1𝑖subscript𝑎𝑗𝑁2superscript𝑣′′subscript𝑓𝜃subscript𝑚𝑖2𝑂superscriptsubscript𝑏𝑖1subscript𝑏𝑖5subscript𝑝rsubscript𝑏𝑖𝑂superscriptsubscript𝑏𝑖1subscript𝑏𝑖6subscript𝑝rsubscript𝑏𝑖\displaystyle\ \left[\frac{1}{2}\left(\frac{\sum_{j=1}^{i}a_{j}}{N}\right)^{2}% v^{\prime\prime}(f_{\theta}(m_{i}))\right]^{2}O((b_{i+1}-b_{i})^{5})p_{\mathrm% {r}}(b_{i})+O((b_{i+1}-b_{i})^{6})p_{\mathrm{r}}(b_{i}).[ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_v start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_O ( ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_O ( ( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT ) italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) .

In summary, the L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT approximation error consists of the sum over all the interval [bi,bi+1]subscript𝑏𝑖subscript𝑏𝑖1[b_{i},b_{i+1}][ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ], with each term depends on aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT through the factor j=1iajNsuperscriptsubscript𝑗1𝑖subscript𝑎𝑗𝑁\frac{\sum_{j=1}^{i}a_{j}}{N}divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG, on bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT through (bi+1bi)5superscriptsubscript𝑏𝑖1subscript𝑏𝑖5(b_{i+1}-b_{i})^{5}( italic_b start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT and the term v′′(fθ(mi))superscript𝑣′′subscript𝑓𝜃subscript𝑚𝑖v^{\prime\prime}(f_{\theta}(m_{i}))italic_v start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ), which also contains ai,bisubscript𝑎𝑖subscript𝑏𝑖a_{i},b_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. ∎

Let us calculate a special case of the Fokker-Planck equation

tp(t,x)(p(t,x)V(x))γΔp(t,x)=0.subscript𝑡𝑝𝑡𝑥𝑝𝑡𝑥𝑉𝑥𝛾Δ𝑝𝑡𝑥0\partial_{t}p(t,x)-\nabla\cdot(p(t,x)\nabla V(x))-\gamma\Delta p(t,x)=0.∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) - ∇ ⋅ ( italic_p ( italic_t , italic_x ) ∇ italic_V ( italic_x ) ) - italic_γ roman_Δ italic_p ( italic_t , italic_x ) = 0 .

Under the Wasserstein metric, the tangent vector in the mapping space is given by

v(x)=V(x)γp(t,x)p(t,x).𝑣𝑥superscript𝑉𝑥𝛾superscript𝑝𝑡𝑥𝑝𝑡𝑥v(x)=-V^{\prime}(x)-\gamma\frac{p^{\prime}(t,x)}{p(t,x)}.italic_v ( italic_x ) = - italic_V start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) - italic_γ divide start_ARG italic_p start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t , italic_x ) end_ARG start_ARG italic_p ( italic_t , italic_x ) end_ARG .

In this case, we have that

v′′(x)=V(3)(x)γp(3)(t,x)p(t,x)2+2p(t,x)33p(t,x)p(t,x)p′′(t,x)p(t,x)3.superscript𝑣′′𝑥superscript𝑉3𝑥𝛾superscript𝑝3𝑡𝑥𝑝superscript𝑡𝑥22superscript𝑝superscript𝑡𝑥33𝑝𝑡𝑥superscript𝑝𝑡𝑥superscript𝑝′′𝑡𝑥𝑝superscript𝑡𝑥3v^{\prime\prime}(x)=-V^{(3)}(x)-\gamma\frac{p^{(3)}(t,x)p(t,x)^{2}+2p^{\prime}% (t,x)^{3}-3p(t,x)p^{\prime}(t,x)p^{\prime\prime}(t,x)}{p(t,x)^{3}}.italic_v start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( italic_x ) = - italic_V start_POSTSUPERSCRIPT ( 3 ) end_POSTSUPERSCRIPT ( italic_x ) - italic_γ divide start_ARG italic_p start_POSTSUPERSCRIPT ( 3 ) end_POSTSUPERSCRIPT ( italic_t , italic_x ) italic_p ( italic_t , italic_x ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 italic_p start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t , italic_x ) start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT - 3 italic_p ( italic_t , italic_x ) italic_p start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t , italic_x ) italic_p start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ( italic_t , italic_x ) end_ARG start_ARG italic_p ( italic_t , italic_x ) start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT end_ARG .

The above function will determine the approximation quality of the projected dynamics.

Remark 1.

The high-order neural mapping function class and associated high-order projected dynamics can also be derived following a similar procedure. For example, we can add a quadratic term of the ReLU function into the network mapping function as

f(θ,z)=1Ni=1Naiσ(zbi)+ciσ2(zbi).𝑓𝜃𝑧1𝑁superscriptsubscript𝑖1𝑁subscript𝑎𝑖𝜎𝑧subscript𝑏𝑖subscript𝑐𝑖superscript𝜎2𝑧subscript𝑏𝑖f\left(\theta,z\right)=\frac{1}{N}\sum_{i=1}^{N}a_{i}\sigma\left(z-b_{i}\right% )+c_{i}\sigma^{2}\left(z-b_{i}\right).italic_f ( italic_θ , italic_z ) = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_σ ( italic_z - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_z - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) . (53)

Notice that adding high order ReLU term is different from increasing the layers in the ReLU neural network which corresponds to function composition. We leave the detailed analysis and numerical experiments on high-order methods in future work.

5. Numerical Examples

In this section, we provide several numerical experiments to test our algorithm and theories. We focus our attention on the linear transport equation, Fokker-Planck equation, porous medium equations, and Keller-Segel equation. They all correspond to some specific energy functionals in the probability space equipped with the Wasserstein-2 distance.

5.1. Neural Network structure

We first describe the structure of our neural network for the experiment. We focus on two-layer neural network with ReLU as activation functions.

f(θ,z)=i=1Naiσ(zbi)+i=N+12Naiσ(biz).𝑓𝜃𝑧superscriptsubscript𝑖1𝑁subscript𝑎𝑖𝜎𝑧subscript𝑏𝑖superscriptsubscript𝑖𝑁12𝑁subscript𝑎𝑖𝜎subscript𝑏𝑖𝑧f(\theta,z)=\sum_{i=1}^{N}a_{i}\cdot\sigma(z-b_{i})+\sum_{i=N+1}^{2N}a_{i}% \cdot\sigma(b_{i}-z)\,.italic_f ( italic_θ , italic_z ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⋅ italic_σ ( italic_z - italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_i = italic_N + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⋅ italic_σ ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_z ) . (54)

Here θ4N𝜃superscript4𝑁\theta\in\mathbb{R}^{4N}italic_θ ∈ blackboard_R start_POSTSUPERSCRIPT 4 italic_N end_POSTSUPERSCRIPT represents the collection of weights {ai}i=12Nsuperscriptsubscriptsubscript𝑎𝑖𝑖12𝑁\{a_{i}\}_{i=1}^{2N}{ italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_N end_POSTSUPERSCRIPT and bias {bi}i=12Nsuperscriptsubscriptsubscript𝑏𝑖𝑖12𝑁\{b_{i}\}_{i=1}^{2N}{ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_N end_POSTSUPERSCRIPT. To simplify our notation, we have absorbed the 1/N1𝑁1/N1 / italic_N factor into aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s. At initialization, we set ai=1/Nsubscript𝑎𝑖1𝑁a_{i}=1/Nitalic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 / italic_N for i{1,,N}𝑖1𝑁i\in\{1,\ldots,N\}italic_i ∈ { 1 , … , italic_N } and ai=1/Nsubscript𝑎𝑖1𝑁a_{i}=-1/Nitalic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = - 1 / italic_N for i{N+1,,2N}𝑖𝑁12𝑁i\in\{N+1,\ldots,2N\}italic_i ∈ { italic_N + 1 , … , 2 italic_N }. To choose the bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s, we first set 𝐛=linspace(B,B,N)𝐛linspace𝐵𝐵𝑁\mathbf{b}=\textrm{linspace}(-B,B,N)bold_b = linspace ( - italic_B , italic_B , italic_N ) for some positive constant B𝐵Bitalic_B (e.g. B=4𝐵4B=4italic_B = 4 or B=10𝐵10B=10italic_B = 10). We then set bi=𝐛[i]subscript𝑏𝑖𝐛delimited-[]𝑖b_{i}=\mathbf{b}[i]italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_b [ italic_i ] for i=1,,N𝑖1𝑁i=1,\ldots,Nitalic_i = 1 , … , italic_N and bj=𝐛[jN]+εsubscript𝑏𝑗𝐛delimited-[]𝑗𝑁𝜀b_{j}=\mathbf{b}[j-N]+\varepsilonitalic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = bold_b [ italic_j - italic_N ] + italic_ε for j=N+1,,2N𝑗𝑁12𝑁j=N+1,\ldots,2Nitalic_j = italic_N + 1 , … , 2 italic_N. Here ε=5×106𝜀5superscript106\varepsilon=5\times 10^{-6}italic_ε = 5 × 10 start_POSTSUPERSCRIPT - 6 end_POSTSUPERSCRIPT is a small offset which will be explained later in Section 5.3. Our initialization is chosen such that f(θ,)𝑓𝜃f(\theta,\cdot)italic_f ( italic_θ , ⋅ ) approximates the identity map at initialization. In practice, we find it beneficial to perform a rescaling of the weights aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s. We replace aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT with a¯i/βsubscript¯𝑎𝑖𝛽\overline{a}_{i}/\betaover¯ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / italic_β for some fixed constant β>0𝛽0\beta>0italic_β > 0. And we initialize a¯i=β/Nsubscript¯𝑎𝑖𝛽𝑁\overline{a}_{i}=\beta/Nover¯ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_β / italic_N for i{1,,N}𝑖1𝑁i\in\{1,\ldots,N\}italic_i ∈ { 1 , … , italic_N } and a¯i=β/Nsubscript¯𝑎𝑖𝛽𝑁\overline{a}_{i}=-\beta/Nover¯ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = - italic_β / italic_N for i{N+1,,2N}𝑖𝑁12𝑁i\in\{N+1,\ldots,2N\}italic_i ∈ { italic_N + 1 , … , 2 italic_N }. This rescaling makes sure that f(θ,)𝑓𝜃f(\theta,\cdot)italic_f ( italic_θ , ⋅ ) still approximates the identity map at initialization. We provide a brief intuition for rescaling. Let us consider g(a,b,z)=aσ(bz)𝑔𝑎𝑏𝑧𝑎𝜎𝑏𝑧g(a,b,z)=a\cdot\sigma(b-z)italic_g ( italic_a , italic_b , italic_z ) = italic_a ⋅ italic_σ ( italic_b - italic_z ) for b=𝒪(1)𝑏𝒪1b=\mathcal{O}(1)italic_b = caligraphic_O ( 1 ), z=𝒪(1)𝑧𝒪1z=\mathcal{O}(1)italic_z = caligraphic_O ( 1 ), a=𝒪(1/N)𝑎𝒪1𝑁a=\mathcal{O}(1/N)italic_a = caligraphic_O ( 1 / italic_N ). Then ag=σ(zb)=𝒪(1)subscript𝑎𝑔𝜎𝑧𝑏𝒪1\partial_{a}g=\sigma(z-b)=\mathcal{O}(1)∂ start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT italic_g = italic_σ ( italic_z - italic_b ) = caligraphic_O ( 1 ). On the other hand, bg=aσ(bz)=𝒪(1/N)subscript𝑏𝑔𝑎superscript𝜎𝑏𝑧𝒪1𝑁\partial_{b}g=a\cdot\sigma^{\prime}(b-z)=\mathcal{O}(1/N)∂ start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT italic_g = italic_a ⋅ italic_σ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_b - italic_z ) = caligraphic_O ( 1 / italic_N ). This simple calculation shows that the partial gradient of (54)italic-(54italic-)\eqref{eq:nn}italic_( italic_) with respect to weights and bias are of different scales. Therefore, to make them the same scale, a natural choice is choosing β=𝒪(N)𝛽𝒪𝑁\beta=\mathcal{O}(N)italic_β = caligraphic_O ( italic_N ).

Remark 2.

The choice of neural network (54) is slightly more complicated than the one studied in Section 4. This symmetric structure is used in numerical experiments to overcome ReLU’s drawback such that only the positive input is activated. Moreover, (54) allows us to construct an approximation to the identity map over \mathbb{R}blackboard_R easily. However, the results of Proposition 9 still hold for (54). And Theorem 3 can be generalized to neural network of the form given in (54) in a straightforward manner. The metric tensor GWsubscript𝐺WG_{\mathrm{W}}italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT is now a 4N×4N4𝑁4𝑁4N\times 4N4 italic_N × 4 italic_N matrix. The calculations of the individual components of GWsubscript𝐺WG_{\mathrm{W}}italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT follow the same procedure presented in proposition 4.

Remark 3.

We remind our readers that our algorithm takes the form of

θk+1=θkhGW(θ)θF~(θk).superscript𝜃𝑘1superscript𝜃𝑘subscript𝐺Wsuperscript𝜃subscript𝜃~𝐹superscript𝜃𝑘\theta^{k+1}=\theta^{k}-hG_{\mathrm{W}}(\theta)^{\dagger}\nabla_{\theta}\tilde% {F}(\theta^{k})\,.italic_θ start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT = italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - italic_h italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) start_POSTSUPERSCRIPT † end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over~ start_ARG italic_F end_ARG ( italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) .

During implementation, θF~(θ)subscript𝜃~𝐹𝜃\nabla_{\theta}\tilde{F}(\theta)∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over~ start_ARG italic_F end_ARG ( italic_θ ) can be obtained by backpropagating F~(θ)~𝐹𝜃\tilde{F}(\theta)over~ start_ARG italic_F end_ARG ( italic_θ ) in the case of Example 4 and Example 5. However, we need to pay special attention to biF(θ)subscriptsubscript𝑏𝑖𝐹𝜃\partial_{b_{i}}F(\theta)∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_F ( italic_θ ) when dealing with Example 6. This will be elaborated further in Section 5.3 and Section 5.4.

5.2. Linear transport PDE

We investigate the linear transport PDE given by Eq. (14) with several choices of potential V(x)𝑉𝑥V(x)italic_V ( italic_x ), corresponding to the gradient flow of them under the Wasserstein metric. For a simple potential function, this example can serve as a sanity check of the projected dynamics formulation. The trajectories of the particles for Eq. (14) (i.e. Lagrangian formulation) follows the following ODE

x˙(t)=V(x).˙𝑥𝑡𝑉𝑥\dot{x}(t)=-\nabla V(x)\,.over˙ start_ARG italic_x end_ARG ( italic_t ) = - ∇ italic_V ( italic_x ) . (55)

Let us denote by T(t,z0)𝑇𝑡subscript𝑧0T(t,z_{0})italic_T ( italic_t , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) the solution to Eq. (55) with initial condition x(0)=z0𝑥0subscript𝑧0x(0)=z_{0}italic_x ( 0 ) = italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. In other words, T(t,z0)𝑇𝑡subscript𝑧0T(t,z_{0})italic_T ( italic_t , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is the transport map at time t𝑡titalic_t starting from position z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. We define the error at time t𝑡titalic_t by

errorerror\displaystyle\mathrm{error}roman_error =|f(θt,z0)T(t,z0)|p0(z0)𝑑z0absentsuperscriptsubscript𝑓subscript𝜃𝑡subscript𝑧0𝑇𝑡subscript𝑧0subscript𝑝0subscript𝑧0differential-dsubscript𝑧0\displaystyle=\int_{-\infty}^{\infty}|f(\theta_{t},z_{0})-T(t,z_{0})|p_{0}(z_{% 0})\;dz_{0}= ∫ start_POSTSUBSCRIPT - ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT | italic_f ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) - italic_T ( italic_t , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) | italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_d italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
1N1j=1N1|f(θt,zj)T(t,zj)|p0(zj),absent1subscript𝑁1superscriptsubscript𝑗1subscript𝑁1𝑓subscript𝜃𝑡subscript𝑧𝑗𝑇𝑡subscript𝑧𝑗subscript𝑝0subscript𝑧𝑗\displaystyle\approx\frac{1}{N_{1}}\sum_{j=1}^{N_{1}}|f(\theta_{t},z_{j})-T(t,% z_{j})|p_{0}(z_{j})\,,≈ divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT | italic_f ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) - italic_T ( italic_t , italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) | italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) , (56)

where we discretize the integration domain by N1subscript𝑁1N_{1}italic_N start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT equally spaced points to approximate the integral. And p0(z0)subscript𝑝0subscript𝑧0p_{0}(z_{0})italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) denotes the initial distribution of z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Below we test our projected dynamics under three choices of potential functions and investigate the convergence behavior of two projected dynamics: (i) fixing the bias terms bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and only updating the weights aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and (ii) updating both bias bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and weights aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. Note that when the bias terms bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s are fixed, we have that GW2N×2Nsubscript𝐺Wsuperscript2𝑁2𝑁G_{\mathrm{W}}\in\mathbb{R}^{2N\times 2N}italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 italic_N × 2 italic_N end_POSTSUPERSCRIPT. Recall that we are essentially simulating the gradient flow on parameter θtsubscript𝜃𝑡\theta_{t}italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT given by Eq. (11). We use M=5×105𝑀5superscript105M=5\times 10^{5}italic_M = 5 × 10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT particles sampled from a standard Gaussian distribution for approximating 𝔼z~pr[V(f(θ,z~))]subscript𝔼similar-to~𝑧subscript𝑝rdelimited-[]𝑉𝑓𝜃~𝑧\mathbb{E}_{\tilde{z}\sim p_{\mathrm{r}}}\Big{[}V(f(\theta,\tilde{z}))\Big{]}blackboard_E start_POSTSUBSCRIPT over~ start_ARG italic_z end_ARG ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V ( italic_f ( italic_θ , over~ start_ARG italic_z end_ARG ) ) ]. Once we have the empirical loss function

𝔼z~pr[V(f(θ,z~))]1Mi=1MV(f(θ,zi)),subscript𝔼similar-to~𝑧subscript𝑝rdelimited-[]𝑉𝑓𝜃~𝑧1𝑀superscriptsubscript𝑖1𝑀𝑉𝑓𝜃subscript𝑧𝑖\mathbb{E}_{\tilde{z}\sim p_{\mathrm{r}}}\Big{[}V(f(\theta,\tilde{z}))\Big{]}% \approx\frac{1}{M}\sum_{i=1}^{M}V(f(\theta,z_{i}))\,,blackboard_E start_POSTSUBSCRIPT over~ start_ARG italic_z end_ARG ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V ( italic_f ( italic_θ , over~ start_ARG italic_z end_ARG ) ) ] ≈ divide start_ARG 1 end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_V ( italic_f ( italic_θ , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ,

we can backpropagate this loss to obtain

𝔼z~pr[θV(f(θ,z~))]1Mi=1MθV(f(θ,zi)),subscript𝔼similar-to~𝑧subscript𝑝rdelimited-[]subscript𝜃𝑉𝑓𝜃~𝑧1𝑀superscriptsubscript𝑖1𝑀subscript𝜃𝑉𝑓𝜃subscript𝑧𝑖\mathbb{E}_{\tilde{z}\sim p_{\mathrm{r}}}\Big{[}\nabla_{\theta}V(f(\theta,% \tilde{z}))\Big{]}\approx\frac{1}{M}\sum_{i=1}^{M}\nabla_{\theta}V(f(\theta,z_% {i}))\,,blackboard_E start_POSTSUBSCRIPT over~ start_ARG italic_z end_ARG ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_V ( italic_f ( italic_θ , over~ start_ARG italic_z end_ARG ) ) ] ≈ divide start_ARG 1 end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_V ( italic_f ( italic_θ , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ,

which will be used in the update of θtsubscript𝜃𝑡\theta_{t}italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT given by Eq. (11).

5.2.1. Quadratic potential

As the first example for linear transport PDE, we consider the quadratic potential V(x)=12(xμ0)2𝑉𝑥12superscript𝑥subscript𝜇02V(x)=\frac{1}{2}(x-\mu_{0})^{2}italic_V ( italic_x ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_x - italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT as a sanity check. The stationary distribution will be the delta mass supported at μ0subscript𝜇0\mu_{0}italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Using the method of characteristics, one can show that the solution at time t>0𝑡0t>0italic_t > 0 is given by

p(t,x)=p0((xμ0)et+μ0)et,𝑝𝑡𝑥subscript𝑝0𝑥subscript𝜇0superscript𝑒𝑡subscript𝜇0superscript𝑒𝑡p(t,x)=p_{0}\big{(}(x-\mu_{0})e^{t}+\mu_{0}\big{)}e^{t}\,,italic_p ( italic_t , italic_x ) = italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( ( italic_x - italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_e start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT + italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_e start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT , (57)

where p0(x)=p(0,x)subscript𝑝0𝑥𝑝0𝑥p_{0}(x)=p(0,x)italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x ) = italic_p ( 0 , italic_x ) is the initial distribution. In Lagrangian coordinates, the transport map of a point z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT at time t𝑡titalic_t is given by

T(t,z0)=μ0+et(z0μ0).𝑇𝑡subscript𝑧0subscript𝜇0superscript𝑒𝑡subscript𝑧0subscript𝜇0T(t,z_{0})=\mu_{0}+e^{-t}(z_{0}-\mu_{0})\,.italic_T ( italic_t , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_e start_POSTSUPERSCRIPT - italic_t end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) . (58)

One can check that T(z0,0)=z0𝑇subscript𝑧00subscript𝑧0T(z_{0},0)=z_{0}italic_T ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , 0 ) = italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and T(t,z0)μ0𝑇𝑡subscript𝑧0subscript𝜇0T(t,z_{0})\to\mu_{0}italic_T ( italic_t , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) → italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT as t𝑡t\to\inftyitalic_t → ∞. It is worthwhile mentioning that at each t>0𝑡0t>0italic_t > 0, the Lagrangian map xt(z0):z0T(t,z0):subscript𝑥𝑡subscript𝑧0maps-tosubscript𝑧0𝑇𝑡subscript𝑧0x_{t}(z_{0}):z_{0}\mapsto T(t,z_{0})italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) : italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ↦ italic_T ( italic_t , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is a linear map. For simplicity, we take δ0=0subscript𝛿00\delta_{0}=0italic_δ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 0. We choose dt=103𝑑𝑡superscript103dt=10^{-3}italic_d italic_t = 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT and run for 1000 steps. We compare our numerical results with Eq. (58). The result is demonstrated in Fig. 2. In Fig. 1(b), we have provided a visualization of the analytic solution to the linear transport PDE in Lagrangian coordinates at t=1𝑡1t=1italic_t = 1 and our computed solution. As shown in Fig. 1(b), the analytic transport map is linear while the neural mapping function is piecewise linear. Increasing N𝑁Nitalic_N does not necessarily give a smaller approximation error. In fact, we see in Fig. 1(a) that larger N𝑁Nitalic_N usually gives a larger error, commonly known as overfitting in machine learning.

Refer to caption
(a) Error
Refer to caption
(b) Mapping comparison
Figure 2. Left: log-log plot of linear transport PDE with a quadratic potential. The y-axis represents log10subscript10\log_{10}roman_log start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT error defined by (56). x-axis represents log10(N)subscript10𝑁\log_{10}(N)roman_log start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT ( italic_N ). The bias terms bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are initialized based on Section 5.1 with B=4𝐵4B=4italic_B = 4. Red line represents results when only the weights terms are updated. Black line represents results when both weights and bias are updated. Right: Mapping comparison between T(t,z)𝑇𝑡𝑧T(t,z)italic_T ( italic_t , italic_z ) given by Eq. (58) and our computed solution f(θt,z)𝑓subscript𝜃𝑡𝑧f(\theta_{t},z)italic_f ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z ).

5.2.2. Quartic potential

Let us consider V(x)=(x1)4/4(x1)2/2𝑉𝑥superscript𝑥144superscript𝑥122V(x)=(x-1)^{4}/4-(x-1)^{2}/2italic_V ( italic_x ) = ( italic_x - 1 ) start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT / 4 - ( italic_x - 1 ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2. The analytic solution of the transport map is given by

T(t,z0)={sgn(z01)et(z01)2+e2t1+1,z01,1,z0=1.T(t,z_{0})=\left\{\begin{aligned} &\mathrm{sgn}(z_{0}-1)\frac{e^{t}}{\sqrt{(z_% {0}-1)^{-2}+e^{2t}-1}}+1,\quad z_{0}\neq 1\,,\\ &1,\hskip 184.9429ptz_{0}=1\,.\end{aligned}\right.italic_T ( italic_t , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = { start_ROW start_CELL end_CELL start_CELL roman_sgn ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - 1 ) divide start_ARG italic_e start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - 1 ) start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT 2 italic_t end_POSTSUPERSCRIPT - 1 end_ARG end_ARG + 1 , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ≠ 1 , end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL 1 , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1 . end_CELL end_ROW (59)

Basic settings are the same as the previous case. We choose dt=2×104𝑑𝑡2superscript104dt=2\times 10^{-4}italic_d italic_t = 2 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT and run for 1000 steps. We compare our numerical results with Eq. (59). We present our results in Fig. 3. In Fig. 2(a), we observe a clear decrease in error as the number of neurons becomes larger. In Fig. 2(b), we visualize the analytic solution to the linear transport PDE in Lagrangian coordinates at t=0.2𝑡0.2t=0.2italic_t = 0.2 and our computed solution. We can see that even when the optimal transport map is nonlinear, our computed solution still matches the analytic solution very accurately.

Refer to caption
(a) Error
Refer to caption
(b) Mapping comparison
Figure 3. Left: log-log plot of linear transport PDE with quartic polynomial potential. The y-axis represents log10subscript10\log_{10}roman_log start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT error defined by (56). x-axis represents log10(N)subscript10𝑁\log_{10}(N)roman_log start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT ( italic_N ). The bias terms bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are initialized based on Section 5.1 with B=4𝐵4B=4italic_B = 4. Red line represents results when only the weights terms are updated. Black line represents results when both weights and bias are updated. Right: Mapping comparison between T(t,z)𝑇𝑡𝑧T(t,z)italic_T ( italic_t , italic_z ) given by Eq. (59) and our computed solution f(θt,z)𝑓subscript𝜃𝑡𝑧f(\theta_{t},z)italic_f ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z ).

5.2.3. Sixth order polynomial potential

Let us consider V(x)=(x4)6/6𝑉𝑥superscript𝑥466V(x)=(x-4)^{6}/6italic_V ( italic_x ) = ( italic_x - 4 ) start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT / 6. The analytic solution of the transport map is given by

T(t,z0)={4+sgn(z04)1214(z04)4+t,z04,4,z0=4.T(t,z_{0})=\left\{\begin{aligned} &4+\mathrm{sgn}(z_{0}-4)\frac{1}{\sqrt{2% \sqrt{\frac{1}{4(z_{0}-4)^{4}}+t}}},\quad z_{0}\neq 4\,,\\ &4,\hskip 161.61143ptz_{0}=4\,.\end{aligned}\right.italic_T ( italic_t , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = { start_ROW start_CELL end_CELL start_CELL 4 + roman_sgn ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - 4 ) divide start_ARG 1 end_ARG start_ARG square-root start_ARG 2 square-root start_ARG divide start_ARG 1 end_ARG start_ARG 4 ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - 4 ) start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT end_ARG + italic_t end_ARG end_ARG end_ARG , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ≠ 4 , end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL 4 , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 4 . end_CELL end_ROW (60)

We choose dt=106𝑑𝑡superscript106dt=10^{-6}italic_d italic_t = 10 start_POSTSUPERSCRIPT - 6 end_POSTSUPERSCRIPT and run for 1000 steps. The reason to choose such a small step size is that the ODE (55) is stiff when V(x)𝑉𝑥V(x)italic_V ( italic_x ) is a sixth order polynomial. This can be readily seen by considering the forward Euler scheme for solving (55), which results in the popular gradient descent algorithm. The step size that can guarantee convergence in gradient descent is at most 2/L2𝐿2/L2 / italic_L where L𝐿Litalic_L is the Lipschitz constant of the gradient function. In our case, the gradient function V(x)𝑉𝑥\nabla V(x)∇ italic_V ( italic_x ) is not globally Lipschitz. Even if we consider a fixed interval (l,l)𝑙𝑙(-l,l)( - italic_l , italic_l ), the Lipschitz constant is L=5(l+4)4𝐿5superscript𝑙44L=5(l+4)^{4}italic_L = 5 ( italic_l + 4 ) start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT. If we take l=10𝑙10l=10italic_l = 10, then we get L=𝒪(106)𝐿𝒪superscript106L=\mathcal{O}(10^{-6})italic_L = caligraphic_O ( 10 start_POSTSUPERSCRIPT - 6 end_POSTSUPERSCRIPT ). We compare our numerical results with Eq. (60). We have chosen {zj}j=1N1superscriptsubscriptsubscript𝑧𝑗𝑗1subscript𝑁1\{z_{j}\}_{j=1}^{N_{1}}{ italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT to be a uniform mesh of size N1=4×106subscript𝑁14superscript106N_{1}=4\times 10^{6}italic_N start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 4 × 10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT on [6,6]66[-6,6][ - 6 , 6 ] in Eq. (56) and p0subscript𝑝0p_{0}italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is the standard Gaussian distribution. Note that N1subscript𝑁1N_{1}italic_N start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is chosen to be much larger than the number of neurons N𝑁Nitalic_N in the network mapping function as it is used to evaluate the accuracy of our algorithm. We present our results in Fig. 4. We can see a clear decrease in error when N𝑁Nitalic_N increases from Fig. 3(a). It is also clear from Fig. 3(a) that updating both weights and bias tends to have a smaller error than just updating the weights, although the difference becomes smaller when N𝑁Nitalic_N increases and more mesh points become available. Comparing dashed and solid lines in Fig. 3(a), we find that the initialization of bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT also plays a role in the overall performance of our solution. The error is smaller when the initial mesh points (i.e., the bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s) are more concentrated near the center of the reference measure. In our case, the reference measure is a standard Gaussian, whose measure is “almost” supported on [4,4]44[-4,4][ - 4 , 4 ]. Hence we see that the solid lines show a smaller error than the dashed lines in Fig. 3(a). In Fig. 3(b), we have given a visualization of the analytic solution to the linear transport PDE in Lagrangian coordinates at t=103𝑡superscript103t=10^{-3}italic_t = 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT and our computed solution. It is worth noting from Fig. 3(b) that our learned Lagrangian map approximates the analytic Lagrangian map well near the center of the reference distribution, which is concentrated near the origin. Even though the error of the learned Lagrangian map is larger outside of [4,4]44[-4,4][ - 4 , 4 ], the overall error from Eq. (56) is still small since the reference measure (standard Gaussian measure) on [4,4]44\mathbb{R}\setminus[-4,4]blackboard_R ∖ [ - 4 , 4 ] is exponentially small.

Remark 4.

According to Proposition 9, updating both aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is a second order method. This can be seen from Fig. 3(a) when N𝑁Nitalic_N is small. When N𝑁Nitalic_N is large, the numerical advantage of updating both aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is less significant compared with updating only aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. This is partially explained by the condition number of the GW(θ)subscript𝐺W𝜃G_{\mathrm{W}}(\theta)italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) grows too large when θ𝜃\thetaitalic_θ contains all of aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. This phenomenon is also observed in our other experiments. Using the implicit scheme or proximal scheme (without solving the linear system that involves GW(θ)subscript𝐺W𝜃G_{\mathrm{W}}(\theta)italic_G start_POSTSUBSCRIPT roman_W end_POSTSUBSCRIPT ( italic_θ ) directly) might help with this difficulty, which we leave as a future study.

Refer to caption
(a) Error
Refer to caption
(b) Mapping comparison
Figure 4. Left: log-log plot of linear transport PDE with sixth order polynomial potential. The y-axis represents log10subscript10\log_{10}roman_log start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT error defined by (56). x-axis represents log10(N)subscript10𝑁\log_{10}(N)roman_log start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT ( italic_N ). The bias terms bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are initialized based on Section 5.1 with B=10𝐵10B=10italic_B = 10 for dashed line and B=4𝐵4B=4italic_B = 4 for solid line. Red lines represent results when only the weights terms are updated. Black lines represent results when both weights and bias are updated. Right: Mapping comparison between T(t,z)𝑇𝑡𝑧T(t,z)italic_T ( italic_t , italic_z ) given by Eq. (60) and our computed solution f(θt,z)𝑓subscript𝜃𝑡𝑧f(\theta_{t},z)italic_f ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z ).

5.3. Fokker-Planck Equation

We consider Fokker-Planck equations. In general, there is no closed-form solution for either the Eulerian or Lagrangian coordinate except for some special forms of potential V𝑉Vitalic_V (e.g. quadratic). We can still have an approximation of the analytic transport map by realizing that the optimal transport map of a point z0subscript𝑧0z_{0}italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT at time t𝑡titalic_t is given by

T(t,z0)=𝔉t(𝔉01(z0))𝑇𝑡subscript𝑧0subscript𝔉𝑡superscriptsubscript𝔉01subscript𝑧0T(t,z_{0})=\mathfrak{F}_{t}\big{(}{\mathfrak{F}_{0}^{-1}}(z_{0})\big{)}italic_T ( italic_t , italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = fraktur_F start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) (61)

where 𝔉tsubscript𝔉𝑡\mathfrak{F}_{t}fraktur_F start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the cumulative distribution function (CDF) of p(t,x)𝑝𝑡𝑥p(t,x)italic_p ( italic_t , italic_x ). 𝔉0subscript𝔉0\mathfrak{F}_{0}fraktur_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT has a closed form expression when we choose our reference measure to be a standard Gaussian. But we still need to know p(t,x)𝑝𝑡𝑥p(t,x)italic_p ( italic_t , italic_x ). Therefore, to investigate the performance of our algorithm, we need to use a numerical solver to solve for p(t,x)𝑝𝑡𝑥p(t,x)italic_p ( italic_t , italic_x ). We choose a center difference in space, implicit in time discretization as our choice of numerical solver with vanishing boundary condition. Recall that we are essentially simulating the gradient flow on parameter θtsubscript𝜃𝑡\theta_{t}italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT given by Eq. (11) and Eq. (13). To calculate the derivative of the energy functionals, we used M=106𝑀superscript106M=10^{6}italic_M = 10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT particles sampled from a standard Gaussian distribution for approximating 𝔼zpr[V(f(θ,z))+U^(pr(z)Dzf(θ,z))]subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]𝑉𝑓𝜃𝑧^𝑈subscript𝑝r𝑧subscript𝐷𝑧𝑓𝜃𝑧\mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}V(f(\theta,z))+\hat{U}(\frac{p_{% \mathrm{r}}(z)}{D_{z}f(\theta,z)})\Big{]}blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_V ( italic_f ( italic_θ , italic_z ) ) + over^ start_ARG italic_U end_ARG ( divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) end_ARG ) ]. Approximating 𝔼zpr[θV(f(θ,z))]subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]subscript𝜃𝑉𝑓𝜃𝑧\mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}\nabla_{\theta}V(f(\theta,z))\Big{]}blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_V ( italic_f ( italic_θ , italic_z ) ) ] is straightforward and has been explained in detail in Section 5.2. On the other hand, some care needs to be taken when approximating 𝔼zpr[θU^(pr(z)Dzf(θ,z))]subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]subscript𝜃^𝑈subscript𝑝r𝑧subscript𝐷𝑧𝑓𝜃𝑧\mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}\nabla_{\theta}\hat{U}(\frac{p_{% \mathrm{r}}(z)}{D_{z}f(\theta,z)})\Big{]}blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over^ start_ARG italic_U end_ARG ( divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) end_ARG ) ] as explained in Section 4.2.1. Suppose that all of the {bk}k=12Nsuperscriptsubscriptsubscript𝑏𝑘𝑘12𝑁\{b_{k}\}_{k=1}^{2N}{ italic_b start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_N end_POSTSUPERSCRIPT are different. Take 2jN2𝑗𝑁2\leq j\leq N2 ≤ italic_j ≤ italic_N. Let us also assume that the bksubscript𝑏𝑘b_{k}italic_b start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT’s are ordered so that b1b2bNsubscript𝑏1subscript𝑏2subscript𝑏𝑁b_{1}\leq b_{2}\leq\cdots\leq b_{N}italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ ⋯ ≤ italic_b start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT.

𝔼zprbjlog(Dzf(θ,z))subscript𝔼similar-to𝑧subscript𝑝rsubscriptsubscript𝑏𝑗subscript𝐷𝑧𝑓𝜃𝑧\displaystyle\mathbb{E}_{z\sim p_{\mathrm{r}}}\partial_{b_{j}}\log(D_{z}f(% \theta,z))blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ) =𝔼zprbjlog(i=1Nai𝟏[bi,)(z)i=N+12Nai𝟏(,bi](z))absentsubscript𝔼similar-to𝑧subscript𝑝rsubscriptsubscript𝑏𝑗superscriptsubscript𝑖1𝑁subscript𝑎𝑖subscript1subscript𝑏𝑖𝑧superscriptsubscript𝑖𝑁12𝑁subscript𝑎𝑖subscript1subscript𝑏𝑖𝑧\displaystyle=\mathbb{E}_{z\sim p_{\mathrm{r}}}\partial_{b_{j}}\log\left(\sum_% {i=1}^{N}a_{i}\mathbf{1}_{[b_{i},\infty)}(z)-\sum_{i=N+1}^{2N}a_{i}\mathbf{1}_% {(-\infty,b_{i}]}(z)\right)= blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ∞ ) end_POSTSUBSCRIPT ( italic_z ) - ∑ start_POSTSUBSCRIPT italic_i = italic_N + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT ( - ∞ , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ( italic_z ) )
=pr(bj)log(i=1j1aii=N+12Nai𝟏(,bi](bj)i=1jaii=N+12Nai𝟏(,bi](bj)).absentsubscript𝑝rsubscript𝑏𝑗superscriptsubscript𝑖1𝑗1subscript𝑎𝑖superscriptsubscript𝑖𝑁12𝑁subscript𝑎𝑖subscript1subscript𝑏𝑖subscript𝑏𝑗superscriptsubscript𝑖1𝑗subscript𝑎𝑖superscriptsubscript𝑖𝑁12𝑁subscript𝑎𝑖subscript1subscript𝑏𝑖subscript𝑏𝑗\displaystyle=p_{\mathrm{r}}(b_{j})\log\left(\frac{\sum_{i=1}^{j-1}a_{i}-\sum_% {i=N+1}^{2N}a_{i}\mathbf{1}_{(-\infty,b_{i}]}(b_{j})}{\sum_{i=1}^{j}a_{i}-\sum% _{i=N+1}^{2N}a_{i}\mathbf{1}_{(-\infty,b_{i}]}(b_{j})}\right)\,.= italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) roman_log ( divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - ∑ start_POSTSUBSCRIPT italic_i = italic_N + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT ( - ∞ , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - ∑ start_POSTSUBSCRIPT italic_i = italic_N + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT ( - ∞ , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_ARG ) . (62)

And

𝔼zprb1log(Dzf(θ,z))=pr(b1)log(i=N+12Nai𝟏(,bi](b1)a1i=N+12Nai𝟏(,bi](b1)).subscript𝔼similar-to𝑧subscript𝑝rsubscriptsubscript𝑏1subscript𝐷𝑧𝑓𝜃𝑧subscript𝑝rsubscript𝑏1superscriptsubscript𝑖𝑁12𝑁subscript𝑎𝑖subscript1subscript𝑏𝑖subscript𝑏1subscript𝑎1superscriptsubscript𝑖𝑁12𝑁subscript𝑎𝑖subscript1subscript𝑏𝑖subscript𝑏1\mathbb{E}_{z\sim p_{\mathrm{r}}}\partial_{b_{1}}\log(D_{z}f(\theta,z))=p_{% \mathrm{r}}(b_{1})\log\left(\frac{\sum_{i=N+1}^{2N}-a_{i}\mathbf{1}_{(-\infty,% b_{i}]}(b_{1})}{a_{1}-\sum_{i=N+1}^{2N}a_{i}\mathbf{1}_{(-\infty,b_{i}]}(b_{1}% )}\right)\,.blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ) = italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) roman_log ( divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = italic_N + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_N end_POSTSUPERSCRIPT - italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT ( - ∞ , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - ∑ start_POSTSUBSCRIPT italic_i = italic_N + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT ( - ∞ , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG ) . (63)

Similarly, if we assume that bN+1bN+2b2Nsubscript𝑏𝑁1subscript𝑏𝑁2subscript𝑏2𝑁b_{N+1}\geq b_{N+2}\geq\cdots\geq b_{2N}italic_b start_POSTSUBSCRIPT italic_N + 1 end_POSTSUBSCRIPT ≥ italic_b start_POSTSUBSCRIPT italic_N + 2 end_POSTSUBSCRIPT ≥ ⋯ ≥ italic_b start_POSTSUBSCRIPT 2 italic_N end_POSTSUBSCRIPT and let N+2j2N𝑁2𝑗2𝑁N+2\leq j\leq 2Nitalic_N + 2 ≤ italic_j ≤ 2 italic_N, we have

𝔼zprbjlog(Dzf(θ,z))subscript𝔼similar-to𝑧subscript𝑝rsubscriptsubscript𝑏𝑗subscript𝐷𝑧𝑓𝜃𝑧\displaystyle\mathbb{E}_{z\sim p_{\mathrm{r}}}\partial_{b_{j}}\log(D_{z}f(% \theta,z))blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ) =𝔼zprbjlog(i=1Nai𝟏[bi,)(z)i=N+12Nai𝟏(,bi](z))absentsubscript𝔼similar-to𝑧subscript𝑝rsubscriptsubscript𝑏𝑗superscriptsubscript𝑖1𝑁subscript𝑎𝑖subscript1subscript𝑏𝑖𝑧superscriptsubscript𝑖𝑁12𝑁subscript𝑎𝑖subscript1subscript𝑏𝑖𝑧\displaystyle=\mathbb{E}_{z\sim p_{\mathrm{r}}}\partial_{b_{j}}\log\left(\sum_% {i=1}^{N}a_{i}\mathbf{1}_{[b_{i},\infty)}(z)-\sum_{i=N+1}^{2N}a_{i}\mathbf{1}_% {(-\infty,b_{i}]}(z)\right)= blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ∞ ) end_POSTSUBSCRIPT ( italic_z ) - ∑ start_POSTSUBSCRIPT italic_i = italic_N + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT ( - ∞ , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ( italic_z ) )
=pr(bj)log(i=1Nai𝟏[bi,)(bj)i=N+1jaii=1Nai𝟏[bi,)(bj)i=N+1j1ai).absentsubscript𝑝rsubscript𝑏𝑗superscriptsubscript𝑖1𝑁subscript𝑎𝑖subscript1subscript𝑏𝑖subscript𝑏𝑗superscriptsubscript𝑖𝑁1𝑗subscript𝑎𝑖superscriptsubscript𝑖1𝑁subscript𝑎𝑖subscript1subscript𝑏𝑖subscript𝑏𝑗superscriptsubscript𝑖𝑁1𝑗1subscript𝑎𝑖\displaystyle=p_{\mathrm{r}}(b_{j})\log\left(\frac{\sum_{i=1}^{N}a_{i}\mathbf{% 1}_{[b_{i},\infty)}(b_{j})-\sum_{i=N+1}^{j}a_{i}}{\sum_{i=1}^{N}a_{i}\mathbf{1% }_{[b_{i},\infty)}(b_{j})-\sum_{i=N+1}^{j-1}a_{i}}\right)\,.= italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) roman_log ( divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ∞ ) end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) - ∑ start_POSTSUBSCRIPT italic_i = italic_N + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ∞ ) end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) - ∑ start_POSTSUBSCRIPT italic_i = italic_N + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j - 1 end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ) . (64)

And

𝔼zprbN+1log(Dzf(θ,z))=pr(bN+1)log(i=1Nai𝟏[bi,)(bN+1)aN+1i=1Nai𝟏[bi,)(bN+1)).subscript𝔼similar-to𝑧subscript𝑝rsubscriptsubscript𝑏𝑁1subscript𝐷𝑧𝑓𝜃𝑧subscript𝑝rsubscript𝑏𝑁1superscriptsubscript𝑖1𝑁subscript𝑎𝑖subscript1subscript𝑏𝑖subscript𝑏𝑁1subscript𝑎𝑁1superscriptsubscript𝑖1𝑁subscript𝑎𝑖subscript1subscript𝑏𝑖subscript𝑏𝑁1\mathbb{E}_{z\sim p_{\mathrm{r}}}\partial_{b_{N+1}}\log(D_{z}f(\theta,z))=p_{% \mathrm{r}}(b_{N+1})\log\left(\frac{\sum_{i=1}^{N}a_{i}\mathbf{1}_{[b_{i},% \infty)}(b_{N+1})-a_{N+1}}{\sum_{i=1}^{N}a_{i}\mathbf{1}_{[b_{i},\infty)}(b_{N% +1})}\right)\,.blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_N + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ) = italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N + 1 end_POSTSUBSCRIPT ) roman_log ( divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ∞ ) end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N + 1 end_POSTSUBSCRIPT ) - italic_a start_POSTSUBSCRIPT italic_N + 1 end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT [ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ∞ ) end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_N + 1 end_POSTSUBSCRIPT ) end_ARG ) . (65)

Note that during implementation, we do not have to order the bjsubscript𝑏𝑗b_{j}italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT’s in order to evaluate the above partial derivatives. Let 0<δ12minij|bibj|0𝛿12subscript𝑖𝑗subscript𝑏𝑖subscript𝑏𝑗0<\delta\leq\frac{1}{2}\min_{i\neq j}|b_{i}-b_{j}|0 < italic_δ ≤ divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_min start_POSTSUBSCRIPT italic_i ≠ italic_j end_POSTSUBSCRIPT | italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT |. Then by a straightforward calculation, we have

𝔼zprbjlog(Dzf(θ,z))={pr(bj)log(Dzf(θ,bjδ)Dzf(θ,bj)),1jN.pr(bj)log(Dzf(θ,bj)Dzf(θ,bj+δ)),N+1j2N.subscript𝔼similar-to𝑧subscript𝑝rsubscriptsubscript𝑏𝑗subscript𝐷𝑧𝑓𝜃𝑧casessubscript𝑝rsubscript𝑏𝑗subscript𝐷𝑧𝑓𝜃subscript𝑏𝑗𝛿subscript𝐷𝑧𝑓𝜃subscript𝑏𝑗1𝑗𝑁𝑜𝑡ℎ𝑒𝑟𝑤𝑖𝑠𝑒subscript𝑝rsubscript𝑏𝑗subscript𝐷𝑧𝑓𝜃subscript𝑏𝑗subscript𝐷𝑧𝑓𝜃subscript𝑏𝑗𝛿𝑁1𝑗2𝑁𝑜𝑡ℎ𝑒𝑟𝑤𝑖𝑠𝑒\mathbb{E}_{z\sim p_{\mathrm{r}}}\partial_{b_{j}}\log(D_{z}f(\theta,z))=\begin% {cases}p_{\mathrm{r}}(b_{j})\log\left(\frac{D_{z}f(\theta,b_{j}-\delta)}{D_{z}% f(\theta,b_{j})}\right)\,,\quad 1\leq j\leq N\,.\\ p_{\mathrm{r}}(b_{j})\log\left(\frac{D_{z}f(\theta,b_{j})}{D_{z}f(\theta,b_{j}% +\delta)}\right)\,,\quad N+1\leq j\leq 2N\,.\end{cases}blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∂ start_POSTSUBSCRIPT italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log ( italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) ) = { start_ROW start_CELL italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) roman_log ( divide start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_δ ) end_ARG start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_ARG ) , 1 ≤ italic_j ≤ italic_N . end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) roman_log ( divide start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_ARG start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT + italic_δ ) end_ARG ) , italic_N + 1 ≤ italic_j ≤ 2 italic_N . end_CELL start_CELL end_CELL end_ROW (66)

In our experiment, we set δ=ε/2𝛿𝜀2\delta=\varepsilon/2italic_δ = italic_ε / 2 where ε𝜀\varepsilonitalic_ε is the small offset we introduced in Section 5.1 during initialization.

5.3.1. Quadratic potential

As a first example for the Fokker-Planck equation, we use the quadratic potential as a sanity check. Here V(x)𝑉𝑥V(x)italic_V ( italic_x ) is chosen to be a quadratic function. This is one of the rare cases where the Fokker-Planck equation has a closed-form analytic solution. In Lagrangian coordinates, the trajectories of the particles follow the following SDE, commonly known as the Ornstein-Uhlenbeck process:

dXt=γ0(Xtμ0)dt+σ0dWt.𝑑subscript𝑋𝑡subscript𝛾0subscript𝑋𝑡subscript𝜇0𝑑𝑡subscript𝜎0𝑑subscript𝑊𝑡dX_{t}=-\gamma_{0}(X_{t}-\mu_{0})dt+\sigma_{0}dW_{t}\,.italic_d italic_X start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = - italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_X start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_d italic_t + italic_σ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_d italic_W start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT . (67)

The corresponding Langevin equation for the density p(t,x)𝑝𝑡𝑥p(t,x)italic_p ( italic_t , italic_x ) is given by

pt=γ0x((xμ0)p)+D2px2,𝑝𝑡subscript𝛾0𝑥𝑥subscript𝜇0𝑝𝐷superscript2𝑝superscript𝑥2\frac{\partial p}{\partial t}=\gamma_{0}\frac{\partial}{\partial x}\big{(}(x-% \mu_{0})p\big{)}+D\frac{\partial^{2}p}{\partial x^{2}}\,,divide start_ARG ∂ italic_p end_ARG start_ARG ∂ italic_t end_ARG = italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT divide start_ARG ∂ end_ARG start_ARG ∂ italic_x end_ARG ( ( italic_x - italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) italic_p ) + italic_D divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p end_ARG start_ARG ∂ italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG , (68)

where D=σ02/2𝐷superscriptsubscript𝜎022D=\sigma_{0}^{2}/2italic_D = italic_σ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2. It can be shown that the solution to (68) is given by

p(t,x)=γ02πD(1e2γ0t)exp(γ02D(xμ0xeγ0t)21e2γ0t)p0(x)dx,𝑝𝑡𝑥subscript𝛾02𝜋𝐷1superscripte2subscript𝛾0𝑡superscriptsubscriptexpsubscript𝛾02𝐷superscript𝑥subscript𝜇0superscript𝑥superscriptesubscript𝛾0𝑡21superscripte2subscript𝛾0𝑡subscript𝑝0superscript𝑥differential-dsuperscript𝑥p(t,x)=\sqrt{\frac{\gamma_{0}}{2\pi D(1-\mathrm{e}^{-2\gamma_{0}t})}}\int_{-% \infty}^{\infty}\mathrm{exp}\Big{(}-\frac{\gamma_{0}}{2D}\frac{(x-\mu_{0}-x^{% \prime}\mathrm{e}^{-\gamma_{0}t})^{2}}{1-\mathrm{e}^{-2\gamma_{0}t}}\Big{)}p_{% 0}(x^{\prime})\,\mathrm{d}x^{\prime}\,,italic_p ( italic_t , italic_x ) = square-root start_ARG divide start_ARG italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_ARG 2 italic_π italic_D ( 1 - roman_e start_POSTSUPERSCRIPT - 2 italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_t end_POSTSUPERSCRIPT ) end_ARG end_ARG ∫ start_POSTSUBSCRIPT - ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT roman_exp ( - divide start_ARG italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_ARG 2 italic_D end_ARG divide start_ARG ( italic_x - italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT roman_e start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_t end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 1 - roman_e start_POSTSUPERSCRIPT - 2 italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_t end_POSTSUPERSCRIPT end_ARG ) italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) roman_d italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , (69)

where p0(x)=p(0,x)subscript𝑝0𝑥𝑝0𝑥p_{0}(x)=p(0,x)italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x ) = italic_p ( 0 , italic_x ) is the initial distribution. In our experiment, p0(x)subscript𝑝0𝑥p_{0}(x)italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x ) is a standard Gaussian. Then (69) implies that p(t,x)𝑝𝑡𝑥p(t,x)italic_p ( italic_t , italic_x ) is also Gaussian with mean μ0(1eγ0t)subscript𝜇01superscript𝑒subscript𝛾0𝑡\mu_{0}(1-e^{-\gamma_{0}t})italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( 1 - italic_e start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_t end_POSTSUPERSCRIPT ) and variance e2γ0t+D(1e2γ0t)γ0superscript𝑒2subscript𝛾0𝑡𝐷1superscript𝑒2subscript𝛾0𝑡subscript𝛾0e^{-2\gamma_{0}t}+\frac{D(1-e^{-2\gamma_{0}t})}{\gamma_{0}}italic_e start_POSTSUPERSCRIPT - 2 italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_t end_POSTSUPERSCRIPT + divide start_ARG italic_D ( 1 - italic_e start_POSTSUPERSCRIPT - 2 italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_t end_POSTSUPERSCRIPT ) end_ARG start_ARG italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG. Then the transport map is given by the optimal transport map between two Gaussians, which has a closed form expression. In this example, the transport map is

T(t,z)=μ0(1eγ0t)+ze2γ0t+D(1e2γ0t)/γ0,𝑇𝑡𝑧subscript𝜇01superscript𝑒subscript𝛾0𝑡𝑧superscript𝑒2subscript𝛾0𝑡𝐷1superscript𝑒2subscript𝛾0𝑡subscript𝛾0T(t,z)=\mu_{0}(1-e^{-\gamma_{0}t})+z\sqrt{e^{-2\gamma_{0}t}+D(1-e^{-2\gamma_{0% }t})/\gamma_{0}}\,,italic_T ( italic_t , italic_z ) = italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( 1 - italic_e start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_t end_POSTSUPERSCRIPT ) + italic_z square-root start_ARG italic_e start_POSTSUPERSCRIPT - 2 italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_t end_POSTSUPERSCRIPT + italic_D ( 1 - italic_e start_POSTSUPERSCRIPT - 2 italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT italic_t end_POSTSUPERSCRIPT ) / italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG , (70)

which is always a linear map, no matter the choice of μ0subscript𝜇0\mu_{0}italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, γ0subscript𝛾0\gamma_{0}italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and D𝐷Ditalic_D. We use M=106𝑀superscript106M=10^{6}italic_M = 10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT particles sampled from a standard Gaussian distribution for approximating 𝔼zpr[θV(f(θ,z))+θU^(pr(z)Dzf(θ,z))]subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]subscript𝜃𝑉𝑓𝜃𝑧subscript𝜃^𝑈subscript𝑝r𝑧subscript𝐷𝑧𝑓𝜃𝑧\mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}\nabla_{\theta}V(f(\theta,z))+\nabla_{% \theta}\hat{U}(\frac{p_{\mathrm{r}}(z)}{D_{z}f(\theta,z)})\Big{]}blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_V ( italic_f ( italic_θ , italic_z ) ) + ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over^ start_ARG italic_U end_ARG ( divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) end_ARG ) ]. We choose dt=103𝑑𝑡superscript103dt=10^{-3}italic_d italic_t = 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT and run for 1000 steps. We used a neural network with m=32𝑚32m=32italic_m = 32 and B=4𝐵4B=4italic_B = 4 following the setup in Section 5.1. We have the following two choices of parameters corresponding to different dynamics.

  • Moving and widening Gaussian. We choose γ0=1subscript𝛾01\gamma_{0}=1italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1, μ0=30subscript𝜇030\mu_{0}=30italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 30, σ0=4subscript𝜎04\sigma_{0}=4italic_σ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 4. Under this setting, the solution at time t𝑡titalic_t is a Gaussian distribution with mean 30(1et)301superscript𝑒𝑡30(1-e^{-t})30 ( 1 - italic_e start_POSTSUPERSCRIPT - italic_t end_POSTSUPERSCRIPT ) and variance e2t+8(1e2t)superscript𝑒2𝑡81superscript𝑒2𝑡e^{-2t}+8(1-e^{-2t})italic_e start_POSTSUPERSCRIPT - 2 italic_t end_POSTSUPERSCRIPT + 8 ( 1 - italic_e start_POSTSUPERSCRIPT - 2 italic_t end_POSTSUPERSCRIPT ). This evolution is shown on the left panel of Fig. 5.

  • Moving and shrinking Gaussian. We choose γ0=1subscript𝛾01\gamma_{0}=1italic_γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1, μ0=10subscript𝜇010\mu_{0}=10italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 10, σ0=0.01subscript𝜎00.01\sigma_{0}=0.01italic_σ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 0.01. Under this setting, the solution at time t𝑡titalic_t is a Gaussian distribution with mean 10(1et)101superscript𝑒𝑡10(1-e^{-t})10 ( 1 - italic_e start_POSTSUPERSCRIPT - italic_t end_POSTSUPERSCRIPT ) and variance e2t+5×105(1e2t)superscript𝑒2𝑡5superscript1051superscript𝑒2𝑡e^{-2t}+5\times 10^{-5}(1-e^{-2t})italic_e start_POSTSUPERSCRIPT - 2 italic_t end_POSTSUPERSCRIPT + 5 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT ( 1 - italic_e start_POSTSUPERSCRIPT - 2 italic_t end_POSTSUPERSCRIPT ). This evolution is shown on the right panel of Fig. 5.

Our results are demonstrated in Fig. 5. As shown in Fig. 5, the computed density closely follows the analytic density of the Fokker-Planck equation from t=0𝑡0t=0italic_t = 0 to t=1𝑡1t=1italic_t = 1.

Refer to caption
(a) Moving, widening Gaussian
Refer to caption
(b) Moving, shrinking Gaussian
Figure 5. Density evolution of Eq. (68). Orange curve represents the solution given by Eq. (69). Blue rectangles represent the histogram using 106superscript10610^{6}10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT particles in 100 bins from t=0𝑡0t=0italic_t = 0 to t=1𝑡1t=1italic_t = 1. Left panel: a Gaussian distribution shifting to the right with increasing variance. Right panel: a Gaussian distribution shifting to the right with decreasing variance.

5.3.2. Quartic potential

We consider V(x)=(x1)4/4(x1)2/2𝑉𝑥superscript𝑥144superscript𝑥122V(x)=(x-1)^{4}/4-(x-1)^{2}/2italic_V ( italic_x ) = ( italic_x - 1 ) start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT / 4 - ( italic_x - 1 ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2. We choose dt=2×104𝑑𝑡2superscript104dt=2\times 10^{-4}italic_d italic_t = 2 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT and run for 1000 steps. We compare our numerical results with the transport map computed from Eq. (61). The results are shown in Fig. 6. In Fig. 5(a), we observe a clear decrease in error when the number of neurons increases. In Fig. 5(b), we plot a comparison between our computed Lagrangian map f(θ,z)𝑓𝜃𝑧f(\theta,z)italic_f ( italic_θ , italic_z ) vs the transport map computed from Eq. (61) using a numerical solver. The evolution of the density is demonstrated in Fig. 5(c) from t=0𝑡0t=0italic_t = 0 to t=0.2𝑡0.2t=0.2italic_t = 0.2.

Refer to caption
(a) Error
Refer to caption
(b) Mapping comparison
Refer to caption
(c) Density evolution
Figure 6. Left: log-log plot of Fokker-Planck equation with a quartic polynomial potential. The y-axis represents log10subscript10\log_{10}roman_log start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT error defined by (56). x-axis represents log10(N)subscript10𝑁\log_{10}(N)roman_log start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT ( italic_N ). The bias terms bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are initialized based on Section 5.1 with B=4𝐵4B=4italic_B = 4. Red lines represent results when only the weights terms are updated. Black lines represent results when both weights and bias are updated. Middle: mapping comparison between T(t,z)𝑇𝑡𝑧T(t,z)italic_T ( italic_t , italic_z ) (using Eq. (61)) and our computed solution f(θt,z)𝑓subscript𝜃𝑡𝑧f(\theta_{t},z)italic_f ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z ). Right: density evolution of the Fokker-Planck equation with a quartic polynomial potential. Orange curve represents the density p(t,x)𝑝𝑡𝑥p(t,x)italic_p ( italic_t , italic_x ) computed by a numerical solver. Blue rectangles represent the histogram of 106superscript10610^{6}10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT particles in 100 bins from t=0𝑡0t=0italic_t = 0 to t=0.2𝑡0.2t=0.2italic_t = 0.2.

5.3.3. Sixth order polynomial potential

We consider V(x)=(x4)6/6𝑉𝑥superscript𝑥466V(x)=(x-4)^{6}/6italic_V ( italic_x ) = ( italic_x - 4 ) start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT / 6. We choose dt=106𝑑𝑡superscript106dt=10^{-6}italic_d italic_t = 10 start_POSTSUPERSCRIPT - 6 end_POSTSUPERSCRIPT and run for 1000 steps. We compare our numerical results with the transport map computed from Eq. (61). The results are shown in Fig. 7. We have observed similar behavior as in the case of linear transport PDE: the error becomes smaller when N𝑁Nitalic_N increases. Moreover, comparing dashed and solid lines in Fig. 6(a) we see that as the initial mesh points (i.e. the bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s) concentrate nearer the center of our reference measure, the errors are smaller. In Fig. 6(b) we show a comparison between Lagrangian maps computed by our method and the numerical solver. We have also plotted the evolution of the density in Fig. 6(c) from t=0𝑡0t=0italic_t = 0 to t=103𝑡superscript103t=10^{-3}italic_t = 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT.

Refer to caption
(a) Error
Refer to caption
(b) Mapping comparison
Refer to caption
(c) Density evolution
Figure 7. Left: log-log plot of Fokker-Planck equation with a sixth order polynomial potential. The y-axis represents log10subscript10\log_{10}roman_log start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT error defined by (56). x-axis represents log10(N)subscript10𝑁\log_{10}(N)roman_log start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT ( italic_N ). The bias terms bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are initialized based on Section 5.1 with B=10𝐵10B=10italic_B = 10 for the dashed line and B=4𝐵4B=4italic_B = 4 for the solid line. Red lines represent results when only the weights terms are updated. Black lines represent results when both weights and bias are updated. Middle: mapping comparison between T(t,z)𝑇𝑡𝑧T(t,z)italic_T ( italic_t , italic_z ) (using Eq. (61)) and our computed solution f(θt,z)𝑓subscript𝜃𝑡𝑧f(\theta_{t},z)italic_f ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z ). Right: density evolution of the Fokker-Planck equation with a sixth order polynomial potential. Orange curve represents the density p(t,x)𝑝𝑡𝑥p(t,x)italic_p ( italic_t , italic_x ) computed by a numerical solver. Blue rectangles represents the histogram of 106superscript10610^{6}10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT particles in 100 bins from t=0𝑡0t=0italic_t = 0 to t=103𝑡superscript103t=10^{-3}italic_t = 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT.
Refer to caption
(a) Error
Refer to caption
(b) Mapping comparison
Refer to caption
(c) Density evolution
Figure 8. Left: log-log plot of porous medium equation. The y-axis represents log10subscript10\log_{10}roman_log start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT error defined by (56). x-axis represents log10(N)subscript10𝑁\log_{10}(N)roman_log start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT ( italic_N ). The bias terms bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are initialized based on Section 5.1 with B=32/3𝐵superscript323B=3^{2/3}italic_B = 3 start_POSTSUPERSCRIPT 2 / 3 end_POSTSUPERSCRIPT. Red lines represent results when only the weights terms are updated. Black lines represent results when both weights and bias are updated. Middle: mapping comparison between T(t,z)𝑇𝑡𝑧T(t,z)italic_T ( italic_t , italic_z ) (using Eq. (61)) and our computed solution f(θt,z)𝑓subscript𝜃𝑡𝑧f(\theta_{t},z)italic_f ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z ). Right: density evolution of the porous medium equation. Orange curve represents the density p(t,x)𝑝𝑡𝑥p(t,x)italic_p ( italic_t , italic_x ) given by Eq. (71). Blue rectangles represent the histogram of 106superscript10610^{6}10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT particles in 100 bins from t=0𝑡0t=0italic_t = 0 to t=1𝑡1t=1italic_t = 1.

5.4. Porous Medium Equation

We consider Example 6 with the functional U(p(x))=1m1p(x)m𝑈𝑝𝑥1𝑚1𝑝superscript𝑥𝑚U(p(x))=\frac{1}{m-1}p(x)^{m}italic_U ( italic_p ( italic_x ) ) = divide start_ARG 1 end_ARG start_ARG italic_m - 1 end_ARG italic_p ( italic_x ) start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT, m>1𝑚1m>1italic_m > 1. This choice of U𝑈Uitalic_U yields the porous medium equation

tp(t,x)=Δp(t,x)m.subscript𝑡𝑝𝑡𝑥Δ𝑝superscript𝑡𝑥𝑚\partial_{t}p(t,x)=\Delta p(t,x)^{m}\,.∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) = roman_Δ italic_p ( italic_t , italic_x ) start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT . (71)

It is known that Eq. (71) admits solutions given by the Barenblatt profile

p(t,x)=(t0+t)α(Cβx2(t0+t)2α/d)+1m1,𝑝𝑡𝑥superscriptsubscript𝑡0𝑡𝛼subscriptsuperscript𝐶𝛽superscriptnorm𝑥2superscriptsubscript𝑡0𝑡2𝛼𝑑1𝑚1p(t,x)=(t_{0}+t)^{-\alpha}\Big{(}C-\beta\|x\|^{2}(t_{0}+t)^{-2\alpha/d}\Big{)}% ^{\frac{1}{m-1}}_{+}\,,italic_p ( italic_t , italic_x ) = ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_t ) start_POSTSUPERSCRIPT - italic_α end_POSTSUPERSCRIPT ( italic_C - italic_β ∥ italic_x ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_t ) start_POSTSUPERSCRIPT - 2 italic_α / italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG italic_m - 1 end_ARG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT + end_POSTSUBSCRIPT , (72)

where xd𝑥superscript𝑑x\in\mathbb{R}^{d}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, α=dd(m1)+2𝛼𝑑𝑑𝑚12\alpha=\frac{d}{d(m-1)+2}italic_α = divide start_ARG italic_d end_ARG start_ARG italic_d ( italic_m - 1 ) + 2 end_ARG, β=(m1)α2dm𝛽𝑚1𝛼2𝑑𝑚\beta=\frac{(m-1)\alpha}{2dm}italic_β = divide start_ARG ( italic_m - 1 ) italic_α end_ARG start_ARG 2 italic_d italic_m end_ARG, t0>0subscript𝑡00t_{0}>0italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT > 0 and C𝐶Citalic_C is a normalization constant so that Eq. (72) integrates to 1 for all t0𝑡0t\geq 0italic_t ≥ 0. In this example, we consider the case when m=2𝑚2m=2italic_m = 2. Then α=13𝛼13\alpha=\frac{1}{3}italic_α = divide start_ARG 1 end_ARG start_ARG 3 end_ARG, β=112𝛽112\beta=\frac{1}{12}italic_β = divide start_ARG 1 end_ARG start_ARG 12 end_ARG and C=31/34𝐶superscript3134C=\frac{3^{1/3}}{4}italic_C = divide start_ARG 3 start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT end_ARG start_ARG 4 end_ARG. Eq. (72) also suggests that the support of the reference measure p0(x)=p(x,0)subscript𝑝0𝑥𝑝𝑥0p_{0}(x)=p(x,0)italic_p start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_x ) = italic_p ( italic_x , 0 ) is bounded in [32/3t01/3,32/3t01/3]superscript323superscriptsubscript𝑡013superscript323superscriptsubscript𝑡013[-3^{2/3}t_{0}^{1/3},3^{2/3}t_{0}^{1/3}][ - 3 start_POSTSUPERSCRIPT 2 / 3 end_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT , 3 start_POSTSUPERSCRIPT 2 / 3 end_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT ], which could help us with initializing the bias. More precisely, we cound initialize our bisubscript𝑏𝑖b_{i}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s following Section 5.1 with B=32/3t01/3𝐵superscript323superscriptsubscript𝑡013B=3^{2/3}t_{0}^{1/3}italic_B = 3 start_POSTSUPERSCRIPT 2 / 3 end_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT. In our experiment, we set t0=1subscript𝑡01t_{0}=1italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1. We use dt=103𝑑𝑡superscript103dt=10^{-3}italic_d italic_t = 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT and run for 1000 steps. We use M=106𝑀superscript106M=10^{6}italic_M = 10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT particles sampled from p(x,0)𝑝𝑥0p(x,0)italic_p ( italic_x , 0 ) defined in Eq. (72) using importance sampling to approximate 𝔼zpr[θU^(pr(z)Dzf(θ,z))]subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]subscript𝜃^𝑈subscript𝑝r𝑧subscript𝐷𝑧𝑓𝜃𝑧\mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}\nabla_{\theta}\hat{U}(\frac{p_{% \mathrm{r}}(z)}{D_{z}f(\theta,z)})\Big{]}blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over^ start_ARG italic_U end_ARG ( divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) end_ARG ) ], where U^(p)=p^𝑈𝑝𝑝\hat{U}(p)=pover^ start_ARG italic_U end_ARG ( italic_p ) = italic_p. Similar to the case of Fokker-Planck equation, special care needs to be taken when evaluating bi𝔼zpr[U^(pr(z)Dzf(θ,z))]subscript𝑏𝑖subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]^𝑈subscript𝑝r𝑧subscript𝐷𝑧𝑓𝜃𝑧\partial b_{i}\mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}\hat{U}(\frac{p_{\mathrm% {r}}(z)}{D_{z}f(\theta,z)})\Big{]}∂ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over^ start_ARG italic_U end_ARG ( divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) end_ARG ) ]. Using similar analysis from Section 5.3, we have that

bi𝔼zpr[U^(pr(z)Dzf(θ,z))]={pr(bi)2Dzf(θ,biδ)pr(bi)2Dzf(θ,bi),1iN.pr(bi)2Dzf(θ,bi)pr(bi)2Dzf(θ,bi+δ),N+1i2N.subscript𝑏𝑖subscript𝔼similar-to𝑧subscript𝑝rdelimited-[]^𝑈subscript𝑝r𝑧subscript𝐷𝑧𝑓𝜃𝑧casessubscript𝑝rsuperscriptsubscript𝑏𝑖2subscript𝐷𝑧𝑓𝜃subscript𝑏𝑖𝛿subscript𝑝rsuperscriptsubscript𝑏𝑖2subscript𝐷𝑧𝑓𝜃subscript𝑏𝑖1𝑖𝑁𝑜𝑡ℎ𝑒𝑟𝑤𝑖𝑠𝑒subscript𝑝rsuperscriptsubscript𝑏𝑖2subscript𝐷𝑧𝑓𝜃subscript𝑏𝑖subscript𝑝rsuperscriptsubscript𝑏𝑖2subscript𝐷𝑧𝑓𝜃subscript𝑏𝑖𝛿𝑁1𝑖2𝑁𝑜𝑡ℎ𝑒𝑟𝑤𝑖𝑠𝑒\partial b_{i}\mathbb{E}_{z\sim p_{\mathrm{r}}}\Big{[}\hat{U}(\frac{p_{\mathrm% {r}}(z)}{D_{z}f(\theta,z)})\Big{]}=\begin{cases}\frac{p_{\mathrm{r}}(b_{i})^{2% }}{D_{z}f(\theta,b_{i}-\delta)}-\frac{p_{\mathrm{r}}(b_{i})^{2}}{D_{z}f(\theta% ,b_{i})},\quad 1\leq i\leq N\,.\\ \frac{p_{\mathrm{r}}(b_{i})^{2}}{D_{z}f(\theta,b_{i})}-\frac{p_{\mathrm{r}}(b_% {i})^{2}}{D_{z}f(\theta,b_{i}+\delta)},\quad N+1\leq i\leq 2N\,.\end{cases}∂ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ over^ start_ARG italic_U end_ARG ( divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_z ) end_ARG start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_z ) end_ARG ) ] = { start_ROW start_CELL divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_δ ) end_ARG - divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG , 1 ≤ italic_i ≤ italic_N . end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG - divide start_ARG italic_p start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_D start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_f ( italic_θ , italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_δ ) end_ARG , italic_N + 1 ≤ italic_i ≤ 2 italic_N . end_CELL start_CELL end_CELL end_ROW (73)

Our results are demonstrated in Fig. 8. In Fig. 7(b), 7(c) we have N=32𝑁32N=32italic_N = 32; both the bias and weights are updated.

Refer to caption
(a) χ=1.5𝜒1.5\chi=1.5italic_χ = 1.5
Refer to caption
(b) χ=0.5𝜒0.5\chi=0.5italic_χ = 0.5
Figure 9. Second moment comparison between our numerical solution and analytic solution (75). x𝑥xitalic_x-axis represents time.
Refer to caption
(a) χ=1.5𝜒1.5\chi=1.5italic_χ = 1.5
Refer to caption
(b) χ=0.5𝜒0.5\chi=0.5italic_χ = 0.5
Figure 10. Density evolution of Keller-Segel equation with different χ𝜒\chiitalic_χ. Blue rectangles represent the histogram of 106superscript10610^{6}10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT particles in 100 bins from t=0𝑡0t=0italic_t = 0 to t=0.3𝑡0.3t=0.3italic_t = 0.3.
Refer to caption
(a) χ=1.5𝜒1.5\chi=1.5italic_χ = 1.5
Refer to caption
(b) χ=0.5𝜒0.5\chi=0.5italic_χ = 0.5
Figure 11. Lagrangian mapping of Keller-Segel equation with different χ𝜒\chiitalic_χ at t=0.3𝑡0.3t=0.3italic_t = 0.3.

5.5. Keller-Segel equation

We consider the one-dimensional modified Keller-Segel equation, which is a combination of interaction energy in Example 5 and potential energy in Example 6:

tp(t,x)=(p(t,x)(U(p)+W*p)),subscript𝑡𝑝𝑡𝑥𝑝𝑡𝑥superscript𝑈𝑝𝑊𝑝\partial_{t}p(t,x)=\nabla\cdot\big{(}p(t,x)\nabla(U^{\prime}(p)+W*p)\big{)},∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_p ( italic_t , italic_x ) = ∇ ⋅ ( italic_p ( italic_t , italic_x ) ∇ ( italic_U start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_p ) + italic_W * italic_p ) ) , (74)

where U(p)=plogp𝑈𝑝𝑝𝑝U(p)=p\log pitalic_U ( italic_p ) = italic_p roman_log italic_p and W(x)=2χlog|x|𝑊𝑥2𝜒𝑥W(x)=2\chi\log|x|italic_W ( italic_x ) = 2 italic_χ roman_log | italic_x |, χ>0𝜒0\chi>0italic_χ > 0 is a constant. The second moment of p(t,x)𝑝𝑡𝑥p(t,x)italic_p ( italic_t , italic_x ) has an analytic form given by

𝔼zp(,t)[z2]=2(1χ)t𝔼zp(,0)[z2].subscript𝔼similar-to𝑧𝑝𝑡delimited-[]superscript𝑧221𝜒𝑡subscript𝔼similar-to𝑧𝑝0delimited-[]superscript𝑧2\mathbb{E}_{z\sim p(\cdot,t)}[z^{2}]=2(1-\chi)t\,\mathbb{E}_{z\sim p(\cdot,0)}% [z^{2}]\,.blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p ( ⋅ , italic_t ) end_POSTSUBSCRIPT [ italic_z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] = 2 ( 1 - italic_χ ) italic_t blackboard_E start_POSTSUBSCRIPT italic_z ∼ italic_p ( ⋅ , 0 ) end_POSTSUBSCRIPT [ italic_z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] . (75)

It is clear from Eq. (75) that χ=1𝜒1\chi=1italic_χ = 1 is a critical value. When χ>1𝜒1\chi>1italic_χ > 1, the solution blows up as t𝑡t\to\inftyitalic_t → ∞. So we consider two cases: χ=1.5𝜒1.5\chi=1.5italic_χ = 1.5, and χ=0.5𝜒0.5\chi=0.5italic_χ = 0.5. We present our results in Fig. 9, 10 and 11. We used 2000 particles with a standard Gaussian initial distribution. We set dt=3×104𝑑𝑡3superscript104dt=3\times 10^{-4}italic_d italic_t = 3 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT and run for 1000 steps. The interaction term W*p𝑊𝑝W*pitalic_W * italic_p is evaluated using the 2000 particles with self-interaction excluded. We used a neural network with N=32𝑁32N=32italic_N = 32 and B=4𝐵4B=4italic_B = 4 following the setup and initialization in Section 5.1. We update both the bias and weights terms in our experiment.

6. Discussion

This paper analyzes the neural network projected dynamics for one-dimensional Wasserstein gradient flows of general energy functionals. For two-layer neural network functions with ReLU activations, we analyze the convergence and stability issues for the proposed numerical schemes from location parameter b𝑏bitalic_b and scale parameter a𝑎aitalic_a. In numerical experiments, we demonstrate the second-order spatial domain accuracy as discussed in the numerical analysis.

In future work, we shall study neural projected dynamics as a computational framework for building theoretical guaranteed machine learning numerical schemes. Various topics in this direction remain to be studied. First, we shall design neural network approximations to approximate the initial value of high-dimensional PDEs, which traditional PDE solvers cannot efficiently solve due to the curse of dimensionality. In particular, how can we understand the numerical accuracy of deep neural network functions in high dimensions when approximating PDEs? Next, we shall generalize the neural projected dynamics to dynamical systems for conservative-dissipative equations in statistical physics. The equation includes Hamiltonian structures induced from the conservative system and the related mean-field control problems. Furthermore, considering the closed relationship between the Wasserstein density manifold and sampling algorithms, we shall investigate sampling using the projected dynamics on neural parameter spaces and study their theoretical and statistical properties. We also consider the time-implicit (proximal-type) computations of the proposed algorithm [23, 19], which could improve the performance and stability of the scheme.

References

  • [1] Shun-ichi Amari. Natural gradient works efficiently in learning. Neural computation, 10(2):251–276, 1998.
  • [2] Shun-ichi Amari. Information geometry and its applications, volume 194. Springer, 2016.
  • [3] Luigi Ambrosio, Nicola Gigli, and Giuseppe Savaré. Gradient flows: in metric spaces and in the space of probability measures. Springer Science & Business Media, 2005.
  • [4] Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein generative adversarial networks. In International conference on machine learning, pages 214–223. PMLR, 2017.
  • [5] Andrew J Barlow, Pierre-Henri Maire, William J Rider, Robert N Rieben, and Mikhail J Shashkov. Arbitrary Lagrangian–Eulerian methods for modeling high-speed compressible multimaterial flows. Journal of Computational Physics, 322:603–665, 2016.
  • [6] Peter Benner, Serkan Gugercin, and Karen Willcox. A survey of projection-based model reduction methods for parametric dynamical systems. SIAM review, 57(4):483–531, 2015.
  • [7] Joan Bruna, Benjamin Peherstorfer, and Eric Vanden-Eijnden. Neural Galerkin schemes with active learning for high-dimensional evolution equations. Journal of Computational Physics, 496:112588, 2024.
  • [8] Jose A Carrillo, Daniel Matthes, and Marie-Therese Wolfram. Lagrangian schemes for Wasserstein gradient flows. Handbook of Numerical Analysis, 22:271–311, 2021.
  • [9] JS Chang and G Cooper. A practical difference scheme for Fokker–Planck equations. Journal of Computational Physics, 6(1):1–16, 1970.
  • [10] Yifan Chen and Wuchen Li. Optimal transport natural gradient for statistical manifolds with continuous sample space. Information Geometry, 3(1):1–32, 2020.
  • [11] Casey Chu, Kentaro Minami, and Kenji Fukumizu. The equivalence between Stein variational gradient descent and black-box variational inference. arXiv preprint arXiv:2004.01822, 2020.
  • [12] Yifan Du and Tamer A Zaki. Evolutional deep neural network. Physical Review E, 104(4):045303, 2021.
  • [13] Jiaojiao Fan, Qinsheng Zhang, Amirhossein Taghvaei, and Yongxin Chen. Variational Wasserstein gradient flow. arXiv preprint arXiv:2112.02424, 2021.
  • [14] Nathan Gaby, Xiaojing Ye, and Haomin Zhou. Neural control of parametric solutions for high-dimensional evolution PDEs. arXiv preprint arXiv:2302.00045, 2023.
  • [15] Juncai He, Lin Li, Jinchao Xu, and Chunyue Zheng. ReLU deep neural networks and linear finite elements. Journal of Computational Mathematics, 38(3):502–527, June 2020.
  • [16] Ziqing Hu, Chun Liu, Yiwei Wang, and Zhiliang Xu. Energetic variational neural network discretizations to gradient flows. arXiv preprint arXiv:2206.07303, 2022.
  • [17] Weizhang Huang and Robert D Russell. Adaptive moving mesh methods, volume 174. Springer Science & Business Media, 2010.
  • [18] Wonjun Lee, Li Wang, and Wuchen Li. Deep JKO: time-implicit particle methods for general nonlinear gradient flows. arXiv preprint arXiv:2311.06700, 2023.
  • [19] Wuchen Li, Alex Tong Lin, and Guido Montúfar. Affine natural proximal learning. In Geometric Science of Information: 4th International Conference, GSI 2019, Toulouse, France, August 27–29, 2019, Proceedings 4, pages 705–714. Springer, 2019.
  • [20] Wuchen Li and Guido Montúfar. Natural gradient via optimal transport. Information Geometry, 1:181–214, 2018.
  • [21] Wuchen Li and Jiaxi Zhao. Scaling limits of the Wasserstein information matrix on Gaussian mixture models. arXiv preprint arXiv:2309.12997, 2023.
  • [22] Wuchen Li and Jiaxi Zhao. Wasserstein information matrix. Information Geometry, pages 1–53, 2023.
  • [23] Alex Tong Lin, Wuchen Li, Stanley Osher, and Guido Montúfar. Wasserstein proximal of gans. In International Conference on Geometric Science of Information, pages 524–533. Springer, 2021.
  • [24] Chun Liu and Yiwei Wang. On Lagrangian schemes for porous medium type generalized diffusion equations: A discrete energetic variational approach. Journal of Computational Physics, 417:109566, 2020.
  • [25] Shu Liu, Wuchen Li, Hongyuan Zha, and Haomin Zhou. Neural parametric Fokker–Planck equation. SIAM Journal on Numerical Analysis, 60(3):1385–1449, 2022.
  • [26] Pierre-Henri Maire, Rémi Abgrall, Jérôme Breil, and Jean Ovadia. A cell-centered Lagrangian scheme for two-dimensional compressible flow problems. SIAM Journal on Scientific Computing, 29(4):1781–1824, 2007.
  • [27] Petr Mokrov, Alexander Korotin, Lingxiao Li, Aude Genevay, Justin M Solomon, and Evgeny Burnaev. Large-scale Wasserstein gradient flows. Advances in Neural Information Processing Systems, 34:15243–15256, 2021.
  • [28] Kirill Neklyudov, Rob Brekelmans, Alexander Tong, Lazar Atanackovic, Qiang Liu, and Alireza Makhzani. A computational framework for solving Wasserstein Lagrangian flows. arXiv preprint arXiv:2310.10649, 2023.
  • [29] Levon Nurbekyan, Wanzhou Lei, and Yunan Yang. Efficient natural gradient descent methods for large-scale PDE-based optimization problems. SIAM Journal on Scientific Computing, 45(4):A1621–A1655, 2023.
  • [30] N Nüsken. On the geometry of Stein variational gradient descent. Journal of Machine Learning Research, 24:1–39, 2023.
  • [31] Yann Ollivier. Riemannian metrics for neural networks I: feedforward networks. Information and Inference: A Journal of the IMA, 4(2):108–153, 2015.
  • [32] Lars Onsager. Reciprocal relations in irreversible processes. I. Phys. Rev., 37:405–426, Feb 1931.
  • [33] Felix Otto. The geometry of dissipative evolution equations the porous medium equation. Communications in Partial Differential Equations, 26(1-2):101–174, 2001.
  • [34] Lars Ruthotto, Stanley J Osher, Wuchen Li, Levon Nurbekyan, and Samy Wu Fung. A machine learning framework for solving high-dimensional mean field game and mean field control problems. Proceedings of the National Academy of Sciences, 117(17):9183–9193, 2020.
  • [35] Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modeling through stochastic differential equations. In 9th International Conference on Learning Representations, ICLR 2021, Virtual Event, Austria, May 3-7, 2021. OpenReview.net, 2021.
  • [36] Tao Tang. Moving mesh methods for computational fluid dynamics. Contemporary mathematics, 383(8):141–173, 2005.
  • [37] Cédric Villani. Optimal Transport: Old and New, volume 338. Springer, 2009.
  • [38] Hao Wu, Shu Liu, Xiaojing Ye, and Haomin Zhou. Parameterized Wasserstein Hamiltonian flow. arXiv preprint arXiv:2306.00191, 2023.