在当前节点使用哪个特征作为切分判定,取决于切分后节点数据集合中的类别纯度。切分后的数据越纯,那么当前切分就越合理。那么如何衡量类别的纯度呢?这里有3个指标。
1、熵(针对分类)

2、基尼(针对分类)
基尼Gini定义: ∑i=1Cfi(1−fi)fi 是某个分区内第 i 个标签的频率C 是该分区中的类别总数Gini系数计算的是类型被分错的可能性Gini系数越小,数据越纯3、方差(针对回归)
方差的定义: F=1N∑i=1N(yi−μ)2yi 是某个实例标签N 是实例的总数μ 是所有实例的均值这个很好理解,方差越小,数据越纯三、决策树实战数据下载地址:github机器学习数据下载第一列是标签,后面是特征可能是同样的数据集,好几个模型,训练的准确率总是93.75%
import org.apache.log4j.{Level, Logger}import org.apache.spark.SparkContext._import org.apache.spark.SparkContextimport org.apache.spark.SparkConfimport org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, NaiveBayes, NaiveBayesModel, SVMWithSGD}import org.apache.spark.mllib.evaluation.MulticlassMetricsimport org.apache.spark.mllib.util.{KMeansDataGenerator, LinearDataGenerator, LogisticRegressionDataGenerator, MLUtils}import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD}import org.apache.spark.mllib.tree.DecisionTree//向量import org.apache.spark.mllib.linalg.Vector//向量集import org.apache.spark.mllib.linalg.Vectors//稀疏向量import org.apache.spark.mllib.linalg.SparseVector//稠密向量import org.apache.spark.mllib.linalg.DenseVector//实例import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}//矩阵import org.apache.spark.mllib.linalg.{Matrix, Matrices}//索引矩阵import org.apache.spark.mllib.linalg.distributed.RowMatrix//RDDimport org.apache.spark.rdd.RDDobject WordCount { def main(args: Array[String]) { // 构建Spark 对象 Logger.getLogger("org.apache.spark").setLevel(Level.ERROR) val conf = new SparkConf().setAppName("HACK-AILX10").setMaster("local") val sc = new SparkContext(conf) // 读取样本数据 val datapath = "C:\\study\\spark\\sample_libsvm_data.txt" val data = MLUtils.loadLibSVMFile(sc,datapath) val splits = data.randomSplit(Array(0.6,0.4),seed=1L) val training = splits(0) val testing = splits(1) //新建决策树,并设置训练参数 val model = DecisionTree.trainClassifier(training,2,Map[Int,Int](),"gini",5,32) //对样本进行测试 val prediction_and_label = testing.map { p => (model.predict(p.features), p.label) } val print_predict = prediction_and_label.take(5) println("预测结果" + "\t\t\t\t\t\t" + "标签") for (i <- 0 to print_predict.length -1){ println(print_predict(i)._1 + "\t\t\t\t\t\t" + print_predict(i)._2) } //计算测试误差 val metrics = new MulticlassMetrics(prediction_and_label) val accuracy = metrics.accuracy println("准确率=" + accuracy) // 模型保存 val modelpath = "C:\\study\\spark\\DecisionTree" model.save(sc,modelpath) println("模型保存 ok") }}
本篇完~