Neural odes: the relationship between backpropagation and the adjoint sensitivity method

neural networksnumerical methodsoptimizationordinary differential equations

I was reading the Neural ODE paper by Chen, Duvenaud, et. al. and trying to understand the relationship between backpropagation and the adjoint sensitivity method. I also looked at Gil Strang's latest book Linear Algebra and Learning from Data for some more background on both backpropagation and the adjoint method.

It seems like both backpropagation and the adjoint method will compute the gradient of a scalar function. So if I have my loss function $\mathcal{L} = \sum_i^{N}(y_i – f(x, \theta)^2$, and I want to optimize the parameters $\theta$ to minimize the loss, then I could use reverse mode automatic differentiation, a.k.a backpropagation, to get the $\frac{d\mathcal{L}}{d\theta}$.

But it also seems like the adjoint method is a different way to obtain the same $\frac{d\mathcal{L}}{d\theta}$?

I was not clear no whether the adjoint method was a necessary step in reverse mode automatic differentiation? Or is it that applying the adjoint method will just reorganize some steps in the reverse mode AD calculation to make it faster?

I understand that the paper's big claim is that using adjoint sensitivities allows us to solve neural odes faster compared to backprop, but I was not clear on how backprop and adjoint methods are different?

Thanks.

Best Answer

This is an old question in control theory and other related (design-)optimization tasks: Adjoint first or discretization first.

Discretization first computes the dual problem for the actual discretization of the primal problem. For the ODE solver this might work well for the Euler method, but requires a higher effort and more saved intermediary results for more accurate ODE solver methods.

Adjoint first determines the dual problem for the continuous or exact version of the primal problem. The numerical solution, direction reversed, of the dual problem then does not require the internals of the forward solution process, only an interpolation of the forward solution, for instance in form of a "dense output".

I do not know the more precise arguments for and against these variants, but in the current case "adjoints first" seems to be a sensible decision.

Related Question