scala – Pipeline中Spark Dataframe中的OneHotEncoder
我一直试图用
adult dataset在Spark和Scala中运行一个例子.
使用Scala 2.11.8和Spark 1.6.1. 问题(目前)在于该数据集中的分类特征量,在Spark ML算法完成其工作之前,所有分类特征都需要编码为数字. 到目前为止我有这个: import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.OneHotEncoder import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf,SparkContext} object Adult { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("Adult example").setMaster("local[*]") val sparkContext = new SparkContext(conf) val sqlContext = new SQLContext(sparkContext) val data = sqlContext.read .format("com.databricks.spark.csv") .option("header","true") // Use first line of all files as header .option("inferSchema","true") // Automatically infer data types .load("src/main/resources/adult.data") val categoricals = data.dtypes filter (_._2 == "StringType") val encoders = categoricals map (cat => new OneHotEncoder().setInputCol(cat._1).setOutputCol(cat._1 + "_encoded")) val features = data.dtypes filterNot (_._1 == "label") map (tuple => if(tuple._2 == "StringType") tuple._1 + "_encoded" else tuple._1) val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(0.01) val pipeline = new Pipeline() .setStages(encoders ++ Array(lr)) val model = pipeline.fit(training) } } 但是,这不起作用.调用pipeline.fit仍然包含原始字符串功能,因此会抛出异常. 我选择遵循这个流程的原因是因为我在Python和Pandas中有广泛的背景,但我正在尝试学习Scala和Spark. 解决方法
如果你已经习惯了更高级别的框架,那么有一点可能会让人感到困惑.您必须先索引功能,然后才能使用编码器.正如在
the API docs中解释的那样:
import org.apache.spark.ml.Pipeline import org.apache.spark.ml.feature.{StringIndexer,OneHotEncoder} val df = Seq((1L,"foo"),(2L,"bar")).toDF("id","x") val categoricals = df.dtypes.filter (_._2 == "StringType") map (_._1) val indexers = categoricals.map ( c => new StringIndexer().setInputCol(c).setOutputCol(s"${c}_idx") ) val encoders = categoricals.map ( c => new OneHotEncoder().setInputCol(s"${c}_idx").setOutputCol(s"${c}_enc") ) val pipeline = new Pipeline().setStages(indexers ++ encoders) val transformed = pipeline.fit(df).transform(df) transformed.show // +---+---+-----+-------------+ // | id| x|x_idx| x_enc| // +---+---+-----+-------------+ // | 1|foo| 1.0| (1,[],[])| // | 2|bar| 0.0|(1,[0],[1.0])| // +---+---+-----+-------------+ 如您所见,不需要从管道中删除字符串列.实际上,OneHotEncoder将接受带有NominalAttribute,BinaryAttribute或缺少类型属性的数字列. (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |