首页 » 排名链接 » Spark MLlib 决策树算法(节点决策树数据切分向量)

Spark MLlib 决策树算法(节点决策树数据切分向量)

乖囧猫 2024-10-24 01:31:55 0

扫一扫用手机浏览

文章目录 [+]

二、决策树理论

在当前节点使用哪个特征作为切分判定,取决于切分后节点数据集合中的类别纯度。
切分后的数据越纯,那么当前切分就越合理。
那么如何衡量类别的纯度呢?这里有3个指标。

1、熵(针对分类)

Spark MLlib 决策树算法(节点决策树数据切分向量) 排名链接
(图片来自网络侵删)
信息量: Ie=−log2pi信息熵:信息量的期望 H(x)=E(I(x))=∑i=1np(xi)I(xi)信息增益:分类前,熵大;分类后,熵小;信息增益表达熵的变化。
特征 A 对训练集 D 的信息增益 g(D,A)=H(D)−H(D|A) ,这是还是很好理解的啦, H(D) 就是集合 D 的熵, H(D|A) 就是特征 A 条件下集合 D 的熵

2、基尼(针对分类)

基尼Gini定义: ∑i=1Cfi(1−fi)fi 是某个分区内第 i 个标签的频率C 是该分区中的类别总数Gini系数计算的是类型被分错的可能性Gini系数越小,数据越纯

3、方差(针对回归)

方差的定义: F=1N∑i=1N(yi−μ)2yi 是某个实例标签N 是实例的总数μ 是所有实例的均值这个很好理解,方差越小,数据越纯三、决策树实战数据下载地址:github机器学习数据下载第一列是标签,后面是特征

可能是同样的数据集,好几个模型,训练的准确率总是93.75%

import org.apache.log4j.{Level, Logger}import org.apache.spark.SparkContext._import org.apache.spark.SparkContextimport org.apache.spark.SparkConfimport org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, NaiveBayes, NaiveBayesModel, SVMWithSGD}import org.apache.spark.mllib.evaluation.MulticlassMetricsimport org.apache.spark.mllib.util.{KMeansDataGenerator, LinearDataGenerator, LogisticRegressionDataGenerator, MLUtils}import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD}import org.apache.spark.mllib.tree.DecisionTree//向量import org.apache.spark.mllib.linalg.Vector//向量集import org.apache.spark.mllib.linalg.Vectors//稀疏向量import org.apache.spark.mllib.linalg.SparseVector//稠密向量import org.apache.spark.mllib.linalg.DenseVector//实例import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}//矩阵import org.apache.spark.mllib.linalg.{Matrix, Matrices}//索引矩阵import org.apache.spark.mllib.linalg.distributed.RowMatrix//RDDimport org.apache.spark.rdd.RDDobject WordCount { def main(args: Array[String]) { // 构建Spark 对象 Logger.getLogger("org.apache.spark").setLevel(Level.ERROR) val conf = new SparkConf().setAppName("HACK-AILX10").setMaster("local") val sc = new SparkContext(conf) // 读取样本数据 val datapath = "C:\\study\\spark\\sample_libsvm_data.txt" val data = MLUtils.loadLibSVMFile(sc,datapath) val splits = data.randomSplit(Array(0.6,0.4),seed=1L) val training = splits(0) val testing = splits(1) //新建决策树,并设置训练参数 val model = DecisionTree.trainClassifier(training,2,Map[Int,Int](),"gini",5,32) //对样本进行测试 val prediction_and_label = testing.map { p => (model.predict(p.features), p.label) } val print_predict = prediction_and_label.take(5) println("预测结果" + "\t\t\t\t\t\t" + "标签") for (i <- 0 to print_predict.length -1){ println(print_predict(i)._1 + "\t\t\t\t\t\t" + print_predict(i)._2) } //计算测试误差 val metrics = new MulticlassMetrics(prediction_and_label) val accuracy = metrics.accuracy println("准确率=" + accuracy) // 模型保存 val modelpath = "C:\\study\\spark\\DecisionTree" model.save(sc,modelpath) println("模型保存 ok") }}

本篇完~

标签:

相关文章