Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

No Representation, No Trust: Connecting Representation, Collapse, and Trust Issues in PPO

Skander Moalla
skander.moalla@epfl.ch
CLAIRE
EPFL &Andrea Miele
CLAIRE
EPFL &Razvan Pascanu
Google DeepMind &Caglar Gulcehre
caglar.gulcehre@epfl.ch
CLAIRE
EPFL
Abstract

Reinforcement learning (RL) is inherently rife with non-stationarity since the states and rewards the agent observes during training depend on its changing policy. Therefore, networks in deep RL must be capable of adapting to new observations and fitting new targets. However, previous works have observed that networks in off-policy deep value-based methods exhibit a decrease in representation rank, often correlated with an inability to continue learning or a collapse in performance. Although this phenomenon has generally been attributed to neural network learning under non-stationarity, it has been overlooked in on-policy policy optimization methods which are often thought capable of training indefinitely. In this work, we empirically study representation dynamics in Proximal Policy Optimization (PPO) on the Atari and MuJoCo environments, revealing that PPO agents are also affected by feature rank deterioration and loss of plasticity. We show that this is aggravated with stronger non-stationarity, ultimately driving the actor’s performance to collapse, regardless of the performance of the critic. We draw connections between representation collapse, performance collapse, and trust region issues in PPO, and present Proximal Feature Optimization (PFO), a novel auxiliary loss, that along with other interventions shows that regularizing the representation dynamics improves the performance of PPO agents. Code and run histories are available at https://github.com/CLAIRE-Labo/no-representation-no-trust.

1 Introduction

Reinforcement learning (RL) agents are inherently subject to non-stationarity as the states and rewards they observe change during learning. Therefore, neural networks in deep RL must be capable of adapting to new inputs and fitting new targets. However, previous works have observed that value networks in off-policy value-based algorithms exhibit a decrease in the rank of their representations, termed feature rank, and a decrease in their ability to regress to arbitrary targets, called plasticity (Kumar et al., 2021; Lyle et al., 2022). Consequently, this decrease in representation fitness was often correlated with the value network’s inability to continue learning and adapting to new tasks, especially in sparse-reward scenarios. Although this phenomenon is more generally attributed to neural networks trained under non-stationarity (Lyle et al., 2023), it has been overlooked in on-policy policy optimization methods that are often thought capable of training indefinitely. In particular, Proximal Policy Optimization (PPO) (Schulman et al., 2017), one of the most popular policy optimization methods, introduces additional non-stationarity by making several minibatch updates over non-stationary data and by optimizing a surrogate loss depending on a moving policy. This raises the question of how much PPO agents are impacted by the same representation degradation attributed to non-stationarity. Igl et al. (2021) have shown that non-stationarity affects the generalization of PPO agents (learning speed when training episodes are very different otherwise performance at test time on novel episodes) but does not prevent training, and no connection was made with the feature rank and plasticity measures used in the recent value-based works.

In this work, we conduct an empirical study of representation dynamics in PPO, highlighting the detrimental effects of non-stationarity on policy optimization. Our contributions are the following:

  1. 1.

    We provide the first study of feature rank and plasticity in policy optimization, revealing that PPO agents in the Arcade Learning Environment (Bellemare et al., 2013) and MuJoCo (Todorov et al., 2012) environments are subject to representation collapse.

  2. 2.

    We draw connections between representation collapse, performance collapse, and trust region issues in PPO, showing that PPO’s clipping becomes ineffective under poor representations and fails to prevent performance collapse, which is irrecoverable due to loss of plasticity.

  3. 3.

    We corroborate these connections by performing interventions that regularize non-stationarity and representations and result in a better trust region and mitigation of performance collapse, incidentally giving insights on sharing an actor-critic trunk.

  4. 4.

    We propose Proximal Feature Optimization (PFO), a new regularization on the representation of the policy that regularizes the change in pre-activations. By addressing the representation issues, PFO can mitigate performance collapse and improve the agent’s performance.

  5. 5.

    We open source our code and run histories, providing a comprehensive and reproducible codebase for studying representation dynamics in policy optimization and a large database of run histories with extensive logging for further investigation on this topic.

2 Background

Reinforcement Learning (Sutton & Barto, 2018)

We formalize our RL setting with the finite-horizon undiscounted Markov decision process, describing the interaction between an agent and an environment with finite 111The pixel-based environment with discrete actions used in our experiments and our simple theoretical example in Section 3.2.1 fit the finite state and action formalism but not our continuous action space environment. We refer the reader to Szepesvári (2022) for a formalism of RL in that setting. sets of states 𝒮𝒮{\cal S}caligraphic_S and actions 𝒜𝒜{\cal A}caligraphic_A, and a reward function r:𝒮×𝒜×𝒮:𝑟𝒮𝒜𝒮r:{\cal S}\times{\cal A}\times{\cal S}\to\mathbb{R}italic_r : caligraphic_S × caligraphic_A × caligraphic_S → blackboard_R. An initial state S0𝒮subscript𝑆0𝒮S_{0}\in{\cal S}italic_S start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ caligraphic_S is sampled from the environment, then at each time step t{0,,tmax1}𝑡0subscript𝑡1t\in\{0,\dots,t_{\max}-1\}italic_t ∈ { 0 , … , italic_t start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT - 1 }, the agent observes the state St𝒮subscript𝑆𝑡𝒮S_{t}\in{\cal S}italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ caligraphic_S, picks an action At𝒜subscript𝐴𝑡𝒜A_{t}\in{\cal A}italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ caligraphic_A according to its policy π:𝒮Δ(𝒜):𝜋𝒮Δ𝒜\pi:{\cal S}\to\Delta({\cal A})italic_π : caligraphic_S → roman_Δ ( caligraphic_A ) with probability π(At|St)𝜋conditionalsubscript𝐴𝑡subscript𝑆𝑡\pi(A_{t}|S_{t})italic_π ( italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), 222The time step t𝑡titalic_t is included in the representation of Stsubscript𝑆𝑡S_{t}italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to preserve the Markov property in finite-horizon tasks as done by Pardo et al. (2018) and is analogous to considering time-dependent policies in the classical formulation of finite-horizon MDPs. observes the next state St+1Ssubscript𝑆𝑡1𝑆S_{t+1}\in Sitalic_S start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ∈ italic_S sampled from the environment and receives a reward Rt+1r(St,At,St+1)approaches-limitsubscript𝑅𝑡1𝑟subscript𝑆𝑡subscript𝐴𝑡subscript𝑆𝑡1R_{t+1}\doteq r(S_{t},A_{t},S_{t+1})italic_R start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ≐ italic_r ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ). We denote by Gtk=ttmax1Rk+1approaches-limitsubscript𝐺𝑡superscriptsubscript𝑘𝑡subscript𝑡1subscript𝑅𝑘1G_{t}\doteq\sum_{k=t}^{t_{\max}-1}R_{k+1}italic_G start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≐ ∑ start_POSTSUBSCRIPT italic_k = italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT the return after the action at time step t𝑡titalic_t. The goal of the agent is to maximize its expected return J(π)𝔼π[t=0tmax1Rt+1]=𝔼π[G0]approaches-limit𝐽𝜋subscript𝔼𝜋delimited-[]superscriptsubscript𝑡0subscript𝑡1subscript𝑅𝑡1subscript𝔼𝜋delimited-[]subscript𝐺0J(\pi)\doteq\mathbb{E}_{\pi\!\!}\left[\sum_{t=0}^{t_{\max}-1}R_{t+1}\right]=% \mathbb{E}_{\pi\!\!}\left[G_{0}\right]italic_J ( italic_π ) ≐ blackboard_E start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT - 1 end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ] = blackboard_E start_POSTSUBSCRIPT italic_π end_POSTSUBSCRIPT [ italic_G start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ] over the induced random trajectories. We discuss the choice of this setting in Appendix A.1.

Actor-Critic Agent

We consider on-policy deep actor-critic agents which train a policy network π(;𝜽)𝜋𝜽\pi(\cdot;{\bm{\theta}})italic_π ( ⋅ ; bold_italic_θ ), called actor and denoted π𝜽subscript𝜋𝜽\pi_{\bm{\theta}}italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT, and a value network v^(;𝐰)^𝑣𝐰\hat{v}(\cdot;{\bf w})over^ start_ARG italic_v end_ARG ( ⋅ ; bold_w ), called critic that approximates the return of π𝜽subscript𝜋𝜽\pi_{\bm{\theta}}italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT at every state (expected sum of rewards starting from a state). At every training stage, the agent collects a batch of samples, called rollout, with its current policy π𝜽subscript𝜋𝜽\pi_{\bm{\theta}}italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT, and both networks are trained with gradient descent on this data. The critic is trained to minimize the Euclidean distance to an estimator of the returns (e.g., Gtsubscript𝐺𝑡G_{t}italic_G start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT). We use λ𝜆\lambdaitalic_λ-returns computed with the Generalized Advantage Estimator (GAE) (Schulman et al., 2015b). The actor is trained with the Proximal Policy Optimization (PPO) (Schulman et al., 2017).

Proximal Policy Optimization

PPO-Clip, the most popular variant of PPO algorithms proposed by Schulman et al. (2017), optimizes the actor by maximizing the surrogate objective in Equation 1 at each rollout.

LπoldCLIP(𝜽)=𝔼πold[t=0tmax1min(π𝜽(At|St)πold(At|St)Ψt,clip(π𝜽(At|St)πold(At|St),1+ϵ,1ϵ)Ψt)]superscriptsubscript𝐿subscript𝜋old𝐶𝐿𝐼𝑃𝜽subscript𝔼subscript𝜋olddelimited-[]superscriptsubscript𝑡0subscript𝑡1subscript𝜋𝜽conditionalsubscript𝐴𝑡subscript𝑆𝑡subscript𝜋oldconditionalsubscript𝐴𝑡subscript𝑆𝑡subscriptΨ𝑡clipsubscript𝜋𝜽conditionalsubscript𝐴𝑡subscript𝑆𝑡subscript𝜋oldconditionalsubscript𝐴𝑡subscript𝑆𝑡1italic-ϵ1italic-ϵsubscriptΨ𝑡L_{\pi_{\text{old}}}^{CLIP}({\bm{\theta}})=\mathbb{E}_{\pi_{\text{old}}\!\!}% \left[\sum_{t=0}^{t_{\max}-1}\min\left(\frac{\pi_{\bm{\theta}}(A_{t}|S_{t})}{% \pi_{\text{old}}(A_{t}|S_{t})}\Psi_{t},\text{clip}\left(\frac{\pi_{\bm{\theta}% }(A_{t}|S_{t})}{\pi_{\text{old}}(A_{t}|S_{t})},1+\epsilon,1-\epsilon\right)% \Psi_{t}\right)\right]italic_L start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C italic_L italic_I italic_P end_POSTSUPERSCRIPT ( bold_italic_θ ) = blackboard_E start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT - 1 end_POSTSUPERSCRIPT roman_min ( divide start_ARG italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ( italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG roman_Ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , clip ( divide start_ARG italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ( italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG , 1 + italic_ϵ , 1 - italic_ϵ ) roman_Ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ] (1)

The objective is defined for some small hyperparameter ϵitalic-ϵ\epsilonitalic_ϵ, where πoldsubscript𝜋old\pi_{\text{old}}italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT is the last π𝜽subscript𝜋𝜽\pi_{\bm{\theta}}italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT of the previous optimization stage and is used to collect the training batch, and ΨtsubscriptΨ𝑡\Psi_{t}roman_Ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is an estimator of the advantage of πoldsubscript𝜋old\pi_{\text{old}}italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT (e.g., Ψt=Gtv^(St;𝐰old)subscriptΨ𝑡subscript𝐺𝑡^𝑣subscript𝑆𝑡subscript𝐰old{\Psi_{t}=G_{t}-\hat{v}(S_{t};{\bf w}_{\text{old}})}roman_Ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_G start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - over^ start_ARG italic_v end_ARG ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_w start_POSTSUBSCRIPT old end_POSTSUBSCRIPT )). We use the GAE in our experiments. The objective is optimized with minibatch gradient steps over multiple epochs on the rollout data (batch). We refer to PPO-Clip as PPO and provide a high-level pseudocode in Algorithm 1.

Intuitively, PPO aims to maximize the policy advantage 𝔼πold[t=0tmax1π𝜽(At|St)πold(At|St)Ψt]subscript𝔼subscript𝜋olddelimited-[]superscriptsubscript𝑡0subscript𝑡1subscript𝜋𝜽conditionalsubscript𝐴𝑡subscript𝑆𝑡subscript𝜋oldconditionalsubscript𝐴𝑡subscript𝑆𝑡subscriptΨ𝑡\mathbb{E}_{\pi_{\text{old}}\!\!}\left[\sum_{t=0}^{t_{\max}-1}\frac{\pi_{\bm{% \theta}}(A_{t}|S_{t})}{\pi_{\text{old}}(A_{t}|S_{t})}\Psi_{t}\right]blackboard_E start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT - 1 end_POSTSUPERSCRIPT divide start_ARG italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ( italic_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG roman_Ψ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] defined by Kakade & Langford (2002), which participates in a lower bound to the improvement of π𝜽subscript𝜋𝜽\pi_{\bm{\theta}}italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT given that it is close to πoldsubscript𝜋old\pi_{\text{old}}italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT (Schulman et al., 2015a, see Theorem 1). In this regard, a gradient step on LπoldCLIP(θ)superscriptsubscript𝐿subscript𝜋old𝐶𝐿𝐼𝑃𝜃L_{\pi_{\text{old}}}^{CLIP}(\theta)italic_L start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C italic_L italic_I italic_P end_POSTSUPERSCRIPT ( italic_θ ) would increase (decrease) the probability of actions at states yielding positive (negative) advantage until the ratio between the policies for those actions reaches 1+ϵ1italic-ϵ1+\epsilon1 + italic_ϵ (1ϵ)1-\epsilon)1 - italic_ϵ ) at which point the gradient at those samples becomes null. This is a heuristic to ensure a trust region that keeps policies close to each other, resulting in policy improvement.

Non-stationarity in deep RL and PPO

The actor and the critic networks are both subject to non-stationarity in deep RL. As the agent improves, it visits different states, shifting the distribution of states which makes the networks’ input distribution non-stationary. This also holds for the targets to fit the critic, which change as the policy returns change. This form of non-stationarity has been shown by previous work to hinder the ability of several deep RL agents (e.g., DQN (Mnih et al., 2015), SAC (Haarnoja et al., 2018)) to continue learning (Lyle et al., 2022; Nikishin et al., 2022). PPO introduces additional non-stationarity to the actor compared to policy gradient methods (such as vanilla policy gradient (Sutton et al., 1999), A2C Mnih et al. (2016)) by making its objective non-stationary as it depends on the previous policy. In addition, the practical benefit of the PPO objective is that it can be optimized by performing multiple epochs of mini-batch gradient descent on every new collected batch; however, this makes the networks more likely to be impacted by previous training rollouts as new rollouts are collected. Increasing the number of epochs can cause the agent to “overfit” more to previous experience, making the impact of non-stationarity stronger.

