加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 综合聚焦 > 服务器 > 安全 > 正文

scala – 如何从CrossValidatorModel中提取最佳参数

发布时间:2020-12-16 09:26:31 所属栏目:安全 来源:网络整理
导读:我想找到ParamGridBuilder的参数,它们是Spark 1.4.x中CrossValidator中最好的模型, 在Spark文档中的Pipeline Example中,他们通过在管道中使用ParamGridBuilder添加不同的参数(numFeatures,regParam).然后通过以下代码行创建最佳模型: val cvModel = crossva
我想找到ParamGridBuilder的参数,它们是Spark 1.4.x中CrossValidator中最好的模型,

在Spark文档中的Pipeline Example中,他们通过在管道中使用ParamGridBuilder添加不同的参数(numFeatures,regParam).然后通过以下代码行创建最佳模型:

val cvModel = crossval.fit(training.toDF)

现在,我想知道ParamGridBuilder中产生最佳模型的参数(numFeatures,regParam)是什么.

我已经使用了以下命令但没有成功:

cvModel.bestModel.extractParamMap().toString()
cvModel.params.toList.mkString("(",",")")
cvModel.estimatorParamMaps.toString()
cvModel.explainParams()
cvModel.getEstimatorParamMaps.mkString("(",")")
cvModel.toString()

有帮助吗?

提前致谢,

解决方法

获取正确的ParamMap对象的一种方法是使用CrossValidatorModel.avgMetrics:Array [Double]来查找argmax ParamMap:

implicit class BestParamMapCrossValidatorModel(cvModel: CrossValidatorModel) {
  def bestEstimatorParamMap: ParamMap = {
    cvModel.getEstimatorParamMaps
           .zip(cvModel.avgMetrics)
           .maxBy(_._2)
           ._1
  }
}

当在管道示例中训练的CrossValidatorModel上运行时,您引用了:

scala> println(cvModel.bestEstimatorParamMap)
{
   hashingTF_2b0b8ccaeeec-numFeatures: 100,logreg_950a13184247-regParam: 0.1
}

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读