Learning under distribution shift
Also transfer learning, covariate shift, transferable learning, domain adaptation, transportability etc
October 17, 2020 — October 24, 2023
Predictive models can be trained on independent or identically distributed data without much fus. Sometimes our data is not identically distributed but is drawn from several different distributions. Say, I am training a model which predicts customer behaviour, and I have customers in Australia and customers in India. Can I nonetheless train a model which works well on all of the data?
If we are using a parametric hierarchical model, we can pool data in the normal way. and learn interaction effects.
If we are doing Neural Network Stuff though, it is not really clear how to to that. We might be vexed, and then surprised, and then write an article about it. If we are a typical research, that article might be blind to prior art in statistics. e.g. Google AI Blog: How Underspecification Presents Challenges for Machine Learning, or, Sebastian Ruder’s NNstyle introduction to “transfer learning”.
I’m hope I don’t sound (too) snarky; there can be virtue in reinventing things with fresh eyes. Transfer learning and domain adaptation and such, these are all concepts that arise in the NN framing, and sometimes the methods overlap with statistical classics and sometimes they extend the repertoire.
Here we will investigate all of them that I have time to.
1 What is transfer learning or domain adaptation actually?
Everyone I talk to seems to have a different notion, and also to think that their idea is canonical.
We need a taxonomy. How about this one? In thuml/ARoadmapforTransferLearning Junguang Jiang, Yang Shu, Jianmin Wang and Mingsheng Long propose the following taxonomy of transfer methods(Jiang et al. 2022):

 MetaLearning (see my metalearning page)
 Causal Learning (see my causal learning page)