Feature rank

As done in most works studying feature dynamics in deep RL (Lyle et al., 2022; Kumar et al., 2021), we refer to the activations of the last hidden layer of a network (the penultimate layer) as the features or representation learned by the network. On a batch of N𝑁Nitalic_N samples, this gives a matrix of dimension N×D𝑁𝐷N\times Ditalic_N × italic_D denoted by ΦΦ\Phiroman_Φ, where D<N𝐷𝑁D<Nitalic_D < italic_N is the width of the penultimate layer. Several measures of the rank of this matrix have been used to quantify the “quality” of the representation (Kumar et al., 2021; Gulcehre et al., 2022; Lyle et al., 2022; Andriushchenko et al., 2023). Their absolute values differ significantly, but their dynamics are often correlated. We track all of the different rank metrics in our experiments, compare them in Appendix E.1 and E, and use the approximate rank in our main figures for its connection to principal component analysis (PCA). Given a threshold δ𝛿\delta\in\mathbb{R}italic_δ ∈ blackboard_R and the singular values of the feature matrix σi(Φ),,σD(Φ)subscript𝜎𝑖Φsubscript𝜎𝐷Φ\langle\sigma_{i}(\Phi),\dots,\sigma_{D}(\Phi)\rangle⟨ italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( roman_Φ ) , … , italic_σ start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( roman_Φ ) ⟩ in decreasing order, the approximate rank of ΦΦ\Phiroman_Φ is mink{i=1kσi2(Φ)j=1Dσj2(Φ)>1δ}subscript𝑘superscriptsubscript𝑖1𝑘superscriptsubscript𝜎𝑖2Φsuperscriptsubscript𝑗1𝐷superscriptsubscript𝜎𝑗2Φ1𝛿\min_{k}\left\{\frac{\sum_{i=1}^{k}\sigma_{i}^{2}(\Phi)}{\sum_{j=1}^{D}\sigma_% {j}^{2}(\Phi)}>1-\delta\right\}roman_min start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT { divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( roman_Φ ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( roman_Φ ) end_ARG > 1 - italic_δ } which corresponds to the smallest dimension of the subspace recovering (1δ)%percent1𝛿(1-\delta)\%( 1 - italic_δ ) % of the variance of ΦΦ\Phiroman_Φ. We use δ=0.01𝛿0.01\delta=0.01italic_δ = 0.01 i.e. the reconstruction recovers 99%percent9999\%99 % of the variance as done by Andriushchenko et al. (2023); Yang et al. (2020). We also refer to this metric as feature rank with reference to the rank of the feature matrix when there is no ambiguity.

Plasticity loss

Plasticity is computed on checkpoints of a network undergoing some training to measure the evolution of its ability to fit targets. Given a fixed target and a fixed optimization budget (number of gradient steps), a checkpoint’s plasticity loss is the loss from fitting the checkpoint to the target at the end of the optimization budget. Usually, the plasticity of a deep RL agent throughout training is measured by its ability to fit the outputs of a model initialized from the same distribution as the agent, on a fixed rollout collected by this target random model (Lyle et al., 2022; Nikishin et al., 2023). The data would in expectation be from the same distribution as the agent’s initial checkpoint. To fit the critic we use an L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT loss on the outputs of the models. To fit the actor we use a KL divergence between the target and the checkpoint (forward KL).

3 Deteriorating representations, collapse, and loss of trust

It is well-known that non-stationarity in deep RL can be a factor causing issues in representation learning. However, most of the observations have been made in value-based methods showing that value networks are prone to rank collapse, harming their expressivity, and in turn, the performance of the agent (Lyle et al., 2022; Kumar et al., 2022); Igl et al. (2021) studied non-stationarity in PPO but only showed that it harms its generalization, with no evidence of rank deterioration or performance collapse. Our motivation is to reuse the tools that showed that value-based methods are prone to representation collapse but in policy optimization methods for the first time. We focus on PPO for its popularity and extra non-stationarity through multi-epoch optimization.

Experimental setup

We begin our experiments by training PPO agents on the Arcade Learning Environment (ALE)(Bellemare et al., 2013) for pixel-based observations with discrete actions and on MuJoCo (Todorov et al., 2012) for continuous observations with continuous actions. To keep our experiments tractable, we choose the Atari-5 subset recommended by Aitchison et al. (2023) and add Gravitar to include at least one sparse-reward hard-exploration game from the taxonomy presented by Bellemare et al. (2016). For MuJoCo, we train on Ants, Half-Cheetahs, Humanoids, and Hoppers, which have varying complexity and observation and output sizes. We use the same model architectures and hyperparameters as popular implementations of PPO on ALE and MuJoCo (Raffin et al., 2021; Huang et al., 2022b); these are also the architectures and hyperparameters used by Schulman et al. (2017) in the original implementation of PPO, they do not include normalization layers. The ALE model uses ReLU activations (Nair & Hinton, 2010) and the MuJoCo one tanh; we also experiment with ReLU on MuJoCo. We use separate actor and critic models for both environments unless specified in Section 4. Details on the performance metrics and tables of all environment parameters, model architectures, and algorithm hyperparameters are presented in Appendix B. Observing that the previous findings on the feature dynamics of value-based approaches (Gulcehre et al., 2022; Lyle et al., 2022) apply to the critic of PPO as well since the loss function is the same, we focus on studying the feature dynamics of the actor unless stated otherwise in the text or figures.

We vary the number of epochs as a way to control the effect of non-stationarity, which gives the agent a more significant number of optimization steps per rollout while not changing the optimal target it can reach due to clipping, as opposed to changing the value of ϵitalic-ϵ\epsilonitalic_ϵ in the trust region for example. We keep the learning rate constant throughout training and use the same value for all the epoch configurations. To understand the feature dynamics, we measure different metrics that are proposed in the literature, including feature rank, number of dead neurons (Gulcehre et al., 2022), plasticity loss (Lyle et al., 2022), and penultimate layer pre-activation norm. Previous work has monitored feature norm values as well (Lyle et al., 2024); however, in our case, we found that as the neurons in the policy network die, the feature norm might be stable while the pre-activation norm blows up. All the metrics are computed on on-policy rollouts except for the plasticity loss.

We run five seeds per hyperparameter configuration and report mean curves with min/max shaded regions unless specified otherwise. All curves, except for plasticity loss, are smoothed using an exponentially weighted moving average of coefficient 0.050.050.050.05. We seek to answer the following questions:

  • Q1. How do a PPO agent’s representation metrics, such as the feature rank and the plasticity loss, evolve during training? Are they subject to the same decline observed by Kumar et al. (2021); Lyle et al. (2022) in value-based methods? Does it affect performance?

  • Q2. How does increasing the number of epochs per rollout to vary non-stationarity affect a PPO agent’s representation? Does it degrade as observed in DQN and SAC agents when increasing the replay ratio (Nikishin et al., 2022; Kumar et al., 2022)?

3.1 PPO suffers from deteriorating representations

Q1. Deteriorating representation

As illustrated in Figure 1 with ALE/Phoenix as an example, we observe a consistent increase in the norm of the pre-activations of the feature layer of the policy network. Learning curves for all the ALE games and MuJoCo tasks considered can be found in Appendix D. The increase in feature norm is present in all the games/tasks considered in both environments, that is, with the two different model architectures and activation functions in the case of MuJoCo. We associate the rapid growth in the norm of the pre-activations with an eventual decline in the policy network’s feature rank. We observe a rank decline in five out of six ALE games and seven out of eight MuJoCo tasks (four with ReLU and three with tanh). The same observations about the increasing norm of the pre-activations can be made about the critic network. However, its rank varies more with the sparsity of the reward. In most environments, its rank experiences a significant deterioration after the policy’s performance (not its rank) declines and rewards become sparser. In the case of the sparse-reward game Gravitar, the critic’s rank collapses before the policy. Furthermore, plasticity loss is increasing for the critic, as observed in value-based plasticity studies, and it is also the case for the actor, for which it explodes around rank collapse.

Q2. Worse consequences

Increasing the replay ratio in DQN and SAC deteriorates the agent’s representation and, in turn, its performance (Kumar et al., 2022; D’Oro et al., 2023). This is commonly attributed to “overfitting” to previous experience (Nikishin et al., 2022). Increasing the number of epochs in PPO is analogous, and a natural hypothesis is that this would accelerate the deterioration of the policy’s representation. Figure 1 shows that increasing the number of epochs accelerates the increase of pre-activations norm and the decrease of the policy’s feature rank. In some cases, the rank eventually collapses, which coincides with a collapse in the policy’s performance. We observe the performance collapse in three of the six ALE games and three of the four MuJoCo tasks.

Refer to caption
Figure 1: Deteriorating performance and representation metrics The policy network of a PPO-Clip agent on ALE/Phoenix-v5 is subject to a deteriorating representation. The norm of the pre-activations of the penultimate layer consistently increases, and its rank eventually decreases. Performing more optimization epochs per rollout to increase the effects of non-stationarity accelerates the growth of the norm of the pre-activations and the collapse of its rank. This ultimately leads to the collapse of the policy. This collapse is not driven by the value network, whose rank is still high.
Refer to caption
Figure 2: Rank collapse gives a high but trivial entropy The rank collapse of the policy network gives a policy with high entropy but zero variance across states. The network outputs the same high-entropy action distribution in all states, as all the neurons in the feature layer are dead. Its output only depends on the constant bias term.
Characterizing the collapse

The collapse observed in the policy’s representation differs from the typical entropy collapse. As one can observe from Figure 2, the collapse in representation gives a policy that has high entropy (still lower than a uniform policy and its initialization). However, in MDPs, entropy is computed as the average across states. A high overall entropy can come from an average of high-entropy states with different action distributions or trivially from the same high-entropy distribution in all states. In our case, it is the latter: we compute the policy variance across states and observe that it collapses to zero, indicating that the policy network outputs the same action distribution at every state, which is often close to uniform. This is consistent with the policy representation, which has collapsed to a layer with mostly dead neurons 333We consider a ReLU neuron as dead when its values are zero for all the samples in the batch and a tanh neuron dead when its standard deviation across samples is less than 0.001.. The output layer gives the output distribution based on only the bias term, as the linear weights are multiplied by a null feature vector. Such a policy takes the same actions across all states; therefore, performance for non-trivial tasks would collapse.

3.2 Collapsed representations create trust issues and unrecoverable loss

Intuitively, the heuristic trust region set by PPO-Clip should prevent sudden catastrophic changes and, therefore, limit the rank collapse, which induces worse performance. However, empirically, it seems the trust region is not able to mitigate the collapse. In this section, we seek to understand the interaction between the rank collapse and the trust region. We argue that as rank collapses, the clipping constraint becomes unreliable and unable to restrict learning. This is inline with previous works that have pointed out that probability ratios during training can go beyond the clipping limits with PPO-Clip (Engstrom et al., 2020; Wang et al., 2020; Sun et al., 2022). We believe, however, that this behavior is systematic when rank collapses and does not merely happen occasionally.

Wang et al. (2020, Theorem 2) state that when the gradients of the unclipped samples align with the gradients of clipped samples, the clipped samples’ ratios will have their probabilities continue to go beyond the clip limit. They claim this condition would hold in practice because of “optimization tricks” or optimizer accumulated moments; however, there is no evidence that these factors induce the gradient alignment or that the alignment is present in practice. Our intuition is that representation degradation leads to alignment in the gradients and, therefore, a breakdown of the trust region constraint. This can create a snowball effect, making PPO-Clip unable to prevent representation collapse. We summarize this in two observations:

1. Loss of trust is extreme around poor representations

The average of probability ratios outside the clipping limits (below 1ϵ1italic-ϵ1-\epsilon1 - italic_ϵ in Figure 3) significantly diverges from the clipping limit around the collapse of the agent’s representation. This gives one more reason why the PPO trust region can be violated. We isolate this in a toy setting and analyze it formally in the next section. We further show in Figure 4 scatter plots of the lowest average probability ratios in runs with their associated representation metrics (20 points per run, across windows of size 1% training progress, spanning at least the horizon of the environment, so that points are well spaced in the run, with each point being the average of the window); we observe no significant correlation in the regions where the representation is rich (high rank, low pre-activation norm), but an apparent decrease of the average of probability ratios below 1ϵ1italic-ϵ1-\epsilon1 - italic_ϵ is observed as the representation reaches poor values.

Refer to caption
Figure 3: Focusing on individual runs Individual training curves on ALE/NameThisGame-v5 with different epochs per batch. Extremely low ratios are observed around the representation collapse of a PPO-Clip agent, implying that the heuristic trust region breaks down when representation power is lacking. The last-minibatch value of the PPO objective is decreasing towards 0 around the representation collapse, implying a reduction of the ability to improve the policy and recover which is corroborated by the increase of the plasticity loss. (The ratios are trivially within the 1ϵ1italic-ϵ1-\epsilon1 - italic_ϵ region after a collapse as the collapsed model does not change much, so no value below 1ϵ1italic-ϵ1-\epsilon1 - italic_ϵ is logged.)
Refer to caption
Figure 4: Representation vs trust region Samples from ALE/Phoenix-v5 training curves. Each point maps an average of the probability ratios below the clipping limit vs. its corresponding average representation metric (dead neurons, feature rank, feature norm). The average ratios are significantly lower around poor representations (high dead neurons, low policy rank, high feature norm) reflecting the failure of the trust region. Averages are over non-overlapping windows larger than episodes.
2. Loss of plasticity renders performance collapse unrecoverable

The monotonic decrease in performance overlaps with a monotonic decrease in policy variance and PPO objective. It appears that as the policy loses its ability to distinguish state, it can also ascend the PPO objective less and less at each batch (recall: after collecting a batch, the loss starts around zero with a normalized advantage, and through minibatch updates, the clipped policy advantage is ascended). Intuitively, this is implied by a loss of plasticity or a collapse in entropy (no new actions to learn from). As seen in Section 2 the entropy does not collapse, and measuring the plasticity loss in Figure 3 shows that that the decrease in objective gain is associated with a significant increase in plasticity loss.

Connecting the dots

Hence, around collapse, the representation of the policy is getting so poor that it is impacting its ability to distinguish and act differently across states; the trust region cannot prevent this sudden and catastrophic change as it also breaks down with a poor representation; finally, the policy’s plasticity is also becoming so poor that the agent cannot recover by optimizing the surrogate objective.

3.2.1 A toy setting to understand the effects of rank collapse on trust region

We present a toy example that illustrates how a collapsed representation bypasses the clipping set by PPO and cannot satisfy the trust region it seeks to set. PPO constructs a trust region around the policy π𝜽(|s)\pi_{{\bm{\theta}}}(\cdot|s)italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_s ) of the agent evaluated at a given state s𝑠sitalic_s, enforcing (in an approximate way) that the update computed on state s𝑠sitalic_s can not move the policy π𝜽(|s)\pi_{{\bm{\theta}}}(\cdot|s)italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_s ) outside of the trust region. However, the constraint does not capture how updates computed on another state ssuperscript𝑠s^{\prime}italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT affect the policy’s probability distribution over the current state s𝑠sitalic_s. The underlying assumption is that updates computed on different states are, at least in expectation, approximately orthogonal to each other, and they do not interact. Therefore, restricting the update of the current state is sufficient to keep the policy within the region.

In our case, however, one can show that as the rank collapses or the neurons die, the representations corresponding to different states become more colinear.444The expected angle between representations shrinks to 0. Therefore, the gradients also become more colinear. In the extreme case, when the rank collapses to 1, or there is only one neuron alive, all representations are exactly colinear; therefore, all gradients are also. This means that even though clipping prevents the policy π𝜽(|s)\pi_{\bm{\theta}}(\cdot|s)italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_s ) on the current state s𝑠sitalic_s from changing due to the update of that state L(π𝜽(|s))\nabla L(\pi_{\bm{\theta}}(\cdot|s))∇ italic_L ( italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_s ) ), π𝜽(|s)\pi_{\bm{\theta}}(\cdot|s)italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_s ) will still change and move outside of the trust region due to the updates on other states ssuperscript𝑠s^{\prime}italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT. Leading to the trust region constraint being ineffective and not constraining the learning process in any meaningful sense. This gives a clear situation where the theorem of Wang et al. (2020) holds and can easily be analyzed as below without resorting to the theorem for an end-to-end proof or to get a better intuition.

Formal statement of the toy setting

Let us consider a batch containing two state-action pairs (x,a1)𝑥subscript𝑎1(x,a_{1})( italic_x , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) and (y,a1)𝑦subscript𝑎1(y,a_{1})( italic_y , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) with sampled probabilities πold(a1|x)subscript𝜋oldconditionalsubscript𝑎1𝑥\pi_{\text{old}}(a_{1}|x)italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ) and πold(a1|y)subscript𝜋oldconditionalsubscript𝑎1𝑦\pi_{\text{old}}(a_{1}|y)italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_y ) and positive estimated advantages A(x,a1),A(y,a1)>0𝐴𝑥subscript𝑎1𝐴𝑦subscript𝑎10A(x,a_{1}),A(y,a_{1})>0italic_A ( italic_x , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_A ( italic_y , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) > 0. Let ϕ(x),ϕ(y)italic-ϕ𝑥italic-ϕ𝑦\phi(x),\phi(y)\in\mathbb{R}italic_ϕ ( italic_x ) , italic_ϕ ( italic_y ) ∈ blackboard_R be fixed 1-dimensional representations of x𝑥xitalic_x, and y𝑦yitalic_y that can be seen as the output of the (frozen) penultimate layer of a policy network with collapsed representation (all but one dead neuron),

Refer to caption
Refer to caption
Figure 5: Simulation of the toy setting Left (α>0𝛼0\alpha>0italic_α > 0): a gradient on (x,a1)𝑥subscript𝑎1(x,a_{1})( italic_x , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) takes the probability of (y,a1)𝑦subscript𝑎1(y,a_{1})( italic_y , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) up and vice versa. When one is above the threshold and should not increase, the other still pushes it. Right (α<0𝛼0\alpha<0italic_α < 0): a gradient on (x,a1)𝑥subscript𝑎1(x,a_{1})( italic_x , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) takes the probability of (y,a1)𝑦subscript𝑎1(y,a_{1})( italic_y , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) down and vice versa. Both slow each down, with one forcing the other to be lower than its initial value.

and let α𝛼\alpha\in\mathbb{R}italic_α ∈ blackboard_R such that ϕ(y)=αϕ(x)italic-ϕ𝑦𝛼italic-ϕ𝑥\phi(y)=\alpha\phi(x)italic_ϕ ( italic_y ) = italic_α italic_ϕ ( italic_x ). Let 𝜽=[θ1,θ2]𝜽subscript𝜃1subscript𝜃2{\bm{\theta}}=[\theta_{1},\theta_{2}]bold_italic_θ = [ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ], be the last layer of the network, computing the logits of two possible actions, a1subscript𝑎1a_{1}italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and a2subscript𝑎2a_{2}italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, that are then fed into a softmax to compute the probabilities. I.e., π𝜽(ai|s)=eθ1ϕ(s)eθ1ϕ(s)+eθ2ϕ(s)subscript𝜋𝜽conditionalsubscript𝑎𝑖𝑠superscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃2italic-ϕ𝑠\pi_{\bm{\theta}}(a_{i}|s)=\frac{e^{\theta_{1}\phi(s)}}{e^{\theta_{1}\phi(s)}+% e^{\theta_{2}\phi(s)}}italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_s ) = divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT end_ARG. Consider PPO minibatch updates alternating between (x,a1)𝑥subscript𝑎1(x,a_{1})( italic_x , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) and (y,a1)𝑦subscript𝑎1(y,a_{1})( italic_y , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ). Ideally, the PPO loss increases π𝜽(a1|s)subscript𝜋𝜽conditionalsubscript𝑎1𝑠\pi_{\bm{\theta}}(a_{1}|s)italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_s ) at gradients on (x,a1)𝑥subscript𝑎1(x,a_{1})( italic_x , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) until it reaches the clip ratio and similarly on (y,a1)𝑦subscript𝑎1(y,a_{1})( italic_y , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ). However, we show in Appendix C that a gradient step in (x,a1)𝑥subscript𝑎1(x,a_{1})( italic_x , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) also affects π𝜽(a1|y)subscript𝜋𝜽conditionalsubscript𝑎1𝑦\pi_{\bm{\theta}}(a_{1}|y)italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_y ) and depending on α𝛼\alphaitalic_α will increase it past its clipped ratio, or decrease it below its initial value.

Essentially, when α0𝛼0\alpha\geq 0italic_α ≥ 0, a gradient on (x,a1)𝑥subscript𝑎1(x,a_{1})( italic_x , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) increases θ1newsuperscriptsubscript𝜃1new\theta_{1}^{\text{new}}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT therefore increasing both π𝜽new(a1|x)subscript𝜋superscript𝜽newconditionalsubscript𝑎1𝑥\pi_{{\bm{\theta}}^{\text{new}}}(a_{1}|x)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ) and π𝜽new(a1|y)subscript𝜋superscript𝜽newconditionalsubscript𝑎1𝑦\pi_{{\bm{\theta}}^{\text{new}}}(a_{1}|y)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_y ). The same holds for a gradient on (y,a1)𝑦subscript𝑎1(y,a_{1})( italic_y , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), causing one state to reach the clip limit first depending on α>1𝛼1\alpha>1italic_α > 1 but still have the other keep pushing its probability upwards. However, when α0𝛼0\alpha\leq 0italic_α ≤ 0, a gradient on (x,a1)𝑥subscript𝑎1(x,a_{1})( italic_x , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) increases θ1newsuperscriptsubscript𝜃1new\theta_{1}^{\text{new}}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT therefore increasing π𝜽new(a1|x)subscript𝜋superscript𝜽newconditionalsubscript𝑎1𝑥\pi_{{\bm{\theta}}^{\text{new}}}(a_{1}|x)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ) but decreasing π𝜽new(a1|y)subscript𝜋superscript𝜽newconditionalsubscript𝑎1𝑦\pi_{{\bm{\theta}}^{\text{new}}}(a_{1}|y)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_y ). For a gradient on (y,a1)𝑦subscript𝑎1(y,a_{1})( italic_y , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) it is the opposite: θ1newsuperscriptsubscript𝜃1new\theta_{1}^{\text{new}}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT decreases therefore π𝜽new(a1|x)subscript𝜋superscript𝜽newconditionalsubscript𝑎1𝑥\pi_{{\bm{\theta}}^{\text{new}}}(a_{1}|x)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ) decreases and π𝜽new(a1|y)subscript𝜋superscript𝜽newconditionalsubscript𝑎1𝑦\pi_{{\bm{\theta}}^{\text{new}}}(a_{1}|y)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_y ) increases, causing each state to reduce the probability of the other, and depending on α<1𝛼1\alpha<1italic_α < 1 one of the probabilities will dominate and push the other one down. Figure 5 shows the evolution of the probabilities when simulating the updates empirically.

4 Intervening to regularize representations and non-stationarity

Having observed that PPO is affected by a consistent representation degradation that impacts its trust region heuristic and causes its performance to collapse, we turn to study interventions that aim at regularizing the representation of the policy network or reducing the non-stationarity in the optimization. We investigate whether these interventions improve the representation metrics we track and if in turn, this affects performance. We choose simple interventions that do not apply modifications to the models during training (e.g., resetting or adding neurons) or require significantly more memory (e.g., maintaining separate copies of the models). We perform interventions on the games/tasks where the collapse is the most significant. We are interested in the state of the agent at the end of the training budget. We record the performance and representation metrics for each run as averages over the last 5% of training progress. We measure the excess ratio at a timestep as the average probability ratio above 1+ϵ1italic-ϵ1+\epsilon1 + italic_ϵ divided by the average probability ratio below 1ϵ1italic-ϵ1-\epsilon1 - italic_ϵ at that timestep. This metric gives an idea of how much the policy exceeds the trust region. Its average value is computed over the last 5% of training progress where the ratios are non-trivial, giving the same window at the end of training as the other metrics when there is no collapse, otherwise a window before total collapse covering 5% of training progress, as after collapse, the model does not change anymore and the ratios are trivially within the 1+ϵ1italic-ϵ1+\epsilon1 + italic_ϵ and 1ϵ1italic-ϵ1-\epsilon1 - italic_ϵ limits. We give additional details on the computation of these aggregate metrics and the interventions performed in Appendix B.

PFO: Regularizing features to mitigate trust issues

The motivation for our first intervention and our proposed regularization method comes from our observation that the norm of the preactivation features is consistently increasing, which can be linked to the trust issues discussed in Section 3. We seek to mitigate this effect in a way that is analogous to the PPO trust region between the optimized policy and the policy that collected the batch. We apply an L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT loss on the difference between the pre-activated features of the optimized policy and the policy that collected the batch. This can also be thought of as a way to keep the pre-activations of the network during an update within a trust region. We apply this regularization to the pre-activations and not the activations, as dead neurons cannot propagate gradients, and even when they do, depending on the activation function, do so with a low magnitude. The regularization is an additional loss/penalty added to the overall loss. We term this loss the Proximal Feature Optimization (PFO) loss.

LπoldPFO(𝜽)=𝔼πold[t=0tmax1(ϕ𝜽(St)ϕold(St))2]superscriptsubscript𝐿subscript𝜋old𝑃𝐹𝑂𝜽subscript𝔼subscript𝜋olddelimited-[]superscriptsubscript𝑡0subscript𝑡1superscriptsubscriptitalic-ϕ𝜽subscript𝑆𝑡subscriptitalic-ϕoldsubscript𝑆𝑡2L_{\pi_{\text{old}}}^{PFO}({\bm{\theta}})=\mathbb{E}_{\pi_{\text{old}}\!\!}% \left[\sum_{t=0}^{t_{\max}-1}\left(\phi_{\bm{\theta}}(S_{t})-\phi_{\text{old}}% (S_{t})\right)^{2}\right]italic_L start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_P italic_F italic_O end_POSTSUPERSCRIPT ( bold_italic_θ ) = blackboard_E start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - italic_ϕ start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] (2)

We apply two versions of PFO: one on only the penultimate layer’s pre-activations and one on all the pre-activations until the penultimate layer. In the scope of this work, we do not tune the coefficient of PFO; we pick the closest power of 10 that sets the magnitude of this loss to a similar magnitude of the clipped PPO objective tracked on the experiments without intervention. This gives a coefficient of 1 for ALE, 1 for MuJoCo with tanh, and 10 for MuJoCo with ReLU. The goal is not necessarily to obtain better performance but to see if PFO improves the representations learned by PPO and if, in turn, it affects its performance. As shown in Figure 6, the regularization of PFO effectively brings the norm of the preactivation down, the number of dead neurons down, the plasticity loss down, and the rank up. This coincides with a significant decrease in the excess probability ratio, especially in the upper tail. More importantly, we also see a significant increase in the lower tail of the returns where no collapse in performance is observed anymore on ALE/NameThisGame and ALE/Phoenix, with a slight increase in the upper tail showing that PFO can increase performance. Among the interventions we have tried, PFO provided the most consistent improvements across the tasks and metrics we monitored.

Refer to caption
Refer to caption
Refer to caption
Figure 6: Effects of regularizing features and non-stationarity Top: ALE/Phoenix-v5. Regularizing the difference between the features of consecutive policies with PFO results in better representations and mitigates performance collapse while improving performance in some cases. Bottom: ALE/Gravitar. Sharing the feature trunk between the actor and the critic results in a worse policy representation as the value network is subject to rank collapse due to reward sparsity. Each boxplot is formed by 15 runs including 3 different epoch configurations with 5 seeds each.
Sharing the actor-critic trunk

In deep RL, the decision to use the same feature network trunk for both the actor and the critic is not trivial. Depending on the complexity of the environment, it can significantly vary the performance of a PPO agent (Andrychowicz et al., 2021; Huang et al., 2022a). We, therefore, attempt to draw a connection between sharing the feature trunk, the resulting representation, and its effects on the PPO objective. In this intervention, we make the actor and the critic share all the layers except their respective output layers and backpropagate the gradients from both the value and policy losses to the shared trunk. Figure 6 shows that the value loss acts as a regularizer, which decreases the feature rank, and depending on the reward’s sparsity, gives two distinct effects; In dense-reward environments such as ALE/Phoenix and ALE/NameThisGame, the ranks are concentrated at low but non-zero values: the upper tail significantly decreases compared to the baselines while the lower tail increases. This coincides with a lower feature norm, lower excess probability ratio, and, in turn, a high tail for the returns. It also increases performance in some cases. However, the opposite is true in the sparse-reward environment Gravitar: the rank completely collapses, and the feature norms and excess ratios are very high, collapsing the model’s performance. This is consistent with the observations made in the plasticity works studying value-based methods where sparse rewards deteriorate the rank of the value network. We provide training curves showing the difference in the evolution of the feature rank when sharing the actor-critic trunk in Appendix D.

Resetting the optimizer

Asadi et al. (2023) argue that as the targets of the value function change with the changing policy rollouts, the old moments accumulated by Adam become harmful to fit the new targets and find that resetting the moments of Adam helps performance in DQN-like algorithms. As the PPO objective creates a dependency on the previous policy, and more generally, in the policy gradient, the advantages change with the policy, the same argument about Adam moments can be made for PPO. Therefore, we experiment with resetting Adam’s moments after each batch collection (to avoid tuning its frequency) for both the actor and the critic; the moments are thus only accumulated over the epochs on the same batch. We observe that this intervention reduces the feature norm and increases the feature rank on ALE, which also reduces the excess probability ratio and, in some cases, improves performance; however, it is not as effective as the other interventions in preventing collapse and, like sharing the actor-critic trunk results in poor performance on ALE/Gravitar. Nevertheless, further tuning the frequency of the reset or coupling it with representation dynamics may yield different conclusions.

5 Conclusion and Discussion

Conclusion

In this work, we have provided evidence that the representation deterioration under non-stationarity observed by previous work in value-based methods generalizes to PPO agents in ALE and MuJoCo with their common model architectures and can lead to performance collapse. We have shown that this is particularly concerning for the heuristic trust region set by PPO-Clip, which fails to prevent collapse as it becomes less effective when the agent’s representation becomes poor. Finally, we presented Proximal Feature Optimization (PFO), a simple novel auxiliary loss based on regularizing the features’ evolution, which mitigates representation degradation and along with other interventions shows that controlling the representation degradation improves performance.

Limitations and open questions

Although we study pixel-based discrete action environments and continuous control environments, we only study the common architectures of each environment, which do not include normalization layers and are limited to relatively small architectures with no memory (i.e., exclude Transformers or RNNs). Therefore, the generalization of the results of this work to other settings is unknown. Still, this work should raise awareness about the representation collapse phenomenon observed in PPO and encourage future work to monitor representations when training PPO agents, as failure to train indefinitely and sudden collapse may be due to representations. We have presented here methods that mitigate representation degradation, but none of them seem to be a fundamental solution and a deeper understanding of the reasons driving representation deterioration under non-stationarity is still needed. We plan to address these issues in future work.

References

  • Aitchison et al. (2023) Matthew Aitchison, Penny Sweetser, and Marcus Hutter. Atari-5: Distilling the arcade learning environment down to five games. In Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett (eds.), Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pp.  421–438. PMLR, 23–29 Jul 2023. URL https://proceedings.mlr.press/v202/aitchison23a.html.
  • Andriushchenko et al. (2023) Maksym Andriushchenko, Dara Bahri, Hossein Mobahi, and Nicolas Flammarion. Sharpness-aware minimization leads to low-rank features. In A. Oh, T. Neumann, A. Globerson, K. Saenko, M. Hardt, and S. Levine (eds.), Advances in Neural Information Processing Systems, volume 36, pp.  47032–47051. Curran Associates, Inc., 2023. URL https://proceedings.neurips.cc/paper_files/paper/2023/file/92dd1adab39f362046f99dfe3c39d90f-Paper-Conference.pdf.
  • Andrychowicz et al. (2021) Marcin Andrychowicz, Anton Raichuk, Piotr Stańczyk, Manu Orsini, Sertan Girgin, Raphaël Marinier, Leonard Hussenot, Matthieu Geist, Olivier Pietquin, Marcin Michalski, Sylvain Gelly, and Olivier Bachem. What matters for on-policy deep actor-critic methods? a large-scale study. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=nIAxjsniDzg.
  • Asadi et al. (2023) Kavosh Asadi, Rasool Fakoor, and Shoham Sabach. Resetting the optimizer in deep rl: An empirical study. In A. Oh, T. Neumann, A. Globerson, K. Saenko, M. Hardt, and S. Levine (eds.), Advances in Neural Information Processing Systems, volume 36, pp.  72284–72324. Curran Associates, Inc., 2023. URL https://proceedings.neurips.cc/paper_files/paper/2023/file/e4bf5c3245fd92a4554a16af9803b757-Paper-Conference.pdf.
  • Bellemare et al. (2016) Marc Bellemare, Sriram Srinivasan, Georg Ostrovski, Tom Schaul, David Saxton, and Remi Munos. Unifying count-based exploration and intrinsic motivation. In D. Lee, M. Sugiyama, U. Luxburg, I. Guyon, and R. Garnett (eds.), Advances in Neural Information Processing Systems, volume 29. Curran Associates, Inc., 2016. URL https://proceedings.neurips.cc/paper_files/paper/2016/file/afda332245e2af431fb7b672a68b659d-Paper.pdf.
  • Bellemare et al. (2013) Marc G Bellemare, Yavar Naddaf, Joel Veness, and Michael Bowling. The arcade learning environment: An evaluation platform for general agents. Journal of Artificial Intelligence Research, 47:253–279, 2013.
  • Bou et al. (2024) Albert Bou, Matteo Bettini, Sebastian Dittert, Vikash Kumar, Shagun Sodhani, Xiaomeng Yang, Gianni De Fabritiis, and Vincent Moens. TorchRL: A data-driven decision-making library for pytorch. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=QxItoEAVMb.
  • D’Oro et al. (2023) Pierluca D’Oro, Max Schwarzer, Evgenii Nikishin, Pierre-Luc Bacon, Marc G Bellemare, and Aaron Courville. Sample-efficient reinforcement learning by breaking the replay ratio barrier. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=OpC-9aBBVJe.
  • Engstrom et al. (2020) Logan Engstrom, Andrew Ilyas, Shibani Santurkar, Dimitris Tsipras, Firdaus Janoos, Larry Rudolph, and Aleksander Madry. Implementation matters in deep rl: A case study on ppo and trpo. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum?id=r1etN1rtPB.
  • Gulcehre et al. (2022) Caglar Gulcehre, Srivatsan Srinivasan, Jakub Sygnowski, Georg Ostrovski, Mehrdad Farajtabar, Matthew Hoffman, Razvan Pascanu, and Arnaud Doucet. An empirical study of implicit regularization in deep offline RL. Transactions on Machine Learning Research, 2022. ISSN 2835-8856. URL https://openreview.net/forum?id=HFfJWx60IT.
  • Haarnoja et al. (2018) Tuomas Haarnoja, Aurick Zhou, Kristian Hartikainen, George Tucker, Sehoon Ha, Jie Tan, Vikash Kumar, Henry Zhu, Abhishek Gupta, Pieter Abbeel, et al. Soft actor-critic algorithms and applications. arXiv preprint arXiv:1812.05905, 2018.
  • Huang et al. (2022a) Shengyi Huang, Rousslan Fernand Julien Dossa, Antonin Raffin, Anssi Kanervisto, and Weixun Wang. The 37 implementation details of proximal policy optimization. The ICLR Blog Track 2023, 2022a.
  • Huang et al. (2022b) Shengyi Huang, Rousslan Fernand Julien Dossa, Chang Ye, Jeff Braga, Dipam Chakraborty, Kinal Mehta, and João G.M. Araújo. Cleanrl: High-quality single-file implementations of deep reinforcement learning algorithms. Journal of Machine Learning Research, 23(274):1–18, 2022b. URL http://jmlr.org/papers/v23/21-1342.html.
  • Huh et al. (2023) Minyoung Huh, Hossein Mobahi, Richard Zhang, Brian Cheung, Pulkit Agrawal, and Phillip Isola. The low-rank simplicity bias in deep networks. Transactions on Machine Learning Research, 2023. ISSN 2835-8856. URL https://openreview.net/forum?id=bCiNWDmlY2.
  • Igl et al. (2021) Maximilian Igl, Gregory Farquhar, Jelena Luketina, Wendelin Boehmer, and Shimon Whiteson. Transient non-stationarity and generalisation in deep reinforcement learning. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=Qun8fv4qSby.
  • Kakade & Langford (2002) Sham Kakade and John Langford. Approximately optimal approximate reinforcement learning. In Proceedings of the Nineteenth International Conference on Machine Learning, ICML ’02, pp.  267–274, San Francisco, CA, USA, 2002. Morgan Kaufmann Publishers Inc. ISBN 1558608737.
  • Kendall (1938) M. G. Kendall. A new measure of rank correlation. Biometrika, 30(1/2):81–93, 1938. ISSN 00063444. URL http://www.jstor.org/stable/2332226.
  • Kumar et al. (2021) Aviral Kumar, Rishabh Agarwal, Dibya Ghosh, and Sergey Levine. Implicit under-parameterization inhibits data-efficient deep reinforcement learning. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=O9bnihsFfXU.
  • Kumar et al. (2022) Aviral Kumar, Rishabh Agarwal, Tengyu Ma, Aaron Courville, George Tucker, and Sergey Levine. DR3: Value-based deep reinforcement learning requires explicit regularization. In International Conference on Learning Representations, 2022. URL https://openreview.net/forum?id=POvMvLi91f.
  • Lyle et al. (2022) Clare Lyle, Mark Rowland, and Will Dabney. Understanding and preventing capacity loss in reinforcement learning. In International Conference on Learning Representations, 2022. URL https://openreview.net/forum?id=ZkC8wKoLbQ7.
  • Lyle et al. (2023) Clare Lyle, Zeyu Zheng, Evgenii Nikishin, Bernardo Avila Pires, Razvan Pascanu, and Will Dabney. Understanding plasticity in neural networks. In Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett (eds.), Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pp.  23190–23211. PMLR, 23–29 Jul 2023. URL https://proceedings.mlr.press/v202/lyle23b.html.
  • Lyle et al. (2024) Clare Lyle, Zeyu Zheng, Khimya Khetarpal, Hado van Hasselt, Razvan Pascanu, James Martens, and Will Dabney. Disentangling the causes of plasticity loss in neural networks. arXiv preprint arXiv:2402.18762, 2024.
  • Machado et al. (2018) Marlos C Machado, Marc G Bellemare, Erik Talvitie, Joel Veness, Matthew Hausknecht, and Michael Bowling. Revisiting the arcade learning environment: Evaluation protocols and open problems for general agents. Journal of Artificial Intelligence Research, 61:523–562, 2018.
  • Mnih et al. (2015) Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Andrei A Rusu, Joel Veness, Marc G Bellemare, Alex Graves, Martin Riedmiller, Andreas K Fidjeland, Georg Ostrovski, et al. Human-level control through deep reinforcement learning. nature, 518(7540):529–533, 2015.
  • Mnih et al. (2016) Volodymyr Mnih, Adria Puigdomenech Badia, Mehdi Mirza, Alex Graves, Timothy Lillicrap, Tim Harley, David Silver, and Koray Kavukcuoglu. Asynchronous methods for deep reinforcement learning. In Maria Florina Balcan and Kilian Q. Weinberger (eds.), Proceedings of The 33rd International Conference on Machine Learning, volume 48 of Proceedings of Machine Learning Research, pp.  1928–1937, New York, New York, USA, 20–22 Jun 2016. PMLR. URL https://proceedings.mlr.press/v48/mniha16.html.
  • Nair & Hinton (2010) Vinod Nair and Geoffrey E Hinton. Rectified linear units improve restricted boltzmann machines. In Proceedings of the 27th international conference on machine learning (ICML-10), pp.  807–814, 2010.
  • Nikishin et al. (2022) Evgenii Nikishin, Max Schwarzer, Pierluca D’Oro, Pierre-Luc Bacon, and Aaron Courville. The primacy bias in deep reinforcement learning. In Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvari, Gang Niu, and Sivan Sabato (eds.), Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pp.  16828–16847. PMLR, 17–23 Jul 2022. URL https://proceedings.mlr.press/v162/nikishin22a.html.
  • Nikishin et al. (2023) Evgenii Nikishin, Junhyuk Oh, Georg Ostrovski, Clare Lyle, Razvan Pascanu, Will Dabney, and Andre Barreto. Deep reinforcement learning with plasticity injection. In A. Oh, T. Neumann, A. Globerson, K. Saenko, M. Hardt, and S. Levine (eds.), Advances in Neural Information Processing Systems, volume 36, pp.  37142–37159. Curran Associates, Inc., 2023. URL https://proceedings.neurips.cc/paper_files/paper/2023/file/75101364dc3aa7772d27528ea504472b-Paper-Conference.pdf.
  • Nota & Thomas (2020) Chris Nota and Philip S. Thomas. Is the policy gradient a gradient? In Proceedings of the 19th International Conference on Autonomous Agents and MultiAgent Systems, AAMAS ’20, pp.  939–947, Richland, SC, 2020. International Foundation for Autonomous Agents and Multiagent Systems. ISBN 9781450375184.
  • Pardo et al. (2018) Fabio Pardo, Arash Tavakoli, Vitaly Levdik, and Petar Kormushev. Time limits in reinforcement learning. In Jennifer Dy and Andreas Krause (eds.), Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, pp.  4045–4054. PMLR, 10–15 Jul 2018. URL https://proceedings.mlr.press/v80/pardo18a.html.
  • Press et al. (2007) William H. Press, Saul A. Teukolsky, William T. Vetterling, and Brian P. Flannery. Numerical Recipes 3rd Edition: The Art of Scientific Computing. Cambridge University Press, USA, 3 edition, 2007. ISBN 0521880688.
  • Raffin et al. (2021) Antonin Raffin, Ashley Hill, Adam Gleave, Anssi Kanervisto, Maximilian Ernestus, and Noah Dormann. Stable-baselines3: Reliable reinforcement learning implementations. Journal of Machine Learning Research, 22(268):1–8, 2021. URL http://jmlr.org/papers/v22/20-1364.html.
  • Roy & Vetterli (2007) Olivier Roy and Martin Vetterli. The effective rank: A measure of effective dimensionality. In 2007 15th European Signal Processing Conference, pp.  606–610, 2007.
  • Schulman et al. (2015a) John Schulman, Sergey Levine, Pieter Abbeel, Michael Jordan, and Philipp Moritz. Trust region policy optimization. In Francis Bach and David Blei (eds.), Proceedings of the 32nd International Conference on Machine Learning, volume 37 of Proceedings of Machine Learning Research, pp.  1889–1897, Lille, France, 07–09 Jul 2015a. PMLR. URL https://proceedings.mlr.press/v37/schulman15.html.
  • Schulman et al. (2015b) John Schulman, Philipp Moritz, Sergey Levine, Michael Jordan, and Pieter Abbeel. High-dimensional continuous control using generalized advantage estimation. arXiv preprint arXiv:1506.02438, 2015b.
  • Schulman et al. (2017) John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, and Oleg Klimov. Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347, 2017.
  • Spearman (1987) C. Spearman. The proof and measurement of association between two things. The American Journal of Psychology, 100(3/4):441–471, 1987. ISSN 00029556. URL http://www.jstor.org/stable/1422689.
  • Sun et al. (2022) Mingfei Sun, Vitaly Kurin, Guoqing Liu, Sam Devlin, Tao Qin, Katja Hofmann, and Shimon Whiteson. You may not need ratio clipping in ppo. arXiv preprint arXiv:2202.00079, 2022.
  • Sutton & Barto (2018) Richard S Sutton and Andrew G Barto. Reinforcement learning: An introduction. MIT press, 2018.
  • Sutton et al. (1999) Richard S Sutton, David McAllester, Satinder Singh, and Yishay Mansour. Policy gradient methods for reinforcement learning with function approximation. In S. Solla, T. Leen, and K. Müller (eds.), Advances in Neural Information Processing Systems, volume 12. MIT Press, 1999. URL https://proceedings.neurips.cc/paper_files/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf.
  • Szepesvári (2022) Csaba Szepesvári. Algorithms for reinforcement learning. Springer Nature, 2022.
  • Todorov et al. (2012) Emanuel Todorov, Tom Erez, and Yuval Tassa. Mujoco: A physics engine for model-based control. In 2012 IEEE/RSJ International Conference on Intelligent Robots and Systems, pp.  5026–5033. IEEE, 2012. doi: 10.1109/IROS.2012.6386109.
  • Towers et al. (2023) Mark Towers, Jordan K. Terry, Ariel Kwiatkowski, John U. Balis, Gianluca de Cola, Tristan Deleu, Manuel Goulão, Andreas Kallinteris, Arjun KG, Markus Krimmel, Rodrigo Perez-Vicente, Andrea Pierré, Sander Schulhoff, Jun Jet Tai, Andrew Tan Jin Shen, and Omar G. Younis. Gymnasium, March 2023. URL https://zenodo.org/record/8127025.
  • Wang et al. (2020) Yuhui Wang, Hao He, and Xiaoyang Tan. Truly proximal policy optimization. In Ryan P. Adams and Vibhav Gogate (eds.), Proceedings of The 35th Uncertainty in Artificial Intelligence Conference, volume 115 of Proceedings of Machine Learning Research, pp.  113–122. PMLR, 22–25 Jul 2020. URL https://proceedings.mlr.press/v115/wang20b.html.
  • Yang et al. (2020) Yuzhe Yang, Guo Zhang, Zhi Xu, and Dina Katabi. Harnessing structures for value-based planning and reinforcement learning. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum?id=rklHqRVKvH.

