加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 综合聚焦 > 服务器 > 安全 > 正文

scala – 如何从交叉验证器获得训练有素的最佳模型

发布时间:2020-12-16 18:53:36 所属栏目:安全 来源:网络整理
导读:我构建了一个包含这样的DecisionTreeClassifier(dt)的管道 val pipeline = new Pipeline().setStages(Array(labelIndexer,featureIndexer,dt,labelConverter)) 然后我使用这个管道作为CrossValidator中的估算器,以获得具有这样的最佳超参数集的模型 val c_v
我构建了一个包含这样的DecisionTreeClassifier(dt)的管道

val pipeline = new Pipeline().setStages(Array(labelIndexer,featureIndexer,dt,labelConverter))

然后我使用这个管道作为CrossValidator中的估算器,以获得具有这样的最佳超参数集的模型

val c_v = new CrossValidator().setEstimator(pipeline).setEvaluator(new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")).setEstimatorParamMaps(paramGrid).setNumFolds(5)

最后,我可以使用这个交叉验证器在训练测试中训练模型

val model = c_v.fit(train)

但问题是,我想查看具有DecisionTreeClassificationModel参数.toDebugTree的经过最佳训练的决策树模型.但模型是CrossValidatorModel.是的,您可以使用model.bestModel,但它仍然是Model类型,您不能将.toDebugTree应用于它.而且我还假设bestModel仍然是一个pipline,包括labelIndexer,labelConverter.

那么有谁知道如何从crossvalidator拟合的模型中获取decisionTree模型,我可以通过toDebugString查看实际模型?或者有没有可以查看decisionTree模型的解决方法?

解决方法

好吧,在 cases like this one中,答案总是一样的 – 具体说明类型.

首先提取管道模型,因为您要训练的是管道:

import org.apache.spark.ml.PipelineModel

val bestModel: Option[PipelineModel] = model.bestModel match {
  case p: PipelineModel => Some(p)
  case _ => None
}

然后,您需要从基础阶段提取模型.在您的情况下,它是一个决策树分类模型:

import org.apache.spark.ml.classification.DecisionTreeClassificationModel

val treeModel: Option[DecisionTreeClassificationModel] = bestModel
  flatMap {
    _.stages.collect {
      case t: DecisionTreeClassificationModel => t
    }.headOption
  }

要打印树,例如:

treeModel.foreach(_.toDebugString)

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读