加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 综合聚焦 > 服务器 > 安全 > 正文

scala – 在Apache Spark中为RandomForestClassifier提供了无效

发布时间:2020-12-16 18:39:44 所属栏目:安全 来源:网络整理
导读:我试图使用SCALA中的随机森林分类器模型使用5倍交叉验证来找到准确度.但是我在运行时遇到以下错误: java.lang.IllegalArgumentException: RandomForestClassifier was given input with invalid label column label,without the number of classes specifie
我试图使用SCALA中的随机森林分类器模型使用5倍交叉验证来找到准确度.但是我在运行时遇到以下错误:

java.lang.IllegalArgumentException: RandomForestClassifier was given input with invalid label column label,without the number of classes specified. See StringIndexer.

在线上获得上述错误—> val cvModel = cv.fit(trainingData)

我用于使用随机森林进行数据集交叉验证的代码如下:

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder,CrossValidator}
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint

val data = sc.textFile("exprogram/dataset.txt")
val parsedData = data.map { line =>
val parts = line.split(',')
LabeledPoint(parts(41).toDouble,Vectors.dense(parts(0).split(',').map(_.toDouble)))
}


val splits = parsedData.randomSplit(Array(0.6,0.4),seed = 11L)
val training = splits(0)
val test = splits(1)

val trainingData = training.toDF()

val testData = test.toDF()

val nFolds: Int = 5
val NumTrees: Int = 5

val rf = new     
RandomForestClassifier()
      .setLabelCol("label")
      .setFeaturesCol("features")
      .setNumTrees(NumTrees)

val pipeline = new Pipeline()
      .setStages(Array(rf)) 

val paramGrid = new ParamGridBuilder()
          .build()

val evaluator = new  MulticlassClassificationEvaluator()
    .setLabelCol("label")
    .setPredictionCol("prediction")
    .setMetricName("precision") 

val cv = new CrossValidator()
   .setEstimator(pipeline)
   .setEvaluator(evaluator) 
   .setEstimatorParamMaps(paramGrid)
   .setNumFolds(nFolds)

val cvModel = cv.fit(trainingData)

val results = cvModel.transform(testData)
.select("label","prediction").collect

val numCorrectPredictions = results.map(row => 
if (row.getDouble(0) == row.getDouble(1)) 1 else 0).foldLeft(0)(_ + _)
val accuracy = 1.0D * numCorrectPredictions / results.size

println("Test set accuracy: %.3f".format(accuracy))

任何人都可以解释上面代码中的错误.

解决方法

与许多其他ML算法相同,RandomForestClassifier需要在标签列上设置特定元数据,并将值标记为来自[0,1,2 ……,#class]的整数值,表示为双精度.通常,这由StringIndexer之类的上游变换器处理.由于您手动转换标签元数据字段未设置且分类器无法确认是否满足这些要求.

val df = Seq(
  (0.0,Vectors.dense(1,0)),(1.0,Vectors.dense(0,(2.0,1))
).toDF("label","features")

val rf = new RandomForestClassifier()
  .setFeaturesCol("features")
  .setNumTrees(5)

rf.setLabelCol("label").fit(df)
// java.lang.IllegalArgumentException: RandomForestClassifier was given input ...

您可以使用StringIndexer重新编码标签列:

import org.apache.spark.ml.feature.StringIndexer

val indexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("label_idx")
  .fit(df)

rf.setLabelCol("label_idx").fit(indexer.transform(df))

或set required metadata manually:

val meta = NominalAttribute
  .defaultAttr
  .withName("label")
  .withValues("0.0","1.0","2.0")
  .toMetadata

rf.setLabelCol("label_meta").fit(
  df.withColumn("label_meta",$"label".as("",meta))
)

注意:

使用StringIndexer创建的标签取决于频率而不是值:

indexer.labels
// Array[String] = Array(2.0,0.0,1.0)

PySpark:

在Python中,元数据字段可以直接在模式上设置:

from pyspark.sql.types import StructField,DoubleType

StructField(
    "label",DoubleType(),False,{"ml_attr": {
        "name": "label","type": "nominal","vals": ["0.0","2.0"]
    }}
)

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读