论文 | An Efficient Convolutional Neural Network for Coronary Heart Disease Prediction

    技术2024-06-12  75

    文章目录

    论文信息ABSTRACT问题解决方法效果 1. Introduction2. Data Processing3. Proposed Architecture3.1 LASSO Shrinkage and Majority Voting3.2 CNN Architecture3.2.1 Training Schedule 4. Results4.1 Summary Statistics4.2 Model Results4.3. Comparison of ML models4.3.1. Comparison with state-of-the-art ML models4.3.2. Our LASSO-CNN vs vanilla CNN 3.3 Data Augementation3.3.1 Data augmentation3.3.2 Data undersampling4.3.3 Data oversampling strategies4.3.4. Data undersampling strategies 4.4. Validation on stroke data4.5 Notes on the resilience to data imbalance 5. Conclusion, Limitations and Future Research5.1 Conclusion5.2 Limitation5.3 Future Research

    论文信息

    Title :An Efficient Convolutional Neural Network for Coronary Heart Disease Prediction Journal : Expert Systems with Applications Year: 2020 Author : Aniruddha Dutta, Tamal Batabyal, Meheli Basu, Scott T. Acton

    ABSTRACT

    问题

    目前大多数的机器学习模型对类不平衡数据很敏感,即使调整了特定类的权重。

    解决

    本研究提出一种带卷积层的神经网络模型,对类不平衡的临床数据 - 冠心病进行分类。

    方法

    特征选择:使用基于最小绝对收缩和选择算子(LASSO)进行特征权重评估,并基于多数投票法对重要特征识别。

    模型训练:模型训练过程中,通过使用 fully connected layer 来均质化重要的特征,这是将层的输出传递到连续卷积层之前的关键步骤。

    此外还提出每个 epoch 的 training schedule,类似于模拟退火过程,以提高分类精度。

    效果

    NHANES 数据集存在较高的类别不平衡问题,本文提出的CNN体系结构在正确分类存在冠心病方面的分类能力为77%,在测试数据上准确分类冠心病病例的能力为81.8%,占总数据集的85.70%。 这一结果表明,本文建议的体系结构可以推广到具有类似特征和不平衡顺序的医疗保健领域的其他研究。

    1. Introduction

    Our architecture is simple in design, elegant in concept, sophisticated in training schedule, effective in outcome with far-reaching applicability in problems with unbalanced datasets.

    Contributions:

    our model uses a variable elimination technique using LASSO and feature voting as preprocessing steps;we leverage a shallow neural network with convolutional layers, which improves CHD prediction rates compared to existing models with comparable subjects (the ‘shallowness’ is dictated by the scarcity of class-specific data to prevent overfitting of the network during training);in conjunction with the architecture, we propose a simulated annealing-like training schedule that is shown to minimize the generalization error between train and test losses.

    2. Data Processing

    数据集来自1999-2000年至2015-2016年的NHANES数据。 由37,079名 (冠心病-1300人,非冠心病-35,779人) 的人口统计、检查、实验室和问卷数据组合而成,如图1所示。 Fig. 1 Data compilation from National Health and Nutritional Survey (NHANES). The data is acquired from 1999 to 2016 in three categories – Demography, Examination and Laboratory. Based on the nature of the factors that are considered, the dataset contains both the quantitative and the qualitative variables.

    总共使用了 30 个连续变量和 6 个分类变量来预测冠心病。

    详细列出的变量有:性别、年龄、家庭年收入、贫困家庭收入比、60秒脉率、收缩压、舒张压、体重、身高、体重指数、白细胞、淋巴细胞、单核细胞、嗜酸性粒细胞、嗜碱性粒细胞、红细胞、血红蛋白、平均细胞体积、血红蛋白平均浓度、血小板计数、血小板平均体积、中性粒细胞、红细胞压积、红细胞宽度、白蛋白、碱性磷酸酶(Alp)。乳酸脱氢酶(LDH)、磷、胆红素、蛋白质、尿酸、甘油三酯、总胆固醇、高密度脂蛋白(HDL)、糖化血红蛋白、剧烈运动、适度运动、健康保险、糖尿病、血液相关糖尿病和血液相关中风。

    3. Proposed Architecture

    3.1 LASSO Shrinkage and Majority Voting

    LASSO 或最小绝对收缩和选择算子是一种回归技术,用于变量选择和正则化,以提高其产生的统计模型的预测精度和可解释性。

    LASSO 是一个二项问题,目标是最小化如下目标函数:

    ∑ i = 1 n ( y i − ∑ j x i j γ j ) 2 + λ ∑ j = 1 p ∣ γ j ∣ \sum_{i=1}^n(y_i - \sum_j x_{ij} \gamma_j)^2 + \lambda \sum_{j=1}^p |\gamma_j| i=1n(yijxijγj)2+λj=1pγj

    λ \lambda λ 是收缩量的调整参数,控制正则化惩罚的强度。 λ = 0 \lambda =0 λ=0,不会消除任何参数。随着 λ \lambda λ 的增加,更多系数被设置为0,并消除。 λ \lambda λ 增加,偏差增加, λ \lambda λ 减小,方差增加。变量(因子)的 γ \gamma γ 值可以解释为变量的重要性,即该变量对数据中潜在变化的贡献。 γ \gamma γ为零的变量被认为不重要。

    为减轻不平衡的影响,采用了对数据集进行随机细分采样和多次迭代 LASSO 的策略。对该组 γ \gamma γ 值执行多数投票,以标识在主要迭代次数中非零的变量。假设在 N N N个随机二次抽样数据集上执行 LASSO N N N 次,其中每个 instance 在CHD和非CHD情况下具有相等数量的samples。

    LASSO的第 i i i 个 instance,得到 γ i = [ γ i , 1 γ i , 2 . . . . . . γ i , 45 ] \gamma_i = [\gamma_{i,1}\gamma_{i,2}......\gamma_{i,45}] γi=[γi,1γi,2......γi,45]

    对于任何变量 𝑐,计算变量为非零的instances的个数,并使用手动设置的阈值来决定选择该变量进行进一步分析。

    3.2 CNN Architecture

    Fig. 2 Proposed convolutional neural network architecture. The ‘Input’ is a 1D numerical array corresponding to all the factors/variables from LASSO-Majority Voting preprocessing stage. The ‘Dense’ layer, immediately after the ‘Input’, combines all the factors and each neuron (computing node) at the output of ‘Dense’ layer is a weighted combination of all the variables, indicating a homogeneous mix of different variable types. The next two convolution layers seek representation of the input variables via the ‘Dense’ layer. The next two ‘Dense’ layers are followed by the ‘Softmax’ layer. The last two ‘Dense’ layers (before the ‘Softmax’ layer) can be retrained for transfer learning in case new data is obtained.

    3.2.1 Training Schedule

    在训练过程中,将类别权重比(class weight ratio)定义为CHD数据集与非CHD数据集的比值,并将其调整为类别不平衡的惩罚因子。什么意思呢?比如分类权重比为10:1的意思是说分错一个 CHD 训练样本要比分错一个非CHD样本多出10倍的惩罚。这个惩罚体现在每一次epoch计算error过程中。

    为了进一步减少可能的过度拟合,我们也使用了这个 training schedule。即,对于足够大的epochs,首先以 1:Ν 的权重比训练模型,然后,随着epochs数的稳步下降,逐渐增大权重比。假设实际的类别权重比为 ρ 0 : 1 \rho_0:1 ρ0:1

    4. Results

    4.1 Summary Statistics

    Fig. 3 Correlation table for the independent predictor variables. In this table, moderately strong correlations among few pairs are observed (Glucose and Glycohemoglobin, Red blood cells and Hemoglobin, ALT and AST, Weight and Body-Mass-Index). Rest of the pairs show fairly low correlation values, implying the variables after the LASSO-Majority voting stage are sufficiently decorrelated.

    30个连续型变量的相关关系如下:

    血清丙氨酸氨基转移酶(ALT)和天冬氨酸转氨酶(AST)之间的相关性很高(0.77)。研究表明AST是预测冠心病的主要危险因素,可作为预测冠心病严重程度的生化标志物。体重指数与体重的相关系数为0.89,接近正常。血红蛋白与红细胞的相关系数为0.74。糖化血红蛋白和葡萄糖之间的高相关值为0.79,虽然血红蛋白与临床公认的冠心病的关联研究有限(Chonchol,2008),但文献(Madjid,2013)中对红细胞在冠心病中的作用进行了很好的研究。已有研究表明,非糖尿病患者的高血糖水平会显著增加患冠心病的风险(Neilson,2006)。蛋白质和白蛋白之间的相关性为0.46。据报道,较低的血清白蛋白水平与心血管死亡率和冠心病的增加水平有关(Shaper,2004),而蛋白质水平较高则会增加冠心病的风险(Clifton,2011)。血清乳酸脱氢酶(LDH)与谷草转氨酶(AST)呈正相关(相关系数为0.41),这与以往的研究结果一致,即活跃人群中LDH的升高与冠心病的低风险有关(Kopel Et)。al.,2012)。

    由于一些相关变量的危险因素及其与冠心病的关联的重要性,LASSO回归正确地确定预测变量。

    4.2 Model Results

    Fig. 4 Model accuracy as a function of majority voting threshold. The threshold value of majority voting affects the classification accuracy of CHD as the selection of this value controls the number of variables that are to be channeled to our CNN model. The smaller is the threshold value, the larger is the set of variables. Based on the training loss, training accuracy and test accuracy, the threshold value between 16.67 (100/6) – 20 (100/5) combining 100 instances of LASSO appears suitable for obtaining balanced per class (CHD and Non-CHD) classification accuracy.

    为了达到训练CNN结构的最佳特征数,特征投票的阈值保持在2到8之间,训练网络的最高准确率为83.17%,训练损失为0.489,阈值特征为6,如图4所示,相应的最高测试准确率为82.32%。LASSO 正则化将三个连续预测变量(体重指数、血糖和丙氨酸氨基转移酶)和一个分类变量的系数降低到零,这被确定为高度相关。在阈值为6的情况下,对不同的二次采样数据集分别训练CNN体系结构。二次抽样以不同的比例进行,从1300:13000(冠心病:非冠心病)开始,增加到1300:4000,如表II所示。相应的测试准确率报告为82.32%。

    不断调整majority voting 阈值得到的模型精度的一张图。多数投票的阈值影响冠心病的分类准确性,因为该值的选择控制了要导入到我们的CNN模型的变量的数量。阈值越小,变量集越大。基于训练损失、训练准确率和测试准确率,结合100个套索实例,阈值在16.67(100/6)-20(100/5)之间似乎适合于获得均衡的每类(CHD和非CHD)分类精度。

    The highest accuracy obtained from training the network is 83.17% with a training loss of 0.489 with a threshold feature of 6 。

    Fig. 5 Training and test accuracies with varying misclassification penalties for class 1 and 0. The minimum difference between training and test accuracies is obtained with a class weight of 3:1 (CHD: Non-CHD) and a training loss = 0.489. The model is trained with a constant optimized learning rate of 0.006 and 60 epochs.

    真阳率(recall)= TP / (TP + FN) = 161 / (161 + 47 ) = 77 % 真阴率(sensitivity)= TN / ( FP + TN ) = 25828 / ( 5743 + 25828 ) = 81 %

    在目前的研究中,我们的目标是让我们的分类器比以前的研究更准确地预测冠心病的存在。 正确预测得冠心病的真阳率为 77%, 而无冠心病的真阴率的为 81%。 本方法证实了我们提出的CNN结构在测试数据上对冠心病病例的正确分类能力为77%,占总数据集的85.70%

    4.3. Comparison of ML models

    4.3.1. Comparison with state-of-the-art ML models

    4.3.2. Our LASSO-CNN vs vanilla CNN

    Vanilla是神经网络领域的常见词汇,比如Vanilla Neural Networks、Vanilla CNN等。Vanilla本意是香草,在这里基本等同于raw。

    Our experimental set up is generalized in the sense that setting N α \frac{N}{\alpha} αN as 0 (or, equivalently α = − ∞ \alpha = -\infty α=) will select all the variables into consideration. Thus, applying vanilla CNN is equivalent to applying the model LASSO (∞) - CNN.

    Using the same subsampled dataset with 3:1 ratio of class samples, vanilla CNN yields 79.42% test accuracy, which is approximately 2% less than the average test accuracy that we obtain on an average by applying LASSO-CNN.

    Although, it seems a marginal improvement, the number of samples in excess that are accurately labeled by our model is 635 (≈31779*0.02).

    这部分的内容是说全部变量进 CNN,得到的 test accuracy 为 79.42%,虽然 LASSO-CNN 只比它高了 2%,但是仍然多正确标记了 635 个样本。

    3.3 Data Augementation

    这部分主要是看不同的过采样和欠采样的方法得到的数据,放到 LASSO-CNN 中,模型的性能会如何变化。

    3.3.1 Data augmentation

    Data augmentation demands attention in the context of data imbalance.

    random oversampling (ROS)synthetic minority over-sampling technique (SMOTE) (Chawla, 2002)adaptive synthetic sampling (ADASYN)

    3.3.2 Data undersampling

    edited nearest neighbor (EDN) (Wilson, 1972),instance hardness threshold (IHT) (Smith et al., 2014)three versions of near- miss (NM-v1, v2 and v3) (Mani, 2003).

    4.3.3 Data oversampling strategies

    Fig. 6 Results using three oversampling techniques for the data augmentation of the minority class. For each technique, the results provide training accuracy, test accuracy, training loss, CHD accuracy (class-specific) and no-CHD accuracy (class-specific) over a number of epochs, added with t-SNE low-dimensional embedding for data visualization in 3D. (a) t-SNE visualization of 90% of the original data used for training. 10% of the data is reserved for testing. (b) The results using random oversampling (ROS). Note that we did not provide the t-SNE visualization for ROS as, in ROS, data samples from the minority class are randomly picked and added to the data, thereby maintaining the same data with redundant samples. So, the visualization is same as the original data in (a). ©-(d) Results using SMOTE with visualization. (e)-(f) Results using ADASYN with visualization.

    4.3.4. Data undersampling strategies

    Fig. 7 t-SNE visualization of the five undersampling techniques for the data reduction of the majority class. (a), (e) and (f) Near-miss using k-nearest neighbor (version-1,2 and 3). (b) Random subsampling with 3:1 as no-CHD: CHD data samples (one instance). © Edited nearest neighbor (EDN). (d) Instance hardness threshold (IHT).

    Fig. 8 Results using the undersampling algorithms from fig. 6 (a)-(e). (a) CHD detection accuracies over epochs using the algorithms. (b) Detection accuracies of no-CHD test data. Among all the competitive undersampling strategies that we compare our results with, near-miss (version 3) works better in improving both class-specific accuracies.

    4.4. Validation on stroke data

    验证提出的模型在类似的类不平衡的医疗数据上的预测效果。采用临床数据的中风数据集。

    It might appear that sequentially arranged, multiple convolutional layers in our proposed model offer resistance to data imbalance only for the CHD data provided by NHANES. To check whether our model is resilient to other imbalanced datasets containing 1D measurement variables, we apply our network on a similar dataset on Stroke, which is also compiled by NHANES. The Stroke dataset contains 37,177 subjects and 36 mixed-type measurement variables. Out of 37177 subjects, there are 1269 subjects who reported that they had strokes. After applying LASSO, we found 34 variables which are important for further processing.

    4.5 Notes on the resilience to data imbalance

    表七表明,如果训练得当,MLP对多数类的预测准确率有所提高,但这不幸地影响了少数类的准确性。MLP-I的分类精度相差27.14%。与MLP-I相比,MLP-II具有更深的层,其准确率差异似乎有5.73%。然而,这是在仔细训练之后实现的,而且有很大的机会过度拟合,因为MLP-II的可训练参数的数量为45,442个,这大约是输入数据量的9倍。

    我们架构的参数总数为32066个,明显高于Conv-I、II和III,但略低于MLP-I和II,如此多的参数是由于密集层(64和512)的存在。

    C2-C4试图最小化分类精度之间的差异,而 Dense 层则试图提高总体测试精度。

    5. Conclusion, Limitations and Future Research

    5.1 Conclusion

    针对 class-imbalanced 的数据集提出浅层的卷积神经网络模型。效果是得到了比现有的机器学习模型更好的分类精度。模型特点:“simple in concept, modular in design, and offers moderate resilience to data imbalance.”应用LASSO,确定变量的重要性,采用 Majority Voting 进行特征选择。实验结果表明 shallow convolutional neural network 在类不平衡数据上表现较好。在37,079例冠心病与非冠心病存在高度失衡的样本中,是冠心病病例的预测准确率为77.3%,预测非冠心病病例的准确率为81.8%。

    5.2 Limitation

    A potential problem might arise while using LASSO due to the linear nature of the estimator.

    LASSO利用因素之间的偏相关(partial correlation)来满足输出响应的相关预测,而互相关(cross correlation)计算每对因素之间的边际相关性,这可能不总是单调和线性的。

    输入因素和输出标签之间的线性关系的假设可能有一些相应的限制,并且因素的数量可能比当前数据中存在的数量大得多。

    进一步改进的方法将是两步非线性的降维,其中,可以使用诸如确定独立筛选(SIS)、条件SIS或graphical LASSO之类的技术来近似因素之间的偏相关/协方差。合适的 threshold 可以利用 LASSO 来降维。

    5.3 Future Research

    未来可能方向:将NHANES记录的营养和饮食数据作为冠心病预测的额外预测变量。

    饮食因素在冠心病的发生中起着重要作用(Masironi,1970,Bupathiraju,2011),可以通过纳入额外的饮食变量来探索预测冠心病的准确性。

    可以用作迁移学习模型,最后两个 Dense 层可以针对新数据进行重新训练。

    未来一个重要的研究方向将是将CNN用于类似临床数据集的这些存在类不平衡的数据进行预测。

    Processed: 0.012, SQL: 10