Abstract
Learning invariant (causal) features for out-of-distribution (OOD) generalization have attracted extensive attention recently, and among the proposals, invariant risk minimization (IRM) is a notable solution. In spite of its theoretical promise for linear regression, the challenges of using IRM in linear classification problems remain. By introducing the information bottleneck (IB) principle into the learning of IRM, the IB-IRM approach has demonstrated its power to solve these challenges. In this paper, we further improve IB-IRM from two aspects. First, we show that the key assumption of support overlap of invariant features used in IB-IRM guarantees OOD generalization, and it is still possible to achieve the optimal solution without this assumption. Second, we illustrate two failure modes where IB-IRM (and IRM) could fail in learning the invariant features, and to address such failures, we propose a Counterfactual Supervision-based Information Bottleneck (CSIB) learning algorithm that recovers the invariant features. By requiring counterfactual inference, CSIB works even when accessing data from a single environment. Empirical experiments on several datasets verify our theoretical results.
Keywords: out-of-distribution generalization, information bottleneck, causal learning
1. Introduction
Modern machine learning models are prone to catastrophic performance loss during deployment when the test distribution is different from the training distribution. This phenomenon has been repeatedly witnessed and intentionally exposed in many examples [1,2,3,4,5]. Among the explanations, shortcut learning [6] is considered as a main factor causing this phenomenon. A good example is the classification of images of cows and camels—a trained convolutional network tends to recognize cows or camels by learning spurious features from image backgrounds (e.g., green pastures for cows and deserts for camels), rather than learning the causal shape features of the animals [7]; decisions based on the spurious features would make the learned models fail when cows or camels appear in unusual or different environments. Machine learning models are expected to have the capability of out-of-distribution (OOD) generalization and avoid shortcut learning.
To achieve OOD generalization, recent theories [8,9,10,11,12] are motivated by causality literature [13,14] and resort to extraction of the invariant, causal features and establishing the relevant conditions under which machine learning models have guaranteed generalization. Among these works, invariant risk minimization (IRM) [8] is a notable learning paradigm that incorporates the invariance principle [15] into practice. In spite of the theoretical promise of IRM, it is only applicable to problems of linear regression. For other problems, such as linear classification, Ahuja et al. [12] first show that for OOD generalization, linear classification is more difficult (see Theorem 1) and propose a new learning method of information bottleneck-based invariant risk minimization (IB-IRM) based on the support overlap assumption (Assumption 7). In this work, we closely investigate the conditions identified in [12] and propose improved results for OOD generalization of linear classification.
Our technical contributions are as follows. In [12], a notion of support overlap of invariant features is assumed in order to make the OOD generalization of linear classification successful. In this work, we first show that this assumption is strong, but it is still possible to achieve such goal without this assumption. Then, we examine whether the IB-IRM proposed in [12] is sufficient to learn invariant features for linear classification and find that IB-IRM (and IRM) could fail in two modes. We then analyze two failure modes of IB-IRM and IRM, in particular when the spurious features in training environments capture sufficient information for the task of interest but have less information than the invariant features. Based on the above analyses, we propose a new method, termed counterfactual supervision-based information bottleneck (CSIB), to address such failures. We prove that without the need of the support overlap assumption, CSIB is theoretically guaranteed for the success of OOD generalization in linear classification. Notably, CSIB works even when accessing data from a single environment. Finally, we design three synthetic datasets and a colored MINST dataset based on our examples; experiments demonstrate the effectiveness of CSIB empirically.
The rest of this article is organized as follows. The learning problem of out-of-distribution (OOD) generalization is formulated in Section 2. In Section 3, we study the learnability of the OOD generalization with different assumptions to the training and test environments. Using these assumptions, two failure modes of previous methods (IRM and IB-IRM) are analysed in Section 4. Based on the above analysis, our method is then proposed in Section 5. The experiments are reported in Section 6. Finally, we discuss the related works in Section 7 and provide some conclusions and limitations of our work in Section 8. All the proofs and details of experiments are given in the Appendice A and Appendice B.
2. OOD Generalization: Background and Formulations
2.1. Background on Structural Equation Models
Before introducing our formulations of OOD generalization, we provide a detailed background on structural equation models (SEMs) [8,13].
Definition 1
(Structural Equation Model (SEM)). A structural equation model (SEM) governing the random vector is a set of structural equations:
where are called the parents of , and are independent noise random variables. For every SEM, we yield a directed acyclic graph (DAG) by adding one vertex for each and directed edges from each parent in (the causes) to child (the effect).
Definition 2
(Intervention). Consider an SEM . An intervention e on consists of replacing one or several of its structural equations to obtain an intervened SEM , with structural equations:
The variable is intervened if or .
In an SEM , we can draw samples from the observational distribution according to the topological ordering of its DAG . We can also manipulate (intervene) a unique SEM in different ways, indexed by e, to different but related SEMs , which results in different interventional distributions . Such family of interventions are used to model the environments.
2.2. Formulations of OOD Generalization
In this paper, we study the OOD generalization problem by following the linear classification structural equation model below [12].
Assumption 1
(Linear classification SEM ).
(1) where is the labeling hyperplane, , , , ⊕ is the XOR operator, is invertible (), · is the dot product function, and if otherwise 0.
The SEM governs four random variables , and its directed acyclic graph (DAG) is illustrated in Figure 1a, where the exogenous noise variable N is omitted. Following Definition 2, each intervention e generates a new environment e with interventional distribution . We assume only the variables of and are observable. In OOD generalization, we are interested in a set of environments defined as below.
Figure 1.
(a) DAG of the SEM (Assumption 1); (b–d) DAGs of the interventional SEM in the training environments with respect to different correlations between and . Grey nodes denote observed variables, and white nodes represent unobserved variables. Dashed lines denote the edges which might vary across the interventional environments and even be absent in some scenarios, whilst solid lines indicate that they are invariant across all the environments. All exogenous noise variables are omitted in the DAGs.
Definition 3
(). Consider the SEM (Assumption 1) and the learning goal of predicting Y from X. Then, the set of all environments indexes all the interventional distributions obtainable by valid interventions e. An intervention is valid as long as (i) the DAG remains acyclic, (ii) , and (iii) .
Assumption 1 shows that is the cause of the response Y. We name the invariant features or causal features because always holds among all valid interventional SEMs , as defined in Definition 3. The is called spurious features because may vary in different environments of .
Let be the training data gathered from a set of training environments , where is the dataset from environment e with each instance i.i.d. drawn from . Let and be the support sets of and Y, respectively. Given observed data D, the goal of OOD generalization is to find a predictor such that it can perform well across a set of OOD environments (test environments) of interest, where . Formally, it is expected to minimize
| (2) |
where is the risk under the environment e with the 0-1 loss function. Since may be different from , this learning problem is called OOD generalization. We assume the predictor includes a feature extractor and a classifier . With a slight abuse of notation, we also let the classifier w and feature extractor be parameteried by themselves, respectively, as and with c the number of feature dimension.
2.3. Background on IRM and IB-IRM
To minimize Equation (2), two notable solutions of IRM [8] and IB-IRM [12] are listed as follows:
| (3) |
| (4) |
where , and with H the Shannon entropy (or a lower bounded differential entropy), and is the threshold on the average risk. If we drop the invariance constraint from IRM and IB-IRM, we obtain standard empirical risk minimization (ERM) and information bottleneck-based empirical risk minimization (IB-ERM), respectively. The use of an entropy constraint in IB-IRM is inspired from the information bottleneck principle [16] where mutual information is used for information compression. Since the representation is a deterministic mapping of X, we have
| (5) |
thus minimizing the entropy of is equivalent to minimizing the mutual information . In brief, the optimization goal of IB-IRM is to select the one that has the least entropy among all highly predictive invariant predictors.
3. OOD Generalization: Assumptions and Learnability
To study the learnability of OOD generalization, we make following definition.
Definition 4.
Given and . We say an algorithm succeeds to solve OOD generalization with respect to () if the predictor returned by this algorithm satisfies the following equation:
(6) where is the learning hypothesis (a function set including all possible linear classifier). Otherwise we say it fails to solve OOD generalization.
So far, we have omitted how different environments of and exactly are to enable OOD generalization. Different assumptions about and make the OOD generalization problem different.
3.1. Assumptions about the Training Environments
Define the support set of the invariant (resp., spurious) features (resp., ) in environment e as (resp., ). In general, we make following assumptions to the invariant features in the training environments .
Assumption 2
(Bounded invariant features). is a bounded set. (A set is bounded if such that ).
Assumption 3
(Strictly separable invariant features).
The difficulties of OOD generalization are due to the spurious correlations between and in the training environments . In this paper, we consider three modes induced by different correlations between and as shown below.
Assumption 4
(Spurious correlation 1). Assume each ,
(7) where , and is a continuous (or discrete with each component supported on at least two distinct values), bounded, and zero mean noise variable.
Assumption 5
(Spurious correlation 2). Assume each ,
(8) where , and is a continuous (or discrete with each component supported on at least two distinct values), bounded, and zero mean noise variable.
Assumption 6
(Spurious correlation 3). Assume each ,
(9) where and are independent noise variables.
For each , the DAGs of its corresponding interventional SEMs with respect to Assumptions 4–6 are illustrated in Figure 1b–d, respectively. It is worth noting that although the DAGs are identical across all training environments in each mode of Assumptions 4–6, the interventional SEMs among different training environments are different due to the interventions on the exogenous noise variables.
3.2. Assumptions about the OOD Environments
Theorem 1
(Impossibility of guaranteed OOD generalization for linear classification [12]). Suppose . If for all the training environments , the latent invariant features are bounded and strictly separable, i.e., Assumptions 2 and 3 hold, then every deterministic algorithm fails to solve the OOD generalization.
The above theorem shows that it is impossible to solve OOD generalization if . To make it learnable, Ahuja et al. [12] propose the support overlap assumption (Assumption 7) to the invariant features.
Assumption 7
(Invariant feature support overlap). .
However, Assumption 7 is strong, and we would show that it is still possible to solve OOD generalization without this assumption. For better illustration, consider an OOD generalization task from to with and , and the support sets of the corresponding invariant features and are intuitively illustrated in Figure 2c (assume in this example). From Figure 2c, it is clear that although the support sets of invariant features between the two environments are different, it is still possible to solve OOD generalization if the learned feature extractor only captures the invariant features, e.g., .
Figure 2.
(a) Example 1; (b) example 2; (c) example illustration. Here, and . The blue and black regions represent the support sets of and , corresponding to the environments and , respectively. is the training environment and is the OOD environment. Although Assumption 7 does not hold in this example, any zero-error classifier with on the environment data would clearly make the classification error zero in , thus succeeding to solve OOD generalization.
To make Assumption 7 weaker, we propose the following assumption.
Assumption 8.
Let be the mixture distribution of invariant features in the training environments. Denote be a hypothesis set including all linear classifiers mapping from to . , assume , where l is the 0-1 loss function and .
Clearly, under the assumption of separable invariant features (Assumption 3), for any , Assumption 7 holds ⇒⇒⇒. Assumption 8 holds, but not vice versa. Therefore, Assumption 8 is weaker than Assumption 7. We show that Assumption 8 could be substituted for Assumption 7 for the success of OOD generalization in our proposed method in Section 5.
4. Failures of IRM and IB-IRM
Under Spurious Correlation 1 (Assumption 4), the IB-IRM algorithm has been shown to enable OOD generalization, while IRM fails [12]. In this section, we would show that both IRM and IB-IRM could fail under Spurious Correlations 2 and 3 (Assumptions 5 and 6).
4.1. Failure under Spurious Correlation 2
Example 1
(Counter-Example 1). Under Assumption 5, let with and be the generated classifier in Assumption 1. We assume two training environments and a OOD environment as:
Figure 2a shows the support points of these features in the training environments. Then, by applying any algorithm to solve the above example with , we would obtain a predictor of . Consider the prediction made by this model as (we ignore the classifier bias for convenience)
| (10) |
It is trivial to show that the of and is an invariant predictor across training environments with classification error , and it achieves the least entropy of for each training environment e. Therefore, it is a solution of IB-IRM and IRM. However, the predictor of relies on spurious features and has the test error ; thus, it fails to solve the OOD generalization.
4.2. Failure under Spurious Correlation 3
Example 2
(Counter-Example 2). Under Assumption 6, let with , be a discrete variable supported uniformly on six points among all environments, and be the generated classifier in Assumption 1. We assume two training environments and a OOD environment as:
Figure 2b shows the support points of these features in the training environments. Then, by applying any algorithm to solve the above example with , we would obtain a predictor of . Consider the prediction made by this model as (we ignore the classifier bias for convenience):
| (11) |
It is trivial to show that the of and is an invariant predictor across training environments with classification error , and it achieves the least entropy of among all highly predictive predictors for each training environment e. and Therefore, it is a solution of IB-IRM and IRM. However, the predictor of relies on spurious features and has the test error ; thus, it fails to solve the OOD generalization.
4.3. Understanding the Failures
From the illustrations of the above simple examples, we can conclude that the failure of the invariance constraint for removing the spurious features is because the spurious features among all training environments are strictly linearly separable by their corresponding labels. This would make the predictor rely only on spurious features to achieve minimum training error and also be the invariant predictor across training environments. Since the label set is finite (with only two values in binary classification) in classification problems, such a phenomenon may exist. We state such failure mode formally as below.
Theorem 2.
Given any and satisfying Assumptions 2, 3, and 7, if two sets and are linearly separable and on each training environment e, then IB-IRM (and IRM, ERM, or IB-ERM) with any fails to solve the OOD generalization.
The understanding of Theorem 2 is intuitive since when the spurious features in the training environments with respect to different labels are linearly separable, there is no algorithm that can distinguish spurious features from invariant features. Although the assumption of linear separation of the spurious features seems strong for this failure, it is easy to hold in high-dimensional space when is large (common cases in practice such as image data). We show one case in Appendix A.3 that if the number of environments is under Assumption 6, the spurious features in the training environments are probably separable by their labels. This is because in o-dimensional space there is a high probability that o randomly drawn distinct points are linearly separable for any two subsets.
5. Counterfactual Supervision-Based Information Bottleneck
In the above analyses, we have shown two failure modes of IB-IRM and IRM for OOD generalization in the linear classification problem. The key reason for the failure is due to the learned features that rely on spurious features. To prevent such failure, we present the counterfactual supervision-based information bottleneck (CSIB) learning algorithm for removing the spurious features progressively.
In general, the IB-ERM method is applied to extract features from the beginning of each iteration:
| (12) |
Due to the information bottleneck, only a part of the information of the input X are exploited in . If the information of spurious features exists in the learned features , the idea of CSIB is to drop such information and meanwhile maintain the causal information (represented by invariant features ) as well. However, achieving such a goal faces two challenges: (1) how to determine whether contains spurious information of ? and (2) how to remove the information of ?
Fortunately, due to the orthogonality in the linear space, it is possible to disentangle the features that are exploited by (denoted as ) and the features that are not exploited by (denoted as ) via Singular Value Decomposition (SVD). Based on that, we could construct an SEM governing three variables of , , and X. Therefore, by conducting counterfactual interventions on and in , we could solve the first challenge by requiring a single supervision on the counterfactual examples . For example, if we intervene on and find that the causal information remains in the resulting , then the extracted features are definitely the spurious features. To address the second challenge, we replace the input by by filtering out the information of and conduct the same learning procedure from the beginning.
The learning algorithm of CSIB is illustrated in Algorithm 1, and Figure 3 shows the framework of CSIB. We show in Theorem 3 that CSIB is theoretically guaranteed to succeed to solve OOD generalization.
| Algorithm 1 Counterfactual Supervision-based Information Bottleneck (CSIB) |
| Input:, , , , , and is an example randomly drawn from . |
| Output: classifier , feature extractor . |
| Begin: |
|
| End |
Figure 3.
A simplified framework for the illustration of the proposed CSIB method.
Theorem 3
(Guarantee of CSIB). Given any and satisfying Assumptions 2, 3, and 8, then for every spurious correlation of Assumptions 4, 5, and 6 (in this correlation mode, assume the spurious features are linearly separable in the training environments), the CSIB algorithm with succeeds in solving the OOD generalization.
Remark 1.
CSIB succeeds to solve OOD generalization without assuming the support overlap to invariant features and could apply to multiple spurious modes where IB-IRM (as well as ERM, IRM, and IB-ERM) may fail. By introducing counterfactual inference and further supervision (usually conducted by a human) with several steps, CSIB works even when accessing data from a single environment, which is significant especially in the cases where multiple environments’ data are not available.
6. Experiments
6.1. Toy Experiments on Synthetic Datasets
We perform experiments on three synthetic datasets from different spurious correlations modes to verify our method—counterfactual, supervision-based, and information bottleneck (CSIB)—and compare them to ERM, IB-ERM, IRM, and IB-IRM. We follow the same protocol for tuning hyperparameters from [8,12,17] and report the classification error for all experiments. In the following, we first briefly describe the designed datasets and then report the main results. More experimental details can be found in the Appendix.
6.1.1. Datasets
Example 1/1S. The example is a modified one from the linear unit tests introduced in [17], which generalizes the cow/camel classification task with relevant backgrounds.
The dataset of each environment is sampled from the following distribution
We set and for the first three environments, and for . The scrambling matrix S is an identical matrix in Example 1 and a random unitary matrix in Example 1S. Here, we set and for all environments to make the spurious features and the invariant features both linearly separable to confuse each other. The experiments on different values of q and are presented in the Appendix, where we have found very interesting observations related to the inductive bias of neural networks.
Example 2/2S. This example is extended from Example 1 to show one of the failure modes of IB-IRM (as well as ERM, IRM, and IB-ERM) and how our method can be improved by intervention (counterfactual supervision). Given , each instance in the environment data is sampled by
where we set , and is the identical matrix in our experiments. We set , , , and if for different training environments. This example shows clearer smaller entropy of spurious features than that of invariant features, which is opposite Example 1/1S.
Example 3/3S. This example extends from Example 2 and is similar to the construction of Example 2/2S. Let for different training environments. Each instance in the environments e is sampled by
where we set in our experiments. The spurious features have smaller entropy than the invariant features in this example, which is similar to Example 2/2S, but the invariant features significantly enjoy much larger margin than the spurious features, which is very different from the above two examples. We show a summary of the properties of these three datasets in Table 1 for a general view.
Table 1.
Summary of three synthetic datasets. Note that for linearly separable features, their margin levels significantly influence the final learning classifier due to the implicit bias of the gradient descent [18]. Such bias would push the standard learning (such as cross-entropy loss) to focus more on the large-margin features. The margin with respect to a dataset (or features) (each instance has a label 0 or 1) is the minimum distance between a point in and the max-margin hyperplane, which separates by its labels.
| Datasets | Margin Relationship | Entropy Relationship | ||
|---|---|---|---|---|
| Example 1/1S | 5 | 5 | ||
| Example 2/2S | 5 | 5 | ||
| Example 3/3S | 5 | 5 |
6.1.2. Summary of Results
Table 2 shows the classification errors of different methods when training data comes from single, three, and six environments. We can see that ERM and IRM fail to recognize the invariant features in the experiment of Example 1/1S, where invariant features have smaller margin than spurious features do, while information bottleneck-based methods (IB-ERM, IB-IRM, and CSIB) show improved results due to the smaller entropy of the invariant features. Our method CSIB shows results consistent with IB-IRM in Example 1/1S when invariant features are extracted in the first run, which verifies the effectiveness of using the information bottleneck for OOD generalization. In another dataset of Example 2/2S, where the invariant features have larger entropy than spurious features do, we can see that only CSIB can remove the spurious features compared with the other method, although the information bottleneck-based method IB-ERM would degrade the performance of ERM by focusing more on the spurious features. In the third experiment of Example 3/3S, we can see that although ERM shows not-bad results due to the significantly larger margin of invariant features, our CSIB method still shows improvements by removing more spurious features. Notably, comparing the IB-ERM and IB-IRM when only spurious features are extracted (Example 2/2S, Example 3/3S), our CSIB method could effectively remove them by counterfactual supervision and then refocus on the invariant features. Note that the reason of non-zero average error and the fluctuant results of CSIB in some experiments is that the entropy minimization in the training process is less accurate, where entropy is substituted by variance for the ease of the optimization. Nevertheless, there always exists a case where the entropy is indeed truly minimized, and the error reaches zero (see (min) in the table) in Example 2/2S and Example 3/3S. In summary, CSIB consistently performs better in different spurious correlations modes and is especially more effective than IB-ERM and IB-IRM when the spurious features enjoy much smaller entropy than the invariant features do.
Table 2.
Main results: #Envs means the number of training environments, and (min) reports the minimal test classification error across different running seeds.
| #Envs | ERM (min) | IRM (min) | IB-ERM (min) | IB-IRM (min) | CSIB (min) | |
|---|---|---|---|---|---|---|
| Example 1 | 1 | 0.50 ± 0.01 (0.49) | 0.50 ± 0.01 (0.49) | 0.23 ± 0.02 (0.22) | 0.31 ± 0.10 (0.25) | 0.23± 0.02 (0.22) |
| Example 1S | 1 | 0.50 ± 0.00 (0.49) | 0.50 ± 0.00 (0.50) | 0.46 ± 0.04 (0.39) | 0.30 ± 0.10 (0.25) | 0.46 ± 0.04 (0.39) |
| Example 2 | 1 | 0.40 ± 0.20 (0.00) | 0.50 ± 0.00 (0.49) | 0.50 ± 0.00 (0.49) | 0.46 ± 0.02 (0.45) | 0.00 ± 0.00 (0.00) |
| Example 2S | 1 | 0.50 ± 0.00 (0.50) | 0.31 ± 0.23 (0.00) | 0.50 ± 0.00 (0.50) | 0.45 ± 0.01 (0.43) | 0.10 ± 0.20 (0.00) |
| Example 3 | 1 | 0.16 ± 0.06 (0.09) | 0.18 ± 0.03 (0.14) | 0.50 ± 0.01 (0.49) | 0.40 ± 0.20 (0.01) | 0.11 ± 0.20 (0.00) |
| Example 3S | 1 | 0.17 ± 0.07 (0.10) | 0.09 ± 0.02 (0.07) | 0.50 ± 0.00 (0.50) | 0.50 ± 0.00 (0.50) | 0.21 ± 0.24 (0.00) |
| Example 1 | 3 | 0.45 ± 0.01 (0.45) | 0.45 ± 0.01 (0.45) | 0.22 ± 0.01 (0.21) | 0.23 ± 0.13 (0.02) | 0.22 ± 0.01 (0.21) |
| Example 1S | 3 | 0.45 ± 0.00 (0.45) | 0.45 ± 0.00 (0.45) | 0.41 ± 0.04 (0.34) | 0.27 ± 0.11 (0.11) | 0.41 ± 0.04 (0.34) |
| Example 2 | 3 | 0.40 ± 0.20 (0.00) | 0.50 ± 0.00 (0.50) | 0.50 ± 0.00 (0.50) | 0.33 ± 0.04 (0.25) | 0.00 ± 0.00 (0.00) |
| Example 2S | 3 | 0.50 ± 0.00 (0.50) | 0.37 ± 0.15 (0.15) | 0.50 ± 0.00 (0.50) | 0.34 ± 0.01 (0.33) | 0.10 ± 0.20 (0.00) |
| Example 3 | 3 | 0.18 ± 0.04 (0.15) | 0.21 ± 0.02 (0.20) | 0.50 ± 0.01 (0.49) | 0.50 ± 0.01 (0.49) | 0.11 ± 0.20 (0.00) |
| Example 3S | 3 | 0.18 ± 0.04 (0.15) | 0.08 ± 0.03 (0.03) | 0.50 ± 0.00 (0.50) | 0.43 ± 0.09 (0.31) | 0.01 ± 0.00 (0.00) |
| Example 1 | 6 | 0.46 ± 0.01 (0.44) | 0.46 ± 0.09 (0.41) | 0.22 ± 0.01 (0.20) | 0.37 ± 0.14 (0.17) | 0.22 ± 0.01 (0.20) |
| Example 1S | 6 | 0.46 ± 0.02 (0.44) | 0.46 ± 0.02 (0.44) | 0.35 ± 0.10 (0.23) | 0.42 ± 0.12 (0.28) | 0.35 ± 0.10 (0.23) |
| Example 2 | 6 | 0.49 ± 0.01 (0.48) | 0.50 ± 0.01 (0.48) | 0.50 ± 0.00 (0.50) | 0.30 ± 0.01 (0.28) | 0.00 ± 0.00 (0.00) |
| Example 2S | 6 | 0.50 ± 0.00 (0.50) | 0.35 ± 0.12 (0.25) | 0.50 ± 0.00 (0.50) | 0.30 ± 0.01 (0.29) | 0.20 ± 0.24 (0.00) |
| Example 3 | 6 | 0.18 ± 0.04 (0.15) | 0.20 ± 0.01 (0.19) | 0.50 ± 0.00 (0.49) | 0.37 ± 0.16 (0.16) | 0.01 ± 0.01 (0.00) |
| Example 3S | 6 | 0.18 ± 0.04 (0.14) | 0.05 ± 0.04 (0.01) | 0.50 ± 0.00 (0.50) | 0.50 ± 0.00 (0.50) | 0.11 ± 0.20 (0.00) |
6.2. Experiments on Color MNIST Dataset
In this experiment, we set up a binary classification task for digit recognition and identify whether the digit is less than five or more than five. We use real-world dataset, the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/), for the construction. Following our learning setting, we use color information as the spurious features that correlates strongly with the class label. By construction, the label is more strongly correlated with the color than with the digit in the training environments, but this correlation is broken in the test environment. Specifically, the three designed environments (two training environments and one test environment containing 10,000 points each) of the color MNIST are as follows: first, we define a preliminary binary label to the image base on the digit: for digits 0–4 and for 5–9. Second, we obtain the final label y by flipping with probability 0.25. Then, we flip the final labels to obtain the color id, where the flipping probabilities with respect to two training environments and one test environment are 0.2 and 0.1, and 0.9. For better understanding, we randomly draw 20 examples for each label from each environment and visualize them in Figure 4.
Figure 4.
Visualization of the color mnist dataset.
The classification results on the color MNIST dataset are shown in Table 3. From the results, we can see that both ERM and IB-ERM methods almost surely use the color features to achieve the task. Although IRM and IB-IRM methods have shown some improvements over ERM, only our method can perform better than a random prediction, which demonstrates the effectiveness of CSIB.
Table 3.
Classification accuracy (%) on color MNIST dataset. “Oracle” in the table means that the training and test data are in the same environment.
| Methods | ERM | IRM | IB-ERM | IB-IRM | CSIB | Oracle |
|---|---|---|---|---|---|---|
| Accuracy | 9.94 ± 0.28 | 20.39 ± 2.76 | 9.94 ± 0.28 | 43.84 ± 12.48 | 60.03 ± 1.28 | 84.72 ± 0.65 |
7. Related Works
We divide the works related to OOD generalization into two categories: theory and methods, though some of them belong to both.
7.1. Theory of OOD Generalization
Based on different definitions to the distributional changes, we review the corresponding theory by the following three categories.
Based on causality. Due to the close connection between the distributional changes and the interventions discussed in the theory of causality [13,14], the problem of OOD generalization is usually built in the framework of causal learning. The theory states that a response Y is directly caused only by its parents variables , and all interventions other that those on Y do not change the conditional distribution of . Such theory inspires a popular learning principle—the invariance principle—that aims to discover a set of variables such that they remain invariant to the response Y in all observed environments [15,19,20]. Invariant risk minimization (IRM) [8] is then proposed to learn a feature extractor in an end-to-end way such that the optimal classifier based on the extracted features remains unchanged in each environment. The theory in [8] shows the guarantee of IRM for OOD generalization under some general assumptions but only focuses on the linear regression tasks. Different from the failure analyses of IRM for the classification tasks in [21,22], where the response Y is the cause of the spurious feature, Ahuja et al. [12] analyse another scenario when the invariant feature is the cause of the spurious feature and show that in this case, linear classification is more difficult than linear regression, where the invariance principle itself is insufficient to ensure the success of OOD generalization. They also claim that the assumption of support overlap of invariant features is necessarily needed. They then propose a learning principle of information bottleneck-based invariant risk minimization (IB-IRM) for linear classification, which shows how to address the failures of IRM by adding information bottleneck [16] into the learning. In this work, we closely investigate the conditions identified in [12] and first show that support overlap of invariant features is not necessarily needed for the success of OOD generalization. We further show several failure cases of IB-IRM and propose improved results for it.
Recently, some works tackle the challenge of OOD generalization in the nonlinear regime [23,24]. Commonly, both of them use variational autoencoder (VAE)-based models [25,26] to identify the latent variables from observations in the first stage. Then, these inferring latent variables are separated into two distinct parts of invariant (causal) and spurious (non-causal) features based on different assumptions. Specifically, Lu et al. [23,27] assume that the latent variables conditioned on some accessible side information such as the environment index or class label follow the exponential family distributions, and Liu et al. [24] directly disentangle the latent variables to two different parts during the inferring stage and assume that the marginal distributions of them are independent of each other. These assumptions, however, are rather strong in general. Nevertheless, these solutions aim to capture the latent variables such that the response given these variables is invariant for different environments, which could still fail because the invariance principle itself is insufficient for OOD generalization in the classification tasks, as shown in [12]. In this work, we focus on the linear classification only and show a new theory of a new method that addresses several OOD generalization failures in the linear settings. Our method could extend to the nonlinear regime by combining with the disentangled representation learning [28] or causal representation learning [29]. Specifically, once the latent representations are well disentangled, i.e., the latent features are represented by a linear transform of the causal features and spurious features, we then could apply our method to filter out the spurious features in the latent space such that only causal features remain.
Based on robustness. Different from those based on the causality, where different distributions are generated by intervention on a same SEM and the goal is to discover causal features, the robustness-based methods aim to protect the model against the potential distributional shifts within the uncertainty set, which is usually constrained by f-divergence [30] or Wasserstein distance [31]. This series of works is theoretically addressed by distributionally robust optimization (DRO) under a minimax framework [32,33]. Recently, some works tend to discover the connections between causality and robustness [34]. Although these works show less relevance to us, it is possible that a well-defined measure of distribution divergence could help to effectively extract causal features under the robustness framework. This would be an interesting avenue for future research.
Others. Some other works assume that the distributions (domains) are generated from a hyper-distribution and aim to minimize the average risk estimation error bound [35,36,37]. These works are often built based on the generalization theory under the independent and identically distributed (IID) assumption. The authors in [38] do not make any assumption on the distributional changes and only study the learnability of OOD generalization in a general way. All of these theories do not cover the OOD generalization problem under a single training environment or domain.
7.2. Methods of OOD Generalization
Based on the invariance principle. Inspired from the invariance principle [15,19], many methods are proposed by designing various loss to extract features to better satisfy the principle itself. IRMv1 [8] is the first objective to address this in an end-to-end way by adding a gradient penalty to the classifier. Following this work, Krueger et al. [9] suggest penalizing the variance of the risks, while Xie et al. [39] give the same objective but take the square root of the variance, and many other alternatives can also be found [40,41,42]. It is clear that all of these methods aim to find an invariant predictor. Recently, Ahuja et al. [12] found that for the classification problem, finding the invariant predictor is not enough to extract causal features since the features could include spurious information to make the predictor invariant across training environments, and they propose IB-IRM to address such a failure. Similar ideas to IB-IRM can also be found in the work [43,44], where different loss functions are proposed to achieve the same purpose. Specifically, Alesiani et al. [44] also use the information bottleneck (IB) for the help in dropping spurious correlations, but their analyses only focus on the scenario when spurious features are independent from the causal features, which could be considered as a special case of ours. More recently, Wang et al. [45] propose similar ideas to ours but only tackle the situation when the invariant features have the same distribution among all environments. In this work, we further show that IB-IRM could still fail in two cases due to the model only relying on spurious features to meet the task of interest. We then propose a counterfactual supervision-based information bottleneck (CSIB) method to address such failures and show improving results to prior works.
Based on distribution matching. It is worth noting that there are many works focused on learning domain invariant features representations [46,47,48]. Most of these works are inspired by the seminal theory of domain adaptation [49,50]. The goal of these methods is to learn a feature extractor such that the marginal distribution of or the conditional distribution of is invariant across different domains. This is different from the invariance principle, where the goal is to make (or ) invariant. We refer readers to the papers of [8,51] for better understanding the details of why these distribution-matching-based methods often fail to address OOD generalization.
Others. Other related methods are varied, including by using data augmentation in both image level [52] or feature level [53], by removing spurious correlations through stable learning [54], and by utilizing the inductive bias of neural networks [3,55], etc. Most of these methods are empirically inspired from experiments and are verified on some specific datasets. Recently, empirical studies in [56,57] notice that the real effects of many OOD generalization (domain generalization) methods are weak, which indicates that the benchmark-based evaluation criteria may be inadequate to validate the OOD generalization algorithms.
8. Conclusions, Limitations and Future Work
In this paper, we focus on the OOD generalization problem of linear classification. We first revisit the fundamental assumptions and results of prior works and show that the condition of invariant features supporting overlap is not necessarily needed for the success of OOD generalization and thus propose a weaker counterpart. Then, we show two failure cases of IB-IRM (as well as ERM, IB-ERM, and IRM) and illustrate its intrinsic causes by theoretical analysis. We further propose a new method—counterfactual supervision-based information bottleneck (CSIB)—and theoretically prove its effectiveness under some weaker assumptions. CSIB works even when accessing data from a single environment and can easily extend to the multi-class problems. Finally, we design several synthetic datasets with our examples for experimental verification. Empirical observations among all comparing methods illustrate the effectiveness of the CSIB.
Since we only take the linear problem into account, including linear representation and linear classifier, any nonlinear case would not be guaranteed by our theoretical results, and thus CSIB may fail. Therefore, the same as prior works (IRM [8] and IB-IRM [12]), the nonlinear challenge is still an unsolved problem [21,22]. We believe this is of great value for investigating in future work since widely used data in the wild are nonlinearly generated. Another fruitful direction is to design a powerful algorithm for entropy minimization during the learning process of CSIB. Currently, we use the variance of features to replace the entropy of the features during optimization. However, variance and entropy are essentially different. A truly effective entropy minimization is the key to the success of CSIB. Another limitation of our method is that we have to require further supervision to the counterfactual examples during the learning process, although it only takes one time for a single step.
Abbreviations
The following abbreviations are used in this manuscript:
| OOD | Out-of-distribution |
| ERM | Empirical risk minimization |
| IRM | Invariant risk minimization |
| IB-ERM | Information bottleneck-based empirical risk minimization |
| IB-IRM | Information bottleneck-based invariant risk minimization |
| CSIB | Counterfactual supervision-based information bottleneck |
| DAG | Directed acyclic graph |
| SEM | Structure equation model |
| SVD | Singular value decomposition |
Appendix A. Experiments Details
In this section, we provide more details on the experiments. The code to reproduce the experiments can be found at https://github.com/szubing/CSIB.
Appendix A.1. Optimization Loss of IB-ERM
The objective function of IB-ERM is as follows:
| (A1) |
Since the entropy of is hard to estimate by a differential variable that can be optimized by using gradient descent, we follow [12] by using the variance instead of the entropy for optimization. The total loss function is given by
| (A2) |
with a hyperparameter onto it.
Appendix A.2. Experiments Setup
Model, hyperparameters, loss, and evaluation. In all experiments, we follow the same protocol as prescribed by [12,17] for the model / hyperparameter selection, training, and evaluation. Except those specified, for all experiments across three examples and five comparing methods, the model is the same with a linear feature extractor followed by a linear classifier . We use binary cross-entropy loss for classification. All hyperparameters, including the learning rate, the penalty term in IRM, or the associated with the Var in Equation (A2), etc., are randomly searched and selected by using 20 test samples for validation. The results reported in the main manuscript use three hyperparameter queries of each and average over five data seeds. The results when searching over more hyperparameter values are reported in the supplementary experiments. The search spaces of all the hyperparameters follow the same as in [12,17]. The classification test errors between 0 and 1 are reported.
Compute description. Our computing resource is one GPU of NVIDIA GeForce GTX 1080 Ti with 6 CPU cores of Intel(R) Core(TM) i7-8700 CPU @ 3.20GHz.
Existing codes and datasets used. In our experiments, we mainly rely on the following two github repositories: InvarianceUnitTests (https://github.com/facebookresearch/InvarianceUnitTests) and IB-IRM (https://github.com/ahujak/IB-IRM).
Appendix A.3. Supplementary Experiments
The purpose of the first supplementary experiment is to illustrate what the result would be when we increase the number of running seeds in the hyperparameters selection. These results are shown in Table A1, where we increase the number of hyperparameter queries to 10 of each. It is clear that overall, the results of the CSIB in Table A1 are much better and have less fluctuations than those in Table 2, and the conclusions remain almost the same as we have summarized in Section 6.1.2. This further verifies the effectiveness of the CSIB method.
Table A1.
Supplementary results when using 10 hyperparameter queries. #Envs means the number of training environments, and (min) reports the minimal test classification error across different running data seeds.
| #Envs | ERM (min) | IRM (min) | IB-ERM (min) | IB-IRM (min) | CSIB (min) | Oracle (min) | |
|---|---|---|---|---|---|---|---|
| Example 1 | 1 | 0.50 ± 0.01 (0.49) | 0.50 ± 0.01 (0.49) | 0.23 ± 0.02 (0.22) | 0.31 ± 0.10 (0.25) | 0.23 ± 0.02 (0.22) | 0.00 ± 0.00 (0.00) |
| Example 1S | 1 | 0.50 ± 0.00 (0.49) | 0.50 ± 0.00 (0.49) | 0.09 ± 0.04 (0.04) | 0.30 ± 0.10 (0.25) | 0.08 ± 0.04 (0.04) | 0.00 ± 0.00 (0.00) |
| Example 2 | 1 | 0.40 ± 0.20 (0.00) | 0.00 ± 0.00 (0.00) | 0.50 ± 0.00 (0.49) | 0.48 ± 0.03 (0.43) | 0.00 ± 0.00 (0.00) | 0.00 ± 0.00 (0.00) |
| Example 2S | 1 | 0.50 ± 0.00 (0.50) | 0.30 ± 0.25 (0.00) | 0.50 ± 0.00 (0.50) | 0.50 ± 0.01 (0.48) | 0.00 ± 0.00 (0.00) | 0.00 ± 0.00 (0.00) |
| Example 3 | 1 | 0.16 ± 0.06 (0.09) | 0.03 ± 0.00 (0.03) | 0.50 ± 0.01 (0.49) | 0.41 ± 0.09 (0.25) | 0.02 ± 0.01 (0.00) | 0.00 ± 0.00 (0.00) |
| Example 3S | 1 | 0.16 ± 0.06 (0.10) | 0.04 ± 0.01 (0.02) | 0.50 ± 0.00 (0.50) | 0.41 ± 0.12 (0.26) | 0.01 ± 0.01 (0.00) | 0.00 ± 0.00 (0.00) |
| Example 1 | 3 | 0.44 ± 0.01 (0.44) | 0.44 ± 0.01 (0.44) | 0.21 ± 0.00 (0.21) | 0.21 ± 0.10 (0.06) | 0.21 ± 0.00 (0.21) | 0.00 ± 0.00 (0.00) |
| Example 1S | 3 | 0.45 ± 0.00 (0.44) | 0.45 ± 0.00 (0.44) | 0.09 ± 0.03 (0.05) | 0.23 ± 0.13 (0.01) | 0.09 ± 0.03 (0.05) | 0.00 ± 0.00 (0.00) |
| Example 2 | 3 | 0.13 ± 0.07 (0.00) | 0.00 ± 0.00 (0.00) | 0.50 ± 0.00 (0.50) | 0.33 ± 0.04 (0.25) | 0.00 ± 0.00 (0.00) | 0.00 ± 0.00 (0.00) |
| Example 2S | 3 | 0.50 ± 0.00 (0.50) | 0.14 ± 0.20 (0.00) | 0.50 ± 0.00 (0.50) | 0.34 ± 0.01 (0.33) | 0.00 ± 0.00 (0.00) | 0.00 ± 0.00 (0.00) |
| Example 3 | 3 | 0.17 ± 0.04 (0.14) | 0.02 ± 0.00 (0.02) | 0.50 ± 0.01 (0.49) | 0.43 ± 0.08 (0.29) | 0.01 ± 0.00 (0.00) | 0.00 ± 0.00 (0.00) |
| Example 3S | 3 | 0.17 ± 0.04 (0.13) | 0.02 ± 0.00 (0.02) | 0.50 ± 0.00 (0.50) | 0.36 ± 0.18 (0.07) | 0.01 ± 0.00 (0.00) | 0.00 ± 0.00 (0.00) |
| Example 1 | 6 | 0.46 ± 0.01 (0.44) | 0.46 ± 0.09 (0.41) | 0.22 ± 0.01 (0.21) | 0.41 ± 0.11 (0.26) | 0.22 ± 0.01 (0.21) | 0.00 ± 0.00 (0.00) |
| Example 1S | 6 | 0.46 ± 0.02 (0.44) | 0.46 ± 0.02 (0.44) | 0.06 ± 0.04 (0.02) | 0.45 ± 0.07 (0.41) | 0.06 ± 0.04 (0.02) | 0.00 ± 0.00 (0.00) |
| Example 2 | 6 | 0.21 ± 0.03 (0.17) | 0.00 ± 0.00 (0.00) | 0.50 ± 0.00 (0.50) | 0.36 ± 0.03 (0.31) | 0.00 ± 0.00 (0.00) | 0.00 ± 0.00 (0.00) |
| Example 2S | 6 | 0.50 ± 0.00 (0.50) | 0.10 ± 0.20 (0.00) | 0.50 ± 0.00 (0.50) | 0.19 ± 0.16 (0.01) | 0.00 ± 0.00 (0.00) | 0.00 ± 0.00 (0.00) |
| Example 3 | 6 | 0.17 ± 0.03 (0.14) | 0.02 ± 0.00 (0.02) | 0.50 ± 0.00 (0.49) | 0.37 ± 0.16 (0.16) | 0.01 ± 0.00 (0.00) | 0.00 ± 0.00 (0.00) |
| Example 3S | 6 | 0.17 ± 0.03 (0.14) | 0.02 ± 0.00 (0.02) | 0.50 ± 0.00 (0.50) | 0.46 ± 0.09 (0.28) | 0.01 ± 0.00 (0.00) | 0.00 ± 0.00 (0.00) |
Observation on different settings in Example 1/1S. In our main experiments of Example 1/1S, we set and to make the spurious features and the invariant features both linearly separable to confuse each other. Here, we analyse what the result would be if we vary their values. Following [17], we set , , , and to make spurious features linearly inseparable, and q is set to 0/0.05 to make invariant features linearly separable/inseparable. Table A2 shows the corresponding results. Interestingly, we find that all methods except for IB-IRM have an ideal error rate (the same as the Oracle) when the spurious features are linearly inseparable (), even when the invariant features are linearly inseparable too (). Why would this happen? We then remove the linear embedding . The results are presented in Table A3. Comparing the results between Table A2 and Table A3, we found there is a significant inductive bias of the neural network, though the model is linear. Further analysis to such observation is out of the scope of this paper, but this would be an interesting avenue for future research.
Table A2.
Results in Example 1/1S, where the learning model is a linear embedding followed by a linear classifier .
| #Envs | ? | q | ERM | IB-ERM | IB-IRM | CSIB | IRM | Oracle | |
|---|---|---|---|---|---|---|---|---|---|
| Example 1 | 1 | Yes | 0 | 0.50 ± 0.01 | 0.23 ± 0.02 | 0.31 ± 0.10 | 0.23 ± 0.02 | 0.50 ± 0.01 | 0.00 ± 0.00 |
| Example 1S | 1 | Yes | 0 | 0.50 ± 0.00 | 0.46 ± 0.04 | 0.30 ± 0.10 | 0.46 ± 0.04 | 0.50 ± 0.00 | 0.00 ± 0.00 |
| Example 1 | 3 | Yes | 0 | 0.45 ± 0.01 | 0.22 ± 0.01 | 0.23 ± 0.13 | 0.22 ± 0.01 | 0.45 ± 0.01 | 0.00 ± 0.00 |
| Example 1S | 3 | Yes | 0 | 0.45 ± 0.00 | 0.41 ± 0.04 | 0.27 ± 0.11 | 0.41 ± 0.04 | 0.45 ± 0.00 | 0.00 ± 0.00 |
| Example 1 | 6 | Yes | 0 | 0.46 ± 0.01 | 0.22 ± 0.01 | 0.37 ± 0.14 | 0.22 ± 0.01 | 0.46 ± 0.09 | 0.00 ± 0.00 |
| Example 1S | 6 | Yes | 0 | 0.46 ± 0.02 | 0.35 ± 0.10 | 0.42 ± 0.12 | 0.35 ± 0.10 | 0.46 ± 0.02 | 0.00 ± 0.00 |
| Example 1 | 1 | No | 0 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.15 ± 0.20 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.00 ± 0.00 |
| Example 1S | 1 | No | 0 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.12 ± 0.19 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.00 ± 0.00 |
| Example 1 | 3 | No | 0 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.00 ± 0.00 |
| Example 1S | 3 | No | 0 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.00 ± 0.01 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.00 ± 0.00 |
| Example 1 | 6 | No | 0 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.30 ± 0.20 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.00 ± 0.00 |
| Example 1S | 6 | No | 0 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.31 ± 0.20 | 0.00 ± 0.00 | 0.04 ± 0.06 | 0.00 ± 0.00 |
| Example 1 | 1 | No | 0.05 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.32 ± 0.22 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.05 ± 0.00 |
| Example 1S | 1 | No | 0.05 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.19 ± 0.17 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.05 ± 0.00 |
| Example 1 | 3 | No | 0.05 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.07 ± 0.03 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.05 ± 0.00 |
| Example 1S | 3 | No | 0.05 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.05 ± 0.00 |
| Example 1 | 6 | No | 0.05 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.30 ± 0.21 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.05 ± 0.00 |
| Example 1S | 6 | No | 0.05 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.32 ± 0.19 | 0.05 ± 0.00 | 0.05 ± 0.00 | 0.05 ± 0.00 |
Table A3.
Results in Example 1/1S, where the learning model is a linear classifier without linear embedding . The CSIB must require a feature extractor, so there are not results related to the CSIB.
| #Envs | ? | q | ERM | IB-ERM | IB-IRM | IRM | Oracle | |
|---|---|---|---|---|---|---|---|---|
| Example 1 | 1 | Yes | 0 | 0.50 ± 0.01 | 0.25 ± 0.01 | 0.31 ± 0.10 | 0.50 ± 0.01 | 0.00 ± 0.00 |
| Example 1S | 1 | Yes | 0 | 0.50 ± 0.00 | 0.49 ± 0.01 | 0.30 ± 0.10 | 0.50 ± 0.00 | 0.00 ± 0.00 |
| Example 1 | 3 | Yes | 0 | 0.44 ± 0.01 | 0.23 ± 0.01 | 0.21 ± 0.10 | 0.44 ± 0.01 | 0.00 ± 0.00 |
| Example 1S | 3 | Yes | 0 | 0.45 ± 0.00 | 0.44 ± 0.01 | 0.42 ± 0.04 | 0.45 ± 0.00 | 0.00 ± 0.00 |
| Example 1 | 6 | Yes | 0 | 0.46 ± 0.01 | 0.27 ± 0.07 | 0.41 ± 0.11 | 0.46 ± 0.01 | 0.01 ± 0.01 |
| Example 1S | 6 | Yes | 0 | 0.46 ± 0.02 | 0.42 ± 0.08 | 0.46 ± 0.09 | 0.46 ± 0.02 | 0.01 ± 0.02 |
| Example 1 | 1 | No | 0 | 0.50 ± 0.01 | 0.00 ± 0.00 | 0.15 ± 0.20 | 0.50 ± 0.01 | 0.00 ± 0.00 |
| Example 1S | 1 | No | 0 | 0.50 ± 0.00 | 0.00 ± 0.00 | 0.13 ± 0.19 | 0.50 ± 0.00 | 0.00 ± 0.00 |
| Example 1 | 3 | No | 0 | 0.45 ± 0.01 | 0.00 ± 0.00 | 0.00 ± 0.00 | 0.45 ± 0.01 | 0.00 ± 0.00 |
| Example 1S | 3 | No | 0 | 0.45 ± 0.00 | 0.01 ± 0.02 | 0.08 ± 0.14 | 0.46 ± 0.02 | 0.00 ± 0.00 |
| Example 1 | 6 | No | 0 | 0.46 ± 0.01 | 0.10 ± 0.16 | 0.30 ± 0.20 | 0.46 ± 0.01 | 0.01 ± 0.01 |
| Example 1S | 6 | No | 0 | 0.46 ± 0.01 | 0.24 ± 0.19 | 0.41 ± 0.12 | 0.47 ± 0.03 | 0.01 ± 0.02 |
| Example 1 | 1 | No | 0.05 | 0.50 ± 0.01 | 0.05 ± 0.00 | 0.32 ± 0.22 | 0.50 ± 0.01 | 0.05 ± 0.00 |
| Example 1S | 1 | No | 0.05 | 0.50 ± 0.01 | 0.05 ± 0.01 | 0.20 ± 0.17 | 0.50 ± 0.00 | 0.05 ± 0.00 |
| Example 1 | 3 | No | 0.05 | 0.45 ± 0.01 | 0.05 ± 0.00 | 0.07 ± 0.03 | 0.47 ± 0.01 | 0.05 ± 0.00 |
| Example 1S | 3 | No | 0.05 | 0.45 ± 0.01 | 0.07 ± 0.03 | 0.11 ± 0.11 | 0.46 ± 0.01 | 0.05 ± 0.00 |
| Example 1 | 6 | No | 0.05 | 0.47 ± 0.01 | 0.14 ± 0.14 | 0.30 ± 0.21 | 0.47 ± 0.01 | 0.05 ± 0.00 |
| Example 1S | 6 | No | 0.05 | 0.47 ± 0.01 | 0.27 ± 0.18 | 0.42 ± 0.11 | 0.47 ± 0.01 | 0.05 ± 0.01 |
Observation on linearly separable properties of high-dimensional data. Here, we empirically show that for o-dimensional data, we have high probability that o randomly drawn points are linearly separable for any two subsets. To verify that, we design a random experiment as follows: (1) Let , and we randomly draw o points from , and give random labels to these o points of 0 or 1. (2) We train a linear classifier to fit these o points and report the final training error. (3) We perform (1) and (2) 100 times for different seeds. Our results show that for 100 runs, all training errors reach 0 for every o, which proves our conjecture.
Then, we look back to Theorem 2. For real data, such as an image, the dimension of spurious features o is often high. Assume different environments enjoy different spurious points randomly; then, from the above observation, there is a high probability that the following events will occur: For any labeling data in the n training environments with (2 is due to the binary label), models could achieve zero training error by relying on spurious features only. This illustrates why prior methods easily fail to address OOD generalization under Assumption 6.
Appendix B. Proofs
Appendix B.1. Preliminary
Before our proofs, we first review some useful properties related to the entropy [12,58].
Entropy. For discrete random variable with support , its entropy (Shannon entropy) is defined as
| (A3) |
The differential entropy of the continuous random variable with support is given by
| (A4) |
where is the probability density function of the distribution . Sometimes, we may confuse using or to represent its entropy no matter whether X is discrete or continuous.
Lemma A1.
If X and Y are discrete random variables that are independent, then
(A5)
Proof.
Define . Since , we have
and similar we have . Therefore,
(A6)
(A7) This completes the proof. □
Lemma A2.
If X and Y are continuous random variables that are independent, then
(A8)
Proof.
Define . Since , we have
and similar, we have . Therefore,
(A9)
(A10) This completes the proof. □
Lemma A3.
If X and Y are discrete random variables that are independent with the supports satisfying , then
(A11)
Proof.
From Lemma A1 and due to the symmetry of X and Y, we only need to prove . The proof is by contradiction. Suppose , then from Equation (A7) it follows that , thus . However, , which is different from (due to ). This contradicts . □
Lemma A4.
If X and Y are continuous random variables that are independent and have a bounded support, then
(A12)
Proof.
From Lemma A2 and due to the symmetry of X and Y, we only need to prove . The proof is by contradiction. Suppose , then from Equation (A10) it follows that , thus . For any , define an event . If occurs, then and . Thus, . However, we can always choose a that is small enough to make . This contradicts . □
Appendix B.2. Proof of Theorem 2
Proof.
The proof is trivial. Since two sets and are linearly separable, there exists a linear classifier w that only relies on spurious features and can achieve zero classification error on each environment. Therefore, w is an invariant predictor across different training environments. In addition, would make IB-IRM prefer to choose these spurious features. Therefore, w would be an optimal solution of IB-IRM, ERM, IRM, and IB-ERM. However, since w relies on spurious features which may change arbitrary in unseen environments, it thus fails to solve OOD generalization. □
Appendix B.3. Proof of Theorem 3
Proof.
Assume and are the feature extractor and classifier learned by IB-ERM. Consider the feature variable extracted by as
(A13) We first show that or . We prove this by contradiction. Assume and . By observing that a solution of could make the average training error to q; therefore any solution returned by IB-ERM should also achieve the error no larger than q (because in the constraint of Equation (12)). Therefore .
In the case when each follows Assumption 4 of , we have
Then, for any of , we must have for any to make error no larger than q. Since is zero mean with at least two distinct points in each component, we can conclude that . Similarly, for any of , we have . From Lemma A3 or Lemma A4, we obtain . Therefore, there exists a more optimal solution to IB-ERM with zero weight to , which contradicts the assumption.
In the case when each follows Assumption 5 of , we have
From Lemma A3 or Lemma A4, we obtain . In addition, the spurious features are assumed to be linearly separable. Therefore, there exists a more optimal solution to IB-ERM with zero weight to , which contradicts the assumption.
In the case when each follows Assumption 6 of , we have
Then, for any of , we must have for any and to make error no larger than q. Since and are both zero mean variables with at least two distinct points in each component, we can conclude that ; Similarly, for any of , we have . From Lemma A3 or Lemma A4, we obtain . Therefore, there exists a more optimal solution to IB-ERM with zero weight to , which contradicts the assumption.
So far, we have proved that the feature extractor learned by IB-ERM would never extract both spurious features and invariant features together. Then, we perform singular value decomposition (SVD) to the as
(A14) Let be the orthogonal matrix. Set r as the rank of the matrix , i.e., , and let with and , and with and , then
(A15) Since contains the information either from spurious features or from invariant features, we must have or , and thus, or due to . If , then extract invariant features only. Otherwise when , we decompose the by
(A16) Since and S are both the orthogonal matrix, is also orthogonal; thus , and then (note that ). Then,
(A17) Therefore, by running the CSIB for one iteration, the rank of spurious features would be decreased by . This would result in zero weight to spurious features by finite runs of CSIB.
Then, we intend to show why the counterfactual supervision step could help to distinguish whether is or not. For a specific instance , let two new features be and , then and ; and . Back the new features and to the input space as and . If , then
and similarly we have . Therefore, the ground truths of and are the same. On other hand, if , then , and
and similarly we have . Since and their magnitudes are large enough to make ; thus the ground truths of and would be different. Therefore, the counterfactual supervision step could help to detect whether invariant features or spurious features are extracted by using a single sample only.
Finally, when only invariant features are extracted by , the training error is minimized, i.e., . Then, based on our assumption to the OOD environments (Assumptions 8), i.e., , therefore, for any , we have . □
It is worth noting that the proof of Theorem 3 does not rely on how many labels there would be, so it is easily extended to the multi-class classification case as long as the corresponding assumptions and conditions are satisfied.
Author Contributions
Conceptualization, B.D. and K.J.; methodology, B.D.; software, B.D.; validation, B.D. and K.J.; formal analysis, B.D.; investigation, B.D.; resources, B.D.; data curation, B.D.; writing—original draft preparation, B.D.; writing—review and editing, B.D. and K.J.; visualization, B.D.; supervision, K.J. All authors have read and agreed to the published version of the manuscript.
Institutional Review Board Statement
Not applicable.
Informed Consent Statement
Not applicable.
Data Availability Statement
Data is contained within the article or supplementary material.
Conflicts of Interest
The authors declare no conflict of interest.
Funding Statement
This research (including the APC) was funded by Guangdong R&D key project of China (No.: 2019B010155001).
Footnotes
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.
References
- 1.Szegedy C., Zaremba W., Sutskever I., Bruna J., Erhan D., Goodfellow I., Fergus R. Intriguing properties of neural networks. arXiv. 20131312.6199 [Google Scholar]
- 2.Rosenfeld A., Zemel R., Tsotsos J.K. The elephant in the room. arXiv. 20181808.03305 [Google Scholar]
- 3.Geirhos R., Rubisch P., Michaelis C., Bethge M., Wichmann F.A., Brendel W. ImageNet-trained CNNs are biased towards texture; increasing shape bias improves accuracy and robustness; Proceedings of the International Conference on Learning Representations; New Orleans, LA, USA. 6–9 May 2019. [Google Scholar]
- 4.Nguyen A., Yosinski J., Clune J. Deep neural networks are easily fooled: High confidence predictions for unrecognizable images; Proceedings of the Computer Vision and Pattern Recognition Conference; Boston, MA, USA. 7–12 June 2015; pp. 427–436. [Google Scholar]
- 5.Gururangan S., Swayamdipta S., Levy O., Schwartz R., Bowman S.R., Smith N.A. Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 2 (Short Papers) Association for Computational Linguistics; New Orleans, LA, USA: 2018. Annotation Artifacts in Natural Language Inference Data. [Google Scholar]
- 6.Geirhos R., Jacobsen J.H., Michaelis C., Zemel R., Brendel W., Bethge M., Wichmann F.A. Shortcut learning in deep neural networks. Nat. Mach. Intell. 2020;2:665–673. doi: 10.1038/s42256-020-00257-z. [DOI] [Google Scholar]
- 7.Beery S., Van Horn G., Perona P. Recognition in terra incognita; Proceedings of the European Conference on Computer Vision; Munich, Germany. 8–14 September 2018; pp. 456–473. [Google Scholar]
- 8.Arjovsky M., Bottou L., Gulrajani I., Lopez-Paz D. Invariant risk minimization. arXiv. 20191907.02893 [Google Scholar]
- 9.Krueger D., Caballero E., Jacobsen J.H., Zhang A., Binas J., Zhang D., Le Priol R., Courville A. Out-of-distribution generalization via risk extrapolation (rex); Proceedings of the International Conference on Machine Learning, PMLR; Virtual. 18–24 July 2021; pp. 5815–5826. [Google Scholar]
- 10.Ahuja K., Shanmugam K., Varshney K., Dhurandhar A. Invariant risk minimization games; Proceedings of the International Conference on Machine Learning, PMLR; Virtual. 13–18 July 2020; pp. 145–155. [Google Scholar]
- 11.Pezeshki M., Kaba O., Bengio Y., Courville A.C., Precup D., Lajoie G. Gradient starvation: A learning proclivity in neural networks; Proceedings of the Neural Information Processing Systems; Virtual. 6–14 December 2021; [Google Scholar]
- 12.Ahuja K., Caballero E., Zhang D., Gagnon-Audet J.C., Bengio Y., Mitliagkas I., Rish I. Invariance principle meets information bottleneck for out-of-distribution generalization; Proceedings of the Neural Information Processing Systems; Virtual. 6–14 December 2021; [Google Scholar]
- 13.Pearl J. Causality. Cambridge University Press; Cambridge, UK: 2009. [Google Scholar]
- 14.Peters J., Janzing D., Schölkopf B. Elements of Causal Inference: Foundations and Learning Algorithms. The MIT Press; Cambridge, MA, USA: 2017. [Google Scholar]
- 15.Peters J., Bühlmann P., Meinshausen N. Causal inference by using invariant prediction: Identification and confidence intervals. J. R. Stat. Soc. Ser. B. 2016;78:947–1012. doi: 10.1111/rssb.12167. [DOI] [Google Scholar]
- 16.Tishby N. The information bottleneck method; Proceedings of the Annual Allerton Conference on Communications, Control and Computing; Monticello, IL, USA. 22–24 September 1999; pp. 368–377. [Google Scholar]
- 17.Aubin B., Słowik A., Arjovsky M., Bottou L., Lopez-Paz D. Linear unit-tests for invariance discovery. arXiv. 20212102.10867 [Google Scholar]
- 18.Soudry D., Hoffer E., Nacson M.S., Gunasekar S., Srebro N. The implicit bias of gradient descent on separable data. J. Mach. Learn. Res. 2018;19:2822–2878. [Google Scholar]
- 19.Heinze-Deml C., Peters J., Meinshausen N. Invariant causal prediction for nonlinear models. arXiv. 2018 doi: 10.1515/jci-2017-0016.1706.08576 [DOI] [Google Scholar]
- 20.Rojas-Carulla M., Schölkopf B., Turner R., Peters J. Invariant models for causal transfer learning. J. Mach. Learn. Res. 2018;19:1309–1342. [Google Scholar]
- 21.Rosenfeld E., Ravikumar P.K., Risteski A. The Risks of Invariant Risk Minimization; Proceedings of the International Conference on Learning Representations; Virtual. 3–7 May 2021. [Google Scholar]
- 22.Kamath P., Tangella A., Sutherland D., Srebro N. Does invariant risk minimization capture invariance?; Proceedings of the International Conference on Artificial Intelligence and Statistics, PMLR; San Diego, CA, USA. 13–15 April 2021; pp. 4069–4077. [Google Scholar]
- 23.Lu C., Wu Y., Hernández-Lobato J.M., Schölkopf B. Invariant Causal Representation Learning for Out-of-Distribution Generalization; Proceedings of the International Conference on Learning Representations; Virtual. 25–29 December 2022. [Google Scholar]
- 24.Liu C., Sun X., Wang J., Tang H., Li T., Qin T., Chen W., Liu T.Y. Learning causal semantic representation for out-of-distribution prediction; Proceedings of the Neural Information Processing Systems; Virtual. 6–14 December 2021; [Google Scholar]
- 25.Kingma D.P., Welling M. Auto-encoding variational bayes. arXiv. 20131312.6114 [Google Scholar]
- 26.Rezende D.J., Mohamed S., Wierstra D. Stochastic backpropagation and approximate inference in deep generative models; Proceedings of the International Conference on Machine Learning, PMLR; Beijing, China. 21–26 June 2014; pp. 1278–1286. [Google Scholar]
- 27.Lu C., Wu Y., Hernández-Lobato J.M., Schölkopf B. Nonlinear invariant risk minimization: A causal approach. arXiv. 20212102.12353 [Google Scholar]
- 28.Bengio Y., Courville A., Vincent P. Representation learning: A review and new perspectives. IEEE Trans. Pattern Anal. Mach. Intell. 2013;35:1798–1828. doi: 10.1109/TPAMI.2013.50. [DOI] [PubMed] [Google Scholar]
- 29.Schölkopf B., Locatello F., Bauer S., Ke N.R., Kalchbrenner N., Goyal A., Bengio Y. Toward causal representation learning. Proc. IEEE. 2021;109:612–634. doi: 10.1109/JPROC.2021.3058954. [DOI] [Google Scholar]
- 30.Namkoong H., Duchi J.C. Stochastic gradient methods for distributionally robust optimization with f-divergences; Proceedings of the Neural Information processing Systems; Barcelona, Spain. 5–10 December 2016; [Google Scholar]
- 31.Sinha A., Namkoong H., Volpi R., Duchi J. Certifying some distributional robustness with principled adversarial training. arXiv. 20171710.10571 [Google Scholar]
- 32.Lee J., Raginsky M. Minimax statistical learning with wasserstein distances; Proceedings of the Neural Information Processing Systems; Montreal, Canada. 3–8 December 2018; [Google Scholar]
- 33.Duchi J.C., Namkoong H. Learning models with uniform performance via distributionally robust optimization. Ann. Stat. 2021;49:1378–1406. doi: 10.1214/20-AOS2004. [DOI] [Google Scholar]
- 34.Bühlmann P. Invariance, causality and robustness. Stat. Sci. 2020;35:404–426. doi: 10.1214/19-STS721. [DOI] [Google Scholar]
- 35.Blanchard G., Lee G., Scott C. Generalizing from several related classification tasks to a new unlabeled sample; Proceedings of the Neural Information Processing Systems; Granada, Spain. 12–15 December 2011; [Google Scholar]
- 36.Muandet K., Balduzzi D., Schölkopf B. Domain generalization via invariant feature representation; Proceedings of the International Conference on Machine Learning, PMLR; Atlanta, GA, USA. 16–21 June 2013; pp. 10–18. [Google Scholar]
- 37.Deshmukh A.A., Lei Y., Sharma S., Dogan U., Cutler J.W., Scott C. A generalization error bound for multi-class domain generalization. arXiv. 20191905.10392 [Google Scholar]
- 38.Ye H., Xie C., Cai T., Li R., Li Z., Wang L. Towards a Theoretical Framework of Out-of-Distribution Generalization; Proceedings of the Neural Information Processing Systems; Virtual. 6–14 December 2021. [Google Scholar]
- 39.Xie C., Chen F., Liu Y., Li Z. Risk variance penalization: From distributional robustness to causality. arXiv. 20202006.07544 [Google Scholar]
- 40.Jin W., Barzilay R., Jaakkola T. Domain extrapolation via regret minimization. arXiv. 20202006.03908 [Google Scholar]
- 41.Mahajan D., Tople S., Sharma A. Domain generalization using causal matching; Proceedings of the International Conference on Machine Learning, PMLR; Virtual. 18–24 July 2021; pp. 7313–7324. [Google Scholar]
- 42.Bellot A., van der Schaar M. Generalization and invariances in the presence of unobserved confounding. arXiv. 20202007.10653 [Google Scholar]
- 43.Li B., Shen Y., Wang Y., Zhu W., Reed C.J., Zhang J., Li D., Keutzer K., Zhao H. Invariant information bottleneck for domain generalization; Proceedings of the Association for the Advancement of Artificial Intelligence; Virtual. 22 Februay–1 March 2022. [Google Scholar]
- 44.Alesiani F., Yu S., Yu X. Gated information bottleneck for generalization in sequential environments. Knowl. Informat. Syst. 2022. pp. 1–23. in press . [DOI]
- 45.Wang H., Si H., Li B., Zhao H. Provable Domain Generalization via Invariant-Feature Subspace Recovery; Proceedings of the International Conference on Machine Learning; Baltimore, MD, USA. 17–23 July 2022. [Google Scholar]
- 46.Ganin Y., Lempitsky V. Unsupervised domain adaptation by backpropagation; Proceedings of the International conference on machine learning, PMLR; Lille, France. 6–11 July 2015; pp. 1180–1189. [Google Scholar]
- 47.Li Y., Tian X., Gong M., Liu Y., Liu T., Zhang K., Tao D. Deep domain generalization via conditional invariant adversarial networks; Proceedings of the European Conference on Computer Vision; Munich, Germany. 8–14 September 2018; pp. 624–639. [Google Scholar]
- 48.Zhao S., Gong M., Liu T., Fu H., Tao D. Domain generalization via entropy regularization; Proceedings of the Neural Information Processing Systems; Virtual. 6–12 December 2020; pp. 16096–16107. [Google Scholar]
- 49.Ben-David S., Blitzer J., Crammer K., Pereira F. Analysis of representations for domain adaptation; Proceedings of the Neural Information Processing Systems; Hong Kong, China. 3–6 October 2006; [Google Scholar]
- 50.Ben-David S., Blitzer J., Crammer K., Kulesza A., Pereira F., Vaughan J.W. A theory of learning from different domains. Mach. Learn. 2010;79:151–175. doi: 10.1007/s10994-009-5152-4. [DOI] [Google Scholar]
- 51.Zhao H., Des Combes R.T., Zhang K., Gordon G. On learning invariant representations for domain adaptation; Proceedings of the International Conference on Machine Learning. PMLR; Long Beach, CA, USA. 10–15 June 2019; pp. 7523–7532. [Google Scholar]
- 52.Xu Q., Zhang R., Zhang Y., Wang Y., Tian Q. A fourier-based framework for domain generalization; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition; Nashville, TN, USA. 20–25 June 2021; pp. 14383–14392. [Google Scholar]
- 53.Zhou K., Yang Y., Qiao Y., Xiang T. Domain Generalization with MixStyle; Proceedings of the International Conference on Learning Representations; Vienna, Austria. 3–7 May 2021. [Google Scholar]
- 54.Zhang X., Cui P., Xu R., Zhou L., He Y., Shen Z. Deep stable learning for out-of-distribution generalization; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition; Nashville, TN, USA. 21–25 June 2021; pp. 5372–5382. [Google Scholar]
- 55.Wang H., Ge S., Lipton Z., Xing E.P. Learning robust global representations by penalizing local predictive power; Proceedings of the Neural Information Processing Systems; Vancouver, BC, Canada. 8–14 December 2019; [Google Scholar]
- 56.Gulrajani I., Lopez-Paz D. In Search of Lost Domain Generalization; Proceedings of the International Conference on Learning Representations; Virtual. 2–4 December 2020. [Google Scholar]
- 57.Wiles O., Gowal S., Stimberg F., Rebuffi S.A., Ktena I., Dvijotham K.D., Cemgil A.T. A Fine-Grained Analysis on Distribution Shift; Proceedings of the International Conference on Learning Representations; Virtual. 25–29 April 2022. [Google Scholar]
- 58.Thomas M., Joy A.T. Elements of Information Theory. Wiley-Interscience; Hoboken, NJ, USA: 2006. [Google Scholar]
Associated Data
This section collects any data citations, data availability statements, or supplementary materials included in this article.
Data Availability Statement
Data is contained within the article or supplementary material.




