# Optimal transport inference

I feel the earth mover under my feet, I feel the ψ tumbling down, I feel my heart start to trembling, whenever you’re around (my barycentre)

March 16, 2021 — May 3, 2023

Doing inference where the probability metric measuring discrepancy between some target distribution and the implied inferential distribution is an optimal-transport one. Frequently intractable, but neat when we can get it. Sometimes we might get there by estimating the (gradients of) an actual OT loss, or even the transport maps implying that loss.

Placeholder/grab bag.

TODO: should we break this into discrete-state and continuous-state cases? Machinery looks different.

## 1 NNs

Wasserstein GANs and OT Gans (Salimans et al. 2018) are argued to do an approximate optimal transport inference, indirectly.

## 2 Surprise connection to matrix factorisation

Non-negative matrix factorisation via OT is a thing, e.g. in topic modeling (Huynh, Zhao, and Phung 2020; Zhao et al. 2020).

## 3 Via Fisher distance

See e.g. (J. H. Huggins et al. 2018b, 2018a) for a particular Bayes posterior approximation using OT and fisher distance.

## 4 Minibatched

Daniel Daza in Approximating Wasserstein distances with PyTorch touches upon Fatras et al. (2020):

Optimal transport distances are powerful tools to compare probability distributions and have found many applications in machine learning. Yet their algorithmic complexity prevents their direct use on large scale datasets. To overcome this challenge, practitioners compute these distances on minibatches i.e., they average the outcome of several smaller optimal transport problems. We propose in this paper an analysis of this practice, which effects are not well understood so far. We notably argue that it is equivalent to an implicit regularization of the original problem, with appealing properties such as unbiased estimators, gradients and a concentration bound around the expectation, but also with defects such as loss of distance property.

## 5 Linearized embedding

Noted in Bai et al. (2023) via Cheng-Soon Ong:

Comparing K (probability) measures requires the pairwise calculation of transport-based distances, which, despite the significant recent computational speed-ups, remains to be relatively expensive. To address this problem, W. Wang et al. (2013) proposed the Linear Optimal Transport (LOT) framework, which linearizes the 2-Wasserstein distance utilizing its weak Riemannian structure. In short, the probability measures are embedded into the tangent space at a fixed reference measure (e.g., the measures’ Wasserstein barycenter) through a logarithmic map. The Euclidean distances between the embedded measures then approximate the 2-Wasserstein distance between the probability measures. The LOT framework is computationally attractive as it only requires the computation of one optimal transport problem per input measure, reducing the otherwise quadratic cost to linear. Moreover, the framework provides theoretical guarantees on convexifying certain sets of probability measures […], which is critical in supervised and unsupervised learning from sets of probability measures.

## 6 Tools

### 6.1 OTT

Optimal Transport Tools (OTT) (Cuturi et al. 2022), a toolbox for all things Wasserstein (documentation):

The goal of OTT is to provide sturdy, versatile and efficient optimal transport solvers, taking advantage of JAX features, such as JIT, auto-vectorization and implicit differentiation.

A typical OT problem has two ingredients: a pair of weight vectors

`a`

and`b`

(one for each measure), with a ground cost matrix that is either directly given, or derived as the pairwise evaluation of a cost function on pairs of points taken from two measures. The main design choice in OTT comes from encapsulating the cost in a`Geometry`

object, and [bundling] it with a few useful operations (notably kernel applications). The most common geometry is that of two clouds of vectors compared with the squared Euclidean distance, as illustrated in the example below:

```
import jax
import jax.numpy as jnp
from ott.tools import transport
# Samples two point clouds and their weights.
rngs = jax.random.split(jax.random.PRNGKey(0),4)
n, m, d = 12, 14, 2
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a, b = a / jnp.sum(a), b / jnp.sum(b)
# Computes the couplings via Sinkhorn algorithm.
ot = transport.solve(x, y, a=a, b=b)
P = ot.matrix
```

The call to

`sinkhorn`

above works out the optimal transport solution by storing its output. The transport matrix can be instantiated using those optimal solutions and the`Geometry`

again. That transport matrix links each point from the first point cloud to one or more points from the second, as illustrated below.To be more precise, the

`sinkhorn`

algorithm operates on the`Geometry`

, taking into account weights`a`

and`b`

, to solve the OT problem, produce a named tuple that contains two optimal dual potentials`f`

and`g`

(vectors of the same size as`a`

and`b`

), the objective`reg_ot_cost`

and a log of the`errors`

of the algorithm as it converges, and a`converged`

flag.

### 6.2 POT

POT: Python Optimal Transport (Rémi Flamary et al. 2021)

This open source Python library provide several solvers for optimization problems related to Optimal Transport for signal, image processing and machine learning.

Website and documentation: https://PythonOT.github.io/

Source Code (MIT): https://github.com/PythonOT/POT

POT provides the following generic OT solvers (links to examples):

- OT Network Simplex solver for the linear program/ Earth Movers Distance .
- Conditional gradient and Generalized conditional gradient for regularized OT .
- Entropic regularization OT solver with Sinkhorn Knopp Algorithm , stabilized version , greedy Sinkhorn and Screening Sinkhorn.
- Bregman projections for Wasserstein barycenter , convolutional barycenter and unmixing .
- Sinkhorn divergence and entropic regularization OT from empirical data.
- Debiased Sinkhorn barycenters Sinkhorn divergence barycenter
- Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations .
- Weak OT solver between empirical distributions
- Non regularized Wasserstein barycenters with LP solver (only small scale).
- Gromov-Wasserstein distances and GW barycenters (exact and regularized ), differentiable using gradients from Graph Dictionary Learning
- Fused-Gromov-Wasserstein distances solver and FGW barycenters
- Stochastic solver and differentiable losses for Large-scale Optimal Transport (semi-dual problem and dual problem )
- Sampled solver of Gromov Wasserstein for large-scale problem with any loss functions
- Non regularized free support Wasserstein barycenters .
- One dimensional Unbalanced OT with KL relaxation and barycenter \[10, 25\]. Also exact unbalanced OT with KL and quadratic regularization and the regularization path of UOT
- Partial Wasserstein and Gromov-Wasserstein (exact and entropic formulations).
- Sliced Wasserstein \[31, 32\] and Max-sliced Wasserstein that can be used for gradient flows .
- Graph Dictionary Learning solvers .
- Several backends for easy use of POT with Pytorch/jax/Numpy/Cupy/Tensorflow arrays.
POT provides the following Machine Learning related solvers:

- Optimal transport for domain adaptation with group lasso regularization, Laplacian regularization and semi supervised setting.
- Linear OT mapping and Joint OT mapping estimation .
- Wasserstein Discriminant Analysis (requires autograd + pymanopt).
- JCPOT algorithm for multi-source domain adaptation with target shift .
Some other examples are available in the documentation.

### 6.3 GeomLoss

The

GeomLosslibrary provides efficient GPU implementations for:

- Kernel norms (also known as Maximum Mean Discrepancies).
- Hausdorff divergences, which are positive definite generalizations of the Chamfer-ICP loss and are analogous to
log-likelihoodsof Gaussian Mixture Models.- Debiased Sinkhorn divergences, which are affordable yet
positive and definiteapproximations of Optimal Transport (Wasserstein) distances.It is hosted on GitHub and distributed under the permissive MIT license.

GeomLoss functions are available through the custom PyTorch layers

`SamplesLoss`

,`ImagesLoss`

and`VolumesLoss`

which allow you to work with weightedpoint clouds(of any dimension),density mapsandvolumetric segmentation masks.

## 7 Incoming

Rigollet and Weed (2018):

We give a statistical interpretation of entropic optimal transport by showing that performing maximum-likelihood estimation for Gaussian deconvolution corresponds to calculating a projection with respect to the entropic optimal transport distance.

