scala – 如何访问由RandomForestClassifier(spark.ml-version)
发布时间:2020-12-16 09:28:32 所属栏目:安全 来源:网络整理
导读:如何访问Spark ML RandomForestClassifier生成的模型中的单个树?我使用的是Scala版本的RandomForestClassifier. 解决方法 实际上它有树属性: import org.apache.spark.ml.attribute.NominalAttributeimport org.apache.spark.ml.classification.{ RandomFo
如何访问Spark ML
RandomForestClassifier生成的模型中的单个树?我使用的是Scala版本的RandomForestClassifier.
解决方法
实际上它有树属性:
import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.{ RandomForestClassificationModel,RandomForestClassifier,DecisionTreeClassificationModel } val meta = NominalAttribute .defaultAttr .withName("label") .withValues("0.0","1.0") .toMetadata val data = sqlContext.read.format("libsvm") .load("data/mllib/sample_libsvm_data.txt") .withColumn("label",$"label".as("label",meta)) val rf: RandomForestClassifier = new RandomForestClassifier() .setLabelCol("label") .setFeaturesCol("features") val trees: Array[DecisionTreeClassificationModel] = rf.fit(data).trees.collect { case t: DecisionTreeClassificationModel => t } 正如您所看到的,唯一的问题是使类型正确,以便我们可以实际使用这些: trees.head.transform(data).show(3) // +-----+--------------------+-------------+-----------+----------+ // |label| features|rawPrediction|probability|prediction| // +-----+--------------------+-------------+-----------+----------+ // | 0.0|(692,[127,128,129...| [33.0,0.0]| [1.0,0.0]| 0.0| // | 1.0|(692,[158,159,160...| [0.0,59.0]| [0.0,1.0]| 1.0| // | 1.0|(692,[124,125,126...| [0.0,1.0]| 1.0| // +-----+--------------------+-------------+-----------+----------+ // only showing top 3 rows 注意: 如果您使用管道,您也可以提取单个树: import org.apache.spark.ml.Pipeline val model = new Pipeline().setStages(Array(rf)).fit(data) // There is only one stage and know its type // but lets be thorough val rfModelOption = model.stages.headOption match { case Some(m: RandomForestClassificationModel) => Some(m) case _ => None } val trees = rfModelOption.map { _.trees // ... as before }.getOrElse(Array()) (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |