Python实现决策树的预剪枝与后剪枝

网站建设5年前发布
19 0 0

决策树是一种用于分类和回归任务的非参数监督学习算法。它是一种分层树形结构,由根节点、分支、内部节点和叶节点组成。,202303060051587133fdd32e2fef4bfc11389a53a13433c12fbb879,从上图中可以看出,决策树从根节点开始,根节点没有任何传入分支。然后,根节点的传出分支为内部节点(也称为决策节点)提供信息。两种节点都基于可用功能执行评估以形成同类子集,这些子集由叶节点或终端节点表示。叶节点表示数据集内所有可能的结果。,Hunt 算法于 20 世纪 60 年代提出,起初用于模拟心理学中的人类学习,为许多常用的决策树算法奠定了基础,例如:,ID3:该算法的开发归功于 Ross Quinlan,全称为"迭代二叉树 3 代" ("Iterative Dichotomiser 3")。该算法利用信息熵与信息增益作为评估候选拆分的指标。,C4.5:该算法是 ID3 的后期扩展,同样由 Quinlan 开发。它可以使用信息增益或增益率来评估决策树中的切分点。,CART:术语 "CART" 的全称是"分类和回归",提出者是 Leo Breiman。该算法通常利用"基尼不纯度"来确定要拆分的理想属性。基尼不纯度衡量随机选择的属性被错误分类的频率。使用该评估方法时,基尼不纯度越小越理想。,详细的构建过程可以参考:决策树的构建原理,20230306005159d4f9ec178a44a2c0bba339e350c72c5960ca28440,泰坦尼克号数据集,20230306005159e83285125b5f5a93c055274ef27ce0a9a74c74767,数据处理后的数据集,2023030600534785a41b505d3f3fea3305466e6b7e4b89370a48631,幸存者统计,20230306005347e6232f2733933a64392506d5a061a1c3dae790576,观察上图所示的模型结构可以发现,该模型是非常复杂的决策树模型,而且决策树的层数远远超过了10层,从而使用该决策树获得的规则会非常的复杂。通过模型的可视化进一步证明了获得的决策树模型具有严重的过拟合问题,需要对模型进行剪枝,精简模型。,模型在训练集上有74个错误样本,而在测试集上存在50个错误样本。,202303060052021563b0c32c1cacf96f61886e71671328866a2c611,训练数据集混淆矩阵,20230306005202b2bdd60605462416cf5055f5d6333d09a48bc0809,测试数据集混淆矩阵,观察图1所示的模型结构可以发现,该模型是非常复杂的决策树模型,而且决策树的层数远远超过了10层,从而使用该决策树获得的规则会非常的复杂。通过模型的可视化进一步证明了获得的决策树模型具有严重的过拟合问题,需要对模型进行剪枝,精简模型。,决策树学习采用"一一击破"的策略,执行贪心搜索 (greedy search) 来识别决策树内的最佳分割点。然后以自上而下的回归方式重复此拆分过程,直到所有或者大多数记录都标记为特定的类别标签。是否将所有数据点归为同类子集在很大程度上取决于决策树的复杂性。较小的决策树更容易获得无法分裂的叶节点,即单个类别中的数据点。然而,决策树的体量越来越大,就会越来越难保持这种纯度,并且通常会导致落在给定子树内的数据过少。这种情况被称为数据碎片,通常会引起数据过拟合。,因此通常选择小型决策树,这与奥卡姆剃刀原理的"简单有效原理"相符,即"如无必要,勿增实体"。换句话说,我们应该只在必要时增加决策树的复杂性,因为最简单的往往是最好的。为了降低复杂性并防止过拟合,通常采用剪枝算法;这一过程会删除不太重要的特征的分支。然后,通过交叉验证评估模型的拟合。另一种保持决策树准确性的方法是使用随机森林算法形成一个集合;这种分类法可以得到更加准确的预测结果,特别是在决策树分支彼此不相关的情况下。,决策树的剪枝有两种思路:,预剪枝就是在构造决策树的过程中,先对每个结点在划分前进行估计,如果当前结点的划分不能带来决策树模型泛化性能的提升,则不对当前结点进行划分并且将当前结点标记为叶结点。,所有决策树的构建方法,都是在无法进一步降低熵的情况下才会停止创建分支的过程,为了避免过拟合,可以设定一个阈值,熵减小的数量小于这个阈值,即使还可以继续降低熵,也停止继续创建分支。但是这种方法实际中的效果并不好。,决策树模型的剪枝操作主要会用到DecisionTreeClassifier()函数中的,这里使用参数网格搜索的方式,对该模型中的四个参数进行搜索,并通过该在验证集上的预测精度为准测,获取较合适的模型参数组合。,20230306005203d3ab46790acea2a6ece634e73cb6167790189b903,预剪枝后决策树,从剪枝后决策树模型中可以发现:该模型和未剪枝的模型相比已经大大的简化了。模型在训练集上有95个错误样本,但在测试集上只存在47个错误样本。,202303060052032623f6274813528ed8076108addc0b2ee842e4221,训练数据集混淆矩阵,2023030600520327965e4460eac08c9a47260b8a00456be2e126301,测试数据集混淆矩阵,决策树构造完成后进行剪枝。剪枝的过程是对拥有同样父节点的一组节点进行检查,判断如果将其合并,熵的增加量是否小于某一阈值。如果确实小,则这一组节点可以合并一个节点,其中包含了所有可能的结果。后剪枝是目前最普遍的做法。,后剪枝的剪枝过程是删除一些子树,然后用其叶子节点代替,这个叶子节点所标识的类别通过大多数原则(majority class criterion)确定。所谓大多数原则,是指剪枝过程中, 将一些子树删除而用叶节点代替,这个叶节点所标识的类别用这棵子树中大多数训练样本所属的类别来标识,所标识的类 称为majority class ,(majority class 在很多英文文献中也多次出现)。,后剪枝算法有很多种,这里简要总结如下:,Reduced-Error Pruning (REP,错误率降低剪枝),这个思路很直接,完全的决策树不是过度拟合么,我再搞一个测试数据集来纠正它。对于完全决策树中的每一个非叶子节点的子树,我们尝试着把它替换成一个叶子节点,该叶子节点的类别我们用子树所覆盖训练样本中存在最多的那个类来代替,这样就产生了一个简化决策树,然后比较这两个决策树在测试数据集中的表现,如果简化决策树在测试数据集中的错误比较少,那么该子树就可以替换成叶子节点。该算法以bottom-up的方式遍历所有的子树,直至没有任何子树可以替换使得测试数据集的表现得以改进时,算法就可以终止。,Pessimistic Error Pruning (PEP,悲观剪枝),PEP剪枝算法是在C4.5决策树算法中提出的, 把一颗子树(具有多个叶子节点)用一个叶子节点来替代(我研究了很多文章貌似就是用子树的根来代替)的话,比起REP剪枝法,它不需要一个单独的测试数据集。,PEP算法首先确定这个叶子的经验错误率(empirical)为(E+0.5)/N,0.5为一个调整系数。对于一颗拥有L个叶子的子树,则子树的错误数和实例数都是就应该是叶子的错误数和实例数求和的结果,则子树的错误率为e,然后用一个叶子节点替代子树,该新叶子节点的类别为原来子树节点的最优叶子节点所决定,J为这个替代的叶子节点的错判个数,但是也要加上0.5,即KJ+0.5。最终是否应该替换的标准为,被替换子树的错误数-标准差 > 新叶子错误数,出现标准差,是因为子树的错误个数是一个随机变量,经过验证可以近似看成是二项分布,就可以根据二项分布的标准差公式算出标准差,就可以确定是否应该剪掉这个树枝了。子树中有N的实例,就是进行N次试验,每次实验的错误的概率为e,符合 B(N,e) 的二项分布,根据公式,均值为Ne,方差为Ne(1-e),标准差为方差开平方。,Cost-Complexity Pruning(CCP,代价复杂度剪枝),在决策树中,这种剪枝技术是由代价复杂性参数ccp_alpha来参数化的。ccp_alpha值越大,剪枝的节点数就越多。简单地说,代价复杂性是一个阈值。只有当模型的整体不纯度改善了一个大于该阈值的值时,该模型才会将一个节点进一步拆分为其子节点,否则将停止。,当CCP值较低时,即使不纯度减少不多,该模型也会将一个节点分割成子节点。随着树的深度增加,这一点很明显,也就是说,当我们沿着决策树往下走时,我们会发现分割对模型整体不纯度的变化没有太大贡献。然而,更高的分割保证了类的正确分类,即准确度更高。,当CCP值较低时,会创建更多的节点。节点越高,树的深度也越高。,下面的代码(Scikit Learn)说明了如何对alpha进行调整,以获得更高精度分数的模型。,20230306005347342e6d546465cdc35916187519ade40c2fdd3b956,2023030600571452c2090112600209daa062c0820f35ddca56f3196,从结果可知如果alpha设置为0.04得到的测试集精度最好,我们将从新训练模型。,20230306005205014bb276495bc58a8cc0763f4a6d4c4039b5be306,后剪枝后决策树,可以看到,模型深度非常浅,也能达到很好的效果。模型在训练集上有140个错误样本,但在测试集上只存在54个错误样本。,20230306005350654adf0712742a93ee983366aa63d458d18f53360,训练数据集混淆矩阵,20230306005206c6aabf73936a97c9c5e519c5e7ca38950d2d69706,测试数据集混淆矩阵

© 版权声明

相关文章