Learning under distribution shift

Also transfer learning, covariate shift, transferable learning, domain adaptation, transportability etc

October 17, 2020 — October 24, 2023

graphical models
how do science
machine learning
Figure 1

Predictive models can be trained on independent or identically distributed data without much fus. Sometimes our data is not identically distributed but is drawn from several different distributions. Say, I am training a model which predicts customer behaviour, and I have customers in Australia and customers in India. Can I nonetheless train a model which works well on all of the data?

If we are using a parametric hierarchical model, we can pool data in the normal way. and learn interaction effects.

If we are doing Neural Network Stuff though, it is not really clear how to to that. We might be vexed, and then surprised, and then write an article about it. If we are a typical research, that article might be blind to prior art in statistics. e.g. Google AI Blog: How Underspecification Presents Challenges for Machine Learning, or, Sebastian Ruder’s NN-style introduction to “transfer learning”.

I’m hope I don’t sound (too) snarky; there can be virtue in reinventing things with fresh eyes. Transfer learning and domain adaptation and such, these are all concepts that arise in the NN framing, and sometimes the methods overlap with statistical classics and sometimes they extend the repertoire.

Here we will investigate all of them that I have time to.

1 What is transfer learning or domain adaptation actually?

Everyone I talk to seems to have a different notion, and also to think that their idea is canonical.

We need a taxonomy. How about this one? In thuml/A-Roadmap-for-Transfer-Learning Junguang Jiang, Yang Shu, Jianmin Wang and Mingsheng Long propose the following taxonomy of transfer methods(Jiang et al. 2022):

They handball to zhaoxin94/awesome-domain-adaptation for a finer domain adaptation taxonomy.

One survey paper not enough? Want a better taxonomy? Here are survey papers harvested from the above links:

(Csurka 2017; Gulrajani and Lopez-Paz 2020; Jiang et al. 2022; Kouw and Loog 2019; Ouali, Hudelot, and Tami 2020; Pan and Yang 2010; Patel et al. 2015; Sun, Shi, and Wu 2015; Tan et al. 2018; M. Wang and Deng 2018; Wilson and Cook 2020; Yuchen Zhang et al. 2019; J. Zhang et al. 2019; L. Zhang and Gao 2020; S. Zhao et al. 2020; Zhuang et al. 2020).

Transfer learning connects also to semi-supervised learning and fairness, argues (Schölkopf et al. 2012; Schölkopf 2022).

2 Generic theory

(Bareinboim and Pearl 2016, 2013, 2014; Ben-David et al. 2010; Kaddour et al. 2022; Kulinski and Inouye 2022; Mansour, Mohri, and Rostamizadeh 2009; Pearl and Bareinboim 2014; Schölkopf et al. 2012, 2021; Schölkopf 2022; Subbaswamy, Schulam, and Saria 2019; Zellinger, Moser, and Saminger-Platz 2021; Yuchen Zhang et al. 2019; J. Wang and Chen 2023; Yang et al. 2020)

3 Graphical models

To my mind the most straightforward thing, Simply do causal inference in a hierarchical model which encodes all the causal constraints. All the tools of graphical modeling stuff are still well-posed. It is easy to explain in a Bayesian framework in particular. I think this is what is referred to in Elias Bareinbohm’s data fusion framing (Bareinboim and Pearl 2016, 2013, 2014; Pearl and Bareinboim 2014). In this case we can use standard stistical tooling, such as HMC to sample from some posterior under various interventions, e.g. a shift in some parameter of the population distribution.

The hairy part is that this breaks down in neural networks. There is a million-dimensional nuisance parameter that we need to integrate out, i.e. the neural weights. For reasons of size alone that is frequently impractical, with the computation cost blowing out.

Some other works that look related: (Gong et al. 2018; Moraffah et al. 2019; Yue et al. 2021; Xu, Wang, and Ni 2022; Rothenhäusler et al. 2020).

A graphical model approach has many things to recommend it if it works, though; We do not need to worry about missing values (they may also be inferred); we can estimate intervention distributions etc.

4 Pre-training

The LLM approach. Out of scope for my current investigation, but very much in the news

5 Sample weighting

If the proportion of the populations of various kinds has changed we can do Stratified sampling to estimate the quantity of interest over the new population

6 Model stacking

Numpyro worked example: Bayesian Hierarchical Stacking: Well Switching Case Study (Yao et al. 2022).

7 Bi-level / adversarial

OK, all that graphical model stuff failed to scale to my problem of interest;what next? As noted in Yuchen Zhang et al. (2019) many domain adaption strategies can be framed as bi-level optimisation problems of minimax type. so that presumable corresponds to Domain Adversarial Learning. I think that Invariant risk minimisation and probably can be put in this minimax framework too, but also “learning invariants” is somehow conceptionally separate.

Update: Yes, Ahuja et al. (2020) are helpful in giving us some semblance of taxonomy:

The standard risk minimization paradigm of machine learning is brittle when operating in environments whose test distributions are different from the training distribution due to spurious correlations. Training on data from many environments and finding invariant predictors reduces the effect of spurious features by concentrating models on features that have a causal relationship with the outcome. In this work, we pose such invariant risk minimization as finding the Nash equilibrium of an ensemble game among several environments. By doing so, we develop a simple training algorithm that uses best response dynamics and, in our experiments, yields similar or better empirical accuracy with much lower variance than the challenging bi-level optimization problem of Arjovsky et al. (2020). One key theoretical contribution is showing that the set of Nash equilibria for the proposed game are equivalent to the set of invariant predictors for any finite number of environments, even with nonlinear classifiers and transformations. As a result, our method also retains the generalization guarantees to a large set of environments shown in Arjovsky et al. (2020). The proposed algorithm adds to the collection of successful game-theoretic machine learning algorithms such as generative adversarial networks.

I’m a little confused that people seem to describe Arjovsky et al. (2020) method as bi-level optimisation; the paper discusses a bi-level optimization but they go on to implement an approximation which seems to be a basic single-level regularized optimization. I am missing something, either in the original paper or the detractors.

I will inspect IBM/OoD: Repository for theory and methods for Out-of-Distribution (OoD) generalization.

8 Semi-supervised learning

See Semi-Supervised Learning.

9 Source and target empirical risks

What does this heading even mean? I had some idea, but I have forgotten it, I confess (Ben-David et al. 2006; Ben-David et al. 2010; Blitzer et al. 2007; Mansour, Mohri, and Rostamizadeh 2009).

10 Learning invariants

I am not sure if the various sub-methods in this category are in fact distinct. H. Zhao et al. (2019) devises necessary conditions for invariant representation learning to work. Possibly this is a special case/particular framing of what I called “bi-level” optimisation, above.

10.1 Regularising features towards invariance

DAN (Long et al. 2015)

Recent studies reveal that a deep neural network can learn transferable features which generalize well to novel tasks for domain adaptation. However, as deep features eventually transition from general to specific along the network, the feature transferability drops significantly in higher layers with increasing domain discrepancy. Hence, it is important to formally reduce the dataset bias and enhance the transferability in task-specific layers. In this paper, we propose a new Deep Adaptation Network (DAN) architecture, which generalizes deep convolutional neural network to the domain adaptation scenario. In DAN, hidden representations of all task-specific layers are embedded in a reproducing kernel Hilbert space where the mean embeddings of different domain distributions can be explicitly matched. The domain discrepancy is further reduced using an optimal multi-kernel selection method for mean embedding matching. DAN can learn transferable features with statistical guarantees, and can scale linearly by unbiased estimate of kernel embedding. Extensive empirical evidence shows that the proposed architecture yields state-of-the-art image classification error rates on standard domain adaptation benchmarks.

10.2 Invariant risk minimisation

A trick from Arjovsky et al. (2020). Ermin Orhan summarises the method plus several negative results (Gulrajani and Lopez-Paz 2020; Rosenfeld, Ravikumar, and Risteski 2020) about IRM:

Take invariant risk minimization (IRM), one of the more popular domain generalization methods proposed recently. IRM considers a classification problem that takes place in multiple domains or environments, \(e_1, e_2, \ldots, e_E\) (in an image classification setting, these could be natural images, drawings, paintings, computer-rendered images etc.). We decompose the learning problem into learning a feature backbone \(\Phi\) (a featurizer), and a linear readout \(\beta\) on top of it. Intuitively, in our classifier, we only want to make use of features that are invariant across different environments (for instance, the shapes of objects in our image classification example), and not features that vary from environment to environment (for example, the local textures of objects). This is because the invariant features are more likely to generalize to a new environment. We could, of course, do the old, boring empirical risk minimization (ERM), your grandmother’s dumb method. This would simply lump the training data from all environments into one single giant training set and minimize the loss on that, with the hope that whatever features are more or less invariant across the environments will automatically emerge out of this optimization. Mathematically, ERM in this setting corresponds to solving the following well-known optimization problem (assuming the same amount of training data from each domain): \(\min _{\Phi, \beta} \frac{1}{E} \sum_c \mathfrak {R}^c(\Phi, \hat{\beta})\), where \(\mathfrak {R}^c\) is the empirical risk in environment \(e\). IRM proposes something much more complicated instead: why don’t we learn a featurizer with the same optimal linear readout on top of it in every environment? The hope is that in this way, the extractor will only learn the invariant features, because the non-invariant features will change from environment to environment and can’t be decoded optimally using the same fixed readout. The IRM objective thus involves a difficult bi-level optimization problem…

Does it though? The general IRM objective is difficult, but there is a simple approximation in the paper, IRMv1 which is claimed to be easier. Either way, though, the critiques of (Gulrajani and Lopez-Paz 2020; Rosenfeld, Ravikumar, and Risteski 2020) are useful.

Interesting variants:

(Ahuja et al. 2022, 2020; Shah et al. 2021)

11 Conformal

Conformal learning + distributional shift.

Figure 2: This Maori gentleman (name unspecified) from the 1800s demonstrates an artful transfer learning from the western fashion domain. Or maybe that is style transfer, I forget.

12 FiLM

13 Justification for batch normalization

Apparently a thing? Should probably note some of the literature about that.

14 Tools

14.1 Transfer-Learning-Library

TLlib (Jiang et al. 2022) is an open-source and well-documented library for Transfer Learning. It is based on pure PyTorch with high performance and friendly API. Our code is pythonic, and the design is consistent with torchvision. You can easily develop new algorithms, or readily apply existing algorithms.

Our API is divided by methods, which include:

  • domain alignment methods (tllib.aligment)
  • domain translation methods (tllib.translation)
  • self-training methods (tllib.self\_training)
  • regularization methods (tllib.regularization)
  • data reweighting/resampling methods (tllib.reweight)
  • model ranking/selection methods (tllib.ranking)
  • normalization-based methods (tllib.normalization)

14.2 DomainBed

facebookresearch/DomainBed: DomainBed is a suite to test domain generalization algorithms

DomainBed is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in Gulrajani and Lopez-Paz (2020)

14.3 Salad

salad is a library to easily setup experiments using the current state-of-the art techniques in domain adaptation. It features several of recent approaches, with the goal of being able to run fair comparisons between algorithms and transfer them to real-world use cases.

14.4 transferlearning.xyz

Jindon Wang’s site is a good resource, and also includes a library of popular technique.

14.5 WILDS

WILDS: A Benchmark of in-the-Wild Distribution Shifts

To facilitate the development of ML models that are robust to real-world distribution shifts, our ICML 2021 paper presents WILDS, a curated benchmark of 10 datasets that reflect natural distribution shifts arising from different cameras, hospitals, molecular scaffolds, experiments, demographics, countries, time periods, users, and codebases.

15 Incoming

16 References

Ahuja, Caballero, Zhang, et al. 2022. Invariance Principle Meets Information Bottleneck for Out-of-Distribution Generalization.”
Ahuja, Shanmugam, Varshney, et al. 2020. Invariant Risk Minimization Games.”
Arjovsky, Bottou, Gulrajani, et al. 2020. Invariant Risk Minimization.”
Bareinboim, and Pearl. 2013. A General Algorithm for Deciding Transportability of Experimental Results.” Journal of Causal Inference.
———. 2014. “Transportability from Multiple Environments with Limited Experiments: Completeness Results.” In Proceedings of the 27th International Conference on Neural Information Processing Systems - Volume 1. NIPS’14.
———. 2016. Causal Inference and the Data-Fusion Problem.” Proceedings of the National Academy of Sciences.
Ben-David, Blitzer, Crammer, et al. 2006. Analysis of Representations for Domain Adaptation.” In Advances in Neural Information Processing Systems.
Ben-David, Blitzer, Crammer, et al. 2010. A Theory of Learning from Different Domains.” Machine Learning.
Besserve, Mehrjou, Sun, et al. 2019. Counterfactuals Uncover the Modular Structure of Deep Generative Models.” In arXiv:1812.03253 [Cs, Stat].
Blitzer, Crammer, Kulesza, et al. 2007. Learning Bounds for Domain Adaptation.” In Advances in Neural Information Processing Systems.
Chapelle, Schölkopf, and Zien, eds. 2010. Semi-Supervised Learning. Adaptive Computation and Machine Learning.
Chu, Jin, Zhu, et al. 2022. DNA: Domain Generalization with Diversified Neural Averaging.” In Proceedings of the 39th International Conference on Machine Learning.
Csurka. 2017. Domain Adaptation for Visual Applications: A Comprehensive Survey.”
Dumoulin, Perez, Schucher, et al. 2018. Feature-Wise Transformations.” Distill.
Ganin, and Lempitsky. 2015. Unsupervised Domain Adaptation by Backpropagation.” In Proceedings of the 32nd International Conference on Machine Learning.
Gong, Zhang, Huang, et al. 2018. Causal Generative Domain Adaptation Networks.”
Gulrajani, and Lopez-Paz. 2020. In Search of Lost Domain Generalization.” In.
Henzi, Shen, Law, et al. 2023. Invariant Probabilistic Prediction.”
Ioffe, and Szegedy. 2015. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.”
Jiang, Shu, Wang, et al. 2022. Transferability in Deep Learning: A Survey.”
Kaddour, Lynch, Liu, et al. 2022. Causal Machine Learning: A Survey and Open Problems.”
Koh, Sagawa, Marklund, et al. 2021. WILDS: A Benchmark of in-the-Wild Distribution Shifts.” arXiv:2012.07421 [Cs].
Kosoy, Chan, Liu, et al. 2022. Towards Understanding How Machines Can Learn Causal Overhypotheses.”
Kouw, and Loog. 2019. An Introduction to Domain Adaptation and Transfer Learning.”
Kulinski, and Inouye. 2022. Towards Explaining Distribution Shifts.”
Kuroki, Charoenphakdee, Bao, et al. 2018. Unsupervised Domain Adaptation Based on Source-Guided Discrepancy.” In.
Lagemann, Lagemann, Taschler, et al. 2023. Deep Learning of Causal Structures in High Dimensions Under Data Limitations.” Nature Machine Intelligence.
Lattimore. 2017. Learning How to Act: Making Good Decisions with Machine Learning.”
Li, Pan, Wang, et al. 2018. Domain Generalization with Adversarial Feature Learning.” In 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition.
Liu, Zhang, Gong, et al. 2022. Identifying Latent Causal Content for Multi-Source Domain Adaptation.”
Long, Cao, Cao, et al. 2019. Transferable Representation Learning with Deep Adaptation Networks.” IEEE Transactions on Pattern Analysis and Machine Intelligence.
Long, Cao, Wang, et al. 2015. Learning Transferable Features with Deep Adaptation Networks.” In Proceedings of the 32nd International Conference on Machine Learning.
Long, Zhu, Wang, et al. 2017. Deep Transfer Learning with Joint Adaptation Networks.” In Proceedings of the 34th International Conference on Machine Learning.
Mansour, Mohri, and Rostamizadeh. 2009. Domain Adaptation: Learning Bounds and Algorithms.” In.
Moraffah, Shu, Raglin, et al. 2019. Deep Causal Representation Learning for Unsupervised Domain Adaptation.”
Ouali, Hudelot, and Tami. 2020. An Overview of Deep Semi-Supervised Learning.”
Pan, and Yang. 2010. A Survey on Transfer Learning.” IEEE Transactions on Knowledge and Data Engineering.
Patel, Gopalan, Li, et al. 2015. Visual Domain Adaptation: A Survey of Recent Advances.” IEEE Signal Processing Magazine.
Pearl, and Bareinboim. 2014. External Validity: From Do-Calculus to Transportability Across Populations.” Statistical Science.
Perez, Strub, de Vries, et al. 2017. FiLM: Visual Reasoning with a General Conditioning Layer.”
Peters, Bühlmann, and Meinshausen. 2016. Causal Inference by Using Invariant Prediction: Identification and Confidence Intervals.” Journal of the Royal Statistical Society Series B: Statistical Methodology.
Quiñonero-Candela. 2009. Dataset Shift in Machine Learning.
Ramchandran, and Mukherjee. 2021. On Ensembling Vs Merging: Least Squares and Random Forests Under Covariate Shift.” arXiv:2106.02589 [Math, Stat].
Rosenfeld, Ravikumar, and Risteski. 2020. The Risks of Invariant Risk Minimization.” In.
Rothenhäusler, Meinshausen, Bühlmann, et al. 2020. Anchor Regression: Heterogeneous Data Meets Causality.” arXiv:1801.06229 [Stat].
Schölkopf. 2022. Causality for Machine Learning.” In Probabilistic and Causal Inference: The Works of Judea Pearl.
Schölkopf, Janzing, Peters, et al. 2012. On Causal and Anticausal Learning.” In ICML 2012.
Schölkopf, Locatello, Bauer, et al. 2021. Toward Causal Representation Learning.” Proceedings of the IEEE.
Shah, Ahuja, Shanmugam, et al. 2021. Treatment Effect Estimation Using Invariant Risk Minimization.”
Simchoni, and Rosset. 2023. Integrating Random Effects in Deep Neural Networks.”
Subbaswamy, Schulam, and Saria. 2019. Preventing Failures Due to Dataset Shift: Learning Predictive Models That Transport.” In The 22nd International Conference on Artificial Intelligence and Statistics.
Sun, Shi, and Wu. 2015. A Survey of Multi-Source Domain Adaptation.” Information Fusion.
Tan, Sun, Kong, et al. 2018. A Survey on Deep Transfer Learning.”
Tibshirani, Foygel Barber, Candes, et al. 2019. Conformal Prediction Under Covariate Shift.” In Advances in Neural Information Processing Systems.
Ventola, Braun, Yu, et al. 2023. Probabilistic Circuits That Know What They Don’t Know.” arXiv.org.
Wang, Jindong, and Chen. 2023. Introduction to Transfer Learning: Algorithms and Practice. Machine Learning: Foundations, Methodologies, and Applications.
Wang, Mei, and Deng. 2018. Deep Visual Domain Adaptation: A Survey.” Neurocomputing.
Wilson, and Cook. 2020. A Survey of Unsupervised Deep Domain Adaptation.”
Xu, Wang, and Ni. 2022. Graphical Modeling for Multi-Source Domain Adaptation.” IEEE Transactions on Pattern Analysis and Machine Intelligence.
Yang, Zhang, Dai, et al. 2020. Transfer Learning.
Yao, Pirš, Vehtari, et al. 2022. Bayesian Hierarchical Stacking: Some Models Are (Somewhere) Useful.” Bayesian Analysis.
Yue, Sun, Hua, et al. 2021. Transporting Causal Mechanisms for Unsupervised Domain Adaptation.” In.
Zellinger, Moser, and Saminger-Platz. 2021. On Generalization in Moment-Based Domain Adaptation.” Annals of Mathematics and Artificial Intelligence.
Zhang, Yabin, Deng, Tang, et al. 2020. Unsupervised Multi-Class Domain Adaptation: Theory, Algorithms, and Practice.” IEEE Transactions on Pattern Analysis and Machine Intelligence.
Zhang, Lei, and Gao. 2020. Transfer Adaptation Learning: A Decade Survey.”
Zhang, Jing, Li, Ogunbona, et al. 2019. Recent Advances in Transfer Learning for Cross-Dataset Visual Recognition: A Problem-Oriented Perspective.”
Zhang, Yuchen, Liu, Long, et al. 2019. Bridging Theory and Algorithm for Domain Adaptation.” In Proceedings of the 36th International Conference on Machine Learning.
Zhao, Han, Combes, Zhang, et al. 2019. On Learning Invariant Representation for Domain Adaptation.”
Zhao, Sicheng, Yue, Zhang, et al. 2020. A Review of Single-Source Deep Unsupervised Visual Domain Adaptation.”
Zhuang, Qi, Duan, et al. 2020. A Comprehensive Survey on Transfer Learning.”