scala – 如何从CrossValidatorModel中提取最佳参数
发布时间:2020-12-16 09:26:31 所属栏目:安全 来源:网络整理
导读:我想找到ParamGridBuilder的参数,它们是Spark 1.4.x中CrossValidator中最好的模型, 在Spark文档中的Pipeline Example中,他们通过在管道中使用ParamGridBuilder添加不同的参数(numFeatures,regParam).然后通过以下代码行创建最佳模型: val cvModel = crossva
我想找到ParamGridBuilder的参数,它们是Spark 1.4.x中CrossValidator中最好的模型,
在Spark文档中的Pipeline Example中,他们通过在管道中使用ParamGridBuilder添加不同的参数(numFeatures,regParam).然后通过以下代码行创建最佳模型: val cvModel = crossval.fit(training.toDF) 现在,我想知道ParamGridBuilder中产生最佳模型的参数(numFeatures,regParam)是什么. 我已经使用了以下命令但没有成功: cvModel.bestModel.extractParamMap().toString() cvModel.params.toList.mkString("(",",")") cvModel.estimatorParamMaps.toString() cvModel.explainParams() cvModel.getEstimatorParamMaps.mkString("(",")") cvModel.toString() 有帮助吗? 提前致谢, 解决方法
获取正确的ParamMap对象的一种方法是使用CrossValidatorModel.avgMetrics:Array [Double]来查找argmax ParamMap:
implicit class BestParamMapCrossValidatorModel(cvModel: CrossValidatorModel) { def bestEstimatorParamMap: ParamMap = { cvModel.getEstimatorParamMaps .zip(cvModel.avgMetrics) .maxBy(_._2) ._1 } } 当在管道示例中训练的CrossValidatorModel上运行时,您引用了: scala> println(cvModel.bestEstimatorParamMap) { hashingTF_2b0b8ccaeeec-numFeatures: 100,logreg_950a13184247-regParam: 0.1 } (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |