scala – SPARK DataFrame:选择每个组的第一行
发布时间:2020-12-16 09:47:08 所属栏目:安全 来源:网络整理
导读:我有一个DataFrame生成如下: df.groupBy($"Hour",$"Category") .agg(sum($"value").alias("TotalValue")) .sort($"Hour".asc,$"TotalValue".desc)) 结果如下: +----+--------+----------+|Hour|Category|TotalValue|+----+--------+----------+| 0| cat26|
我有一个DataFrame生成如下:
df.groupBy($"Hour",$"Category") .agg(sum($"value").alias("TotalValue")) .sort($"Hour".asc,$"TotalValue".desc)) 结果如下: +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 0| cat13| 22.1| | 0| cat95| 19.6| | 0| cat105| 1.3| | 1| cat67| 28.5| | 1| cat4| 26.8| | 1| cat13| 12.6| | 1| cat23| 5.3| | 2| cat56| 39.6| | 2| cat40| 29.7| | 2| cat187| 27.9| | 2| cat68| 9.8| | 3| cat8| 35.6| | ...| ....| ....| +----+--------+----------+ 如您所见,DataFrame按小时按升序排序,然后按TotalValue按降序排序。 我想选择每个组的顶行,即 >从小时的组== 0 select(0,cat26,30.9) 所以期望的输出将是: +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 1| cat67| 28.5| | 2| cat56| 39.6| | 3| cat8| 35.6| | ...| ...| ...| +----+--------+----------+ 也可以方便地选择每个组的前N行。 任何帮助是高度赞赏。 解决方法
窗口函数:
这样的东西应该做的诀窍: import org.apache.spark.sql.functions.{rowNumber,max,broadcast} import org.apache.spark.sql.expressions.Window val df = sc.parallelize(Seq( (0,"cat26",30.9),(0,"cat13",22.1),"cat95",19.6),"cat105",1.3),(1,"cat67",28.5),"cat4",26.8),12.6),"cat23",5.3),(2,"cat56",39.6),"cat40",29.7),"cat187",27.9),"cat68",9.8),(3,"cat8",35.6))).toDF("Hour","Category","TotalValue") val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc) val dfTop = df.withColumn("rn",rowNumber.over(w)).where($"rn" === 1).drop("rn") dfTop.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+ 在有严重数据偏移的情况下,该方法将是低效的。 平均SQL聚合后加入: 或者,您可以加入聚合数据框: val dfMax = df.groupBy($"hour").agg(max($"TotalValue")) val dfTopByJoin = df.join(broadcast(dfMax),($"hour" === $"max_hour") && ($"TotalValue" === $"max_value")) .drop("max_hour") .drop("max_value") dfTopByJoin.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+ 它将保持重复的值(如果每小时有多个类别,具有相同的总值)。您可以按如下方式删除它们: dfTopByJoin .groupBy($"hour") .agg( first("category").alias("category"),first("TotalValue").alias("TotalValue")) 使用结构排序: 整洁,虽然不是很好测试,技巧,不需要连接或窗口函数: val dfTop = df.select($"Hour",struct($"TotalValue",$"Category").alias("vs")) .groupBy($"hour") .agg(max("vs").alias("vs")) .select($"Hour",$"vs.Category",$"vs.TotalValue") dfTop.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+ 使用DataSet API(Spark 1.6,2.0): Spark 1.6: case class Record(Hour: Integer,Category: String,TotalValue: Double) df.as[Record] .groupBy($"hour") .reduce((x,y) => if (x.TotalValue > y.TotalValue) x else y) .show // +---+--------------+ // | _1| _2| // +---+--------------+ // |[0]|[0,cat26,30.9]| // |[1]|[1,cat67,28.5]| // |[2]|[2,cat56,39.6]| // |[3]| [3,cat8,35.6]| // +---+--------------+ Spark 2.0: df.as[Record] .groupByKey(_.Hour) .reduceGroups((x,y) => if (x.TotalValue > y.TotalValue) x else y) 最后两种方法可以利用地图边组合,并且不需要完全随机播放,因此大多数时间应该表现出比窗口函数和连接更好的性能。 (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |