加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 综合聚焦 > 服务器 > 安全 > 正文

scala – 如何访问由RandomForestClassifier(spark.ml-version)

发布时间:2020-12-16 09:28:32 所属栏目:安全 来源:网络整理
导读:如何访问Spark ML RandomForestClassifier生成的模型中的单个树?我使用的是Scala版本的RandomForestClassifier. 解决方法 实际上它有树属性: import org.apache.spark.ml.attribute.NominalAttributeimport org.apache.spark.ml.classification.{ RandomFo
如何访问Spark ML RandomForestClassifier生成的模型中的单个树?我使用的是Scala版本的RandomForestClassifier.

解决方法

实际上它有树属性:

import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.{
  RandomForestClassificationModel,RandomForestClassifier,DecisionTreeClassificationModel
}

val meta = NominalAttribute
  .defaultAttr
  .withName("label")
  .withValues("0.0","1.0")
  .toMetadata

val data = sqlContext.read.format("libsvm")
  .load("data/mllib/sample_libsvm_data.txt")
  .withColumn("label",$"label".as("label",meta))

val rf: RandomForestClassifier = new RandomForestClassifier()
  .setLabelCol("label")
  .setFeaturesCol("features")

val trees: Array[DecisionTreeClassificationModel] = rf.fit(data).trees.collect {
  case t: DecisionTreeClassificationModel => t
}

正如您所看到的,唯一的问题是使类型正确,以便我们可以实际使用这些:

trees.head.transform(data).show(3)
// +-----+--------------------+-------------+-----------+----------+
// |label|            features|rawPrediction|probability|prediction|
// +-----+--------------------+-------------+-----------+----------+
// |  0.0|(692,[127,128,129...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
// |  1.0|(692,[158,159,160...|   [0.0,59.0]|  [0.0,1.0]|       1.0|
// |  1.0|(692,[124,125,126...|   [0.0,1.0]|       1.0|
// +-----+--------------------+-------------+-----------+----------+
// only showing top 3 rows

注意:

如果您使用管道,您也可以提取单个树:

import org.apache.spark.ml.Pipeline

val model = new Pipeline().setStages(Array(rf)).fit(data)

// There is only one stage and know its type 
// but lets be thorough
val rfModelOption = model.stages.headOption match {
  case Some(m: RandomForestClassificationModel) => Some(m)
  case _ => None
}

val trees = rfModelOption.map {
  _.trees //  ... as before
}.getOrElse(Array())

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读