Thomas Viehmann, An efficient implementation of the Sinkhorn algorithm for the GPU is a Pytorch CUDA extension (Viehmann 2019)

Marco Cuturi’s course notes on OT include a 400 page slide deck.

## 8 References

*SIAM Journal on Mathematical Analysis*.

*Advances in Neural Information Processing Systems*.

*Proceedings of the 32Nd International Conference on Neural Information Processing Systems*. NIPS’18.

*Gradient Flows: In Metric Spaces and in the Space of Probability Measures*. Lectures in Mathematics. ETH Zürich.

*SIAM Journal on Mathematical Analysis*.

*International Conference on Machine Learning*.

*arXiv:1703.00573 [Cs]*.

*arXiv:1805.00753 [Stat]*.

*Acta Numerica*.

*arXiv:1412.5154 [Math]*.

*IFAC Proceedings Volumes*, 19th IFAC World Congress,.

*arXiv:1802.04885 [Stat]*.

*arXiv:1810.07717 [Cs]*.

*arXiv:1610.05627 [Math, Stat]*.

*arXiv:1906.01614 [Math, Stat]*.

*AISTATS 2018*.

*Electronic Journal of Probability*.

*arXiv:1209.1077 [Cs, Stat]*.

*arXiv:1607.05816 [Math]*.

*ICML*.

*arXiv:2102.07850 [Cs, Stat]*.

*arXiv:1507.00504 [Cs]*.

*Advances in Neural Information Processing Systems 26*.

*International Conference on Machine Learning*.

*Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics*.

*von Mises calculus for statistical functionals*. Lecture Notes in Statistics 19.

*Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics*.

*Journal of Machine Learning Research*.

*Machine Learning*.

*Advances in Neural Information Processing Systems 28*.

*SIAM Journal on Applied Dynamical Systems*.

*Advances in Neural Information Processing Systems 29*.

*arXiv:1706.00292 [Stat]*.

*Advances in Neural Information Processing Systems 27*. NIPS’14.

*arXiv:1003.3852 [Math]*.

*arXiv:1704.00028 [Cs, Stat]*.

*arXiv:1705.07164 [Cs, Stat]*.

*Advances in Neural Information Processing Systems 30*.

*arXiv:1806.10234 [Cs, Stat]*.

*arXiv:1809.09505 [Cs, Math, Stat]*.

*arXiv:1910.04102 [Cs, Math, Stat]*.

*Advances in Neural Information Processing Systems*.

*Information Geometry*.

*Discrete & Continuous Dynamical Systems - A*.

*International Conference on Machine Learning*.

*Advances In Neural Information Processing Systems*.

*PMLR*.

*arXiv:1906.03317 [Cs, Math, Stat]*.

*Information Geometry*.

*Handbook of Uncertainty Quantification*.

*SIAM/ASA Journal on Uncertainty Quantification*.

*Mathematical Programming*.

*Advances in Neural Information Processing Systems 29*.

*Annual Review of Statistics and Its Application*.

*Computational Optimal Transport*.

*International Conference on Machine Learning*.

*Proceedings of the Twenty-Fifth International Joint Conference on Artificial Intelligence*. IJCAI’16.

*The 22nd International Conference on Artificial Intelligence and Statistics*.

*International Conference on Machine Learning*. ICML’15.

*Stat*.

*Geophysical Journal International*.

*Optimal Transport for Applied Mathematicians*. Edited by Filippo Santambrogio. Progress in Nonlinear Differential Equations and Their Applications.

*Proceedings of the 38th International Conference on Machine Learning*.

*arXiv:1610.06519 [Cs, Math]*.

*SIAM Journal on Imaging Sciences*.

*ACM Transactions on Graphics*.

*Journal of Machine Learning Research*.

*IEEE Transactions on Automatic Control*.

*UAI18*.

*Electronic Journal of Statistics*.

*International Journal of Computer Vision*.

*Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)*.

*Proceedings of NeurIPS 2020*.

*IEEE Transactions on Information Theory*.