Cox-nnet: an artificial neural network method for prognosis prediction on high-throughput omics data

Artificial neural networks (ANN) are computing architectures with massively parallel interconnections of simple neurons and has been applied to biomedical fields such as imaging analysis and diagnosis. We have developed a new ANN framework called Cox-nnet to predict patient prognosis from high throughput transcriptomics data. In over 10 TCGA RNA-Seq data sets, Cox-nnet achieves a statistically significant increase in predictive accuracy, compared to the other three methods including Cox-proportional hazards (Cox-PH), Random Forests Survival and CoxBoost. Cox-nnet also reveals richer biological information, from both pathway and gene levels. The outputs from the hidden layer node can provide a new approach for survival-sensitive dimension reduction. In summary, we have developed a new method for more accurate and efficient prognosis prediction on high throughput data, with functional biological insights. The source code is freely available at github.com/lanagarmire/cox-nnet.


Introduction
5 comparison with the other methods mentioned above (Cox-PH, RF-S and CoxBoost), Cox-nnet has better 56 overall predictive accuracy. It is also optimized on graphics processing unit (GPU) with at least an order 57 of computational speed-up over the central processing unit (CPU), making it a compelling new tool to 58 predict disease prognosis in the era of precision medicine. Second, Cox-nnet utilizes feature importance 59 scores based on the partial derivatives of gene features selected by the model, so that the relative 60 importance of the genes to prognosis outcome can be directly assessed. Thirdly, the hidden layer node 61 structure in ANN can be harnessed to reveal much richer information of featuring genes and biological 62 pathways, compared to the Cox-PH method. Overall, Cox-nnet is a desirable survival analysis method 63 with both excellent predictive power and usage to gain biological functions related to prognosis. 64

65
The Cox model 66 The Cox-PH model is a log-linear model that estimates individual hazard, i.e., an instantaneous measure 67 of the likelihood of an event, based on a set of features. The hazard is given by the equation: 68 ( 1 ) Where is the log hazard ratio for patient . The partial likelihood is represented by the following 69 formula: 70 ( 3 ) Where is the censoring status of a patient, and if the patient was censored or 1 if the 71 patient died or had a recurrence event, etc. The partial log-likelihood is used as the cost function: 72 In a Cox model with L2 ridge regression, a penalty term is added which is proportional to the L2 norm of 73 the coefficients. The cost function is minimized to find the best coefficients for the model: 74 ( 5 ) where the tuning parameter is determined by maximizing CV. 75 The cross-validated performance metric may be Harrel's concordance index (C-index) 10 or the "cross-76 validated partial likelihood" 11 . Since the contribution of each patient in the partial likelihood is 77 determined only in the context of all the other patients, the cross-validated partial likelihood is calculated 78 subtracting full partial likelihood from the training set in the CV. In the k-th iteration of a K-fold CV, the 79 optimal coefficients are found by minimizing the cost function on the training sub-samples. If 80 is the partial likelihood of the training sub-samples, and is the partial likelihood of 81 the full dataset, then the cross-validated partial likelihood is the sum of differences: 82 ( 6 ) ANN extension of Cox regression 83 The ANN extension of Cox regression (Cox-nnet) is a neural network whose output layer is replaced by a 84 Cox model. In a Cox-nnet model with one input layer of input features and one hidden layer composed 85 of hidden nodes, the linear predictor is replaced by the outputs of the hidden layer: 86 Where is the coefficient weight matrix between the input and hidden layer with the size H x J, is 87 the bias term for each hidden node and is the activation function (applied element-wise on a vector). 88 Subsequently, the ridge regression cost function is modified to: 89 ( 8 ) In this manuscript, the tanh activation function is used, as it results in faster training time compared to the 90 sigmoid activation 12 . The tanh function is: 91 In addition to ridge regularization, we also employ dropout regularization 13 . In this approach, nodes are 92 removed during each training iteration with probability 1-p. During evaluation, output from the nodes are 93 multiplied by p. The optimal dropout parameter, p, is determined through cross-validation on the training 94 set. Dropout regularization has been shown to reduce overfitting and improve performance over other 95 regularization schemes 13 . 96 The source code of cox-nnet can be found at: https://github.com/lgarmire/cox-nnet, and can be installed 97 through the Python Package Index (PyPI). Documentation of package can be found at 98 http://lgarmire.github.io/cox-nnet/docs. 99

