加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 综合聚焦 > 服务器 > 安全 > 正文

scala – Spark随机森林二进制分类器指标

发布时间:2020-12-16 18:08:40 所属栏目:安全 来源:网络整理
导读:在Spark Mllib(F score,AUROC,AUPRC等)中训练随机森林二元分类器模型时,我们如何获得模型指标? 问题是BinaryClassificationMetrics采用概率,而RandomForest分类器的预测方法返回离散值0或1. 见:https://spark.apache.org/docs/latest/mllib-evaluation-met
在Spark Mllib(F score,AUROC,AUPRC等)中训练随机森林二元分类器模型时,我们如何获得模型指标?

问题是BinaryClassificationMetrics采用概率,而RandomForest分类器的预测方法返回离散值0或1.

见:https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html#binary-classification

RandomForest.trainClassifier没有任何clearThreshold方法,这将使其返回概率而不是离散的0或1标签.

解决方法

我们需要使用基于新的ml DataFrames的API来获取概率,而不是基于RDD的mllib API.

更新

以下是Spark文档的更新示例,以使用BinaryClassificationEvaluator并显示指标:接收器操作特性下的区域(AUROC)和精确调用曲线下的区域(AUPRC).

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString,StringIndexer,VectorIndexer}

// Load and parse the data file,converting it to a DataFrame.
val data = sqlContext.read.format("libsvm").load("D:/Sources/spark/data/mllib/sample_libsvm_data.txt")

// Index labels,adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
val labelIndexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("indexedLabel")
  .fit(data)

// Automatically identify categorical features,and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
val featureIndexer = new VectorIndexer()
  .setInputCol("features")
  .setOutputCol("indexedFeatures")
  .setMaxCategories(4)
  .fit(data)

// Split the data into training and test sets (30% held out for testing)
val Array(trainingData,testData) = data.randomSplit(Array(0.7,0.3))

// Train a RandomForest model.
val rf = new RandomForestClassifier()
  .setLabelCol("indexedLabel")
  .setFeaturesCol("indexedFeatures")
  .setNumTrees(10)

// Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
  .setInputCol("prediction")
  .setOutputCol("predictedLabel")
  .setLabels(labelIndexer.labels)

// Chain indexers and forest in a Pipeline
val pipeline = new Pipeline()
  .setStages(Array(labelIndexer,featureIndexer,rf,labelConverter))

// Train model.  This also runs the indexers.
val model = pipeline.fit(trainingData)

// Make predictions.
val predictions = model.transform(testData)

// Select example rows to display.
predictions
  .select("indexedLabel","rawPrediction","prediction")
  .show()

val binaryClassificationEvaluator = new BinaryClassificationEvaluator()
  .setLabelCol("indexedLabel")
  .setRawPredictionCol("rawPrediction")

def printlnMetric(metricName: String): Unit = {
  println(metricName + " = " + binaryClassificationEvaluator.setMetricName(metricName).evaluate(predictions))
}

printlnMetric("areaUnderROC")
printlnMetric("areaUnderPR")

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读