Abstract
The ability to predict health outcomes from gene expression would catalyze a revolution in molecular diagnostics. This task is complicated because expression data are high dimensional whereas each experiment is usually small (e.g., ∼20,000 genes may be measured for ∼100 subjects). However, thousands of transcriptomics experiments with hundreds of thousands of samples are available in public repositories. Can representation learning techniques leverage these public data to improve predictive performance on other tasks? Here, we report a comprehensive analysis using different gene sets, normalization schemes, and machine learning methods on a set of 24 binary and multiclass prediction problems and 26 survival analysis tasks. Methods that combine large numbers of genes outperformed single gene methods, but neither unsupervised nor semi-supervised representation learning techniques yielded consistent improvements in out-of-sample performance across datasets. Our findings suggest that using l2-regularized regression methods applied to centered log-ratio transformed transcript abundances provide the best predictive analyses.
I. INTRODUCTION
The potential to tailor therapies for individual patients rests on the ability to accurately diagnose disease and predict outcomes under various treatment conditions. Predictors based on high-throughput ‘omics technologies hold great promise, but a number of technical challenges have limited their applicability [1]. Phenotypes may be complex—involving contributions from large numbers of genes—but ‘omics data are so high-dimensional that exploring all possible interactions is intractable. This situation is further complicated by the small sample sizes of typical biological studies and by large systematic sources of variation between experiments [2],[3]. However, recent developments in machine learning have raised hopes that new computational methods integrating data from many studies may be able to overcome these difficulties. Accurate prediction of phenotype or endpoint(s) from ‘omics data would usher in an era of molecular diagnostics [4],[5].
Machine learning methods often benefit from large datasets where learning complex relationships is feasible. Although individual biological experiments tend to be small, relatively large amounts of ‘omics data are available in public repositories. For example, hundreds of thousands of samples from human RNA sequencing (RNA-seq) experiments are available from the recount2 and ARCHS4 databases [6–8]. Still, these data cover a wide variety of tissues and diseases. Moreover, there are no specific diseases with large numbers of samples and, in many cases, the metadata are not sufficient to determine basic experimental facts like the tissue of origin [9]. As a result, leveraging these data to improve prediction tasks will require machine learning techniques that can learn from large, heterogeneous datasets.
Genes rarely act in isolation, so it is reasonable to expect that combinations of genes may be more effective than individual genes for predicting phenotypes. For example, linear models operating on RNA-seq data create predictors from a weighted combination of gene expression values. However, some of these features could reflect biological processes that are involved in multiple phenotypes. Many previous analyses have explored this possibility by creating complex features that incorporate biological knowledge from gene sets [10, 11], ontologies [12], or interaction graphs [13–15]. More recently, machine learning methods used with principal components analysis [16], autoencoders [17–20], or other neural network architectures have been developed to discover such features by analyzing large transcriptomics datasets. If these learned features capture biologically relevant processes, then predictive models built from those features should outperform models built directly from relative transcript abundances.
In this work, we present a comprehensive analysis of phenotype prediction from transcriptomics data with a particular emphasis on representation learning. Using the recount2 database [7], we systematically explored the impact of normalization techniques, gene sets, learned representations, and machine learning methods on predictive performance for a set of 24 binary and multiclass prediction problems and 26 survival analysis tasks. In total, we analyzed thousands of predictive models using 5-fold nested cross validation to rigorously assess out-of-sample performance. We found that predictors that combined multiple genes outperformed single gene predictors, that logarithmic transformations outperformed untransformed relative expression measurements, and that for survival analyses larger gene sets outperformed smaller gene sets. However, neither unsupervised nor semi-supervised representation learning techniques yielded consistent improvement on out-of-sample predictive performance across datasets. In fact, l2-regularized regression methods applied directly to the centered log-ratio transform of transcript abundances performed consistently well relative to the other methods. Therefore we recommend treating that particular combination as a baseline method for predictive analysis on RNAseq data. Throughout this text we refer to the combination of l2-regularized regression methods applied directly to the centered log-ratio transform of transcript abundances as the recommended model.
II. RESULTS
Approach
A high-level description of our quality control, data processing, and machine learning analyses is provided in Figure 1. Details of these steps are provided in the online Methods.
Briefly, we selected a subset of experiments from the recount2 database [7] that did not have sparse gene expression data and could be mapped to the same set of tissues covered in the Genotype-Tissue Expression (GTEx) project [21]. We assigned the various experiments to “training” (∼37k samples), “validation” (∼4k samples), and “test” sets (∼4k samples). All samples lacking suitable metadata for supervised learning were allocated to the training set. From metadata provided with recount2, the Gene Expression Omnibus [22], and The Cancer Genome Atlas (TCGA) Pan-Cancer Clinical Data Resource [23] we derived labels for 24 binary and multiclass and 26 survival analysis tasks. Descriptions of these tasks and their assignment to the training, validation, and test sets are provided in the online Methods.
We considered four different normalization methods to correct for variance introduced in the data collection and measurement process. We first converted the samples from counts to Transcripts Per Million (TPM) [24], a normalization which estimates relative molar concentration of transcripts in a sample. Under the operating assumption that transcript abundance is determinant of downstream biological function, TPMs should be the baseline quantification to work with from RNAseq. In contrast, raw counts (or CPMs) contain irrelevant counting bias stemming from variable transcript length. Likewise, the common alternative of FPKMs do not coherently measure relative molar concentration, because they rely on a sample-dependent normalization factor. As such, FPKMs are not a useful measure when processing samples which are not entirely technical replicates of a single tissue sample [24, 25]. Secondly we applied the centered log-ratio transformation (CLR) [26] to the TPM data to address the fact that RNA-seq data quantify relative, rather than absolute, gene expression [27],[28]. Since these two normalization methods do not account for the tissue of origin of the sample, we evaluated additional normalization methods based on differential expression with respect to normal tissue. The third normalization method converted the CLR transformed expression data from each sample to a tissue-normalized Z-score by subtracting the mean and dividing by the standard deviation of the associated tissue in GTEx. This mean and standard deviation of the CLR transformed expression data were computed across the GTEx data in recount2 for each annotated tissue type. Finally, a fourth ternarized normalization discretized the Z-scores into down-regulated (Z < −2), normal (−2 < Z < 2), or up-regulated (Z > 2) categories.
For each of these normalization approaches, we also explored three gene sets corresponding to transcription factors [29] (denoted OT), protein coding genes annotated as with biological processes or molecular functions in the Gene Ontology 12 (denoted O), and all genes provided by recount2 (denoted all). The O and OT gene sets are substantially smaller than the all gene set and allow exploration of the dependence on the number of genes. In total, we examined twelve different normalization-gene set combinations for each predictive problem.
We considered four different types of representations of the gene expression data learned by unsupervised models. First, supervised models were trained directly on the normalized expression data without a learned embedding. We also considered representations constructed with Principal Components Analysis (PCA), a Stacked Denoising Autoencoder (SDAE), and a Variational Autoencoder (VAE) trained on the 37k samples in the training set without any supervising information.
For each binary or multiclass prediction task, we trained a k-Nearest Neighbor (kNN) classifier, a Random Forest (RF), and an l2-regularized multinomial Logistic Regression (LR) on the normalized and transformed data using 5-fold nested cross validation. Using nested cross validation (Methods) is important because it accounts for performance variance that results from different hyperparameter choices (e.g., the number of nearest neighbors, the depth of the trees in the forest, or the strength of the regularization coefficient). An l2-regularized Cox proportional hazards model was used for all survival tasks, also with 5-fold nested cross validation. Binary tasks were compared using the Area Under the receiver operating characteristic Curve (AUC); multiclass tasks were compared using the accuracy, and survival tasks were compared using the concordance-index (C-index) [30],[31].
Our systematic model search covered four normalization methods, three gene sets, four representations, and three supervised algorithms totaling 144 comparison models for each of the 24 binary and multiclass tasks. For the survival tasks we used the same normalization methods, gene sets, and representations, but considered only one supervised algorithm (Cox proportional hazards). For comparison, we also trained linear predictors using the recommended method that were only allowed to use a single gene. The choice of gene was treated as a hyperparameter and optimized using 5-fold nested cross validation.
Analyses
The predictive performance assessed through 5-fold nested cross validation varied considerably across and within the predictive problems (see Figure 2). Gene expression data improved predictive performance relative to random guessing in almost all cases, indicating that RNA-seq data do contain information that is broadly useful for out-of-sample prediction. Moreover, linear predictors that used the expression data from all genes generally outperformed models that only used a single, most predictive gene. It is still common to analyze genes independently in differential expression and regression analyses; our results indicate, however, that linear combinations of genes are significantly more predictive than individual genes. Although there was sizable variance in performance across tasks, predictive performance was not correlated with any obvious dataset characteristics such as the number of subjects.
In order to compare the effects of the gene set size and transformation, it is helpful remove between-task variance and then to aggregate results across tasks. To remove the between-task variance, we defined a shifted statistic in which we subtracted the median value of all models on the same task. For example, the AUC for the random forest classifier on the STAD stage dataset was shifted by subtracting the median AUC for all of the binary classifiers trained on the STAD stage dataset. Averages of the shifted statistics across predictive problems can be easily interpreted: if the value is less than zero then the method underperformed the median, whereas the method outperformed the median if the value is greater than zero.
Within-task variance in predictive performance was partially explained by the choice of gene set and normalization method (see Figure 3). Because the number of samples in each dataset was much smaller than the number of genes annotated in recount2, we hypothesized that using prior knowledge to select small, biologically relevant gene sets based on the Gene Ontology or transcription factor activity would improve out-of-sample predictive performance by preventing overfitting. However, this hypothesis was not supported by our analyses. The choice of gene set made no difference for the classification problems, whereas the smaller gene sets underperformed on the survival tasks. The log-transformed normalization methods slightly outperformed TPMs, and the Z-score normalization performed the best, on average. Performance improvements of Z-score normalization relative to CLR were small, however, and we do not think that the small gains justify the additional complexity introduced by referencing each sample to an external dataset (i.e., GTEx).
Next, we examined differences in absolute performance between the kNN, RF, and LR models on the classification problems (only a linear Cox proportional hazards model was tested on the survival tasks). As shown in Figure 4, the kNN classifier consistently underperformed the RF and LR classifiers. The RF was the best performing method for thirteen tasks, LR for nine tasks, and kNN for two tasks, but LR was more consistent than RF and had better average performance.
Gene expression data are very high dimensional, with the number of genes ranging from ∼1.5k in the transcription factor gene set to ∼56k in the gene set consisting of all genes annotated in recount2. In contrast, the supervised task datasets typically consisted of only a few hundred samples. Moreover, it seems unlikely that genes actually coordinate in a linear fashion to generate complex phenotypes. Therefore we hypothesized that predictive performance could be improved by training predictors on lower dimensional representations derived from unsupervised analyses of the ∼37k unlabeled samples in the training set. One could also view these analyses as a type of transfer learning, in which biological knowledge derived from the analysis of one dataset is used to inform the analyses of another.
The first feature representation that we considered was a Principal Components Analysis (PCA) with 512 latent dimensions. These principal components are orthogonal linear combinations of expression values that represent the directions of largest variance in the training set. Together, the 512 principal components we used explained the majority of the variation in the transcriptional datasets (Figure 9). We found that using PCA derived representations as features decreased the out-of-sample performance of downstream predictive analyses (Figure 5). Therefore, we do not recommend using features derived from PCA of large RNA-seq compendia for predictive analyses.
Training a linear model on top of representations derived from a linear transformation like PCA is equivalent to a regularized linear model trained on the unembedded data. Deep neural network-based architectures like SDAEs and VAEs, by contrast, process an input expression vector through a series of nonlinear transformations to learn more complex features. Therefore, we also trained a 512-dimensional SDAE and VAE on the training set for each gene set-normalization combination and used the representations derived from these neural networks as features for downstream prediction tasks. Nevertheless, we found that preprocessing the expression data using these networks decreased the out-of-sample performance of downstream prediction tasks relative to just using the normalized expression data directly (Figure 5).
Semi-supervised representation learning
There are a variety of reasons that unsupervised representation learning can fail to discover features that are useful for downstream predictive tasks. For example, a small but consistent difference in the expression of a gene between two groups (e.g., healthy and diseased) can be used to train a highly accurate predictor. However, if this difference is much smaller than the variance in the expression of other non-predictive genes, then it will be ignored by most unsupervised representation learning algorithms. One way to avoid this problem is use a semi-supervised method to learn the representation.
The goal of semi-supervised representation learning is to derive a common set of features that are useful for multiple downstream predictive tasks. Our semi-supervised model consists of an autoencoder along with a number of logistic regression classifiers, one for each supervised task involved in the training set. The predictors operate on the 512-dimensional latent space embedding of the autoencoder. For any expression vector the autoencoder contributes a reconstruction loss. Furthermore, if there is a predictive task label associated to the expression vector, then the associated linear predictor contributes a classification loss as well. We trained the model to minimize a loss function that was a weighted combination of the autoencoder loss and the supervised loss averaged across each of the predictive tasks. We considered the out-of-sample predictive performance of four representations: the unembedded data, data embedded by a model trained using only autoencoder loss, data embedded by a model trained on the combined autoencoder and supervised losses, and data embedded by a model trained using only the supervised loss. More details are provided in the online Methods.
In order to test the semi-supervised model, we divided the larger labeled training datasets into two halves. The first half of the labeled training datasets were combined with the unlabeled data from the training set and used to train the semi-supervised autoencoder. The second half of the training datasets were held out as validation. We also held out the validation and testing labeled datasets as in the analyses of the representations learned by unsupervised algorithms. This strategy provided two types of validation tasks: those in which the representation was trained on similar data (e.g., from the same study), and those in which the representation had not been trained on similar data. The results are shown in Figure 6. Using the learned features slightly improved median predictive performance on the divided tasks but did not improve predictive performance on the validation and testing tasks used in the previous analyses.
III. DISCUSSION
The hypothesis that gene expression measurements can be combined into higher level features that should be useful for predicting phenotypic characteristics has intuitive appeal. Indeed, we believe that genes act together as coordinated pathways that control cellular processes. Moreover, changes in expression at the tissue level could reflect higher level changes due to differences in cellular composition. As a result, one would expect that it is possible to define useful highlevel features for expression data; this intuition has driven the development of pathway analyses [32–34], gene set analyses [11],[10], knowledge graphs [15],[14], and cell-type deconvolution approaches [35–37] to analyzing transcriptomics experiments. More recently, a number of studies have introduced deep learning methods that aim to discover useful gene, or transcript, combinations that reflect the underlying biology without imposing particular prior knowledge [4],[5],[38– 40]. In theory, these learned representations should provide better predictive performance because they are transferring biological knowledge derived from one dataset to another. In addition, they reduce the dimension of the input data and, as a result, potentially mitigate overfitting. Here, we set out to systematically and rigorously assess the impact of these representations on downstream predictive tasks.
Our key results can be summarized in a few bullet points:
Multivariate predictors outperformed predictors based on the best single gene.
Larger gene sets performed better than smaller gene sets.
CLR and tissue-specific Z-score normalization were better than TPM.
Logistic regression and random forests performed equally well.
Representations derived from unsupervised or semi-supervised methods did not improve out-of-sample performance for phenotype prediction. Based on these key results, we conclude that l2-regularized regression applied to the CLR transformed relative transcript abundances is generally the best choice for predictive analyses using transcriptomics data. The Z-score and Z-ternary normalizations generally perform comparably to CLR, but require the GTEx data as a reference and hence CLR is recommended.
Figures 2-5 present results for the evaluation of unsupervised models on supervised tasks, studying performance as different aspects of the models change. Figure 2 shows how the performance varies across supervised tasks and demonstrates that the recommended model is nearly always one of the better performing models. Figure 3 presents the relative performance for the choices of normalization and gene set, showing that using larger gene sets improves performance on survival tasks. Figure 4 presents the performance across supervised tasks for different choices of the supervised model, showing that random forest and logistic regression models perform well. Figure 5 shows the relative performance across different unsupervised models, divided by gene set, demonstrating that supervised models on unembedded data for all genes are the best performing. All supervised evaluation results are recorded and further visualized in the appendices (Results Table and Figure 7). Taken together, these motivate the choice of the recommended model.
Our first conclusion, that multivariate predictors outperform predictors based on single gene expression measurements, was the clearest cut. This has some practical consequences when combined with our other conclusion that larger gene sets are better, especially for the fitting of proportional hazards models used for survival analyses. First, using multivariate predictors on large gene sets means that the number of covariates will almost always vastly outnumber the subjects in a study. Therefore, it is absolutely necessary to regularize these models by adding penalties to the coefficients. Moreover, nested cross validation should be used for all performance assessments to mitigate overfitting to hyperparameter choices and to minimize variance in the performance metric. Second, it is often impractical —or even impossible— to fit these models using standard methods on typical computing architectures. For example, open source packages for survival analyses typically use second-order methods to optimize the objective function. This works for a single gene, but fitting the multivariate model requires computing a very large matrix of second derivatives, e.g., 56,000 × 56,000 in this study. As a result, it was necessary to implement first-order optimization methods and perform most of the matrix operations using graphical processing units to make the survival analyses in this study feasible.
Overall, we found that choices such as the normalization method, the gene set, the type of supervised prediction algorithm, and the use of a learned representation made surprisingly little impact on out-of-sample predictive performance. Moreover, we could not identify any clear trends. For example, it is not necessarily better to use smaller gene sets or other lower dimensional representations for studies with smaller sample sizes. In light of these results, it is not clear that features derived from either prior knowledge or from representation learning methods have much value in the analyses of bulk RNA-seq data. If the relationship between bulk gene expression and phenotype is not one-to-one, then there is already a limit on how well one could predict phenotype from gene expression. Relatively simple methods may be already very close to this limit. Improving ‘omics-based phenotype prediction is likely to be contingent on other factors including reduction of systematic errors in sequencing data, incorporation of other data types such as single-cell sequencing and proteomics, and improved use of prior knowledge.
IV. METHODS
The analysis presented here and depicted in Figure 1 is a multi-step procedure, starting from read counts data in the recount2 database and ending at performance metrics for various models. There are principally three stages: dataset preparation, unsupervised model training, and supervised model training.
Dataset preparation
The recount2 database [7] is a repository of transcriptomics data sourced from over 2000 independent transcriptomics experiments. The transcriptomics data from these experiments has been reprocessed using a uniform processing pipeline, forming a single dataset amenable to large scale computational analyses. Such analyses would otherwise be problematic due to systematic differences between the original processing pipelines. The data in recount2 consists of counts of gene reads as well as exon-level quantifications. Our study concerned the gene counts data exclusively.
The data comprising recount2 can be divided into three broad groups according to their sources: GTEx, TCGA, and SRA. The GTEx group was sourced from the Genotype Tissue Expression program and contains 9538 samples from healthy individuals across 30 tissue types. The TCGA group was sourced from the Cancer Genome Atlas project and contains 11284 samples from individuals with cancer across 21 tissue types. In that group samples were taken from tumor sites as well as normal tissue adjacent to tumor (NAT) sites. GTEx and TCGA are each single, large collaboration projects with high quality control standards and protocols for sample processing. Metadata for these projects is extensive. The SRA group contains 49638 samples from 2033 smaller, distinct experiments collected in the Sequence Read Archive. Metadata for experiments in SRA are sparser, with tissue labels occasionally absent.
In total, 70460 samples were available in recount2 when the database was downloaded. However, many of these samples are not ideal for representation learning with transcriptomics data. We developed a quality control (QC) pipeline to remove samples or entire SRA studies. The number of samples remaining after the QC pipeline is 39848. The QC steps are as follows:
Remove samples in which the reported cell type is a cell line. 9644 samples fit this criterion.
Remove studies in SRA from single-cell sequencing. Examining metadata from GEO, 38 studies in SRA with 5865 samples in recount2 have single-cell transcriptomic data.
Remove samples in which the reported tissue does not match any tissue in GTEx. 6824 samples fit this criterion (see Z-score normalization later).
Remove samples in SRA which have duplicate GEO accession numbers (GSMs). There were 9601 samples that met this criterion.
Remove samples in which more than 30% of genes listed in the Gene Ontology (GO) 12 under the “biological process” or “molecular function” categories have a counts value of 0. 15390 samples met this criterion.
The number of matching samples in each step are non-exclusive, meaning a sample can match more than one of the exclusion criteria. These effect of these exclusion criteria are depicted in the supplementary figures (Figure 8). In total we removed 30612 samples, approximately 43% of the total. No GTEx samples were removed; and only 521 TCGA samples were removed.
It is useful to present some detailed commentary on the duplicate GEO accession number criterion. We observed that several samples have duplicate GSMs, and that many such samples had the same number of reads (a round number, e.g., 8 million). This suggests that the individual samples could be chunks of reads from the same underlying sample. However, we could find no satisfying reason for duplicate GSMs or the round number of read counts for these samples, and therefore excluded them from the dataset.
The QC pipeline determines which samples are admitted to the final dataset; there are also choices to be made about which genes to consider in the analysis, and which normalizing procedures to apply to the expression data.
We considered three different gene sets in the analysis:
all genes: (57992 genes).
O genes: genes in GO under the “biological process” or “molecular function” categories (17970 genes at the time of dataset creation).
OT genes: O genes also labeled as transcription factors 29 (1530 genes at the time of dataset creation).
In addition, we considered four different normalizing transformations of the counts data:
TPM: The counts are transformed into transcripts per million (tpm), which account for gene length to normalize reads. The TPM value is determined in terms of the counts as,
CLR: A centered-log-ratio transform is carried out on the TPM vectors. The CLR value is determined in terms of the TPM values as,
Z-score: The Z-score transform is carried out on the CLR features. The Z-score is the CLR value standardized by the mean expression of a gene in healthy tissue, determined by the GTEx samples for the same tissue. The Z-score value is determined in terms of the CLR value as,
Z-ternary: The Z-ternary transform is carried out on the Z-score features. The Z-score values are ternarized based on their value, and the ternarization indicates whether the gene’s expression is increased, decreased, or unchanged relative to the mean expression in healthy tissue. Since the distribution of Z-score values is expected to be approximately normal for healthy tissue, any Z-score value below −2 is assigned the Z-ternary value of −1; any Z-score value above 2 is assigned the Z-ternary value of 1, and any Z-score value between −2 and 2 is assigned the Z-ternary value of 0.
We made use of the open-source python library genemunge [41] for making these normalizations and selecting the gene sets. Each of the normalizations are carried out on the expression data for all genes. Whenever a smaller gene set is used, the values of the features for the selected genes are simply taken from the data for all genes. The three gene sets and four normalizations yield twelve different datasets that are used in the analysis.
Tasks and dataset allocation
The above procedure describes the preparation of the gene expression datasets. In addition to the expression data, some samples have one or more labels suitable for predictive modeling. TCGA has rich metadata with natural label types, available in the TCGA Pan-Cancer Clinical Data Resource [23]; some SRA studies also contain useful metadata in GEO [22] relevant to human disease. From the TCGA and recount2 metadata we selected four categories of predictive tasks: binary labels for the grade of a tumor in various cancer types (8 tasks); binary labels for the stage of a tumor in various cancer types (10 tasks); times for overall survival in various cancer types (13 tasks); and times for progression free interval in various cancer types (13 tasks). From the GEO metadata we selected binary and multiclass labels for various clinical characteristics (6 tasks) [42–47]. In total there are 50 tasks for which supervised models may be built.
We divided the 50 predictive tasks into three groups, “training”, “validation”, and “test”. We then built a “training” gene expression dataset consisting of any samples with a label in the training task group, as well as any samples with no label. This dataset, which has 36794 gene expression samples, was used to train the unsupervised models. A sample’s inclusion in this dataset distinguishes the training and validation task groups. Both the training and validation tasks were used in the analysis, whereas tasks in the test group were held out until the end of the project so that no model selection criteria might influence performance on these tasks in any way. The supervised tasks are summarized in the appendix (Tables I,II,III,IV).
Unsupervised models
Three different types of unsupervised models were trained on the gene expression datasets: principal components analysis (PCA), stacked denoising autoencoders (SDAE), and variational autoencoders (VAE).
Principal components analysis (PCA)
The problem of finding the kprincipal components of a suitably large collection of vectors of dimension n admits an analytic solution. But the computation required to perform this calculation is in O(n3), making it intractable in high dimensions. Due to the high dimension of the larger gene sets (>17k), we performed the PCA analysis via stochastic gradient descent following the algorithm introduced by Arora et. al. [48] called “Stochastic Approximation.” The leading 512 principal components were retained.
Stacked denoising autoencoders (SDAE)
We employed denoising autoencoder architectures of “hourglass” shape with seven layers. The hourglass narrows to a middle layer of 512 dimensions, yielding a 512-dimensional encoder. The details of the architecture are recorded in the appendix. The models were trained with stochastic gradient descent to minimize the mean squared reconstruction loss. We found that pre-training the models layerwise before end-to-end training produced the best results. Therefore these models are best described as stacked denoising autoencoders per the original presentation [49]. The models were regularized by input noise variance and an l2 weight penalty, with these hyperparameters selected by sweeping a range.
Variational autoencoders (VAE)
We also included a deep generative model among our unsupervised model types, the variational autoencoder [50]. In particular, we employed the methods of Klambauer, et. al. [51] which make use of self-normalizing units, SNNs, for improved training dynamics and representational capability. We trained the models using the KL-annealing method of Bowman et. al. [52] during the first 100 epochs and then let training proceed with the normal loss function for the remaining epochs. The layer dimensions are recorded in the appendix. The latent encodings consist of 512 dimensions for the distributional means and 512 for the distributional log variances. Therefore the trained model’s feature encoder is the restriction to the 512 dimensions of the means variables.
The “no-embedding” model
In addition to these unsupervised models, we also employed a kind of control comparator: a “no-embedding” model which does nothing to the expression data. The dimension of the gene expression data is not reduced under the no-embedding model; the features are the normalized gene expression vectors themselves.
Computational constraints
We trained the PCA model on each of the four normalizations for each of the three gene sets. Due to computational constraints we applied the SDAE and VAE models to each of the four normalizations for the O and OT gene sets excluding the all genes set. The lack of an improvement in performance on smaller gene sets indicated the dataset with all genes was unlikely to provide quality embedding models.
Supervised models
We evaluated the ability of a unsupervised model to learn useful representations across transcriptomics data by assessing the performance of supervised models operating on the learned representations. For each unsupervised model and predictive task, we trained and evaluated supervised models using nested cross validation. The performance of these predictive models gave an indication of how well the learned representation captured features in the data useful for various kinds of phenotype prediction. Before presenting the different kinds of supervised models, we present a small primer on nested cross validation.
Nested cross validation
Nested cross validation is designed to provide a robust estimate of the expected (predictive) model performance on new data, optimizing over a set of hyperparameter values (such as the maximum depth in a random forest). In nested cross validation, there are two loops over the data, the outer and inner loop. The inner loop is used to select an optimal hyperparameter value, and the outer loop is used to estimate the performance of the model with this hyperparameter value. In the outer loop, data is divided evenly into K groups, or folds (we use K = 5). For each fold, the data for that fold is held out and the remaining K − 1 folds are used for the inner loop. In the inner loop, this data is divided into K folds, and on each fold the data for that fold is held out and the model is trained on the remaining K − 1 folds for each hyperparameter value. The held-out fold is used to estimate the model performance for each hyperparameter value, and this performance is averaged over all folds in the inner loop. The best performing hyperparameter value is selected, and the model is re-trained on all data used in the inner loop. The model performance is then evaluated on the held out data from the outer fold. This value is averaged over all folds in the outer loop, and this final average is the estimated model performance. Note that a different optimal hyperparameter may be selected for each outer fold. Nested cross validation is resistant to hyperparameter overfitting, as the model is evaluated on data completely held out from the process of selecting the optimal hyperparameter. With this robustness comes increased computational complexity —if there are N hyperparameter values tested, nested cross validation requires training K(KN + 1) individual models.
Classification tasks
For classification tasks we applied three different types of supervised models:
Logistic regression (LR). Logistic regression with an l2 penalty, trained via stochastic gradient descent. The logistic regression model is a single layer neural network with a softmax activation on the output. The hyperparameter optimized was the l2 penalty, logarithmically spaced between 10−6 and 103 in ten steps. The model was implemented in pytorch [53]. Hyperparameters and training notes are provided in the appendix.
Random forest (RF). Random forest models with 100 trees per forest. The hyperparameter optimized was the maximum depth of the random forest, logarithmically spaced between 2 and 27 in seven steps. We relied on the scikit-learn [54] implementation of random forest.
K-nearest neighbors (kNN). The hyperparameter optimized was the value of k, the number of neighbors used, taking a value of 1, 3, 5, 7, or 9.
Survival tasks
For survival tasks, in which the overall survival time or the progression free interval time were predicted, we trained a Cox proportional hazard (CPH) model. The standard solvers for CPH models use second-order methods, such as versions of Newton’s method, making them unsuitable for use with a large number of features. The computation time required for the 512-dimensional embedding, using nested cross validation, is already immense; training CPH models on data without an embedding is completely impractical. Instead, we implemented a CPH model in pytorch, and trained it via stochastic gradient descent by backpropagating through the Cox-Efron pseudolikelihood[55]. Such models can be trained with a large number of features —even all genes— and can be GPU accelerated. We regularized these models with an l2 penalty whose strength, logarithmically spaced between 10−6 and 103 in ten steps, was optimized in the inner cross-validation loop. Even with this computational speedup, evaluating the survival tasks requires the bulk of compute time. It bears noting that these models were still trained with a fixed initial learning rate which was small enough to guarantee controlled gradient descent across all tasks. It is certain that absolute performance on individual contrasts could be improved by also optimizing the learning rate in the nested cross validation. However, because the study concerns the relative performance of this algorithm across gene sets, embeddings, and normalizations, we avoided this additional multiplier on the computational time.
Single-gene comparators
All of the above supervised models are trained on features from multiple genes. In order to compare our embedding models to single-gene analysis, we also trained a set of models on single genes with no-embedding model. For these models the hyperparameter optimized in the inner cross validation loop is the gene selected for the model. We had no need to run these comparators across all transform/predictor combinations so we restricted these examples to clr-transformed data and used only the univariate logistic regression models for classification tasks. For survival tasks, single-gene CPH models were trained. No regularization term was necessary in either case because these models have so few parameters. These results provide a direct comparison to the multi-gene, CLR-transformed, no-embedding results.
Components of supervised model results
In total, we evaluated a very large number of (multi-gene) supervised models. There are five different characteristics of a single result:
Task. There are 24 classification tasks and 26 survival tasks.
Gene set. There are three gene sets, all genes, O genes, and OT genes.
Normalization. There are four data normalizations, TPM, CLR, Z-score, and Z-ternary.
Unsupervised model. There are four types of unsupervised model, PCA, SDAE, VAE, and no-embedding. SDAE and VAE were only trained on the O and OT gene sets.
Supervised model. For classification tasks, three different supervised models were trained, LR, RF, and kNN. For survival tasks, a CPH model was trained.
This amounts to 3920 results. In terms of individual models trained during nested cross validation, there are 807600 models.
Semi-supervised models
The semi-supervised models are designed to learn a feature embedding which is co-adapted to the purpose of reconstruction as well as the performance of supervised models operating on the embedding. Each of the semi-supervised models consists of a linear, single-hidden layer autoencoder coupled to a number of logistic regression predictors —one for each supervised task involved in the training dataset. The supervised predictors operate on the autoencoder’s 512-dimensional encoding. Both a schematic diagram of the model architecture and the details of the architectures are recorded in the supplementary figure (Figure 10).
Data preparation
In order to be able to assess the performance of semi-supervised representation learning within-task, we had to further subdivide some of the labeled expression data. In particular, we subdivided into two halves the binary predictive tasks within the “training set” which contained at least 200 samples. The first half was used in the training of the semi-supervised model; the second half was held out for validation. We called these sets the “divided tasks”; they were drawn from the following binary tasks:
{CESC grade, COAD stage, KIRC grade, KIRC stage, LGG grade, LIHC stage, LIHC grade, LUAD stage, SKCM stage, STAD stage, STAD grade, THCA stage, UCEC stage, UCEC grade}.
The rest of the tasks which constituted the original “validation” and “test” sets were used for validation. So the training set for each semi-supervised model consisted of all expression data from the first halves of the fourteen divided tasks along with their associated binary labels.
Training of semi-supervised models
Given any sample expression vector x from the training set we can compute the autoencoder reconstruction loss on that sample, specifically as the squared reconstruction error,
Supposing that x has a class label lx ∈ {0, 1} from the jth predictive task, we can also compute a classification error of the associated binary logistic regression classifier Pj,
We trained our semi-supervised models (via stochastic gradient descent) to minimize a convex combination of these two error terms. The constant controlling the interpolation of these two losses we called the “predictor strength,” π, which ranged from 0 to 1. Our training algorithm allowed different batch sizes for the autoencoder loss and the predictor losses; let these be denoted by BR, and BC, respectively. Let {xm}, {xn, lxn} be batches drawn randomly from the training data, the first consisting of only expression vectors, the second containing both expression vectors and paired class labels. Our loss term takes the form,
Here, J denotes the total number of predictive tasks. The last two terms are l2 weight penalties on the model parameters; these are controlled by adjustable constants λAE and λP.
We compared three scenarios for the predictor strength in our analysis,
π = 0.0, i.e. the model is an autoencoder only.
π = 0.1, the model is trained with a mixture of both losses.
π = 1.0, i.e. the model is a purely supervised shared-embedding model.
We also compared results to a no-embedding model as a kind of control.
For each gene set, data normalization, and predictor strength scenario, we performed a sweep over all 16 pairs of values for λAE and λP in the cartesian product {0, 0.1, 0.01, 0.001}s2. We selected the l2 coefficient pair which minimized average error on the held-out half of the divided contrasts.
Finally, we assessed the performance of predictive models (across all three types, LR, RF, kNN) operating on the learned data embedding to compare the effect of semi-supervised representation learning across these three scenarios. Those results are displayed in Figure 6 in the main text.
DATA AVAILABILITY
The gene expression data used to train the unsupervised models and the gene expression data and labels used to train and evaluate the supervised models are available on figshare, and can be accessed at datasets. These data include the gene expression data for all three gene sets and all four normalizations, and the accompanying supervised labels.
CODE AVAILABILITY
The code used to process the data and train the unsupervised, supervised, and semi-supervised models is a combination of open source and proprietary, licensable software. The open source software made available on GitHub, accessible at representation learning for transcriptomics, is a framework to carry out nested cross validation and includes an example to train the recommended model on a particular supervised task.
AUTHOR CONTRIBUTIONS
D.Z. and C.K.F. conceived of the analysis. A.M.S., J.R.W., and J.L. carried out the analysis. A.M.S., J.R.W., and C.K.F. wrote the manuscript. All authors participated in the design of the analysis, the interpretation of its results, and revisions to the manuscript.
COMPETING INTERESTS
A.M.S, J.R.W, and C.K.F are affiliated with Unlearn.AI, Inc., a company that creates software for clinical research, and hence may have competing financial interests. J.L., C.D., P.H., M.H., M.M., J.M., S.R., S.Z., and D.Z. are employees of Pfizer and may own stock or stock options in Pfizer.
Footnotes
↵∗ drams{at}unlearn.ai
References
- [1].↵
- [2].↵
- [3].↵
- [4].↵
- [5].↵
- [6].↵
- [7].↵
- [8].↵
- [9].↵
- [10].↵
- [11].↵
- [12].↵
- [13].↵
- [14].↵
- [15].↵
- [16].↵
- [17].↵
- [18].
- [19].
- [20].↵
- [21].↵
- [22].↵
- [23].↵
- [24].↵
- [25].↵
- [26].↵
- [27].↵
- [28].↵
- [29].↵
- [30].↵
- [31].↵
- [32].↵
- [33].
- [34].↵
- [35].↵
- [36].
- [37].↵
- [38].↵
- [39].
- [40].↵
- [41].↵
- [42].↵
- [43].
- [44].
- [45].
- [46].
- [47].↵
- [48].↵
- [49].↵
- [50].↵
- [51].↵
- [52].↵
- [53].↵
- [54].↵
- [55].↵
- [56].↵
- [57].↵
- [58].↵