100
We implement Cox-nnet using a feed forward, back propagation network with gradient descent. The 101 partial log likelihood is usually written as a double conditional sum (equation 4). To avoid the 102 computational inefficiency of calculating the partial log likelihood (equation 4) using two nested for 103 loops, we convert it into a formulation of matrix operations and basic sums. First we define an indicator 104 matrix with elements: 105 We also define an indicator vector with elements given by the censoring of each patient. An operation 106 using replaces the conditional sum over , and an operation using replaces the conditional 107 sum over in equation 4. In Theano, the partial log likelihood is: 108 pl=T.sum((theta -T.log(T.sum(T.exp(theta) * R,axis=1))) * C) ( 11 )

109
To evaluate the performance of all methods in comparison, we trained each model on 80% of the samples 110 for each dataset (chosen randomly) and evaluated the performance on the 20% holdout test set. The 111 output of Cox-PH, Cox-nnet and CoxBoost are the log hazard ratios (i.e., Prognosis Index, or PI) for each 112 patient. The hazard ratio describes the relative risk of a patient compared to a non-parametric baseline. 113 On the other hand, the output of RF-S is an estimation of the survival time for each patient. 114 We use C-index and log-ranked p-value based on dichotomization of the hold-out test data of the holdout 115 test data to measure the performance of each model. The C-index is a measure of how well the model 116 prediction corresponds to the ranking of the survival data 14 . It is calculated for censored survival data, 117 which evaluates a value between 0 and 1, with 0.5 equivalent to a random process. The C-index can be 118 computed as a summation over all events in the dataset, whereby patients with a higher survival time and 119 lower log hazard ratios (and conversely patients with a lower survival time but higher log hazard ratios) 120 are considered concordant. The C-index is a measure of concordance of the data with the model 121 prediction. To calculate the log-ranked p-value, a PI cutoff threshold is used to dichotomize the patients 122 in the data set into higher and lower risk groups, similar to our earlier report 15,16 . A log-ranked p-value is 123 then computed to differentiate the Kaplan-Meier survival curves between the higher vs. lower risk groups. 124 In this report, we used the median log hazard ratio as the cutoff threshold. 125

Feature evaluation
126 For computing the importance of a feature in Cox-nnet, we use a method of partial derivatives (PaD) 17,18 . 127 For each patient, we compute the partial derivatives of each input with respect to the linear output of the 128 model (e.g., the log hazard ratio). The average of the partial derivatives for each input across all patient 129 samples is calculated as the feature score. 130

131
In order to evaluate the performance of Cox-nnet, we analyzed 10 TCGA datasets which were combined 132 into a pan-cancer dataset. The TCGA datasets included the following cancer types: Bladder Urothelial and censoring information were extracted from the clinical follow-up data. Raw count data were 139 normalized using the DESeq2 R package 20 and then log-transformed. Datasets were selected from 140 TCGA based on the following criteria: > 300 samples with both RNASeq and survival data and > 50 141 survival events. In total, 5031 patient samples were used (see Table S1 for a patient tabulation by 142 individual dataset). 143

Cox-nnet structure and optimization
145 Cox-nnet is the neural network extension of the Cox-PH model. We created a package suitable for high 146 dimensional datasets using the Theano math library in Python. The neural network model used in this 147 paper is shown in Figure 1 and an overview of modules in the Cox-nnet package is shown in Figure S1. 148 As a proof of concept, the current ANN architecture is composed of three layers: one input layer, one 149 fully connected hidden layer and an output "Cox regression" layer. The output layer of Cox-nnet replaces 150 the linear predictors in the standard Cox-PH model. Many other functions are implemented to improve 151 the usability of the package, including CVSearch, CVProfile, CrossValidation, and TrainCoxMlp. 152 CVSearch, CVProfile, CrossValidation are methods that perform CV to find the optimal regularization 153 parameter. TrainCoxMlp performs optimization of coefficients on the regularized partial likelihood 154 function. The optimization strategies include momentum gradient descent 21 , Nesterov accelerated 155 gradient 22 and Ada Delta 23 . A comparison of these descent methods is shown in Figure S2A, where 156 Nesterov accelerated gradient method achieved the best efficiency based on TCGA kidney renal clear cell 157 carcinoma (KIRC) data. Moreover, this package can be run on multiple threads or a Graphics Processing 158 Unit (GPU), and it achieves slightly faster training time compared to Random Forest and CoxBoost 159 ( Figure S2B). Thus, Cox-nnet is a modern software implementation that can achieve efficient 160 computational time. 161

162
We compared four methods, including Cox-nnet, Cox-PH, CoxBoost and RF-S, on 10 datasets from The 163 Cancer Genome Atlas (TCGA), which were selected based on having at least 50 death events (Table S1). 164 For each dataset, we trained the model on 80% of the randomly selected samples and determined the 165 regularization parameter using 5-fold CV on the training set. We used two types of regularizations, L2 166 ridge regularization (also known as weight decay) and dropout regularization. We evaluated the 167 performance on the remaining 20% holdout test set. Two metrics are used to evaluate the performance of 168 the model. The first one is Harrell's concordance index (C-index) calculated for censored survival data 169 10,24 . It evaluates the relative ordering of the samples and ranges between 0 and 1, with 0.5 equivalent to a 170 random process. The second metric is the log-ranked p-value from Kaplan-Meier survival curves of two 171 different survival risk groups. This is done by using the median threshold of Prognosis Index (PI), the 172 output of Cox-nnet, to dichotomize the patients into higher and lower risk groups, similar to our earlier 173 11 reports 15,16,24 . A log-ranked p-value is then computed to differentiate the Kaplan-Meier survival curves 174 from these two groups. 175 The comparison of C-indices among the four methods over the 10 TCGA data is shown in Figure 2A. 176 Overall, Cox-nnet has higher predictive accuracy over the other three methods, regardless of the 177 regularization method. Cox-PH performs the second best, followed by CoxBoost and RF-S in descending 178 order ( Figure 2B). The comparison of log-ranked p-values on the dichotomized survival risk groups is 179 shown in Figure S3. Generally, log-ranked p-values in the 10 TCGA datasets are more significant in  nnet, compared to other methods. However, the dichotomization of patients ignores the differences within 181 each dichotomized group, thus the resulting log-ranked p-values are less consistent than C-indices on the 182 same data. 183

184
To explore the biological relevance of the hidden nodes of Cox-nnet, we used the TCGA KIRC dataset as 185 an example. We first extracted the contribution of each hidden node to the PI score for each patient 186 ( Figure 3A). The contribution was calculated as the output value of each hidden node weighted by the 187 corresponding coefficient at the Cox regression output layer. As expected, the value of the hidden nodes 188 strongly correlated to the PI score. However, there is still significant heterogeneity among the nodes, 189 suggesting that individual nodes may reflect different biological processes. We hypothesize that the top 190 nodes may serve as surrogate features to discriminate patient survival. To explore this idea, we selected 191 the top 20 nodes with the highest variances, and presented the patients PI scores using t-SNE, a popular 192 method to enhance the separation among samples 25 . The nodes represent a dimension reduction of the 193 original data and clearly discriminate samples by their PI scores ( Figure 3B). In contrast, the top 20 194 principle components obtained from principal component analysis (PCA) in combination with t-SNE fail 195 to separate the patient samples ( Figure 3B). This drastic difference demonstrates that the nodes in  nnet effectively capture the survival information, and the top node PI scores can be used as features for 197 dimension reduction in survival analysis.

12
To further explore the biological relevance of the top 20 hidden nodes, we conducted Gene Set 199 Enrichment Analysis (GSEA) 26 using KEGG pathways 27 . We calculated significantly enriched pathways 200 using gene correlation to the output score of each node ( Figure 3C and Table S2), and compared these 201 enriched pathways to those from GSEA of the Cox-PH model (Table S3). To calculate statistical 202 significance of the pathways, we performed 10,000 permutations, followed by multiple hypothesis testing 203 with Benjamini Hochberg adjustment. A total of 110 (out of 187) significantly enriched pathways (Table  204 S2) were identified in at least one node, including seven pathways enriched in all 20 nodes that were not 205 found by the Cox-PH method (Table 1). In contrast, Cox-PH only identified 30 significantly enriched 206 pathways using the same significance threshold. Among the seven pathways, the P53 signaling pathway 207 stands out as an important biologically relevant pathway (Figure 4 and Figure S4), since it was shown to 208 be highly prognostic of patient survival in kidney cancer 28 . 209 Next, we estimated the predicative accuracies of the leading edge genes (LEG) enriched in the KEGG 210 pathways from Cox-nnet vs. those enriched in Cox-PH model. We used the C-index of each LEG, 211 obtained from single-variable analysis (Figure 4). Collectively, LEGs from Cox-nnet have significantly 212 higher C-index scores (p = 5.79e-05) than those from Cox-PH, suggesting that Cox-nnet has selected 213 more informative features. In order to visualize these gene level and pathway level differences between 214 Cox-nnet and Cox-PH, we reconstructed a bipartite graph between LEGs for Cox-nnet or feature genes 215 (for Cox-PH) and their corresponding enriched pathways ( Figure 5). Besides P53 pathway mentioned 216 earlier that is specific to Cox-nnet, several other pathways, such as insulin signaling pathway, endocytosis 217 and adherens junction, also have many more genes enriched in Cox-nnet. Among them, some have been 218 previously reported to relevant to renal carcinoma development and prognosis, such as CASP9 29 , 219 TGFBR2 30 , KDR (VEGFR) 31 . These results demonstrate that Cox-nnet model reveals richer biological 220 information than Cox-PH. 221 To further examine the importance of each gene relative to the survival outcome, we calculated the 222 averaged partial derivative (PaD) of each input gene feature over all patients, with respect to the linear 223 13 output of the model (e.g., the log hazard ratio). As demonstrated by the LEGs in seven common pathways 224 of all nodes in Cox-nnet, the feature importance scores produce stronger biological insight ( Figure S4). 225 For example, the feature importance for the BAI1 gene in the P53 pathway is much higher in the Cox-226 nnet model compared to the Cox-PH model. Corresponding to our finding, the BAI gene family was 227 found to be involved in several types of cancers including renal cancer 32 33 34 35 . BAI1 acts as an inhibitor 228 to angiogenesis and is transcriptionally regulated by P53 36 . Its expression level was significantly 229 decreased in tumor vs. normal kidney tissue, and was even lower in advanced stage renal carcinoma 35 . 230 Mice kidney cancer models treated with BAI1 showed slower tumor growth and proliferation 37 . 231 Additionally, the MAPK1 gene (also known as ERK2) has a much higher feature importance score in 232 Cox-nnet compared to Cox-PH, and is annotated in the Adherens Junction pathway as well as the Insulin 233 Signalling Pathway found by Cox-nnet. MAPK1 is one of the key kinases in intra-cellular transduction, 234 and was found constitutively activated in renal cell carcinoma 38

238
In this report, we have implemented Cox-nnet, a new non-linear ANN method, to predict patient survival 239 from high throughput omics data. Cox-nnet is an improved, modern alternative to the standard Cox-PH 240 regression, as demonstrated by increased performance for survival prediction and the capabilities to 241 explore more deeply the biological information. 242 First, through in-depth comparison of 10 TCGA RNA-Seq, Cox-nnet achieves overall statistically 243 significant improvements over Cox-PH on its predictive accuracy, as measured by C-indices. 244 Interestingly, the ensemble-based method RF-S consistently ranks worse than Cox-nnet and Cox-PH. 245 Because RF-S bootstraps both samples and features for individual trees, many uninformative features in 246 each tree may be chosen for node splitting in particularly high dimensional datasets, leading to a decrease 247 14 in overall accuracy 40 . In contrast, the dropout and L2-regularization approach used by both Cox-nnet and 248 Cox-PH can prune out uninformative features. 249 Second, Cox-nnet can reveal a lot richer biological information than Cox-PH. This is manifested both at 250 the pathway and gene levels. As a promising new predictive method for prognosis, the current Cox-nnet implementation has some 259 limitations. Its architecture includes 3-layer ANN, and it is possible to incorporate other more 260 sophisticated architecture into the model, such as including more layers of neurons. A convolutional 261 neural network approach using convolutional and pooling layers could also be used, as those reported in 262 processing imaging or other types of positional data 41 . Additionally, it is possible to embed a priori 263 biological pathway information into the network architecture, e.g., by connecting genes in a pathway to a 264 common node in the next hidden layer of neurons. In the future, we plan to further analyze how different 265  1995 (1995).  Leading edge genes found uniquely in Cox-nnet are labeled in orange, and genes found in both Cox-nnet 389 and Cox-PH are labeled in blue. 390  Figure S1. An overview of the structure, methods and classes in Cox-nnet package. 395