Spark——自定义UDAF函数

    技术2025-05-02  13

    弱类型用户自定义UDAF函数

    继承UserDefinedAggregateFunction类

    import org.apache.spark.SparkConf import org.apache.spark.rdd.RDD import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} object SparkSQL05_UDAF { def main(args: Array[String]): Unit = { //SparkSQL //SparkConf //创建配置对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL01_Demo") val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate() //进行转换之前,需要引入隐式转换规则 //这里的spark不是包名含义,是SparkSession对象的名字 import spark.implicits._ //自定义聚合函数 //创建聚合函数对象 val udaf = new MyAgeAvgFunction //注册聚合函数 spark.udf.register("avgAge",udaf) //使用聚合函数 val frame: DataFrame = spark.read.json("in/user.json") frame.createOrReplaceTempView("user") spark.sql("select avgAge(age) from user").show() //释放资源 spark.stop() } } //TODO 声明用户自定义聚合函数 //1)继承UserDefinedAggregateFunction //2) 实现方法 class MyAgeAvgFunction extends UserDefinedAggregateFunction{ //输入的数据结构 override def inputSchema: StructType = { new StructType().add("age", LongType) } //计算时的数据结构 override def bufferSchema: StructType = { new StructType().add("sum",LongType).add("count",LongType) } //函数返回的数据类型 override def dataType: DataType = { DoubleType } //函数是否稳定 override def deterministic: Boolean = true //计算之前缓冲区的初始化 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0L buffer(1) = 0L } //根据查询结果更新数据 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0) = buffer.getLong(0) + input.getLong(0) buffer(1) = buffer.getLong(1) + 1 } //将我们多个节点的缓冲区合并 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { //sum buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) //count buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) } //计算 override def evaluate(buffer: Row): Any = { buffer.getLong(0).toDouble / buffer.getLong(1) } }

    强类型用户自定义UDAF函数

    继承Aggregator类

    import org.apache.spark.SparkConf import org.apache.spark.rdd.RDD import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType} import org.apache.spark.sql._ object SparkSQL06_UDAF_Class { def main(args: Array[String]): Unit = { //SparkSQL //SparkConf //创建配置对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL01_Demo") val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate() //进行转换之前,需要引入隐式转换规则 //这里的spark不是包名含义,是SparkSession对象的名字 import spark.implicits._ //创建聚合函数对象 val udaf = new MyAgeAvgClassFunction //将聚合函数转换为查询的列 val avgCol: TypedColumn[UserBean, Double] = udaf.toColumn.name("avgAge") val frame: DataFrame = spark.read.json("in/user.json") val userDS: Dataset[UserBean] = frame.as[UserBean] //应用函数 userDS.select(avgCol).show() //释放资源 spark.stop() } } case class UserBean(name : String , age : BigInt) case class AvgBuffer(var sum : BigInt, var count :Int) //TODO 声明用户自定义聚合函数(强类型) //1)继承Aggregator,设定泛型 //2) 实现方法 class MyAgeAvgClassFunction extends Aggregator[UserBean, AvgBuffer, Double]{ //初始化 override def zero: AvgBuffer = { AvgBuffer(0,0) } /** * 聚合数据 * @param b * @param a * @return */ override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = { b.sum = b.sum + a.age b.count = b.count + 1 b } //缓冲区合并操作 override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = { b1.sum = b1.sum + b2.sum b1.count = b1.count + b2.count b1 } //完成计算 override def finish(reduction: AvgBuffer): Double = { reduction.sum.toDouble / reduction.count } // 为中间值类型指定“编码器” override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product //为最终输出值类型指定“编码器”。 override def outputEncoder: Encoder[Double] = Encoders.scalaDouble }

    代码中使用的数据user.json {“name”:“123”,“age”:20} {“name”:“456”,“age”:30} {“name”:“789”,“age”:40} 代码运行结果

    +------+ |avgAge| +------+ | 30.0| +------+
    Processed: 0.020, SQL: 9