scala – 在Apache Spark中为RandomForestClassifier提供了无效
发布时间:2020-12-16 18:39:44 所属栏目:安全 来源:网络整理
导读:我试图使用SCALA中的随机森林分类器模型使用5倍交叉验证来找到准确度.但是我在运行时遇到以下错误: java.lang.IllegalArgumentException: RandomForestClassifier was given input with invalid label column label,without the number of classes specifie
我试图使用SCALA中的随机森林分类器模型使用5倍交叉验证来找到准确度.但是我在运行时遇到以下错误:
在线上获得上述错误—> val cvModel = cv.fit(trainingData) 我用于使用随机森林进行数据集交叉验证的代码如下: import org.apache.spark.ml.Pipeline import org.apache.spark.ml.tuning.{ParamGridBuilder,CrossValidator} import org.apache.spark.ml.classification.RandomForestClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint val data = sc.textFile("exprogram/dataset.txt") val parsedData = data.map { line => val parts = line.split(',') LabeledPoint(parts(41).toDouble,Vectors.dense(parts(0).split(',').map(_.toDouble))) } val splits = parsedData.randomSplit(Array(0.6,0.4),seed = 11L) val training = splits(0) val test = splits(1) val trainingData = training.toDF() val testData = test.toDF() val nFolds: Int = 5 val NumTrees: Int = 5 val rf = new RandomForestClassifier() .setLabelCol("label") .setFeaturesCol("features") .setNumTrees(NumTrees) val pipeline = new Pipeline() .setStages(Array(rf)) val paramGrid = new ParamGridBuilder() .build() val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("label") .setPredictionCol("prediction") .setMetricName("precision") val cv = new CrossValidator() .setEstimator(pipeline) .setEvaluator(evaluator) .setEstimatorParamMaps(paramGrid) .setNumFolds(nFolds) val cvModel = cv.fit(trainingData) val results = cvModel.transform(testData) .select("label","prediction").collect val numCorrectPredictions = results.map(row => if (row.getDouble(0) == row.getDouble(1)) 1 else 0).foldLeft(0)(_ + _) val accuracy = 1.0D * numCorrectPredictions / results.size println("Test set accuracy: %.3f".format(accuracy)) 任何人都可以解释上面代码中的错误. 解决方法
与许多其他ML算法相同,RandomForestClassifier需要在标签列上设置特定元数据,并将值标记为来自[0,1,2 ……,#class]的整数值,表示为双精度.通常,这由StringIndexer之类的上游变换器处理.由于您手动转换标签元数据字段未设置且分类器无法确认是否满足这些要求.
val df = Seq( (0.0,Vectors.dense(1,0)),(1.0,Vectors.dense(0,(2.0,1)) ).toDF("label","features") val rf = new RandomForestClassifier() .setFeaturesCol("features") .setNumTrees(5) rf.setLabelCol("label").fit(df) // java.lang.IllegalArgumentException: RandomForestClassifier was given input ... 您可以使用StringIndexer重新编码标签列: import org.apache.spark.ml.feature.StringIndexer val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("label_idx") .fit(df) rf.setLabelCol("label_idx").fit(indexer.transform(df)) 或set required metadata manually: val meta = NominalAttribute .defaultAttr .withName("label") .withValues("0.0","1.0","2.0") .toMetadata rf.setLabelCol("label_meta").fit( df.withColumn("label_meta",$"label".as("",meta)) ) 注意: 使用StringIndexer创建的标签取决于频率而不是值: indexer.labels // Array[String] = Array(2.0,0.0,1.0) PySpark: 在Python中,元数据字段可以直接在模式上设置: from pyspark.sql.types import StructField,DoubleType StructField( "label",DoubleType(),False,{"ml_attr": { "name": "label","type": "nominal","vals": ["0.0","2.0"] }} ) (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |