Neural nets with implicit layers

Also, declarative networks, bi-level optimization and other ingenious uses of the implicit function theorem

December 8, 2020 — August 9, 2022

Figure 1

To my mind a few things fit in the category of implicit networks. There are all networks where the layers are not explicit forward operators, but whose outputs are defined implicitly as the fixed point of some iteration, such as an optimisation. There are a few names and subcategories: declarative networks, equilibrium networks etc. In general

NB: This is different to the implicit representation method. Since implicit layers and implicit representations also occur in the same problems (such as some PINNS) avoidable terminological confusion will haunt us.

A common feature of these architectures is that they benefit from the implcit function theorem to compute gradients.

1 The implicit function theorem in learning

Say your model includes an iterative step, e.g. “optimise this to convergence”. How do we learn with such a thing? The good old Implicit function theorem. A beautiful explanation of what is special about differentiating systems at convergence is Blondel et al. (2021).

For further tutorial-form background, see the NeurIPS 2020 tutorial, Deep Implicit Layers - Neural ODEs, Deep Equilibrium Models, and Beyond, by Zico Kolter, David Duvenaud, and Matt Johnson or ADCME: Automatic Differentiation for Implicit Operators.

Under-regarded paper: Domke (2012) shows that we can be very sloppy with our optimisation iterations, and still use the implicit function theorem.

2 Optimization layers

Differentiable Convex Optimization Layers introduces cvxpylayers:

Optimization layers add domain-specific knowledge or learnable hard constraints to machine learning models.. Many of these layers solve convex and constrained optimization problems of the form

\[ \begin{array}{rl} x^{\star}(\theta)=\operatorname{arg min}_{x} & f(x ; \theta) \\ \text { subject to } g(x ; \theta) & \leq 0 \\ h(x ; \theta) & =0 \end{array} \]

with parameters θ, objective f, and constraint functions g,h and do end-to-end learning through them with respect to θ.

In this tutorial we introduce our new library cvxpylayers for easily creating differentiable new convex optimization layers. This lets you express your layer with the CVXPY domain specific language as usual and then export the CVXPY object to an efficient batched and differentiable layer with a single line of code. This project turns every convex optimization problem expressed in CVXPY into a differentiable layer.

3 Unrolling algorithms

Turning iterations into layers.

See model based NN.

4 Deep declarative networks

Figure 2

A different terminology, although AFAICT closely related technology, is used by Stephen Gould in Gould, Hartley, and Campbell (2019), under the banner of Deep Declarative Networks. Fun applications he highlights: robust losses in pooling layers, projection onto shapes, convex programming and warping, matching problems, (relaxed) graph alignment, noisy point-cloud surface reconstruction… (I am sitting in his seminar as I write this.) They implemented a ddn library (pytorch).

To follow up from that presentation: Learning basis decomposition, hyperparameter optimisation… Stephen relates these to deep declarative nets by discussing both problems as “bi-level optimisation problems”. Also discusses some minimax-like optimisations to “Stackelberg games” which are an optimisation problem embedded in game theory.

5 Deep equilibrium networks

Related: Deep equilibrium networks (Bai, Kolter, and Koltun 2019; Bai, Koltun, and Kolter 2020). In this one we assume that the network has a single layer which is iterated, and then solve for a fixed point of that iterated layer; this turns out to be memory efficient and in fact powerful (you need to scale up the width of that magic layer up to make it match the effective depth of a non-iterative layer stack, but not so very much.)

Example code: locuslab/deq.

6 Deep Ritz method

As seen in NN-for-PDEs Fits here, maybe? TBC (E, Han, and Jentzen 2017; E and Yu 2018; Müller and Zeinhofer 2020)

7 In practice

In general we are using autodiff to find the gradients of our systems. Writing custom gradients to exploit the efficiencies of implicit gradients: how do we do that in practice?

Overriding autodiff is surprisingly easy in jax: Custom derivative rules for JAX-transformable Python functions, including implicit functions. Blondel et al. (2021) adds some extra conveniences in the form of google/jaxopt: Hardware accelerated, batchable and differentiable optimizers in JAX..

Hardware accelerated, batchable and differentiable optimizers in JAX.

  • Hardware accelerated: our implementations run on GPU and TPU, in addition to CPU.
  • Batchable: multiple instances of the same optimization problem can be automatically vectorized using JAX’s vmap.
  • Differentiable: optimization problem solutions can be differentiated with respect to their inputs either implicitly or via autodiff of unrolled algorithm iterations.

Julia autodiff also allows convenient overrides, and in fact the community discourse around them is full of helpful tips.

8 Incoming

9 References

Ablin, Peyré, and Moreau. 2020. Super-Efficiency of Automatic Differentiation for Functions Defined as a Minimum.” In Proceedings of the 37th International Conference on Machine Learning.
Adler, and Öktem. 2018. Learned Primal-Dual Reconstruction.” IEEE Transactions on Medical Imaging.
Agrawal, Amos, Barratt, et al. 2019. Differentiable Convex Optimization Layers.” In Advances In Neural Information Processing Systems.
Amos, and Kolter. 2017. OptNet: Differentiable Optimization as a Layer in Neural Networks.”
Amos, Rodriguez, Sacks, et al. 2018. Differentiable MPC for End-to-End Planning and Control.”
Andersson, Gillis, Horn, et al. 2019. CasADi: A Software Framework for Nonlinear Optimization and Optimal Control.” Mathematical Programming Computation.
Arora, Ge, Ma, et al. 2015. Simple, Efficient, and Neural Algorithms for Sparse Coding.” In Proceedings of The 28th Conference on Learning Theory.
Bai, Kolter, and Koltun. 2019. Deep Equilibrium Models.” In Advances in Neural Information Processing Systems.
Bai, Koltun, and Kolter. 2020. Multiscale Deep Equilibrium Models.” In Advances in Neural Information Processing Systems.
———. 2021. Stabilizing Equilibrium Models by Jacobian Regularization.” arXiv:2106.14342 [Cs, Stat].
Banert, Rudzusika, Öktem, et al. 2021. Accelerated Forward-Backward Optimization Using Deep Learning.” arXiv:2105.05210 [Math].
Barratt. 2018. On the Differentiability of the Solution to Convex Optimization Problems.”
Blondel, Berthet, Cuturi, et al. 2021. Efficient and Modular Implicit Differentiation.” arXiv:2105.15183 [Cs, Math, Stat].
Border. 2019. Notes on the Implicit Function Theorem.”
Borgerding, and Schniter. 2016. Onsager-Corrected Deep Networks for Sparse Linear Inverse Problems.” arXiv:1612.01183 [Cs, Math].
Djolonga, and Krause. 2017. Differentiable Learning of Submodular Models.” In Proceedings of the 31st International Conference on Neural Information Processing Systems. NIPS’17.
Domke. 2012. Generic Methods for Optimization-Based Modeling.” In International Conference on Artificial Intelligence and Statistics.
Donti, Amos, and Kolter. 2017. Task-Based End-to-End Model Learning in Stochastic Optimization.”
E, Han, and Jentzen. 2017. Deep Learning-Based Numerical Methods for High-Dimensional Parabolic Partial Differential Equations and Backward Stochastic Differential Equations.” Communications in Mathematics and Statistics.
E, and Yu. 2018. The Deep Ritz Method: A Deep Learning-Based Numerical Algorithm for Solving Variational Problems.” Communications in Mathematics and Statistics.
Gould, Fernando, Cherian, et al. 2016. On Differentiating Parameterized Argmin and Argmax Problems with Application to Bi-Level Optimization.”
Gould, Hartley, and Campbell. 2019. Deep Declarative Networks: A New Hope.”
Granas, and Dugundji. 2003. Fixed Point Theory. Springer Monographs in Mathematics.
Gregor, and LeCun. 2010. Learning fast approximations of sparse coding.” In Proceedings of the 27th International Conference on Machine Learning (ICML-10).
———. 2011. Efficient Learning of Sparse Invariant Representations.” arXiv:1105.5307 [Cs].
Haber, and Ruthotto. 2018. Stable Architectures for Deep Neural Networks.” Inverse Problems.
Huang, Bai, and Kolter. 2021. (Implicit)\(^2\): Implicit Layers for Implicit Representations.” In.
Krantz, and Parks. 2002. The Implicit Function Theorem.
Landry, Lorenzetti, Manchester, et al. 2019. Bilevel Optimization for Planning Through Contact: A Semidirect Method.”
Lee, Maji, Ravichandran, et al. 2019. Meta-Learning with Differentiable Convex Optimization.”
Ma, Han, Liu, et al. 2021. Neural-Pull: Learning Signed Distance Functions from Point Clouds by Learning to Pull Space onto Surfaces.” In.
Mena, Belanger, Linderman, et al. 2018. Learning Latent Permutations with Gumbel-Sinkhorn Networks.”
Müller, and Zeinhofer. 2020. Deep Ritz Revisited.”
Poli, Massaroli, Yamashita, et al. 2020. Hypersolvers: Toward Fast Continuous-Depth Models.” In Advances in Neural Information Processing Systems.
Rajeswaran, Finn, Kakade, et al. 2019. Meta-Learning with Implicit Gradients.”
Sulam, Aberdam, Beck, et al. 2020. On Multi-Layer Basis Pursuit, Efficient Algorithms and Convolutional Neural Networks.” IEEE Transactions on Pattern Analysis and Machine Intelligence.
Takikawa, Litalien, Yin, et al. 2021. Neural Geometric Level of Detail: Real-Time Rendering with Implicit 3D Shapes.”
Wang, Donti, Wilder, et al. 2019. SATNet: Bridging Deep Learning and Logical Reasoning Using a Differentiable Satisfiability Solver.”
Zhu, Peng, Larsson, et al. 2022. “NICE-SLAM: Neural Implicit Scalable Encoding for SLAM.”