Appendix A Additional background

A.1 Reinforcement Learning

The undiscounted formulation presented in the background (Section 2) has also been used by Schulman et al. (2015b) and does not limit the use of a discount factor to discount future rewards; for that purpose, as we consider a finite-horizon setting, we can assume that discounting would already be present in the reward which depends on time through the state. This allows to isolate the discount factor γ𝛾\gammaitalic_γ for the purpose of the value estimation with GAE which serves as a trade-off between the bias and the variance in the estimator, in addition to λ𝜆\lambdaitalic_λ used for the λ𝜆\lambdaitalic_λ-returns that combine multiple n𝑛nitalic_n-step returns. More importantly, this also allows us to reuse the policy gradient and PPO losses without discount factors, as the deep RL community is used to them while avoiding their incorrect use in the discounted setting as pointed out by Nota & Thomas (2020). In any case, our results can also be translated to the discounted setting using a biased gradient estimator (missing a discount factor), being the typical setting considered in deep RL works.

Appendix B Experiment details

B.1 Code and run histories

Our code is available at https://github.com/CLAIRE-Labo/no-representation-no-trust. It includes the development environment distributed as a Docker image for GPU-accelerated machines and a Conda environment for MPS-accelerated machines, the training code, scripts to run all the experiments, and the notebook that generated the plots. The codebase uses TorchRL (Bou et al., 2024) and provides a comprehensive toolbox to study representation dynamics in policy optimization. We also provide modified scripts of CleanRL (Huang et al., 2022b) to replicate the collapse observed in this work and ensure it is not a bug from our novel codebase.

The code repository contains links to the Weights&Biases (W&B) project with all of our run histories, a summary W&B report of the runs, and a W&B report with the replication with CleanRL.

Runs are fully reproducible on the same acceleration device on which they were run.

B.2 Additional details on our experimental setup

Environment
Repeat action probability (Sticky actions) 0.25
Frameskip 3
Max environment steps per episode 108,000
Noop reset steps 0
Observation transforms
Grayscale True
Resize width (‘resize_w‘) 84
Resize height (‘resize_h‘) 84
Frame stack 4
Normalize observations False
Reward transforms
Sign True
Collector
Total environment steps 100,000,000
Num envs in parallel 8
Num envs in parallel plasticity 1
Agent steps per batch 10,24 (128 per env)
Total agent steps plasticity 36,000 (at least one full episode)
Models (actor and critic)
Activation ReLU
Convolutional Layers
Filters [32, 64, 64]
Kernel sizes [8, 4, 3]
Strides [4, 2, 1]
Linear Layers
Number of layers 1
Layer size 512
Optimization
Advantage estimator
Advantage estimator GAE
Gamma 0.99
Lambda 0.95
Value loss
Value loss coefficient 0.5
Loss type L2
Policy loss
Normalize advantages minibatch normalization
Clipping epsilon 0.1
Entropy coefficient 0.01
Feature regularization coefficient 1 (last pre-activation), 10 (all pre-activations)
Optimizer (actor and critic)
Optimizer Adam
Learning rate 0.00025
Max grad norm 0.5
Annealing linearly False
Number of epochs 4, 6, 8
Number of epochs plasticity fit 1
Minibatch size 256
Logging (% of the total number of batches)
Training every 0.1% (~100,000 env steps)
Plasticity every 2.5% (41 times in total)
Table 1: Hyperparameters for ALE.
Environment
Frameskip 1
Max env steps per episode 1,000
Noop reset steps 0
Observation transforms
Normalize observations True (from initial steps collected by uniform policy)
Initial random steps for normalization 4000 (at least 4 episodes)
Collector
Total environment steps 5,000,000
Num envs in parallel 2
Num envs in parallel plasticity 4
Agent steps per batch 2048 (1024 per env)
Total environment steps plasticity 4,000 (at least 4 full episodes)
Models (actor and critic)
Activation Tanh, ReLU
Convolutional layers
Number of Layers 0
Linear layers
Number of layers 2
Layer size 64
Optimization
Advantage estimator
Advantage estimator GAE
Gamma 0.99
Lambda 0.95
Value loss
Value coefficient 0.5
Loss type L2
Policy loss
Normalize advantages minibatch normalization
Clipping epsilon (PPO-Clip) 0.2
Entropy coefficient 0.0
Feature regularization coefficient 1 (tanh), 10 (ReLU)
Optimizer (actor and critic)
Optimizer Adam
Learning rate 0.0003
Max grad norm 0.5
Annealing linearly False
Number of epochs 10, 15, 20
Number of epochs plasticity fit 4
Minibatch size 64
Logging (% of the total number of batches)
Training every 0.1% (6,144 env steps)
Plasticity every 2.5% (41 times in total)
Table 2: Hyperparameter for MuJoCo.

We conduct experiments on an environment with pixel-based observations and discrete actions and an environment with continuous observations and actions, each with a different model architecture. For the discrete action case, we use the Arcade Learning Environment (ALE)(Bellemare et al., 2013) with the specification recommended by Machado et al. (2018) in v5 on Gymnasium (Towers et al., 2023). That is, with a sticky action probability of 0.25 as the only form of environment stochasticity, using only the game-over signal for termination (as opposed to end-of-life signals) with the default maximum of 108×103108superscript103108\times 10^{3}108 × 10 start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT environment frames per episode and reporting performance over training episodes (i.e., with sampling according to the policy distribution as opposed to taking the mode action). We train all models for 100 million environment frames. We use standard algorithmic choices to make our setting and results relevant to previous work. This includes taking only the sign of rewards (clipping) and frame skipping. We use a frame skip of 3, as opposed to the standard value of 4, due to limitations in the ALE-v5 environment, which does not implement frame pooling.555That is taking the max over the last two skipped and unskipped frames to capture elements that only appear in even or odd frames of the game (https://github.com/Farama-Foundation/Arcade-Learning-Environment/issues/467). Using an odd frame skip value alleviates the issue. We use the standard architecture of Mnih et al. (2015) consisting of convolutional layers followed by linear layers, all with ReLU activations, and no normalization layers. We also use Mnih et al. (2015)’s standard observation transformations with a resizing to 84x84, grayscaling, and a frame stacking of 4.

For the continuous case, we use MuJoCo (Todorov et al., 2012) with v4 on Gymnasium (Towers et al., 2023) with the default maximum of 1,000 environment frames to mark episode termination. Similarly to Atari, we report performance as the average episode return over training episodes. We train all models for 5 million environment frames. We standardize the observations (subtract mean and divide by standard deviation) according to an initial rollout of 4,000 environment steps (at least four episodes). The standardization parameters are kept the same to avoid adding non-stationarity. We use the same architecture as Schulman et al. (2017), with only linear layers, tanh activations, and no normalization layers. We also experiment with ReLU activations. The network outputs a mean and a standard deviation (with softplus), both conditioning on the observation independently for each action dimension, which are then used to create a TanhNormal distribution, similarly to Haarnoja et al. (2018).

To measure the plasticity loss of a checkpoint, we use the same optimization hyperparameters used to train the checkpoint, i.e. the same batch size and learning rate. The dataset sizes and fitting budgets for plasticity are listed in Tables 1 and 2.

We provide a high-level pseudocode for PPO in Algorithm 1 and list all hyperparameters considered in Tables 1 and 2.

Algorithm 1 High-level Pseudocode for PPO
1:N𝑁Nitalic_N: number of environments in parallel.
2:Benvsubscript𝐵envB_{\text{env}}italic_B start_POSTSUBSCRIPT env end_POSTSUBSCRIPT: agent steps per environment to collect in a batch.
3:K𝐾Kitalic_K: number of optimization epochs per batch.
4:
5:LπoldCLIP(𝜽)superscriptsubscript𝐿subscript𝜋oldCLIP𝜽L_{\pi_{\text{old}}}^{\text{CLIP}}({\bm{\theta}})italic_L start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT CLIP end_POSTSUPERSCRIPT ( bold_italic_θ ): PPO-Clip objective.
6:H(𝜽)𝐻𝜽H({\bm{\theta}})italic_H ( bold_italic_θ ): entropy bonus/loss; cHsubscript𝑐𝐻c_{H}italic_c start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT: entropy bonus coefficient.
7:LVF(𝐰)superscript𝐿𝑉𝐹𝐰L^{VF}({\bf w})italic_L start_POSTSUPERSCRIPT italic_V italic_F end_POSTSUPERSCRIPT ( bold_w ): critic loss (L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT to GAE); cVFsubscript𝑐𝑉𝐹c_{VF}italic_c start_POSTSUBSCRIPT italic_V italic_F end_POSTSUBSCRIPT: critic loss coefficient.
8:
9:while collected environment steps \leq total environment steps do
10:Collect a batch of interaction steps of size B=N×Benv𝐵𝑁subscript𝐵envB=N\times B_{\text{env}}italic_B = italic_N × italic_B start_POSTSUBSCRIPT env end_POSTSUBSCRIPT and computes advantages.
11:     for actor=1actor1\text{actor}=1actor = 1 to N𝑁Nitalic_N do
12:         Run policy πoldsubscript𝜋old\pi_{\text{old}}italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT in environment for Benvsubscript𝐵envB_{\text{env}}italic_B start_POSTSUBSCRIPT env end_POSTSUBSCRIPT agent steps.
13:         Compute advantage estimates Ψ1actor,,ΨBenvactorsuperscriptsubscriptΨ1actorsuperscriptsubscriptΨsubscript𝐵envactor\Psi_{1}^{\text{actor}},\ldots,\Psi_{B_{\text{env}}}^{\text{actor}}roman_Ψ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT actor end_POSTSUPERSCRIPT , … , roman_Ψ start_POSTSUBSCRIPT italic_B start_POSTSUBSCRIPT env end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT actor end_POSTSUPERSCRIPT with GAE.
14:     end for
15:     Minimize overall policy and value loss (LπoldCLIP(𝜽)cHH(𝜽)+cVFLVF(𝐰))superscriptsubscript𝐿subscript𝜋oldCLIP𝜽subscript𝑐𝐻𝐻𝜽subscript𝑐𝑉𝐹superscript𝐿𝑉𝐹𝐰(-L_{\pi_{\text{old}}}^{\text{CLIP}}({\bm{\theta}})-c_{H}H({\bm{\theta}})+c_{% VF}L^{VF}({\bf w}))( - italic_L start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT CLIP end_POSTSUPERSCRIPT ( bold_italic_θ ) - italic_c start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT italic_H ( bold_italic_θ ) + italic_c start_POSTSUBSCRIPT italic_V italic_F end_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT italic_V italic_F end_POSTSUPERSCRIPT ( bold_w ) ) with autograd on the on the collected batch over K𝐾Kitalic_K epochs with minibatch size MB𝑀𝐵M\leq Bitalic_M ≤ italic_B.
16:     πoldπ𝜽subscript𝜋oldsubscript𝜋𝜽\pi_{\text{old}}\leftarrow\pi_{\bm{\theta}}italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ← italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT
17:end while
Proximal Feature Regularization

With a coefficient cPFOsubscript𝑐𝑃𝐹𝑂c_{PFO}italic_c start_POSTSUBSCRIPT italic_P italic_F italic_O end_POSTSUBSCRIPT, the PFO loss is added to the overall loss (LπoldCLIP(𝜽)+cPFOLπoldPFO(𝜽)cHH(𝜽)+cVFLVF(𝐰))superscriptsubscript𝐿subscript𝜋oldCLIP𝜽subscript𝑐𝑃𝐹𝑂superscriptsubscript𝐿subscript𝜋old𝑃𝐹𝑂𝜽subscript𝑐𝐻𝐻𝜽subscript𝑐𝑉𝐹superscript𝐿𝑉𝐹𝐰(-L_{\pi_{\text{old}}}^{\text{CLIP}}({\bm{\theta}})+c_{PFO}L_{\pi_{\text{old}}% }^{PFO}({\bm{\theta}})-c_{H}H({\bm{\theta}})+c_{VF}L^{VF}({\bf w}))( - italic_L start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT CLIP end_POSTSUPERSCRIPT ( bold_italic_θ ) + italic_c start_POSTSUBSCRIPT italic_P italic_F italic_O end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_P italic_F italic_O end_POSTSUPERSCRIPT ( bold_italic_θ ) - italic_c start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT italic_H ( bold_italic_θ ) + italic_c start_POSTSUBSCRIPT italic_V italic_F end_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT italic_V italic_F end_POSTSUPERSCRIPT ( bold_w ) ) optimized with autograd over multiple mini-batch epochs.

B.3 Additional details on metrics used in the figures

Figure 4

A window of size 1% of training progress represents approximately 1 million training steps on ALE and 50,000 training steps on MuJoCo We average the metrics per window and then take the 20 windows with the lowest average probability ratios below 1ϵ1italic-ϵ1-\epsilon1 - italic_ϵ. The probability ratios in a run can be trivially within the 1ϵ1italic-ϵ1-\epsilon1 - italic_ϵ region after the model collapses, resulting in less than 20 points if the model collapses before 20% of the training progress. When all runs give 20 points, we can observe 300 points in total per scatter plot.

Figure 6

A window of size 5% of training progress represents approximately five million training steps in ALE and captures at least five episodes per environment so in total at least 40 episodes. For MuJoCo this represents approximately 256,000 training steps and captures at least 128 episodes per environment so in total at least 256 episodes.

When a model collapses, it typically doesn’t change anymore so its optimization trivially gives ratios within the clipping limits (no value above 1+ϵ1italic-ϵ1+\epsilon1 + italic_ϵ and below 1ϵ1italic-ϵ1-\epsilon1 - italic_ϵ is logged). In that case, we are more interested in the evolution of the excess ratio before the ratios become trivial. Therefore, the upper limit of the 5% of training progress is taken such that it is the latest timestep where there are at least 10 non-trivial ratios, i.e. 10 logged excess ratios. This coincides with a window before the collapse of the model capturing the values we are interested in. Note that when a model collapses this window may not coincide with the window used to report the other metrics such as the average return, however, these other metrics typically do not change after a collapse, so it is more robust to capture them at the end of training rather than looking for an arbitrary window after the collapse. We give training curves similar to Figure 1 with the interventions performed.

In MuJoCo, with continuous action distributions the ratios diverge to infinity and 0 before collapse therefore to get meaningful plots, we clip average probability ratios above 1+ϵ1italic-ϵ1+\epsilon1 + italic_ϵ and below 1ϵ1italic-ϵ1-\epsilon1 - italic_ϵ to 1012superscript101210^{12}10 start_POSTSUPERSCRIPT 12 end_POSTSUPERSCRIPT and 1012superscript101210^{-12}10 start_POSTSUPERSCRIPT - 12 end_POSTSUPERSCRIPT, respectively, before computing the average excess ratio.

We group the different epoch configurations of an intervention on the same environment, giving 15 runs per boxplot (three epochs with five seeds each).

B.4 Hardware and runtime

The experiments in this project took a total of 10,368 hours between NVIDIA V100 and A100 GPUs (ALE) and CPUs (MuJoCo). A run on ALE takes around 12 hours on an A100 and 24 hours on a V100. A run on MuJoCo takes around 5 hours on CPUs.

Appendix C Toy setting derivation details

The derivatives of the softmax probability π𝜽(a1|s)subscript𝜋𝜽conditionalsubscript𝑎1𝑠\pi_{\bm{\theta}}(a_{1}|s)italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_s ) with respect to θ1subscript𝜃1\theta_{1}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and θ2subscript𝜃2\theta_{2}italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT are as follows:

π𝜽(a1|s)θ1=θ1(eθ1ϕ(s)eθ1ϕ(s)+eθ2ϕ(s))=ϕ(s)eθ1ϕ(s)eθ2ϕ(s)(eθ1ϕ(s)+eθ2ϕ(s))2subscript𝜋𝜽conditionalsubscript𝑎1𝑠subscript𝜃1subscript𝜃1superscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃2italic-ϕ𝑠italic-ϕ𝑠superscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃2italic-ϕ𝑠superscriptsuperscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃2italic-ϕ𝑠2\frac{\partial\pi_{\bm{\theta}}(a_{1}|s)}{\partial\theta_{1}}=\frac{\partial}{% \partial\theta_{1}}\left(\frac{e^{\theta_{1}\phi(s)}}{e^{\theta_{1}\phi(s)}+e^% {\theta_{2}\phi(s)}}\right)=\phi(s)\cdot\frac{e^{\theta_{1}\phi(s)}\cdot e^{% \theta_{2}\phi(s)}}{(e^{\theta_{1}\phi(s)}+e^{\theta_{2}\phi(s)})^{2}}divide start_ARG ∂ italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_s ) end_ARG start_ARG ∂ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG = divide start_ARG ∂ end_ARG start_ARG ∂ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ( divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT end_ARG ) = italic_ϕ ( italic_s ) ⋅ divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT ⋅ italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT end_ARG start_ARG ( italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG (3)
π𝜽(a1|s)θ2=θ2(eθ1ϕ(s)eθ1ϕ(s)+eθ2ϕ(s))=ϕ(s)eθ1ϕ(s)eθ2ϕ(s)(eθ1ϕ(s)+eθ2ϕ(s))2subscript𝜋𝜽conditionalsubscript𝑎1𝑠subscript𝜃2subscript𝜃2superscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃2italic-ϕ𝑠italic-ϕ𝑠superscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃2italic-ϕ𝑠superscriptsuperscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃2italic-ϕ𝑠2\frac{\partial\pi_{\bm{\theta}}(a_{1}|s)}{\partial\theta_{2}}=\frac{\partial}{% \partial\theta_{2}}\left(\frac{e^{\theta_{1}\phi(s)}}{e^{\theta_{1}\phi(s)}+e^% {\theta_{2}\phi(s)}}\right)=-\phi(s)\cdot\frac{e^{\theta_{1}\phi(s)}\cdot e^{% \theta_{2}\phi(s)}}{(e^{\theta_{1}\phi(s)}+e^{\theta_{2}\phi(s)})^{2}}divide start_ARG ∂ italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_s ) end_ARG start_ARG ∂ italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG = divide start_ARG ∂ end_ARG start_ARG ∂ italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ( divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT end_ARG ) = - italic_ϕ ( italic_s ) ⋅ divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT ⋅ italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT end_ARG start_ARG ( italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG (4)

The update rule for each parameter θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in θ𝜃\thetaitalic_θ with SGD is θinew=θi+ηLθisuperscriptsubscript𝜃𝑖newsubscript𝜃𝑖𝜂𝐿subscript𝜃𝑖\theta_{i}^{\text{new}}=\theta_{i}+\eta\frac{\partial L}{\partial\theta_{i}}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT = italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_η divide start_ARG ∂ italic_L end_ARG start_ARG ∂ italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG where η𝜂\etaitalic_η is the learning rate. Therefore, given the partial derivatives, the updated values for θ1subscript𝜃1\theta_{1}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and θ2subscript𝜃2\theta_{2}italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT after taking a gradient step are (if the probability is still inferior to 1 + ϵitalic-ϵ\epsilonitalic_ϵ, otherwise the gradient is 0)

θ1new=θ1+ηA(s,a1)πold(ai|s)(ϕ(s)eθ1ϕ(s)eθ2ϕ(s)(eθ1ϕ(s)+eθ2s)2)andθ2new=θ2ηA(s,a1)πold(ai|s)(ϕ(s)eθ1ϕ(s)eθ2ϕ(s)(eθ1ϕ(s)+eθ2ϕ(s))2)formulae-sequencesuperscriptsubscript𝜃1newsubscript𝜃1𝜂𝐴𝑠subscript𝑎1subscript𝜋oldconditionalsubscript𝑎𝑖𝑠italic-ϕ𝑠superscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃2italic-ϕ𝑠superscriptsuperscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃2𝑠2andsuperscriptsubscript𝜃2newsubscript𝜃2𝜂𝐴𝑠subscript𝑎1subscript𝜋oldconditionalsubscript𝑎𝑖𝑠italic-ϕ𝑠superscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃2italic-ϕ𝑠superscriptsuperscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃2italic-ϕ𝑠2\theta_{1}^{\text{new}}=\theta_{1}+\eta\cdot\frac{A(s,a_{1})}{\pi_{\text{old}}% (a_{i}|s)}\cdot\left(\phi(s)\cdot\frac{e^{\theta_{1}\phi(s)}\cdot e^{\theta_{2% }\phi(s)}}{(e^{\theta_{1}\phi(s)}+e^{\theta_{2}s})^{2}}\right)\quad\text{and}% \quad\theta_{2}^{\text{new}}=\theta_{2}-\eta\cdot\frac{A(s,a_{1})}{\pi_{\text{% old}}(a_{i}|s)}\cdot\left(\phi(s)\cdot\frac{e^{\theta_{1}\phi(s)}\cdot e^{% \theta_{2}\phi(s)}}{(e^{\theta_{1}\phi(s)}+e^{\theta_{2}\phi(s)})^{2}}\right)italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT = italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_η ⋅ divide start_ARG italic_A ( italic_s , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_s ) end_ARG ⋅ ( italic_ϕ ( italic_s ) ⋅ divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT ⋅ italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT end_ARG start_ARG ( italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_s end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) and italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT = italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_η ⋅ divide start_ARG italic_A ( italic_s , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_s ) end_ARG ⋅ ( italic_ϕ ( italic_s ) ⋅ divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT ⋅ italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT end_ARG start_ARG ( italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG )

Hence,

θ1newsuperscriptsubscript𝜃1new\displaystyle\theta_{1}^{\text{new}}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT =θ1+δswith δs=ηA(s,a1)πold(ai|s)(ϕ(s)eθ1ϕ(s)eθ2ϕ(s)(eθ1ϕ(s)+eθ2ϕ(s))2)formulae-sequenceabsentsubscript𝜃1subscript𝛿𝑠with subscript𝛿𝑠𝜂𝐴𝑠subscript𝑎1subscript𝜋oldconditionalsubscript𝑎𝑖𝑠italic-ϕ𝑠superscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃2italic-ϕ𝑠superscriptsuperscript𝑒subscript𝜃1italic-ϕ𝑠superscript𝑒subscript𝜃2italic-ϕ𝑠2\displaystyle=\theta_{1}+\delta_{s}\quad\text{with }\delta_{s}=\eta\cdot\frac{% A(s,a_{1})}{\pi_{\text{old}}(a_{i}|s)}\cdot\left(\phi(s)\cdot\frac{e^{\theta_{% 1}\phi(s)}\cdot e^{\theta_{2}\phi(s)}}{(e^{\theta_{1}\phi(s)}+e^{\theta_{2}% \phi(s)})^{2}}\right)= italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT with italic_δ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT = italic_η ⋅ divide start_ARG italic_A ( italic_s , italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG italic_π start_POSTSUBSCRIPT old end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_s ) end_ARG ⋅ ( italic_ϕ ( italic_s ) ⋅ divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT ⋅ italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT end_ARG start_ARG ( italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_s ) end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG )
θ2newsuperscriptsubscript𝜃2new\displaystyle\theta_{2}^{\text{new}}italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT =θ2δsabsentsubscript𝜃2subscript𝛿𝑠\displaystyle=\theta_{2}-\delta_{s}= italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_δ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT

Let α0𝛼0\alpha\geq 0italic_α ≥ 0 and without loss of generality, let’s take α1𝛼1\alpha\geq 1italic_α ≥ 1. After a gradient step on x𝑥xitalic_x one has

π𝜽new(a1|x)subscript𝜋superscript𝜽newconditionalsubscript𝑎1𝑥\displaystyle\pi_{{\bm{\theta}}^{\text{new}}}(a_{1}|x)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ) =eθ1newϕ(x)eθ1newϕ(x)+eθ2newϕ(x)absentsuperscript𝑒superscriptsubscript𝜃1newitalic-ϕ𝑥superscript𝑒superscriptsubscript𝜃1newitalic-ϕ𝑥superscript𝑒superscriptsubscript𝜃2newitalic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}^{\text{new}}\phi(x)}}{e^{\theta_{1}^{\text{% new}}\phi(x)}+e^{\theta_{2}^{\text{new}}\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=e(θ1+δx)ϕ(x)e(θ1+δx)ϕ(x)+e(θ2δx)ϕ(x)absentsuperscript𝑒subscript𝜃1subscript𝛿𝑥italic-ϕ𝑥superscript𝑒subscript𝜃1subscript𝛿𝑥italic-ϕ𝑥superscript𝑒subscript𝜃2subscript𝛿𝑥italic-ϕ𝑥\displaystyle=\frac{e^{(\theta_{1}+\delta_{x})\phi(x)}}{e^{(\theta_{1}+\delta_% {x})\phi(x)}+e^{(\theta_{2}-\delta_{x})\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1ϕ(x)eθ1ϕ(x)+e(θ22δx)ϕ(x)absentsuperscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃22subscript𝛿𝑥italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\phi(x)}}{e^{\theta_{1}\phi(x)}+e^{(\theta_{2% }-2\delta_{x})\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - 2 italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1ϕ(x)eθ1ϕ(x)+eθ2ϕ(x)2δxϕ(x)absentsuperscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃2italic-ϕ𝑥2subscript𝛿𝑥italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\phi(x)}}{e^{\theta_{1}\phi(x)}+e^{\theta_{2}% \phi(x)-2\delta_{x}\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) - 2 italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
eθ1ϕ(x)eθ1ϕ(x)+eθ2ϕ(x)(since 2δxϕ(x)0)formulae-sequenceabsentsuperscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃2italic-ϕ𝑥(since 2subscript𝛿𝑥italic-ϕ𝑥0)\displaystyle\geq\frac{e^{\theta_{1}\phi(x)}}{e^{\theta_{1}\phi(x)}+e^{\theta_% {2}\phi(x)}}\quad\text{(since }-2\delta_{x}\phi(x)\leq 0\text{)}≥ divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG (since - 2 italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_ϕ ( italic_x ) ≤ 0 )
=π𝜽(a1|x)absentsubscript𝜋𝜽conditionalsubscript𝑎1𝑥\displaystyle=\pi_{\bm{\theta}}(a_{1}|x)= italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x )
π𝜽new(a1|y)subscript𝜋superscript𝜽newconditionalsubscript𝑎1𝑦\displaystyle\pi_{{\bm{\theta}}^{\text{new}}}(a_{1}|y)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_y ) =eθ1newαϕ(x)eθ1newαϕ(x)+eθ2newαϕ(x)absentsuperscript𝑒superscriptsubscript𝜃1new𝛼italic-ϕ𝑥superscript𝑒superscriptsubscript𝜃1new𝛼italic-ϕ𝑥superscript𝑒superscriptsubscript𝜃2new𝛼italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}^{\text{new}}\alpha\phi(x)}}{e^{\theta_{1}^{% \text{new}}\alpha\phi(x)}+e^{\theta_{2}^{\text{new}}\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=e(θ1+δx)αϕ(x)e(θ1+δx)αϕ(x)+e(θ2δx)αϕ(x)absentsuperscript𝑒subscript𝜃1subscript𝛿𝑥𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1subscript𝛿𝑥𝛼italic-ϕ𝑥superscript𝑒subscript𝜃2subscript𝛿𝑥𝛼italic-ϕ𝑥\displaystyle=\frac{e^{(\theta_{1}+\delta_{x})\alpha\phi(x)}}{e^{(\theta_{1}+% \delta_{x})\alpha\phi(x)}+e^{(\theta_{2}-\delta_{x})\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1αϕ(x)eθ1αϕ(x)+e(θ22δx)αϕ(x)absentsuperscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃22subscript𝛿𝑥𝛼italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\alpha\phi(x)}}{e^{\theta_{1}\alpha\phi(x)}+e% ^{(\theta_{2}-2\delta_{x})\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - 2 italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1αϕ(x)eθ1αϕ(x)+eθ2αϕ(x)2δxαϕ(x)absentsuperscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃2𝛼italic-ϕ𝑥2subscript𝛿𝑥𝛼italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\alpha\phi(x)}}{e^{\theta_{1}\alpha\phi(x)}+e% ^{\theta_{2}\alpha\phi(x)-2\delta_{x}\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) - 2 italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
eθ1αϕ(x)eθ1αϕ(x)+eθ2αϕ(x)(since 2δxαϕ(x)0)formulae-sequenceabsentsuperscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃2𝛼italic-ϕ𝑥(since 2subscript𝛿𝑥𝛼italic-ϕ𝑥0)\displaystyle\geq\frac{e^{\theta_{1}\alpha\phi(x)}}{e^{\theta_{1}\alpha\phi(x)% }+e^{\theta_{2}\alpha\phi(x)}}\quad\text{(since }-2\delta_{x}\alpha\phi(x)\leq 0% \text{)}≥ divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG (since - 2 italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) ≤ 0 )
=π𝜽(a1|y)absentsubscript𝜋𝜽conditionalsubscript𝑎1𝑦\displaystyle=\pi_{\bm{\theta}}(a_{1}|y)= italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_y )

And after a gradient step on y𝑦yitalic_y:

π𝜽new(a1|x)subscript𝜋superscript𝜽newconditionalsubscript𝑎1𝑥\displaystyle\pi_{{\bm{\theta}}^{\text{new}}}(a_{1}|x)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ) =eθ1newϕ(x)eθ1newϕ(x)+eθ2newϕ(x)absentsuperscript𝑒superscriptsubscript𝜃1newitalic-ϕ𝑥superscript𝑒superscriptsubscript𝜃1newitalic-ϕ𝑥superscript𝑒superscriptsubscript𝜃2newitalic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}^{\text{new}}\phi(x)}}{e^{\theta_{1}^{\text{% new}}\phi(x)}+e^{\theta_{2}^{\text{new}}\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=e(θ1+δy)ϕ(x)e(θ1+δy)ϕ(x)+e(θ2δy)ϕ(x)absentsuperscript𝑒subscript𝜃1subscript𝛿𝑦italic-ϕ𝑥superscript𝑒subscript𝜃1subscript𝛿𝑦italic-ϕ𝑥superscript𝑒subscript𝜃2subscript𝛿𝑦italic-ϕ𝑥\displaystyle=\frac{e^{(\theta_{1}+\delta_{y})\phi(x)}}{e^{(\theta_{1}+\delta_% {y})\phi(x)}+e^{(\theta_{2}-\delta_{y})\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1ϕ(x)eθ1ϕ(x)+e(θ22δy)ϕ(x)absentsuperscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃22subscript𝛿𝑦italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\phi(x)}}{e^{\theta_{1}\phi(x)}+e^{(\theta_{2% }-2\delta_{y})\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - 2 italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1ϕ(x)eθ1ϕ(x)+eθ2ϕ(x)2δyϕ(x)absentsuperscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃2italic-ϕ𝑥2subscript𝛿𝑦italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\phi(x)}}{e^{\theta_{1}\phi(x)}+e^{\theta_{2}% \phi(x)-2\delta_{y}\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) - 2 italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
eθ1ϕ(x)eθ1ϕ(x)+eθ2ϕ(x)(since 2δyϕ(x)0)formulae-sequenceabsentsuperscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃2italic-ϕ𝑥(since 2subscript𝛿𝑦italic-ϕ𝑥0)\displaystyle\geq\frac{e^{\theta_{1}\phi(x)}}{e^{\theta_{1}\phi(x)}+e^{\theta_% {2}\phi(x)}}\quad\text{(since }-2\delta_{y}\phi(x)\leq 0\text{)}≥ divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG (since - 2 italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT italic_ϕ ( italic_x ) ≤ 0 )
=π𝜽(a1|x)absentsubscript𝜋𝜽conditionalsubscript𝑎1𝑥\displaystyle=\pi_{\bm{\theta}}(a_{1}|x)= italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x )
π𝜽new(a1|y)subscript𝜋superscript𝜽newconditionalsubscript𝑎1𝑦\displaystyle\pi_{{\bm{\theta}}^{\text{new}}}(a_{1}|y)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_y ) =eθ1newαϕ(x)eθ1newαϕ(x)+eθ2newαϕ(x)absentsuperscript𝑒superscriptsubscript𝜃1new𝛼italic-ϕ𝑥superscript𝑒superscriptsubscript𝜃1new𝛼italic-ϕ𝑥superscript𝑒superscriptsubscript𝜃2new𝛼italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}^{\text{new}}\alpha\phi(x)}}{e^{\theta_{1}^{% \text{new}}\alpha\phi(x)}+e^{\theta_{2}^{\text{new}}\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=e(θ1+δy)αϕ(x)e(θ1+δy)αϕ(x)+e(θ2δy)αϕ(x)absentsuperscript𝑒subscript𝜃1subscript𝛿𝑦𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1subscript𝛿𝑦𝛼italic-ϕ𝑥superscript𝑒subscript𝜃2subscript𝛿𝑦𝛼italic-ϕ𝑥\displaystyle=\frac{e^{(\theta_{1}+\delta_{y})\alpha\phi(x)}}{e^{(\theta_{1}+% \delta_{y})\alpha\phi(x)}+e^{(\theta_{2}-\delta_{y})\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1αϕ(x)eθ1αϕ(x)+e(θ22δy)αϕ(x)absentsuperscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃22subscript𝛿𝑦𝛼italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\alpha\phi(x)}}{e^{\theta_{1}\alpha\phi(x)}+e% ^{(\theta_{2}-2\delta_{y})\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - 2 italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1αϕ(x)eθ1αϕ(x)+eθ2αϕ(x)2δyαϕ(x)absentsuperscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃2𝛼italic-ϕ𝑥2subscript𝛿𝑦𝛼italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\alpha\phi(x)}}{e^{\theta_{1}\alpha\phi(x)}+e% ^{\theta_{2}\alpha\phi(x)-2\delta_{y}\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) - 2 italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
eθ1αϕ(x)eθ1αϕ(x)+eθ2αϕ(x)(since 2δyαϕ(x)0)formulae-sequenceabsentsuperscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃2𝛼italic-ϕ𝑥(since 2subscript𝛿𝑦𝛼italic-ϕ𝑥0)\displaystyle\geq\frac{e^{\theta_{1}\alpha\phi(x)}}{e^{\theta_{1}\alpha\phi(x)% }+e^{\theta_{2}\alpha\phi(x)}}\quad\text{(since }-2\delta_{y}\alpha\phi(x)\leq 0% \text{)}≥ divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG (since - 2 italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) ≤ 0 )
=π(a1,αx,𝜽)absent𝜋subscript𝑎1𝛼𝑥𝜽\displaystyle=\pi(a_{1},\alpha x,{\bm{\theta}})= italic_π ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_α italic_x , bold_italic_θ )
=π𝜽(a1|y)absentsubscript𝜋𝜽conditionalsubscript𝑎1𝑦\displaystyle=\pi_{\bm{\theta}}(a_{1}|y)= italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_y )

Let α0𝛼0\alpha\leq 0italic_α ≤ 0 and without loss of generality, let’s take α1𝛼1\alpha\leq 1italic_α ≤ 1, after a gradient step on x𝑥xitalic_x one has

π𝜽new(a1|x)subscript𝜋superscript𝜽newconditionalsubscript𝑎1𝑥\displaystyle\pi_{{\bm{\theta}}^{\text{new}}}(a_{1}|x)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ) =eθ1newϕ(x)eθ1newϕ(x)+eθ2newϕ(x)absentsuperscript𝑒superscriptsubscript𝜃1newitalic-ϕ𝑥superscript𝑒superscriptsubscript𝜃1newitalic-ϕ𝑥superscript𝑒superscriptsubscript𝜃2newitalic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}^{\text{new}}\phi(x)}}{e^{\theta_{1}^{\text{% new}}\phi(x)}+e^{\theta_{2}^{\text{new}}\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=e(θ1+δx)ϕ(x)e(θ1+δx)ϕ(x)+e(θ2δx)ϕ(x)absentsuperscript𝑒subscript𝜃1subscript𝛿𝑥italic-ϕ𝑥superscript𝑒subscript𝜃1subscript𝛿𝑥italic-ϕ𝑥superscript𝑒subscript𝜃2subscript𝛿𝑥italic-ϕ𝑥\displaystyle=\frac{e^{(\theta_{1}+\delta_{x})\phi(x)}}{e^{(\theta_{1}+\delta_% {x})\phi(x)}+e^{(\theta_{2}-\delta_{x})\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1ϕ(x)eθ1ϕ(x)+e(θ22δx)ϕ(x)absentsuperscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃22subscript𝛿𝑥italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\phi(x)}}{e^{\theta_{1}\phi(x)}+e^{(\theta_{2% }-2\delta_{x})\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - 2 italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1ϕ(x)eθ1ϕ(x)+e(θ2ϕ(x)2δxϕ(x)\displaystyle=\frac{e^{\theta_{1}\phi(x)}}{e^{\theta_{1}\phi(x)}+e^{(\theta_{2% }\phi(x)-2\delta_{x}\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) - 2 italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
eθ1ϕ(x)eθ1ϕ(x)+eθ2ϕ(x)(since 2δxϕ(x)0)formulae-sequenceabsentsuperscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃2italic-ϕ𝑥(since 2subscript𝛿𝑥italic-ϕ𝑥0)\displaystyle\geq\frac{e^{\theta_{1}\phi(x)}}{e^{\theta_{1}\phi(x)}+e^{\theta_% {2}\phi(x)}}\quad\text{(since }-2\delta_{x}\phi(x)\leq 0\text{)}≥ divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG (since - 2 italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_ϕ ( italic_x ) ≤ 0 )
=π𝜽(a1|x)absentsubscript𝜋𝜽conditionalsubscript𝑎1𝑥\displaystyle=\pi_{\bm{\theta}}(a_{1}|x)= italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x )
π𝜽new(a1,y)subscript𝜋superscript𝜽newsubscript𝑎1𝑦\displaystyle\pi_{{\bm{\theta}}^{\text{new}}}(a_{1},y)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y ) =eθ1newαϕ(x)eθ1newαϕ(x)+eθ2newαϕ(x)absentsuperscript𝑒superscriptsubscript𝜃1new𝛼italic-ϕ𝑥superscript𝑒superscriptsubscript𝜃1new𝛼italic-ϕ𝑥superscript𝑒superscriptsubscript𝜃2new𝛼italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}^{\text{new}}\alpha\phi(x)}}{e^{\theta_{1}^{% \text{new}}\alpha\phi(x)}+e^{\theta_{2}^{\text{new}}\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=e(θ1+δx)αϕ(x)e(θ1+δx)αϕ(x)+e(θ2δx)αϕ(x)absentsuperscript𝑒subscript𝜃1subscript𝛿𝑥𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1subscript𝛿𝑥𝛼italic-ϕ𝑥superscript𝑒subscript𝜃2subscript𝛿𝑥𝛼italic-ϕ𝑥\displaystyle=\frac{e^{(\theta_{1}+\delta_{x})\alpha\phi(x)}}{e^{(\theta_{1}+% \delta_{x})\alpha\phi(x)}+e^{(\theta_{2}-\delta_{x})\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1αϕ(x)eθ1αϕ(x)+e(θ22δx)αϕ(x)absentsuperscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃22subscript𝛿𝑥𝛼italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\alpha\phi(x)}}{e^{\theta_{1}\alpha\phi(x)}+e% ^{(\theta_{2}-2\delta_{x})\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - 2 italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1αϕ(x)eθ1αϕ(x)+eθ2αϕ(x)2δxαϕ(x)absentsuperscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃2𝛼italic-ϕ𝑥2subscript𝛿𝑥𝛼italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\alpha\phi(x)}}{e^{\theta_{1}\alpha\phi(x)}+e% ^{\theta_{2}\alpha\phi(x)-2\delta_{x}\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) - 2 italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
eθ1αϕ(x)eθ1αϕ(x)+eθ2αϕ(x)(since 2δxαϕ(x)0)formulae-sequenceabsentsuperscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃2𝛼italic-ϕ𝑥(since 2subscript𝛿𝑥𝛼italic-ϕ𝑥0)\displaystyle\leq\frac{e^{\theta_{1}\alpha\phi(x)}}{e^{\theta_{1}\alpha\phi(x)% }+e^{\theta_{2}\alpha\phi(x)}}\quad\text{(since }-2\delta_{x}\alpha\phi(x)\geq 0% \text{)}≤ divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG (since - 2 italic_δ start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) ≥ 0 )
=π𝜽(a1|y)absentsubscript𝜋𝜽conditionalsubscript𝑎1𝑦\displaystyle=\pi_{\bm{\theta}}(a_{1}|y)= italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_y )

And after a gradient step on y𝑦yitalic_y:

π𝜽new(a1|x)subscript𝜋superscript𝜽newconditionalsubscript𝑎1𝑥\displaystyle\pi_{{\bm{\theta}}^{\text{new}}}(a_{1}|x)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x ) =eθ1newϕ(x)eθ1newϕ(x)+eθ2newϕ(x)absentsuperscript𝑒superscriptsubscript𝜃1newitalic-ϕ𝑥superscript𝑒superscriptsubscript𝜃1newitalic-ϕ𝑥superscript𝑒superscriptsubscript𝜃2newitalic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}^{\text{new}}\phi(x)}}{e^{\theta_{1}^{\text{% new}}\phi(x)}+e^{\theta_{2}^{\text{new}}\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=e(θ1+δy)ϕ(x)e(θ1+δy)ϕ(x)+e(θ2δy)ϕ(x)absentsuperscript𝑒subscript𝜃1subscript𝛿𝑦italic-ϕ𝑥superscript𝑒subscript𝜃1subscript𝛿𝑦italic-ϕ𝑥superscript𝑒subscript𝜃2subscript𝛿𝑦italic-ϕ𝑥\displaystyle=\frac{e^{(\theta_{1}+\delta_{y})\phi(x)}}{e^{(\theta_{1}+\delta_% {y})\phi(x)}+e^{(\theta_{2}-\delta_{y})\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1ϕ(x)eθ1ϕ(x)+e(θ22δy)ϕ(x)absentsuperscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃22subscript𝛿𝑦italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\phi(x)}}{e^{\theta_{1}\phi(x)}+e^{(\theta_{2% }-2\delta_{y})\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - 2 italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1ϕ(x)eθ1ϕ(x)+eθ2ϕ(x)2δyϕ(x)absentsuperscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃2italic-ϕ𝑥2subscript𝛿𝑦italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\phi(x)}}{e^{\theta_{1}\phi(x)}+e^{\theta_{2}% \phi(x)-2\delta_{y}\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) - 2 italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
eθ1ϕ(x)eθ1ϕ(x)+eθ2ϕ(x)(since 2δyϕ(x)0)formulae-sequenceabsentsuperscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃1italic-ϕ𝑥superscript𝑒subscript𝜃2italic-ϕ𝑥(since 2subscript𝛿𝑦italic-ϕ𝑥0)\displaystyle\leq\frac{e^{\theta_{1}\phi(x)}}{e^{\theta_{1}\phi(x)}+e^{\theta_% {2}\phi(x)}}\quad\text{(since }-2\delta_{y}\phi(x)\geq 0\text{)}≤ divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG (since - 2 italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT italic_ϕ ( italic_x ) ≥ 0 )
=π𝜽(a1|x)absentsubscript𝜋𝜽conditionalsubscript𝑎1𝑥\displaystyle=\pi_{\bm{\theta}}(a_{1}|x)= italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_x )
π𝜽new(a1|y)subscript𝜋superscript𝜽newconditionalsubscript𝑎1𝑦\displaystyle\pi_{{\bm{\theta}}^{\text{new}}}(a_{1}|y)italic_π start_POSTSUBSCRIPT bold_italic_θ start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_y ) =eθ1newαϕ(x)eθ1newαϕ(x)+eθ2newαϕ(x)absentsuperscript𝑒superscriptsubscript𝜃1new𝛼italic-ϕ𝑥superscript𝑒superscriptsubscript𝜃1new𝛼italic-ϕ𝑥superscript𝑒superscriptsubscript𝜃2new𝛼italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}^{\text{new}}\alpha\phi(x)}}{e^{\theta_{1}^{% \text{new}}\alpha\phi(x)}+e^{\theta_{2}^{\text{new}}\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=e(θ1+δy)αϕ(x)e(θ1+δy)αϕ(x)+e(θ2δy)αϕ(x)absentsuperscript𝑒subscript𝜃1subscript𝛿𝑦𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1subscript𝛿𝑦𝛼italic-ϕ𝑥superscript𝑒subscript𝜃2subscript𝛿𝑦𝛼italic-ϕ𝑥\displaystyle=\frac{e^{(\theta_{1}+\delta_{y})\alpha\phi(x)}}{e^{(\theta_{1}+% \delta_{y})\alpha\phi(x)}+e^{(\theta_{2}-\delta_{y})\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1αϕ(x)eθ1αϕ(x)+e(θ22δy)αϕ(x)absentsuperscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃22subscript𝛿𝑦𝛼italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\alpha\phi(x)}}{e^{\theta_{1}\alpha\phi(x)}+e% ^{(\theta_{2}-2\delta_{y})\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - 2 italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
=eθ1αϕ(x)eθ1αϕ(x)+eθ2αϕ(x)2δyαϕ(x)absentsuperscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃2𝛼italic-ϕ𝑥2subscript𝛿𝑦𝛼italic-ϕ𝑥\displaystyle=\frac{e^{\theta_{1}\alpha\phi(x)}}{e^{\theta_{1}\alpha\phi(x)}+e% ^{\theta_{2}\alpha\phi(x)-2\delta_{y}\alpha\phi(x)}}= divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) - 2 italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG
eθ1αϕ(x)eθ1αϕ(x)+eθ2αϕ(x)(since 2δyαϕ(x)0)formulae-sequenceabsentsuperscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃1𝛼italic-ϕ𝑥superscript𝑒subscript𝜃2𝛼italic-ϕ𝑥(since 2subscript𝛿𝑦𝛼italic-ϕ𝑥0)\displaystyle\geq\frac{e^{\theta_{1}\alpha\phi(x)}}{e^{\theta_{1}\alpha\phi(x)% }+e^{\theta_{2}\alpha\phi(x)}}\quad\text{(since }-2\delta_{y}\alpha\phi(x)\leq 0% \text{)}≥ divide start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG start_ARG italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT + italic_e start_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) end_POSTSUPERSCRIPT end_ARG (since - 2 italic_δ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT italic_α italic_ϕ ( italic_x ) ≤ 0 )
=π𝜽(a1|y)absentsubscript𝜋𝜽conditionalsubscript𝑎1𝑦\displaystyle=\pi_{\bm{\theta}}(a_{1}|y)= italic_π start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_y )

Appendix D Main paper figures on all environments

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 7: Figure 1 on ALE. QBert is the only game where rank decline and collapse are not observed, apart from an outlier run that collapsed at initialization.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 8: Figure 1 on MuJoCo with the tanh activation.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 9: Figure 1 on MuJoCo with the ReLU activation.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 10: Figure 2 on ALE.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 11: Figure 2 on MuJoCo with the tanh activation. With a continuous action distribution, the policy variance can either drop or explode. Dead neurons for the tanh activation are hard to compute as they are dependent on an arbitrary threshold.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 12: Figure 2 on MuJoCo with the ReLU activation.
Refer to caption
Figure 13: Figure 3 on ALE. (No other environments considered; same figure as Figure 3).
Refer to caption
Figure 14: Figure 3 on MuJoCo with the tanh activation. The PPO-Clip objective explodes in the negative direction after collapse so we clip the y-axis of that plot to 11-1- 1.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 15: Figure 4 ALE. Qbert and Gravitar do not have runs with poor representation regions (dead neurons >510absent510>510> 510) to exhibit the correlation around collapse. Qbert has one outlier where the agent collapsed at the very beginning of the training and kept a high (but lower than 510) number of dead neurons and a trivial rank, but a low excess ratio.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 16: Figure 4 on MuJoCo with the tanh activation. Dead neurons for the tanh activation are hard to compute as they are dependent on an arbitrary threshold. In Humanoid the rank does not arrive at low values to exhibit the correlation around collapse.
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 17: Figure 4 on MuJoCo with the ReLU activation. In Ant, the rank does not arrive at low values to exhibit the correlation around collapse.
Refer to caption
Refer to caption
Refer to caption
Figure 18: Figure 6 on ALE. The tails of the plasticity loss on Phoenix with interventions can be higher than without interventions on the runs where the models collapse too early without interventions, leading to the plasticity loss of the non-collapsed models with interventions eventually becoming higher. This can be observed from the training curves with interventions. Nevertheless, their medians are lower.
Refer to caption
Refer to caption
Figure 19: Figure 6 on MuJoCo with the tanh activation.
Refer to caption
Refer to caption
Figure 20: Figure 6 on MuJoCo with the ReLU activation.
Refer to caption
Refer to caption
Refer to caption
Figure 21: Figure 1 on ALE/Phoenix-v5 with interventions.
Refer to caption
Refer to caption
Refer to caption
Figure 22: Figure 1 on ALE/NameThisGame-v5 with interventions.
Refer to caption
Refer to caption
Refer to caption
Figure 23: Figure 1 on ALE/NameThisGame-v5 with interventions.
Refer to caption
Refer to caption
Refer to caption
Figure 24: Figure 1 on MuJoCo Hopper with the tanh activation.
Refer to caption
Refer to caption
Refer to caption
Figure 25: Figure 1 on MuJoCo Humanoid with the tanh activation.
Refer to caption
Refer to caption
Refer to caption
Figure 26: Figure 1 on MuJoCo Hopper with the ReLU activation.
Refer to caption
Refer to caption
Refer to caption
Figure 27: Figure 1 on MuJoCo Humanoid with the ReLU activation.

Appendix E Measuring and comparing rank dynamics

Several matrix rank approximations have been used in the deep learning literature, and more specifically the deep RL literature, to measure the rank of the representation of features learned by a deep network. In complement to the background presented in section 2, we give here all the rank metrics we have tracked in this work and their correlations, showing that although their absolute values differ, their dynamics tend to describe the same evolution.

E.1 Definitions of different rank metrics

Essentially, the main difference between the rank metrics considered in the literature is whether they apply a relative thresholding of the singular values or an absolute one. Their implementation can be found under src/po_dynamics/modules/metrics.py in our codebase.

Referring by ΦΦ\Phiroman_Φ the N×D𝑁𝐷N\times Ditalic_N × italic_D matrix of representations as in Section 2, and letting δ=0.01𝛿0.01\delta=0.01italic_δ = 0.01 be the threshold, and σi(Φ),,σD(Φ)subscript𝜎𝑖Φsubscript𝜎𝐷Φ\langle\sigma_{i}(\Phi),\dots,\sigma_{D}(\Phi)\rangle⟨ italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( roman_Φ ) , … , italic_σ start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( roman_Φ ) ⟩ the singular values of ΦΦ\Phiroman_Φ in decreasing order, the different rank definitions are as follows.

Effective rank (Roy & Vetterli, 2007)

A relative measure of the rank. Let H(p1,,pk)𝐻subscript𝑝1subscript𝑝𝑘H(p_{1},\dots,p_{k})italic_H ( italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) denote the Shannon entropy of a probability distribution over k𝑘kitalic_k events and 𝝈1subscriptnorm𝝈1\|\bm{\sigma}\|_{1}∥ bold_italic_σ ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT be the sum of the singular values. Let σ~i(Φ)=σi(Φ)𝝈1subscript~𝜎𝑖Φsubscript𝜎𝑖Φsubscriptnorm𝝈1\tilde{\sigma}_{i}(\Phi)=\frac{\sigma_{i}(\Phi)}{\|\bm{\sigma}\|_{1}}over~ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( roman_Φ ) = divide start_ARG italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( roman_Φ ) end_ARG start_ARG ∥ bold_italic_σ ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG be the normalized singular values. The effective rank is

exp(H(σ~1(Φ),,σ~D(Φ))}\exp(H(\tilde{\sigma}_{1}(\Phi),\dots,\tilde{\sigma}_{D}(\Phi))\}roman_exp ( italic_H ( over~ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( roman_Φ ) , … , over~ start_ARG italic_σ end_ARG start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ( roman_Φ ) ) }

This rank measure has also been used in deep learning by Huh et al. (2023).

Approximate rank (PCA)

A relative measure of the rank. Intuitively this rank measures the number of PCA values that together explain 99%percent9999\%99 % of the variance of the matrix. This can also be viewed as the lowest-rank reconstruction of the feature matrix with an error lower than 1%. 666https://github.com/epfml/ML_course/blob/94d3f8458e31fb619038660ed2704cef3f4bb512/lectures/12/lecture12b_pca_annotated.pdf It is also used in RL by Yang et al. (2020).

mink{i=1kσi2(Φ)j=1Dσj2(Φ)>1δ}subscript𝑘superscriptsubscript𝑖1𝑘superscriptsubscript𝜎𝑖2Φsuperscriptsubscript𝑗1𝐷superscriptsubscript𝜎𝑗2Φ1𝛿\min_{k}\left\{\frac{\sum_{i=1}^{k}\sigma_{i}^{2}(\Phi)}{\sum_{j=1}^{D}\sigma_% {j}^{2}(\Phi)}>1-\delta\right\}roman_min start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT { divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( roman_Φ ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( roman_Φ ) end_ARG > 1 - italic_δ }
srank (Kumar et al., 2021)

A relative measure of the rank. This is a relative thresholding of the singular values, similar to the approximate rank but with no connection to low-rank reconstruction or variance of the feature matrix.

mink{i=1kσi(Φ)j=1Dσj(Φ)>1δ}subscript𝑘superscriptsubscript𝑖1𝑘subscript𝜎𝑖Φsuperscriptsubscript𝑗1𝐷subscript𝜎𝑗Φ1𝛿\min_{k}\left\{\frac{\sum_{i=1}^{k}\sigma_{i}(\Phi)}{\sum_{j=1}^{D}\sigma_{j}(% \Phi)}>1-\delta\right\}roman_min start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT { divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( roman_Φ ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( roman_Φ ) end_ARG > 1 - italic_δ }
Feature Rank (Lyle et al., 2022)

An absolute measure of the rank. The number of singular values of the normalized ΦΦ\Phiroman_Φ that are larger than a threshold δ𝛿\deltaitalic_δ.

|{σi(Φ)N>δfori{1,,D}}|subscript𝜎𝑖Φ𝑁𝛿for𝑖1𝐷\left|\left\{\frac{\sigma_{i}(\Phi)}{\sqrt{N}}>\delta\;\text{for}\;i\in\{1,% \dots,D\}\right\}\right|| { divide start_ARG italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( roman_Φ ) end_ARG start_ARG square-root start_ARG italic_N end_ARG end_ARG > italic_δ for italic_i ∈ { 1 , … , italic_D } } |
PyTorch rank

An absolute measure of the rank. This is the rank computed by torch.linalg.matrix_rank and torch.linalg.matrix_rank. Let ϵitalic-ϵ\epsilonitalic_ϵ be the smallest difference possible between points of the data type of the singular values, i.e. for torch.float32 that is 1.19209e71.19209superscript𝑒71.19209e^{-7}1.19209 italic_e start_POSTSUPERSCRIPT - 7 end_POSTSUPERSCRIPT. This rank is computed as follows.

|{σi(Φ)σ1×N>ϵfori{1,,D}}|subscript𝜎𝑖Φsubscript𝜎1𝑁italic-ϵfor𝑖1𝐷\left|\left\{\frac{\sigma_{i}(\Phi)}{\sigma_{1}\times N}>\epsilon\;\text{for}% \;i\in\{1,\dots,D\}\right\}\right|| { divide start_ARG italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( roman_Φ ) end_ARG start_ARG italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT × italic_N end_ARG > italic_ϵ for italic_i ∈ { 1 , … , italic_D } } |

It also appears in Press et al. (2007) in the discussion of SVD solutions for linear least squares.

E.2 Correlations between the rank metrics

We compute various correlation coefficients and distance measures between the rank metrics. To compute a correlation/distance on a pair of rank metrics (X,Y)𝑋𝑌(X,Y)( italic_X , italic_Y ), we take for each training run the set {(xt,yt)t{0,,T}}subscript𝑥𝑡subscript𝑦𝑡𝑡0𝑇\{(x_{t},y_{t})t\in\{0,...,T\}\}{ ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) italic_t ∈ { 0 , … , italic_T } } of coinciding values of the curves of the two rank metrics during the run that had T𝑇Titalic_T logged steps, compute the correlation/distance on this set, and average the correlation/distance values across all considered runs. We also compute the worst correlation/distance between each rank metric pair for a worst-case analysis. We separate the average values and worst-case values by environment (ALE vs. MuJoCo) for a more granular analysis. We consider all the runs without the interventions and exclude a few runs where the models collapse since the beginning of training, giving constant trivial ranks, as these result in undefined or trivial correlation coefficients.

We compute Kendall’s τ𝜏\tauitalic_τ coefficient (Kendall, 1938), Spearman’s ρ𝜌\rhoitalic_ρ coefficient (Spearman, 1987), the Pearson correlation coefficient, and a normalized L2-distance computed as t=1T(xtyt)2T×Lsuperscriptsubscript𝑡1𝑇superscriptsubscript𝑥𝑡subscript𝑦𝑡2𝑇𝐿\frac{\sqrt{\sum_{t=1}^{T}(x_{t}-y_{t})^{2}}}{\sqrt{T}\times L}divide start_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG start_ARG square-root start_ARG italic_T end_ARG × italic_L end_ARG where L𝐿Litalic_L is the width of the feature layer considered (i.e., 512 for ALE and 64 for MuJoCo).

Results

We visualize the correlation/distance between the pairs of ranks as heatmaps annotated with averages and standard deviations. Overall, the metrics are highly correlated with average correlation the coefficients varying between 0.99 and 0.51. Individually, no rank metric correlates significantly more on average with the other metrics. Interestingly, from the average correlations, we clearly see two consistent clusters of stronger correlations between the relative rank metrics (approximate rank (PCA) and Effective rank (Roy & Vetterli, 2007)) and absolute rank metrics (Feature Rank (Lyle et al., 2022) and PyTorch rank). The srank (Kumar et al., 2021) which is technically a relative metric, but with a weak normalization rationale, correlates more with the relative metrics on MuJoCo with tanh but more with the absolute metrics on ALE and MuJoCo with ReLU.

Refer to caption
Figure 28: Average correlation between rank metrics on MuJoCo ALE.
Refer to caption
Figure 29: Average correlation between rank metrics on MuJoCo with the tanh activation.
Refer to caption
Figure 30: Average correlation between rank metrics on MuJoCo with the ReLU activation.
Refer to caption
Figure 31: Worst-case correlations between rank metrics on ALE.
Refer to caption
Figure 32: Worst-case correlations between rank metrics on MuJoCo with the tanh activation.
Refer to caption
Figure 33: Worst-case correlations between rank metrics on MuJoCo with the ReLU activation.