They handball to zhaoxin94/awesomedomainadaptation for a finer domain adaptation taxonomy.
One survey paper not enough? Want a better taxonomy? Here are survey papers harvested from the above links:
(Csurka 2017; Gulrajani and LopezPaz 2020; Jiang et al. 2022; Kouw and Loog 2019; Ouali, Hudelot, and Tami 2020; Pan and Yang 2010; Patel et al. 2015; Sun, Shi, and Wu 2015; Tan et al. 2018; M. Wang and Deng 2018; Wilson and Cook 2020; Yuchen Zhang et al. 2019; J. Zhang et al. 2019; L. Zhang and Gao 2020; S. Zhao et al. 2020; Zhuang et al. 2020).
Transfer learning connects also to semisupervised learning and fairness, argues (Schölkopf et al. 2012; Schölkopf 2022).
2 Generic theory
(Bareinboim and Pearl 2016, 2013, 2014; BenDavid et al. 2010; Kaddour et al. 2022; Kulinski and Inouye 2022; Mansour, Mohri, and Rostamizadeh 2009; Pearl and Bareinboim 2014; Schölkopf et al. 2012, 2021; Schölkopf 2022; Subbaswamy, Schulam, and Saria 2019; Zellinger, Moser, and SamingerPlatz 2021; Yuchen Zhang et al. 2019; J. Wang and Chen 2023; Yang et al. 2020)
3 Graphical models
To my mind the most straightforward thing, Simply do causal inference in a hierarchical model which encodes all the causal constraints. All the tools of graphical modeling stuff are still wellposed. It is easy to explain in a Bayesian framework in particular. I think this is what is referred to in Elias Bareinbohm’s data fusion framing (Bareinboim and Pearl 2016, 2013, 2014; Pearl and Bareinboim 2014). In this case we can use standard stistical tooling, such as HMC to sample from some posterior under various interventions, e.g. a shift in some parameter of the population distribution.
The hairy part is that this breaks down in neural networks. There is a milliondimensional nuisance parameter that we need to integrate out, i.e. the neural weights. For reasons of size alone that is frequently impractical, with the computation cost blowing out.
Some other works that look related: (Gong et al. 2018; Moraffah et al. 2019; Yue et al. 2021; Xu, Wang, and Ni 2022; Rothenhäusler et al. 2020).
A graphical model approach has many things to recommend it if it works, though; We do not need to worry about missing values (they may also be inferred); we can estimate intervention distributions etc.
4 Pretraining
The LLM approach. Out of scope for my current investigation, but very much in the news
5 Sample weighting
If the proportion of the populations of various kinds has changed we can do Stratified sampling to estimate the quantity of interest over the new population
6 Model stacking
Numpyro worked example: Bayesian Hierarchical Stacking: Well Switching Case Study (Yao et al. 2022).
7 Bilevel / adversarial
OK, all that graphical model stuff failed to scale to my problem of interest;what next? As noted in Yuchen Zhang et al. (2019) many domain adaption strategies can be framed as bilevel optimisation problems of minimax type. so that presumable corresponds to Domain Adversarial Learning. I think that Invariant risk minimisation and probably can be put in this minimax framework too, but also “learning invariants” is somehow conceptionally separate.
Update: Yes, Ahuja et al. (2020) are helpful in giving us some semblance of taxonomy:
The standard risk minimization paradigm of machine learning is brittle when operating in environments whose test distributions are different from the training distribution due to spurious correlations. Training on data from many environments and finding invariant predictors reduces the effect of spurious features by concentrating models on features that have a causal relationship with the outcome. In this work, we pose such invariant risk minimization as finding the Nash equilibrium of an ensemble game among several environments. By doing so, we develop a simple training algorithm that uses best response dynamics and, in our experiments, yields similar or better empirical accuracy with much lower variance than the challenging bilevel optimization problem of Arjovsky et al. (2020). One key theoretical contribution is showing that the set of Nash equilibria for the proposed game are equivalent to the set of invariant predictors for any finite number of environments, even with nonlinear classifiers and transformations. As a result, our method also retains the generalization guarantees to a large set of environments shown in Arjovsky et al. (2020). The proposed algorithm adds to the collection of successful gametheoretic machine learning algorithms such as generative adversarial networks.
I’m a little confused that people seem to describe Arjovsky et al. (2020) method as bilevel optimisation; the paper discusses a bilevel optimization but they go on to implement an approximation which seems to be a basic singlelevel regularized optimization. I am missing something, either in the original paper or the detractors.
I will inspect IBM/OoD: Repository for theory and methods for OutofDistribution (OoD) generalization.
8 Semisupervised learning
9 Source and target empirical risks
What does this heading even mean? I had some idea, but I have forgotten it, I confess (BenDavid et al. 2006; BenDavid et al. 2010; Blitzer et al. 2007; Mansour, Mohri, and Rostamizadeh 2009).
10 Learning invariants
I am not sure if the various submethods in this category are in fact distinct. H. Zhao et al. (2019) devises necessary conditions for invariant representation learning to work. Possibly this is a special case/particular framing of what I called “bilevel” optimisation, above.
10.1 Regularising features towards invariance
DAN (Long et al. 2015)
Recent studies reveal that a deep neural network can learn transferable features which generalize well to novel tasks for domain adaptation. However, as deep features eventually transition from general to specific along the network, the feature transferability drops significantly in higher layers with increasing domain discrepancy. Hence, it is important to formally reduce the dataset bias and enhance the transferability in taskspecific layers. In this paper, we propose a new Deep Adaptation Network (DAN) architecture, which generalizes deep convolutional neural network to the domain adaptation scenario. In DAN, hidden representations of all taskspecific layers are embedded in a reproducing kernel Hilbert space where the mean embeddings of different domain distributions can be explicitly matched. The domain discrepancy is further reduced using an optimal multikernel selection method for mean embedding matching. DAN can learn transferable features with statistical guarantees, and can scale linearly by unbiased estimate of kernel embedding. Extensive empirical evidence shows that the proposed architecture yields stateoftheart image classification error rates on standard domain adaptation benchmarks.
10.2 Invariant risk minimisation
A trick from Arjovsky et al. (2020). Ermin Orhan summarises the method plus several negative results (Gulrajani and LopezPaz 2020; Rosenfeld, Ravikumar, and Risteski 2020) about IRM:
Take invariant risk minimization (IRM), one of the more popular domain generalization methods proposed recently. IRM considers a classification problem that takes place in multiple domains or environments, \(e_1, e_2, \ldots, e_E\) (in an image classification setting, these could be natural images, drawings, paintings, computerrendered images etc.). We decompose the learning problem into learning a feature backbone \(\Phi\) (a featurizer), and a linear readout \(\beta\) on top of it. Intuitively, in our classifier, we only want to make use of features that are invariant across different environments (for instance, the shapes of objects in our image classification example), and not features that vary from environment to environment (for example, the local textures of objects). This is because the invariant features are more likely to generalize to a new environment. We could, of course, do the old, boring empirical risk minimization (ERM), your grandmother’s dumb method. This would simply lump the training data from all environments into one single giant training set and minimize the loss on that, with the hope that whatever features are more or less invariant across the environments will automatically emerge out of this optimization. Mathematically, ERM in this setting corresponds to solving the following wellknown optimization problem (assuming the same amount of training data from each domain): \(\min _{\Phi, \beta} \frac{1}{E} \sum_c \mathfrak {R}^c(\Phi, \hat{\beta})\), where \(\mathfrak {R}^c\) is the empirical risk in environment \(e\). IRM proposes something much more complicated instead: why don’t we learn a featurizer with the same optimal linear readout on top of it in every environment? The hope is that in this way, the extractor will only learn the invariant features, because the noninvariant features will change from environment to environment and can’t be decoded optimally using the same fixed readout. The IRM objective thus involves a difficult bilevel optimization problem…
Does it though? The general IRM objective is difficult, but there is a simple approximation in the paper, IRMv1 which is claimed to be easier. Either way, though, the critiques of (Gulrajani and LopezPaz 2020; Rosenfeld, Ravikumar, and Risteski 2020) are useful.
 facebookresearch/InvariantRiskMinimization: PyTorch code to run synthetic experiments. (Arjovsky et al. 2020)
 reiinakano/invariantriskminimization: Implementation of Invariant Risk Minimization
Interesting variants:
11 Conformal
Conformal learning + distributional shift.
12 FiLM
13 Justification for batch normalization
Apparently a thing? Should probably note some of the literature about that.
14 Tools
14.1 TransferLearningLibrary
TLlib (Jiang et al. 2022) is an opensource and welldocumented library for Transfer Learning. It is based on pure PyTorch with high performance and friendly API. Our code is pythonic, and the design is consistent with torchvision. You can easily develop new algorithms, or readily apply existing algorithms.
Our API is divided by methods, which include:
 domain alignment methods (
tllib.aligment
) domain translation methods (
tllib.translation
) selftraining methods (
tllib.self\_training
) regularization methods (
tllib.regularization
) data reweighting/resampling methods (
tllib.reweight
) model ranking/selection methods (
tllib.ranking
) normalizationbased methods (
tllib.normalization
)
14.2 DomainBed
facebookresearch/DomainBed: DomainBed is a suite to test domain generalization algorithms
DomainBed is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in Gulrajani and LopezPaz (2020)
14.3 Salad
salad is a library to easily setup experiments using the current stateofthe art techniques in domain adaptation. It features several of recent approaches, with the goal of being able to run fair comparisons between algorithms and transfer them to realworld use cases.
14.4 transferlearning.xyz
Jindon Wang’s site is a good resource, and also includes a library of popular technique.
 Transfer Learning  Transfer learning / domain adaptation / domain generalization / multitask learning etc. Papers, codes, datasets, applications, tutorials.迁移学习
 jindongwang/transferlearning: Transfer learning / domain adaptation / domain generalization / multitask learning etc. Papers, codes, datasets, applications, tutorials.迁移学习
14.5 WILDS
WILDS: A Benchmark of intheWild Distribution Shifts
To facilitate the development of ML models that are robust to realworld distribution shifts, our ICML 2021 paper presents WILDS, a curated benchmark of 10 datasets that reflect natural distribution shifts arising from different cameras, hospitals, molecular scaffolds, experiments, demographics, countries, time periods, users, and codebases.