sparkSQL 自定义UDAF函数总结

    技术2023-11-21  101

    学完了sparkSQL的自定义函数部分,做个总结。也希望能帮到大家,如有错误谢谢指正。 在学习Hive的时候已经了解到当内置函数无法满足业务处理需要时,此时就可以考虑使用用户自定义函数(UDF:user defined function)

    用户自定义函数类别分为以下三种:

    1).UDF:输入一行,返回一个结果(一对一) 上代码

    创建DataFrame scala> val df = spark.read.json("data/user.json") df: org.apache.spark.sql.DataFrame = [age: bigint, username: string] 注册UDF scala> spark.udf.register("addName",(x:String)=> "Name:"+x) res9: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType))) 创建临时表 scala> df.createOrReplaceTempView("people") 应用UDF scala> spark.sql("Select addName(name),age from people").show()

    2).UDTF:输入一行,返回多行(一对多),在SparkSQL中没有,因为Spark中使用flatMap即可实现这个功能

    3).UDAF(重点):输入多行,返回一行,这里的A是aggregate,聚合的意思,如果业务复杂,需要自己实现聚合函数。

    对于Spark的dataframe而言,提供了通用的聚合方法,比如count(),countDistinct(),avg(),max(),min()等等。然而这些函数是针对dataframe设计的,当然sparksql也有类型安全的版本,java和scala语言接口都有,这些就适用于强类型Datasets。spark提供的两种聚合函数接口:

    1, UserDefinedAggregateFunction 2,Aggregator

    一.UserDefinedAggregateFunction

    类UserDefinedAggregateFunction,在文件udaf.scala里面。是实现用户自定义聚合函数UDAF的基础类,先看源码:

    abstract class UserDefinedAggregateFunction extends Serializable { StructType代表的是该聚合函数输入参数的类型。例如,一个UDAF实现需要两个输入参数, 类型分别是DoubleType和LongType,那么该StructType格式如下: new StructType() .add("doubleInput",DoubleType) .add("LongType",LongType) 那么该udaf就只会识别,这种类型的输入的数据。 def inputSchema: StructType 该StructType代表aggregation buffer的类型参数。例如,一个udaf的buffer有 两个值,类型分别是DoubleType和LongType,那么其格式将会如下: new StructType() .add("doubleInput", DoubleType) .add("longInput", LongType) 也只会适用于类型格式如上的数据 def bufferSchema: StructType dataTypeda代表该UDAF的返回值类型 def dataType: DataType 如果该函数是确定性的,那么将会返回true,例如,给相同的输入,就会有相同 的输出 def deterministic: Boolean 初始化聚合buffer,例如,给聚合buffer以0值 在两个初始buffer调用聚合函数,其返回值应该是初始函数自身,例如 merge(initialBuffer,initialBuffer)应该等于initialBuffer。 def initialize(buffer: MutableAggregationBuffer): Unit 利用输入输入去更新给定的聚合buffer,每个输入行都会调用一次该函数 def update(buffer: MutableAggregationBuffer, input: Row): Unit 合并两个聚合buffer,并且将更新的buffer返回给buffer1 该函数在聚合并两个部分聚合数据集的时候调用 def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit 计算该udaf在给定聚合buffer上的最终结果 def evaluate(buffer: Row): Any 使用给定的Column作为输入参数,来为当前UDAF创建一个Column @scala.annotation.varargs def apply(exprs: Column*): Column = { val aggregateExpression = AggregateExpression( ScalaUDAF(exprs.map(_.expr), this), Complete, isDistinct = false) Column(aggregateExpression) } 使用给定Column去重后的值作为参数来生成一个Column @scala.annotation.varargs def distinct(exprs: Column*): Column = { val aggregateExpression = AggregateExpression( ScalaUDAF(exprs.map(_.expr), this), Complete, isDistinct = true) Column(aggregateExpression) } } /** * A `Row` representing a mutable aggregation buffer. * * This is not meant to be extended outside of Spark. * * @since 1.5.0 */ @InterfaceStability.Stable abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */ def update(i: Int, value: Any): Unit }

    看具体实现(非类型安全):

    //TODO 自定义聚合函数:dataframe //1)继承UserDefinedAggregateFunction class MyAvgFunt extends UserDefinedAggregateFunction{ //输入的数据结构 override def inputSchema: StructType = { // new StructType().add("age",LongType) StructType(Array(StructField("age",LongType))) } //计算时缓冲区中的数据结构 override def bufferSchema: StructType = { new StructType().add("sum",LongType).add("count",LongType) } //返回数据类型 override def dataType: DataType = DoubleType //函数是否稳定:多次传入同一数据,返回结果是否一样 override def deterministic: Boolean =true //缓冲区初始化 buffer(0)对应sum, buffer(1)对应count override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0)=0L buffer(1)=0L } //聚合函数在聚合数据时对数据更新 sum count override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0)=buffer.getLong(0)+input.getLong(0) buffer(1)=buffer.getLong(1)+1L } //多个聚合函数的合并 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0) buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1) } //计算 override def evaluate(buffer: Row): Any = { buffer.getLong(0).toDouble / buffer.getLong(1) } }

    在main函数中调用:

    //创建聚合函数对象 val uadf = new MyAvgFunt //注册聚合函数 spark.udf.register("avgAge",uadf) //创建DF val dataFrame: DataFrame = spark.read.json("in/user.json") //创建一张临时表user,才能使用sql dataFrame.createOrReplaceTempView("user") spark.sql("select avgAge(age) from user").show()

    二.Aggregator

    用户自定义聚合函数的基类,可以在Dataset中使用,取出一个组的数据,然后聚合。还是先看源码:

    /** * :: Experimental :: * A base class for user-defined aggregations, which can be used in `Dataset` operations to take * all of the elements of a group and reduce them to a single value. * * For example, the following aggregator extracts an `int` from a specific class and adds them up: * {{{ * case class Data(i: Int) * * val customSummer = new Aggregator[Data, Int, Int] { * def zero: Int = 0 * def reduce(b: Int, a: Data): Int = b + a.i * def merge(b1: Int, b2: Int): Int = b1 + b2 * def finish(r: Int): Int = r * }.toColumn() * * val ds: Dataset[Data] = ... * val aggregated = ds.select(customSummer) * }}} * * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird * * @tparam IN The input type for the aggregation. * @tparam BUF The type of the intermediate value of the reduction. * @tparam OUT The type of the final output result. * @since 1.6.0 */ @Experimental @InterfaceStability.Evolving abstract class Aggregator[-IN, BUF, OUT] extends Serializable { 该聚合函数的0值。需要满足对于任何输入b,那么b+zero=b def zero: BUF 聚合两个值产生一个新的值,为了提升性能,该函数会修改b,然后直接返回b,而 不适新生成一个b的对象。 def reduce(b: BUF, a: IN): BUF 合并两个中间值 def merge(b1: BUF, b2: BUF): BUF 转换reduce的输出 def finish(reduction: BUF): OUT 为中间值类型提供一个编码器 def bufferEncoder: Encoder[BUF] 为最终的输出结果提供一个编码器 def outputEncoder: Encoder[OUT] 将该聚合函数返回为一个TypedColumn,目的是为了能在Dataset中使用 def toColumn: TypedColumn[IN, OUT] = { implicit val bEncoder = bufferEncoder implicit val cEncoder = outputEncoder val expr = AggregateExpression( TypedAggregateExpression(this), Complete, isDistinct = false) new TypedColumn[IN, OUT](expr, encoderFor[OUT]) } }

    具体实现:

    //TODO 自定义聚合函数(强类型):dataset //继承Aggregator,增加泛型 //重写方法 case class Emp(name:String,age:Long) //根据user.json表中的数据格式来定义 case class BufferUDAF(var total:Long,var count:Long) // IN Buffer OUT class ageAvgUDAF extends Aggregator[Emp,BufferUDAF,Double]{ //初始化 override def zero: BufferUDAF = BufferUDAF(0L,0L) //聚合数据 override def reduce(buffer: BufferUDAF, emp: Emp): BufferUDAF = { buffer.total=buffer.total+emp.age buffer.count=buffer.count+1L buffer } //聚合多个缓冲区数据 override def merge(b1: BufferUDAF, b2: BufferUDAF): BufferUDAF = { b1.total=b1.total+b2.total b1.count=b1.count+b2.count b1 } //返回数据 override def finish(reduction: BufferUDAF): Double = reduction.total.toDouble/reduction.count override def bufferEncoder: Encoder[BufferUDAF] = Encoders.product override def outputEncoder: Encoder[Double] = Encoders.scalaDouble }

    在main函数中调用:

    //强类型的写法有很大区别 因为sql中没有类型,所以不能向df一样写sql val userdf: DataFrame = spark11.read.json("in/user.json") val userds: Dataset[Emp] = userdf.as[Emp] //创建聚合函数 val avgUDAF = new ageAvgUDAF //将聚合函数转换为查询列 val column: TypedColumn[Emp, Double] = avgUDAF.toColumn //只能用DSL语法,不能用sql userds.select(column).show()

    ok,关于sparkSQL的自定义函数就先总结到这里,以后学的更深了再补。

    参考: https://blog.csdn.net/rlnLo2pNEfx9c/article/details/80972447

    Processed: 0.010, SQL: 9