加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 综合聚焦 > 服务器 > 安全 > 正文

scala – 如何在Spark SQL中定义和使用用户定义的聚合函数?

发布时间:2020-12-16 09:30:22 所属栏目:安全 来源:网络整理
导读:我知道如何在Spark SQL中编写UDF: def belowThreshold(power: Int): Boolean = { return power -40 }sqlContext.udf.register("belowThreshold",belowThreshold _) 我可以做类似的定义聚合函数吗?这怎么做? 对于上下文,我想运行以下SQL查询: val aggDF
我知道如何在Spark SQL中编写UDF:

def belowThreshold(power: Int): Boolean = {
        return power < -40
      }

sqlContext.udf.register("belowThreshold",belowThreshold _)

我可以做类似的定义聚合函数吗?这怎么做?

对于上下文,我想运行以下SQL查询:

val aggDF = sqlContext.sql("""SELECT span,belowThreshold(opticalReceivePower),timestamp
                                    FROM ifDF
                                    WHERE opticalReceivePower IS NOT null
                                    GROUP BY span,timestamp
                                    ORDER BY span""")

它应该返回类似的东西

行(span1,false,T0)

我想要聚合函数告诉我,在由span和timestamp定义的组中是否存在opticalReceivePower的任何值都低于阈值。我需要写UDAF与上面粘贴的UDF吗?

解决方法

支持的方法

Spark 2.0(可选1.6,但略有不同的API):

可以在类型数据集上使用聚合器:

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder,Encoders}

class BelowThreshold[I](f: I => Boolean)  extends Aggregator[I,Boolean,Boolean]
    with Serializable {
  def zero = false
  def reduce(acc: Boolean,x: I) = acc | f(x)
  def merge(acc1: Boolean,acc2: Boolean) = acc1 | acc2
  def finish(acc: Boolean) = acc

  def bufferEncoder: Encoder[Boolean] = Encoders.scalaBoolean
  def outputEncoder: Encoder[Boolean] = Encoders.scalaBoolean
}

val belowThreshold = new BelowThreshold[(String,Int)](_._2 < - 40).toColumn
df.as[(String,Int)].groupByKey(_._1).agg(belowThreshold)

Spark> = 1.5:

在Spark 1.5中,您可以创建像这样的UDAF,尽管它最有可能是一种过分的:

import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

object belowThreshold extends UserDefinedAggregateFunction {
    // Schema you get as an input
    def inputSchema = new StructType().add("power",IntegerType)
    // Schema of the row which is used for aggregation
    def bufferSchema = new StructType().add("ind",BooleanType)
    // Returned type
    def dataType = BooleanType
    // Self-explaining 
    def deterministic = true
    // zero value
    def initialize(buffer: MutableAggregationBuffer) = buffer.update(0,false)
    // Similar to seqOp in aggregate
    def update(buffer: MutableAggregationBuffer,input: Row) = {
        if (!input.isNullAt(0))
          buffer.update(0,buffer.getBoolean(0) | input.getInt(0) < -40)
    }
    // Similar to combOp in aggregate
    def merge(buffer1: MutableAggregationBuffer,buffer2: Row) = {
      buffer1.update(0,buffer1.getBoolean(0) | buffer2.getBoolean(0))    
    }
    // Called on exit to get return value
    def evaluate(buffer: Row) = buffer.getBoolean(0)
}

使用示例

df
  .groupBy($"group")
  .agg(belowThreshold($"power").alias("belowThreshold"))
  .show

// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// |    a|         false|
// |    b|          true|
// +-----+--------------+

Spark 1.4解决方法:

我不知道我是否正确地了解你的要求,但是据我所知,在这里我们应该说明一般的旧的聚合是足够的:

val df = sc.parallelize(Seq(
    ("a",0),("a",1),("b",30),-50))).toDF("group","power")

df
  .withColumn("belowThreshold",($"power".lt(-40)).cast(IntegerType))
  .groupBy($"group")
  .agg(sum($"belowThreshold").notEqual(0).alias("belowThreshold"))
  .show

// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// |    a|         false|
// |    b|          true|
// +-----+--------------+

Spark <= 1.4: 到目前为止,我知道,在这个时刻(Spark 1.4.1),除了Hive之外,不支持UDAF。 Spark 1.5应该是可能的(见SPARK-3947)。

不支持/内部方法

内部Spark使用了许多类,包括ImperativeAggregatesDeclarativeAggregates

有意内部使用,可能会更改,恕不另行通知,所以它可能不是您想要使用在您的生产代码,但只是为了完整性BelowThreshold与DeclarativeAggregate可以像这样实现(测试与Spark 2.2-SNAPSHOT):

import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

case class BelowThreshold(child: Expression,threshold: Expression) 
    extends  DeclarativeAggregate  {
  override def children: Seq[Expression] = Seq(child,threshold)

  override def nullable: Boolean = false
  override def dataType: DataType = BooleanType

  private lazy val belowThreshold = AttributeReference(
    "belowThreshold",BooleanType,nullable = false
  )()

  // Used to derive schema
  override lazy val aggBufferAttributes = belowThreshold :: Nil

  override lazy val initialValues = Seq(
    Literal(false)
  )

  override lazy val updateExpressions = Seq(Or(
    belowThreshold,If(IsNull(child),Literal(false),LessThan(child,threshold))
  ))

  override lazy val mergeExpressions = Seq(
    Or(belowThreshold.left,belowThreshold.right)
  )

  override lazy val evaluateExpression = belowThreshold
  override def defaultResult: Option[Literal] = Option(Literal(false))
}

它应该进一步包装相当于withAggregateFunction

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读