Abstract
A challenge unique to classification model development is imbalanced data. In a binary classification problem, class imbalance occurs when one class, the minority group, contains significantly fewer samples than the other class, the majority group. In imbalanced data, the minority class is often the class of interest (e.g., patients with disease). However, when training a classifier on imbalanced data, the model will exhibit bias towards the majority class and, in extreme cases, may ignore the minority class completely. A common strategy for addressing class imbalance is data augmentation. However, traditional data augmentation methods are associated with overfitting, where the model is fit to the noise in the data. In this tutorial we introduce an advanced method for data augmentation: Generative Adversarial Networks (GANs). The advantages of GANs over traditional data augmentation methods are illustrated using the Breast Cancer Wisconsin study. To promote the adoption of GANs for data augmentation, we present an end-to-end pipeline that encompasses the complete life cycle of a machine learning project along with alternatives and good practices both in the paper and in a separate videod. Our code, data, full results and video tutorial are publicly available in the paper’s github repositoryd.
Keywords: class imbalance, classification, data augmentation, generative adversarial networks, machine learning
1. Introduction
Machine Learning uses computers to find patterns in data and make predictions that can assist with decisions in our everyday lives. In medical research, machine learning algorithms have been used widely to diagnose diseases and predict patient outcomes1–4. To develop a machine learning model for prediction, we first train the model on one set of data (the training data), and then use this model to predict the Target (a variable of interest) in a new set of data (the test data). When the target is a discrete outcome variable, we usually call each outcome category a Class. Predicting the discrete classes of a target is often referred to as Classification. For example, in the Breast Cancer Wisconsin (BCW) datasete the target is breast cancer diagnosis, where the two classes are Benign and Malignant. Since the target only has two classes, we call this a Binary Classification problem. While the Generative Adversarial Networks approach introduced in this paper can handle targets with any number of classes5, for the sake of simplicity we only discuss binary classification hereafter.
A major challenge in classification is Class Imbalance (CI), where the number of samples in one class is much greater than the number in the other class. We usually call the class with more samples the Majority Class and the class with fewer samples the Minority Class. CI is common in many medical diagnosis and patient outcome prediction applications such as cancer and HIV detection. Without loss of generality, the majority class is assumed to be the non-diseased class, while the minority class is assumed to be the diseased class which is usually of primary interest. In the BCW dataset, the benign class outnumbers the malignant class (count/percentage: 357/63% vs. 212/37%, respectively). Figure 1 shows a scatter plot of the two classes produced by t-SNE6, a popular tool for visualizing high-dimensional data (see details in Section 3.1). The CI problem should be carefully addressed, as traditional machine learning algorithms are often biased toward the majority class, and in extreme cases the minority class may be ignored altogether. In the case of the imbalanced BCW dataset, a model trained using a traditional machine learning algorithm would be more likely to predict a sample in the malignant class to be benign, causing a delay in diagnosis.
Figure 1.

Scatter plot of Breast Cancer Wisconsin dataset (where x1 and x2 are the two dimensions produced by t-SNE). Here, dots denote the benign class, and triangles denote the malignant class. The circles highlight samples of the malignant class that are not distinguishable from samples of the benign class.
One of the strategies for addressing CI is Data Augmentation. This strategy is based on the idea that adding new samples to the minority class can produce balanced data, where both classes have the same number of samples. The simplest data augmentation method is the Random Oversampler (RO)7, which duplicates samples in the minority class; in other words, the new data are exact replications of the original data. Another data augmentation approach uses the average of some samples in the minority class to generate new samples for the class. The primary methods using this approach include the Synthetic Minority Over-sampling Technique (SMOTE) and its many variations8–14. A major limitation of RO and SMOTE is that they tend to add new samples to the minority class that are indistinguishable from the original majority class samples15. To illustrate, in Figure 1 we’ve circled samples of the malignant class (the triangles) that are indistinguishable from samples of the benign class (the dots). RO and SMOTE tend to generate new malignant samples in overlapping regions represented by these circles, where neither the new nor original malignant samples are separable from the benign samples. Models trained on data with many instances of indistinguishable majority and minority class samples will become excessively complex as they attempt to correctly classify these samples. This process will lead to overfitting7, where a model is fit too closely to the training data and thus generalizes poorly on new data. Therefore, a model trained on augmented (balanced) data may not perform better, or may even perform worse, than a model trained on the original (imbalanced) data, achieving the opposite of data augmentation’s goal of improving model performance. However, advanced data augmentation methods based on deep learning have been developed to reduce overfitting. The most popular amongst these advanced methods are Generative Adversarial Networks (GANs)5, which have largely been applied to imaging data in computer vision applications. For instance, DCGAN16 uses Deep Convolutional Networks to implement GANs. SDGAN17 uses D2 adversarial loss and cycle consistency loss for bettering training GANs. AGGAN18 adopts a simulated annealing based evolutionary training process to augment minority classes in image data. ACGAN19 trains GANs using both the majority and minority classes, where the discriminator has two outputs: one to discriminate between real and fake images and one to classify problem-specific class label. BAGAN20 extends ACGAN by limiting the discriminator to have only one output, which solves the contradictory optimization problem of ACGAN on imbalanced data. ciGAN21 uses class conditional GAN with mask infilling for data augmentation. Recently, Covid-GAN22 extends ACGAN by stacking the generator on top of the discriminator and is applied to balance Chest X-Ray. GAMO23 and Polarity-GAN24 extended GANs by adding a third classifier, which pushes the distribution learned by the generator towards the periphery of the minority class, addressing the problem of class imbalance. A thorough review of GAN-based image data augmentation can be seen at25. While GANs have been utilized for traditional tabular data26, 27, to the best of our knowledge there is no off-the-shelf, accessible software readily available to medical researchers to use GANs for data augmentation.
The goal of this tutorial is to introduce GANs to a wider audience, as well as promote their adoption by medical researchers through introduction of a ready-to-use, end-to-end pipeline. The rest of the paper is organized as follows. In Section 2, we introduce GANs and discuss why the data they generate can significantly improve downstream classification performance compared to data generated by traditional data augmentation methods. In Section 3, we first demonstrate our end-to-end pipeline in a step-by-step fashion through the example of cancer diagnosis in the BCW dataset. We then compare empirical results between GANs and traditional data augmentation methods. In Section 4, we conclude the paper and discuss future work.
2. Methods
GANs consist of two components, a generator and a discriminator. The generator tries to generate data similar to the real data, whereas the discriminator tries to discriminate between the real data and the generated data. GANs training is based on an adversarial game between these two components. On one hand, the generator improves (by deceiving the discriminator) only when the discriminator improves (by debunking the generator). On the other hand, the discriminator improves only when the generator improves. Figure 2 illustrates how GANs work to augment the malignant class in the BCW dataset, using the cell nucleus area feature as an example. Here, the generator generates samples to mimic the tumors in the malignant class, whereas the discriminator decides whether the generated tumors belong to the malignant class. In Stage 1, the discriminator finds the samples generated by the generator to have cell sizes much smaller than the real samples in the malignant class. Thus, the discriminator identifies these new samples as belonging to the benign class. This insight is later shared with the generator. As a result, in stage 2 the generator generates samples with larger cells that the discriminator can identify as belonging to the malignant class. After stage 2, training terminates and the generator is ready to generate realistic cancerous tumor samples that will be used to augment the malignant class.
Figure 2.

A simplified diagram of how GANs were trained on the Breast Cancer Wisconsin data to generate realistic samples to augment the malignant class. In Stage 1, the discriminator finds the samples generated by the generator to have cell sizes smaller than the real samples in the malignant class, and thus classifies them as benign. This information is later shared with the generator. Based on this insight, in Stage 2 the generator generates samples with cell sizes similar to real samples in the malignant class, such that the discriminator identifies the generated samples as member of the malignant class. After stage 2, training terminates and the generator is ready to generate samples with realistic cell sizes to augment the malignant class.
GANs augment data using algorithms that are fundamentally different from traditional augmentation methods. While traditional methods sample data directly from the minority class, GANs approximate the underlying distribution of the minority class. During GANs training the generator learns the real data distribution so that the generated data look similar to the real samples. In a simultaneous process, the discriminator learns to classify sample as real versus generated. After enough training iterations, the generator and discriminator will jointly converge to a point where the discriminator cannot distinguish the generated data from the real data. In the BCW data, while the two classes are mostly separable, the malignant class has a few noise (as represented by the triangles in the circles in Figure 1) which are not distinguishable from samples in the benign class. As discussed earlier, traditional methods tend to augment this noise, leading to overfitting (where a classifier tries to fit the noise). In contrast, GANs tend to learn the general distribution of the minority class and ignore the noise. Thus, malignant class samples generated by GANs will be similar to each other but dissimilar from samples in the benign class. In other words, unlike traditional data augmentation methods, GANs do not tend to augment the noise. As a result, GANs are much less likely to cause overfitting, in turn improving trained classifier performance. However, unlike many traditional methods, as far as we know there is no off-the-shelf package that uses GANs for data augmentation. Moreover, since GANs are essentially deep neural networks, implementing GANs which can be efficiently and effectively trained may not be straightforward. This is due to the fact that the speed and accuracy in training GANs largely depend on the architecture and parameters of the model. For this reason, we provide a step-by-step discussion of the good practices for building, compiling and training GANs, both in this paper and in a separate tutorial video publicly available in our paper github repositoryd.
3. Applications
In this section we demonstrate the advantages of GANs over traditional methods with respect to downstream classification accuracy using the BCW study. The main goal of this application is to familiarize audiences with an end-to-end pipeline that encompasses the complete life cycle of a machine learning project, including data preprocessing, data augmentation using GANs or other methods, training / validating / testing classifiers, and interpreting the results using visualization.
The code for this pipeline (named gan_classification.ipynb) and its helper functions (named pmlm_utilities_shallow.ipynb), along with the data and results, are publicly available in our paper github repository,. The algorithm was executed using Google Colaboratory (Colab hereafter)f, an executable document that allows you to write, run and share python code on Google Drive. Instructions (named google_colab_instruction.ipynb) for set-up and use of Colab can also be found in the github repository. Code cell numbers are provided throughout the procedure to help readers locate the code corresponding to each step in the pipeline.
3.1. An end-to-end pipeline for diagnosing cancer in BCW
The pipeline, as illustrated in Figure 3, can be summarized in the pseudocode below.
Figure 3.

The pipeline for diagnosing cancer in the Breast Cancer Wisconsin data. Step 1: We preprocess the imbalanced data and separate the preprocessed data into imbalanced training, validation and test data. Step 2: We feed the imbalanced training data to three different data augmentation methods (RO / SMOTE / GANs), which separately augment the malignant class and generate the balanced training data (where the number of samples in each class is the same). Step 3: We first feed the balanced training data and imbalanced validation data to a classifier to train the classifier and fine-tune its hyperparameters. Next, we select the classifier that corresponds to the best hyperparameter setting (i.e., which leads to the best cancer diagnosis performance on the validation data). Last, we feed the test data to the best classifier to generate the test score (cancer diagnosis performance on the test data). Step 4: We use the balanced training data (augmented in Step 2) to interpret the results.
Algorithm: An end-to-end pipeline for diagnosing cancer in BCW
Input: BCW
Output: Cancer diagnosis and their accuracy
- Step 1: Data preprocessing
- Step 1a. Loading the data
- Step 1b. Splitting the data
- Step 1c. Handling patient identifiers
- Step 1d. Encoding the data
- Step 1e. Splitting the feature and target
- Step 1f. Scaling the features
- Step 2: Data augmentation using RO, SOMTE and GANs
- Step 2a. Obtaining the number of samples for each class and identifying the minority class
- Step 2b. Applying RO and SMOTE
- Step 2c. Developing GANs
- Obtaining the training data from the malignant class
- Building GANs
- Compiling GANs
- Training GANs
- Augmenting the malignant class in the training data
Step 3: Training, validating and testing the classifiers
Here we explain each step in the pseudocode in details.
Step 1: Data preprocessing
The BCW dataset has 30 features generated from digitized images of a breast mass. These variables describe 10 characteristics of the cell nuclei present in the images, including the radius, texture, perimeter, area, smoothness, compactness, concavity, concave points, symmetry and fractal dimension (more details on these features can be found on the kaggle websiteg). In this study, we used these features to diagnose breast cancer. Next, we describe data preprocessing in BCW.
Step 1a. Loading the data (see code in cells 7–9). Here we load the BCW data into memory.
Step 1b. Splitting the data (cells 10–13). Here we divide the data into training (50%), validation (30%) and test (20%) data, where training data is used for training the models, validation data is used for fine-tuning the hyperparameters (i.e., parameters whose value must be predetermined) of the models, and the test data is used for evaluating how well the model generalizes on new data. It is worth noting that we can also use other ratios when splitting the data (e.g., training 60%, validation 20% and test 20%). A good practice is that, the larger the sample size the higher the ratio for training. For instance, if the data set has 1 million samples, we can use, say, 90% for training, 5% for validation and 5% for test, as 5% of the million samples (i.e., 5K) is usually sufficient for validation and test.
Step 1c. Handling patient identifiers (cells 14–19). We identify and remove patient identifiers from the data. Since identifiers have no predictive power for the target, it is good practice to remove them from the data. Removing identifiers could speed up training and improve classification accuracy.
Step 1d. Encoding the data (cells 20–27). We transform non-numerical data into integers. A good practice is that, we use One Hot Encoding to encode a categorical feature into k feature-value pairs (where k is the number of unique values of the feature). For a categorical target (e.g., the two classes, benign/malignant, in BCW), on the other hand, we simply encode its values into different integers (e.g., 0/1).
Step 1e. Splitting the feature and target (cell 28). Here we create the feature matrix and target vector and convert them into a Numpy array, a data structure used by most machine learning methods. Splitting the feature and target is necessary since 1) in the next step (Step 1f) we will be scaling only the features and 2) most machine learning methods require the features and target to be input separately.
Step 1f. Scaling the features (cell 29). Features are normalized into [0, 1], a range that makes training GANs much easier.
Step 2: Data augmentation using RO, SOMTE and GANs
In this step we use three methods, RO, SMOTE, and GANs, to separately augment the malignant class in the training data.
Steps 2a. Obtaining the number of samples for each class and identifying the minority class (cells 30–31).
Steps 2b. Applying RO and SMOTE (cells 32–35). We use RO and SMOTE implemented in the imblearn packageh to augment the malignant class in the training data. Specifically, RO uses the fit_resample function to duplicate the original malignant class samples, whereas SMOTE uses the fit_resample function to generate malignant class samples that are the average of their neighbors in the sampled malignant class.
Step 2c. Developing GANs (cells 37–44). We discuss the key steps in implementing GANs for data augmentation.
Obtaining the training data from the malignant class (cell 37). We will later train GANs only on this dataset. In other words, the benign class will not be used for training GANs.
Creating the dictionary of the saved model (cell 38). Since GANs are Deep Neural Networks, training such models on a large dataset can be computationally expensive. Therefore, it is highly recommended to periodically save the models so that the training process can be resumed if it terminates unexpectedly.
-
Building GANs (cell 38).
To build the generator in GANs (lines 8–15):- Specify that the generator should be a fully connected feedforward neural network (line 8).
- Add the input layer consisting of 100 perceptrons (line 9).
- Add four hidden layers consisting of 200, 300, 400 and 500 perceptrons, respectively (lines 10–13).
- Add the output layer where the number of perceptrons is the same as the number of features (line 14).
- The input and hidden layers in the generator use selu as the activation function (lines 9 −13). In contrast, the output layer uses sigmoid as the activation function (line 14).
To build the discriminator in GANs (lines 18–26):- Specify that the discriminator should be a fully connected feedforward neural network (line 18).
- Add the input layer where the number of perceptrons is the same as the number of features (line 19).
- Add four hidden layers consisting of 500, 400, 300 and 200 perceptrons, respectively (lines 20–24).
- Add the output layer consisting of only one perceptron (line 25).
- Identical to the generator, the hidden layers in the discriminator use selu as the activation function (lines 20–24) whereas the output layer uses sigmoid as the activation function (line 25).
Complete a GAN by combining the generator and discriminator built earlier (line 29).
Below are some good practices for building GANs:- The generator usually gets wider when it gets deeper. In other words, the closer a layer to the output layer, the larger the number of perceptrons on this layer (as shown between lines 9 and 13).
- The discriminator, on the other hand, usually gets narrower when it gets deeper. In other words, the closer a layer to the output layer, the smaller the number of perceptrons on this layer (as shown between lines 20 and 24).
- While we can tweak the number of perceptrons on some layers, this number must be fixed on the others. Concretely, the number on the output layer of the generator and input layer of the discriminator must be the same as the number of features (as shown in lines 14 and 19). The number on the input layer of the generator must be the same as the dimension of the noise (i.e., coding_size, as shown in line 9), whereas the number for the output layer of the discriminator must be 1 (as shown in line 25).
- Using sigmoid on the output layer of the generator could make training the generator much easier, since it only generates data in a narrow range, [0, 1]. The reason why we can use sigmoid on the output layer is that, we used normalization to transform the data into [0, 1] (as discussed in Step 1.f in data preprocessing).
-
Compiling GANs (cell 40):
- Compile the discriminator by specifying its loss and optimizer (lines 2–3).
- Freeze the discriminator to train the discriminator and generator alternately (line 5).
- Compile the generator by specifying its loss and optimizer (lines 8–9).
Below are some good practices for compiling GANs:- While the default learning rate of Adam optimizer is 10−3, we found that a smaller value (e.g., 10−4, see lines 3 and 9) could work even better.
- We recommend the readers to use binary_crossentropy as the loss for both the discriminator (see line 2) and generator (line 8), since both of them output values in range [0, 1].
-
Training GANs (cell 41) in 10 iterations:
- Shuffle the data at the beginning of each iteration (line 19).
- Train the discriminator in an iteration (lines 24–49)
- Train the generator in an iteration (lines 52–61)
- Save the trained GANs after each iteration (line 64).
Below are some good practices for training GANs:- Normally batch_size (the number of samples in each minibatch) should be no larger than 32 (see line 8 in cell 41)28.
- In each epoch, we should shuffle the data prior to the minibatch gradient descent (see line 19).
- We should first train the discriminator (see lines between 24 and 49) then the generator (lines between 52 and 61).
- The dimension of noise used for training GANs (lines 34 and 52 in cell 41) must be the same as the dimension of noise used for building GANs (line 9 in cell 39).
- We should save the trained GANs periodically (e.g., after each epoch, see line 64).
-
Augmenting the malignant class in the training data (cells 42–44):
- Load the GANs saved earlier (cell 42).
- Generate malignant class data using the trained generator (cell 43).
- Augment the malignant class in the training data by incorporating the data generated earlier (cell 44).
Below are some good practices for augmenting the minority class:- The dimension of noise used for generating data (see line 12 in cell 43) must be the same as the dimension of noise used for building GANs (line 9 in cell 39) and for training GANs (lines 34 and 52 in cell 41).
Step 3: Training, validating and testing the classifiers
Classification
To evaluate the performance of the three data augmentation methods, we assessed their impact on downstream cancer diagnosis performance. While in theory any classifier can be used to diagnose cancer, in this study we applied Gradient Boosting Machines (GBMs), which is one of the most popular machine learning models29. Specifically, we used the latest package of GBMs in sklearn, named HistGradientBoostingClassifier (HGBC)i. We first trained HGBC on the original (imbalanced) and augmented (balanced) training data and fine-tuned two key hyperparameters on the validation data, namely learning_rate (which determines the speed of training HGBC) and min_samples_leaf (which determines the complexity of HGBC). For learning_rate, we tried five values: 10−3, 10−2, 10−1, 1 and 10 (see line 2). For min_samples_leaf, we tried three values: 1, 20 and 100 (see line 5). We then selected the HGBC that achieved the best classification performance on the validation data. Lastly, we tested the best HGBC on the test data. Considering the data is imbalanced, we use f1 as the measurement for scoring when fine-tuning the hyperparameters (see line 43).
Below are some good practices for training, validating and testing the classifiers:
If computational cost is not a top concern, the readers may consider to fine-tune more hyperparameters and try more values for each hyperparameter. This will usually lead to more accurate models.
Besides f1, other metrics such as roc_auc can also be used when fine-tuning the hyperparameters on imbalanced data. However, we would not recommend the audiences to use accuracy when the data is imbalanced, since a useless model that simply predicts every sample as the non-disease class can have an accuracy close to 1. A more detailed discussion between a wide range of metrics can be seen in30.
Evaluation
We evaluated the ability of our trained GBMs to diagnose cancer using multiple metrics, including precision, recall, F1-score, and Area Under the ROC Curve (AUC). Precision, also known as positive predictive value, is the number of subjects who are correctly predicted as diseased (e.g., malignant in BCW) out of the number of subjects who are predicted as diseased. Recall, also known as sensitivity, is the number of subjects who are correctly predicted as diseased out of the number of subjects who have the disease. The F1-score combines precision and recall and can be interpreted as a weighted average of the two measures. AUC is the area under the ROC curve, which shows the true positive rate (i.e., sensitivity) and false positive rate (i.e., 1 – specificity) as a function of the prediction threshold. Both F1-score and AUC are particularly useful when data are imbalanced, as they are partially based on recall of the diseased class. All of these metrics range in value from 0 to 1 (where 0 / 1 means the model predicts every class incorrectly / correctly), and the higher the value, the better the performance of the classifier. We calculated all of these metrics for both the malignant and benign classes in BCW. In addition, we conducted a visual comparison of the different augmentation methods by generating plots (Figures 4 to 6) displaying the distributions of the original and generated data for each augmentation method. These plots were created using t-SNE6, a tool for Dimensionality Reduction, which can transform high-dimensional data into two dimensions. Code for generating these figures can be found in cells 55 to 61 in gan_classification.ipynb.
Figure 4.

Scatter plot of the training data augmented by RO (where x1 and x2 are the two dimensions produced by t-SNE). Here, dots indicate the original samples in the benign class, triangles indicate the original samples in the malignant class, and squares denote the data generated by RO to augment the malignant class. The circles highlight generated samples of the malignant class that are not distinguishable from samples of the benign class.
Figure 6.

Scatter plot of the training data augmented by GANs (where x1 and x2 are the two dimensions produced by t-SNE). Here, dots indicate the original data of the benign class, triangles indicate the original data of the malignant class, and squares denote the generated data of the malignant class.
Below are some good practices for interpreting results using visualization:
Since we want to scale the data down to 2D, we set parameter n_components as 2 (see notebook pmlm_utilities_shallow.ipynb, cell 6, line 28). The audiences can change this value to 3 if a 3D visualization is needed.
We would highly recommend the audiences to set parameter random_state as a fixed value (see notebook pmlm_utilities_shallow.ipynb, cell 6, line 28) so that the results of t-SNE are reproducible.
3.2. Results
Table 1 summarizes the model classification performance, where HGBC was applied to the test data after it was trained either on the original (imbalanced) training data (Baseline), or on the augmented (balanced) training data generated by RO, SMOTE, or GANs, respectively. Across all metrics (precision, recall, F1-score and AUC), RO and SMOTE had the exact same results as the Baseline data, indicating that augmenting data using RO and SMOTE may not improve downstream classification performance. In contrast, GANs performance was equivalent or superior to the best-performing alternatives. Specifically, the recall of malignant increased from 0.906 (other methods) to 0.953 (GANs), corresponding to a 50% reduction in error rate. This improvement is particularly meaningful as it indicates that half of the malignant patients who were incorrectly classified as benign can be correctly classified with the help of GANs. It is worth noting that GANs improved recall without decreasing precision. In fact, GANs obtained the highest precision possible (i.e., 1.0). With the highest recall and precision of all methods tested, GANs had the highest F1-score and AUC. Last but not least, GANs improved classification performance on the malignant class without compromising performance on the benign class. In fact, GANs increased the precision in the benign class from 0.947 to 0.973 (corresponding to a 49% reduction in error rate), while the recall remained fixed at 1.0. It is worth noting that, the significant improvement in classification accuracy brought about by GANs is not at the expense of significantly greater computational cost, as the run time for RO, SMOTE and GANs are 0.002, 0.007 and 5 seconds, respectively.
Table 1.
Precision, recall, F1-score and AUC of HGBC on the test data, where HGBC was trained either on the original imbalanced training data (Baseline), or on the balanced training data augmented by RO, SMOTE or GANs, respectively.
| Method | Precision (benign) | Precision (malignant) | Recall (benign) | Recall (malignant) | F1-score (benign) | F1-score (malignant) | AUC |
|---|---|---|---|---|---|---|---|
| Baseline | 0.947 | 1.0 | 1.0 | 0.906 | 0.973 | 0.951 | 0.953 |
| RO | 0.947 | 1.0 | 1.0 | 0.906 | 0.973 | 0.951 | 0.953 |
| SMOTE | 0.947 | 1.0 | 1.0 | 0.906 | 0.973 | 0.951 | 0.953 |
| GANs | 0.973 | 1.0 | 1.0 | 0.953 | 0.986 | 0.976 | 0.977 |
While the underlying distribution of the BCW data can largely be distinguished by cancer status, there are a few occasions where malignant (triangles) and benign (dots) are not separable (see circles in Figure 1). These occasions are ‘noise’ which can be augmented by RO (circles in Figure 4) and SMOTE (circles in Figure 5) if they are sampled. As a result, additional noise is introduced to the training data, leading to classifier overfitting problems. In contrast, GANs can learn the underlying distribution of the BCW data (where the malignant samples are close to each other but far away from benign samples), and are therefore less likely to be affected by noise. As a result, the data generated by GANs (squares in Figure 6) are clustered and separated from the benign class (dots), and thus much less likely to cause overfitting.
Figure 5.

Scatter plot of the training data augmented by SMOTE (where x1 and x2 are the two dimensions produced by t-SNE). Here, dots indicate the original data of the benign class, triangles indicate the original data of the malignant class, and squares denote the generated data of the malignant class. The circles highlight generated samples of the malignant class that are not distinguishable from samples of the benign class.
It is worth noting that, while Figures 1, 4 to 6 are all based on the same BCW data, the scatter plots look quite differently. This is because Figure 1 only contains BCW (i.e., with no augmentation), whereas Figures 4 to 6 also contain data augmented by RO, SMOTE and GANS. In other words, the data added to BCW are different in the four figures, resulting in different scatter plots.
4. Discussions
Through the introduction and application of three data augmentation methods we were able to show that GANs can effectively address class imbalance and, in turn, significantly improve downstream classification performance. It is worth noting that, we do not argue GANs are consistently better than traditional methods, which requires either rigorous theoretical proof or comprehensive empirical evaluation, both of which are beyond the scope of this tutorial paper. As we mentioned in Introduction, the goal of this paper is simply introducing GANs to a wider audience as a potentially useful tool, as well as promoting their adoption by medical researchers through introduction of a ready-to-use, end-to-end pipeline.
We also want to point out that, like many others, in this paper we define an imbalanced dataset as one in which the number of samples in one class is much higher than in the other class. However, to the best of our knowledge there is currently no consensus with respect to 1) when data augmentation is necessary, or 2) which type of data augmentation method should be used for a given dataset. We believe that when classification accuracy is favored over computational cost (as is usually the case), data augmentation should be performed when there is any degree of class imbalance. In order to achieve optimal classification performance, we also recommend conducting a sensitivity analysis by comparing multiple data augmentation methods.
In this paper we implemented GANs using fully connected deep neural networks. In the future, we will extend this work by 1) allowing more advanced architectures such as deep convolutional neural networks, and 2) adopting more accurate and efficient training methodologies such as those proposed in31. These extensions could further improve the performance of GANs. Since the target outcome in healthcare data is usually binary32, our end-to-end pipeline currently only permits binary targets. However, we are extending this pipeline to allow arbitrary numbers of classes. We are also working on open source software which will take imbalanced data as input, preprocess the data, select the best data augmentation method, and output the balanced data augmented by the selected method.
Acknowledgements
This research was supported by the National Institute on Minority Health and Health Disparities of the National Institutes of Health under Award Number R01MD013901 and the George Washington University Facilitating Fund (UFF) FY21.
Footnotes
Data Availability Statement
Our code, data, full results and video tutorial are publicly available in the paper’s github repository*.
*https://github.com/yuxiaohuang/research/tree/master/gwu/accepted/sam_2021
References
- 1.Shipp MA, Ross KN, Tamayo P, Weng AP, Kutok JL, Aguiar RC, Gaasenbeek M, Angelo M, Reich M, Pinkus GS and Ray TS. Diffuse large B-cell lymphoma outcome prediction by gene-expression profiling and supervised machine learning. Nature medicine. 2002;8(1):68–74. [DOI] [PubMed] [Google Scholar]
- 2.Weng SF, Reps J, Kai J, Garibaldi JM and Qureshi N. Can machine-learning improve cardiovascular risk prediction using routine clinical data?. PloS one. 2017;12(4). [DOI] [PMC free article] [PubMed] [Google Scholar]
- 3.Rastegar DA, Ho N, Halliday GM and Dzamko N. Parkinson’s progression prediction using machine learning and serum cytokines. NPJ Parkinson’s disease. 2019;5(1):1–8. [DOI] [PMC free article] [PubMed] [Google Scholar]
- 4.Menni C, Valdes AM, Freidin MB, Sudre CH, Nguyen LH, Drew DA, Ganesh S, Varsavsky T, Cardoso MJ, Moustafa JSES and Visconti A. Real-time tracking of self-reported symptoms to predict potential COVID-19. Nature medicine. 2020;1–4. [DOI] [PMC free article] [PubMed] [Google Scholar]
- 5.Goodfellow I, Pouget-Abadie J, Mirza M, Xu B, Warde-Farley D, Ozair S, Courville A and Bengio Y. Generative adversarial nets. In Advances in neural information processing systems. 2014;2672–2680. [Google Scholar]
- 6.Maaten LVD and Hinton G. Visualizing data using t-SNE. Journal of machine learning research. 2008;9:2579–2605. [Google Scholar]
- 7.Douzas G and Bacao F. Effective data generation for imbalanced learning using conditional generative adversarial networks. Expert Systems with applications. 2018;91:464–471. [Google Scholar]
- 8.Chawla NV, Bowyer KW, Hall LO and Kegelmeyer WP. SMOTE: synthetic minority over-sampling technique. Journal of artificial intelligence research. 2002; 16:321–357. [Google Scholar]
- 9.Batista GE, Prati RC and Monard MC. A study of the behavior of several methods for balancing machine learning training data. ACM SIGKDD explorations newsletter. 2004;6(1):20–29. [Google Scholar]
- 10.Bunkhumpornpat C, Sinapiromsaran K and Lursinsap C. Safe-level-smote: Safe-level-synthetic minority over-sampling technique for handling the class imbalanced problem. In Pacific-Asia conference on knowledge discovery and data mining. 2009;475–482. [Google Scholar]
- 11.Han H, Wang WY and Mao BH. Borderline-SMOTE: a new over-sampling method in imbalanced data sets learning. In International conference on intelligent computing. 2005;878–887. [Google Scholar]
- 12.Barua S, Islam MM, Yao X and Murase K. MWMOTE-Majority Weighted Minority Oversampling Technique for Imbalanced Data Set Learning. IEEE Transactions on Knowledge and Data Engineering. 2012;26(2):405–425. [Google Scholar]
- 13.He H, Bai Y, Garcia EA and Li S. ADASYN: Adaptive synthetic sampling approach for imbalanced learning. In 2008 IEEE international joint conference on neural networks (IEEE world congress on computational intelligence). 2008;1322–1328. [Google Scholar]
- 14.Tang B and He H. KernelADASYN: Kernel based adaptive synthetic data generation for imbalanced learning. In 2015 IEEE Congress on Evolutionary Computation (CEC). 2015;664–671. [Google Scholar]
- 15.He H and Garcia EA. Learning from imbalanced data. IEEE Transactions on knowledge and data engineering. 2009;21(9):1263–1284. [Google Scholar]
- 16.Shoohi LM and Saud JH, 2020. Dcgan for handling imbalanced malaria dataset based on over-sampling technique and using cnn. Medico Legal Update, 20(1), pp.1079–1085. [Google Scholar]
- 17.Niu S, Li B, Wang X and Lin H, 2020. Defect image sample generation with GAN for improving defect recognition. IEEE Transactions on Automation Science and Engineering, 17(3), pp.1611–1622. [Google Scholar]
- 18.Hao J, Wang C, Zhang H and Yang G, 2020. Annealing genetic GAN for minority oversampling. arXiv preprint arXiv:2008.01967. [Google Scholar]
- 19.Odena A, Olah C and Shlens J, 2017, July. Conditional image synthesis with auxiliary classifier gans. In International conference on machine learning (pp. 2642–2651). PMLR. [Google Scholar]
- 20.Mariani G, Scheidegger F, Istrate R, Bekas C and Malossi C, 2018. Bagan: Data augmentation with balancing gan. arXiv preprint arXiv:1803.09655. [Google Scholar]
- 21.Wu E, Wu K, Cox D and Lotter W, 2018. Conditional infilling GANs for data augmentation in mammogram classification. In Image analysis for moving organ, breast, and thoracic images (pp. 98–106). Springer, Cham. [Google Scholar]
- 22.Waheed A, Goyal M, Gupta D, Khanna A, Al-Turjman F and Pinheiro PR, 2020. Covidgan: data augmentation using auxiliary classifier gan for improved covid-19 detection. Ieee Access, 8, pp.91916–91923. [DOI] [PMC free article] [PubMed] [Google Scholar]
- 23.Mullick SS, Datta S and Das S. Generative adversarial minority oversampling. In Proceedings of the IEEE International Conference on Computer Vision. 2019;1695–1704. [Google Scholar]
- 24.Deepshikha K and Naman A, 2020. Removing Class Imbalance using Polarity-GAN: An Uncertainty Sampling Approach. arXiv preprint arXiv:2012.04937. [Google Scholar]
- 25.Sampath V, Maurtua I, Martín JJA and Gutierrez A, 2021. A survey on generative adversarial networks for imbalance problems in computer vision tasks. Journal of big Data, 8(1), pp.1–59. [DOI] [PMC free article] [PubMed] [Google Scholar]
- 26.Engelmann J and Lessmann S, 2021. Conditional Wasserstein GAN-based oversampling of tabular data for imbalanced learning. Expert Systems with Applications, 174, p.114582. [Google Scholar]
- 27.Ren J, Liu Y and Liu J, 2019, July. Ewgan: Entropy-based wasserstein gan for imbalanced learning. In Proceedings of the AAAI Conference on Artificial Intelligence (Vol. 33, No. 01, pp. 10011–10012). [Google Scholar]
- 28.Masters D and Luschi C, 2018. Revisiting small batch training for deep neural networks. arXiv preprint arXiv:1804.07612. [Google Scholar]
- 29.Chollet F. Deep Learning with Python. Manning Publications. 2017. [Google Scholar]
- 30.Géron A, 2019. Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems. O’Reilly Media. [Google Scholar]
- 31.Gulrajani I, Ahmed F, Arjovsky M, Dumoulin V and Courville AC. Improved training of wasserstein gans. In Advances in neural information processing systems. 2017;5767–5777. [Google Scholar]
- 32.Ferraro KF and Wilmoth JM. Measuring morbidity: disease counts, binary variables, and statistical power. The Journals of Gerontology Series B: Psychological Sciences and Social Sciences. 2000;55(3):S173–S189. [DOI] [PubMed] [Google Scholar]
Associated Data
This section collects any data citations, data availability statements, or supplementary materials included in this article.
Data Availability Statement
Our code, data, full results and video tutorial are publicly available in the paper’s github repository*.
*https://github.com/yuxiaohuang/research/tree/master/gwu/accepted/sam_2021
