Adjoint sensitivity analysis for a cost functional under an ODE constraint

derivationsmachine learning

I am trying to recover the result given by equation 10 in the article here. I am unable to get rid of the integral, any help would be much appreciated. To keep the description as self contained as possible, I will describe the relevant notations etc., a more detailed reference is the article 1 itself. Here is my attempt:

Suppose, the overparametrized deep ResNet is modeled via a mean-field ODE:

$$
\dot{X_\rho}(x,t)=\int_{\theta}f(X_{\rho}(x,t),\theta)\rho(\theta,t)d\theta.
$$

Here $x$ denotes the input at the layer $t=0$ and $X_{\rho}(x,1)$ is the output at the final layer $t=1$.
In this model, every residual block $f(\cdot,\theta_i)$ is considered as a particle and optimization (training) will be done over the distribution of the particles $\rho(\theta,t)$ where $\theta$ denotes the parameters of the Residual block and $t$ denotes the $t-th$ layer of the block. We will further represent $\int_{\theta}f(X_{\rho}(x,t),\theta)\rho(\theta, t)d\theta= F(X_{\rho}(x,t);\rho)$. We also know that $\rho(\theta, t)$ is a density for every $t$. Thus the ODE equation above is is reduced to:

$\dot{X_\rho}(x,t)=F(X_{\rho}(x,t);\rho).$

Let $E(x,\rho)$ be the cost function that depends on the mismatch of the output of the neural net, $X_{\rho}(x,1)$, and the true output $y(x)$. For emphasis we note that $X_{\rho}(x,1)$ is the final output of the neural net at layer $t=1$ corresponding to the true input $x$. Now we calculate the sensitivity of the cost function $E(x,\rho)$ with respect to the parameter $\rho(\theta,t)$ that describes the distribution of the weight parameters $\theta$ at every $t-$th layer of the neural net. We will sometimes supress the explicit dependence of $E(x,\rho)$ on its arguments and simply write $E:=E(x,\rho)$ for convenience.

$\frac{dE(x,\rho)}{d\rho}=\frac{\partial E}{\partial X_{\rho}(x,1)}\frac{d X_{\rho}(x,1)}{d \rho}.$

To calculate $\frac{d X_{\rho}(x,1)}{d \rho}$ we will use the adjoint sensitivity method.
Recall:

$ X_{\rho}(x,1)=x+\int_{0}^1 F(X_{\rho}(x,t);\rho) dt$

$ \frac{d X_{\rho}(x,1)}{d\rho}=\int_{0}^1 \bigg[\frac{\partial {F(X_{\rho}(x,t);\rho)}}{\partial X_{\rho}(x,t)}\frac{d{X_{\rho}(x,t)}}{d \rho}+\frac{\partial F(X_{\rho}(x,t);\rho)}{\partial \rho}\bigg]dt-\int_{0}^1 \lambda(t) \bigg[\frac{d \dot{X_\rho}(x,t)}{d\rho}-\frac{\partial F}{\partial X_{\rho}(x,t) }\frac{{d X_{\rho}(x,t) }}{d \rho}-\frac{\partial F}{\partial \rho}\bigg]$.

Note that the second integral is zero due to the ODE equation above. More specifically, $\frac{d}{d\rho}\bigg(\dot{X_\rho}(x,t)-F(X_{\rho}(x,t);\rho)\bigg)=0$.

Consider the term
$-\int_{0}^1 \lambda (t)\frac{d \dot{X_\rho}(x,t)}{d\rho} dt=-\int_{0}^1 \lambda (t)\frac{d}{dt}\frac{d {X_\rho}(x,t)}{d\rho} dt$
Evaluating using integration by parts

$-\int_{0}^1 \lambda (t)\frac{d \dot{X_\rho}(x,t)}{d\rho} dt=-\lambda(t) \frac{d X_{\rho}(x,t)}{d\rho}|_{t=0}^{t=1}
+\int_0^1 \frac{d\lambda (t)}{dt} \frac{d X\rho}{d\rho} dt $

We will choose $\lambda$ such that $\lambda(1)=0$. We also note that $\frac{dX_{\rho}}{d\rho}|_{t=0}=0$.

Using these we can rewrite,

$ \frac{d X_{\rho}(x,1)}{d\rho}=\int_{0}^1 \bigg[(\lambda(t)+Id)\frac{\partial F}{\partial X_{\rho}(x,t)}+\frac{d\lambda(t)}{dt}\bigg]\frac{d X_{\rho}(x,t)}{d\rho} dt+\int_{0}^1 (\lambda+Id)\frac{\partial F}{\partial \rho} $

Now we choose $\lambda(t)$ to satisfy the ODE equation:
$(\lambda(t)+Id)\frac{\partial F}{\partial X_{\rho}(x,t)}+\frac{d\lambda(t)}{dt}=0$
along with the condition $\lambda(1)=0$. This is equivalent to the system for $\tilde{\lambda}=\lambda+Id$,

$-\tilde{\lambda}(t)\frac{\partial F}{\partial X_{\rho}(x,t)}=\frac{d\tilde{\lambda}(t)}{dt} $
and $\tilde{\lambda}(1)=Id$.

It can be independently verified that $\tilde{\lambda}(t)=J_{\rho}(x,t)$ where $J_{\rho}(x,t)=\frac{d X_{\rho}(x,1)}{d X_{\rho}(x,t)}$ satisfies the system along with final value at $t=1$, see also 1(eqn 9).

Thus we get,
$\frac{d X_{\rho}(x,1)}{d\rho}=\int_{0}^1 J_{\rho}(x,t)\frac{\partial F}{\partial \rho} dt$ and so,

$\frac{dE(x,\rho)}{d\rho}=\frac{\partial E}{\partial X_{\rho}(x,1)} \int_{0}^1 J_{\rho}(x,t)\frac{\partial F}{\partial \rho} dt$

whereas in the article, 1(eqn 10), it is evaluated:

$\frac{dE(x,\rho)}{d\rho}=\frac{\partial E}{\partial X_{\rho}(x,1)} J_{\rho}(x,t)\frac{\partial F}{\partial \rho}$

Best Answer

Ah, that is just about the meaning of the expression $\frac{\partial E(x,\rho)}{\partial\rho}$. Since $\rho$ is a function of $t$, it really means "a function $D(t)$ such that $$ E(x,\rho+\Delta\rho)-E(x,\rho)\approx \int_0^1 D(t)\Delta\rho(t)\,dt $$ for all small perturbations $\Delta\rho(t)$".

What you did was to compute the derivative treating $\rho$ like a constant, i.e., your computation is formally valid assuming $\Delta\rho(t)=h$ throughout the whole interval, in which case your formula is just a partial case of their formula, i.e., $$ E(x,\rho+h)-E(x,\rho)\approx h\int_0^1 D(t)\,dt. $$ However you derivation is incomplete because you need to find the linearization for all $\Delta\rho$, not just constants. Fortunately, you hardly need to change anything in it: almost mechanical insertion of $\Delta\rho$ where it belongs should do the trick.

Related Question