SparkSql笔记

    技术2022-07-14  78

    文章目录

    3种结构的关系自定义函数UDAF-弱类型UDAF-强类型SparkSQL通用的读取SparkSQL通用的保存CSVMySQL读数据写数据 Hive本地hive操作hive外连接案例:造表 导入数据需求:各区域热门商品 Top3

    3种结构的关系

    package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql01_Test { def main(args: Array[String]): Unit = { //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //builder构建,创建 val spark = SparkSession.builder().config(sparkConf).getOrCreate() //导入隐式转换,这里的spark其实是环境对象的名称 //要求这个对象使用val声明 import spark.implicits._ //最好用不用上都加上 //逻辑操作 val jsonDF: DataFrame = spark.read.json("input/user.json") //SQL //将df转换为临时视图 jsonDF.createOrReplaceTempView("user") spark.sql("select * from user").show() //DSL //如果查询列名采用单引号,那么需要隐式转换 jsonDF.select("name", "age").show jsonDF.select($"name", $"age").show jsonDF.select('name, 'age).show val rdd = spark.sparkContext.makeRDD(List( (1, "zhangsan", 30), (2, "lisi", 20), (3, "wangwu", 40), )) //RDD<=>DataFrame val df: DataFrame = rdd.toDF("id", "name", "age") val dfToRDD1: RDD[Row] = df.rdd dfToRDD1.foreach( row=>{ println(row(0)) }) //RDD<=>DataSet val userRDD: RDD[User] = rdd.map { case (id, name, age) => { User(id, name, age) } } val userDS: Dataset[User] = userRDD.toDS() val dsToRDD: RDD[User] = userDS.rdd //DataFram <=>DataSet val dsToDS: Dataset[User] = df.as[User] //type DataFrame = Dataset[Row] Dataset就是特殊类型的DataFrame val dsToDF: DataFrame = dsToDS.toDF() rdd.foreach(println) df.show() userDS.show() //释放对象 spark.stop() } case class User(id: Int, name: String, age: Int) }

    自定义函数

    package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql02_Test { def main(args: Array[String]): Unit = { //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //builder构建,创建 val spark = SparkSession.builder().config(sparkConf).getOrCreate() //导入隐式转换,这里的spark其实是环境对象的名称 //要求这个对象使用val声明 import spark.implicits._ //最好用不用上都加上 //逻辑操作 val rdd = spark.sparkContext.makeRDD(List( (1, "zhangsan", 30), (2, "lisi", 20), (3, "wangwu", 40), )) //RDD<=>DataSet // val userRDD: RDD[User] = rdd.map { // case (id, name, age) => { // User(id, name, age) // } // } //val userDS: Dataset[User] = userRDD.toDS() //sparkSql封装的对象提供了大量的方法进行处理,类似于RDD的算子操作 //userDS.join() //error //val df: DataFrame = rdd.toDF("id", "name", "age") // val ds: Dataset[Row] = df.map(row => { // val id: Any = row(0) // val name: Any = row(1) // val age: Any = row(3) // Row(id, "name" + name, age) // }) val userRDD: RDD[User] = rdd.map { case (id, name, age) => { User(id, name, age) } } val userDS: Dataset[User] = userRDD.toDS() val newDS: Dataset[User] = userDS.map(user => { User(user.id, "name:" + user.name, user.age) }) newDS.show() //使用自定义函数在SQL中完成数据的转换操作 val df = rdd.toDF("id", "name", "age") df.createOrReplaceTempView("user") spark.udf.register("addName", (x: String) => "Name:" + x) spark.udf.register("changeAge", (x: Int) => 18) spark.sql("select addName(name),changeAge(age) from user").show spark.stop() } case class User(id: Int, name: String, age: Int) }

    UDAF-弱类型

    (用户定义聚合函数)

    package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType} import org.apache.spark.sql.{Row, SparkSession} /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql03_UDAF { def main(args: Array[String]): Unit = { //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //builder构建,创建 val spark = SparkSession.builder().config(sparkConf).getOrCreate() import spark.implicits._ //最好用不用上都加上 val rdd = spark.sparkContext.makeRDD(List( (1, "zhangsan", 30L), (2, "lisi", 20L), (3, "wangwu", 40L), )) val df = rdd.toDF("id", "name", "age") df.createOrReplaceTempView("user") //创建UDAF函数 val udaf = new MyAvgAgeUDAF //注册到SparkSQL中 spark.udf.register("avgAge",udaf) //在SQL中使用聚合函数 //定义用户的自定义函数 spark.sql("select avgAge(age) from user").show spark.stop() } //自定义聚合函数 //1.继承UserDefinedAggregateFunction //2.重写方法 //totalage,count class MyAvgAgeUDAF extends UserDefinedAggregateFunction { //输入数据的结构信息:年龄信息 override def inputSchema: StructType = { StructType(Array(StructField("age", LongType))) } //缓冲区的数据结构信息:年龄的总和,人的数量 override def bufferSchema: StructType = { StructType(Array( StructField("totalage", LongType), StructField("count", LongType) )) } //聚合函数返回的结果类型 override def dataType: DataType = LongType //函数稳定性 override def deterministic: Boolean = true //函数缓冲区初始化 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0L buffer(1) = 0L } //更新缓冲区数据, override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0) = buffer.getLong(0) + input.getLong(0) buffer(1) = buffer.getLong(1) + 1 } //合并缓冲区 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) } //函数的计算 override def evaluate(buffer: Row): Any = { buffer.getLong(0) / buffer.getLong(1) } } }

    UDAF-强类型

    自定义聚合函数 - 强类型

    package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession} /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql04_UDAF_Class { def main(args: Array[String]): Unit = { //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //builder构建,创建 val spark = SparkSession.builder().config(sparkConf).getOrCreate() import spark.implicits._ //最好用不用上都加上 val rdd = spark.sparkContext.makeRDD(List( (1, "zhangsan", 30L), (2, "lisi", 20L), (3, "wangwu", 40L), )) val df = rdd.toDF("id", "name", "age") val ds: Dataset[User] = df.as[User] //创建UDAF函数 val udaf = new MyAvgAgeUDAFClass //在SQL中使用聚合函数 //因为聚合函数是强类型,那么sql中没有类型的概念,所以无法使用 //可以采用DSL语法进行访问 //将聚合函数转换为查询的列让DataSet访问 ds.select(udaf.toColumn).show spark.stop() } case class User(id: Int, name: String, age: Long) case class AvgBuffer(var totalage: Long, var count: Long) //自定义聚合函数 - 强类型 //1.继承Aggregator,敌营泛型 //IN:输入数据的类型User // BUF:缓冲区的数据类型 AvgBuffer //OUT:输出的数据类型Long //2.重写方法 class MyAvgAgeUDAFClass extends Aggregator[User, AvgBuffer, Long] { //缓冲区的初始值 override def zero: AvgBuffer = { AvgBuffer(0L, 0L) } //聚合数据 override def reduce(buffer: AvgBuffer, user: User): AvgBuffer = { buffer.totalage = buffer.totalage + user.age buffer.count = buffer.count + 1 buffer } //合并缓冲区 override def merge(buffer1: AvgBuffer, buffer2: AvgBuffer): AvgBuffer = { buffer1.totalage = buffer1.totalage + buffer2.totalage buffer1.count = buffer1.count + buffer2.count buffer1 } //计算函数结果 override def finish(reduction: AvgBuffer): Long = { reduction.totalage / reduction.count } //编解码器,用于序列化 固定写法 override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product override def outputEncoder: Encoder[Long] = Encoders.scalaLong } }

    SparkSQL通用的读取

    user.json

    格式不符合 json要求,符合spark要求,不能有“,”

    {"name": "zhangsan","age": "20"} {"name": "lisi","age": "30"} {"name": "wangwu","age": "40"} package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession} /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql05_LoadSave { def main(args: Array[String]): Unit = { //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //builder构建,创建 val spark = SparkSession.builder().config(sparkConf).getOrCreate() //sparkSQL通用的读取和保存 //通用的读取 //RuntimeException: file:xxxx/input/user.json is not a Parquet file. //SparkSQL通用读取的数据格式为Parquet列式存储格式 //val frame: DataFrame = spark.read.load("input/user.json") //如果想要改变读取文件的格式,需要使用特殊的操作 //如果读取的文件格式为JSON格式,Spark对JSON文件的格式有要求 //JSON => JavaScrip Object Notation //JSON文件的格式要求整个文件满足JSON的语法规则 //Spark读取文件默认是以行为单位来读取 //Spark读取JSON文件时,要求文件中的每一行符合JSON的格式要求 //如果文件格式不正确,那么不会发生错误,但是解析结果不正确 val frame: DataFrame = spark.read.format("json").load("input/user.json") //通用的 //spark.read.json() //特殊的 frame.show() spark.stop() } }

    另一种读取方式更简单

    package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql07_LoadSave { def main(args: Array[String]): Unit = { //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //builder构建,创建 val spark = SparkSession.builder().config(sparkConf).getOrCreate() spark.sql("select * from json.`input/user.json`").show() spark.stop() } }

    SparkSQL通用的保存

    package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql06_LoadSave { def main(args: Array[String]): Unit = { //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //builder构建,创建 val spark = SparkSession.builder().config(sparkConf).getOrCreate() //sparkSQL通用的读取和保存 //通用的保存 val df = spark.read.format("json").load("input/user.json") //sparksql默认通用保存的文件格式为parquet //如果想要保存的格式是指定的格式,比如json,那么需要进行对应的格式化操作 //如果路径已经存在,那么执行保存操作会发生错误 df.write.format("json").save("output1") //如果非得想要路径已经存在的情况下,保存数据,那么可以使用保存模式 //df.write.mode("overwrite").format("json").save("output") df.write.mode("append").format("json").save("output") spark.stop() } }

    没有“,”,默认字典序

    CSV

    第一行写数据类型

    name;age zhangsan;30 wangwu;40 lisi;20 package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession} /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql08_Load_CSV { def main(args: Array[String]): Unit = { //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //builder构建,创建 val spark = SparkSession.builder().config(sparkConf).getOrCreate() val frame: DataFrame = spark.read.format("csv") .option("sep", ";") .option("inferSchema", "true") .option("header", "true") .load("input/user.csv") frame.show() spark.stop() } }

    MySQL

    添加依赖

    <dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</artifactId> <version>5.1.27</version> </dependency>

    读数据

    通用方法

    package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql09_Load_MySQL { def main(args: Array[String]): Unit = { //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //builder构建,创建 val spark = SparkSession.builder().config(sparkConf).getOrCreate() spark.read.format("jdbc") .option("url", "jdbc:mysql://hadoop130:3306/spark-sql") .option("driver", "com.mysql.jdbc.Driver") .option("user", "root") .option("password", "123456") .option("dbtable", "user") .load().show spark.stop() } }

    写数据

    通用方法

    package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql10_Save_MySQL { def main(args: Array[String]): Unit = { //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //builder构建,创建 val spark = SparkSession.builder().config(sparkConf).getOrCreate() val frame: DataFrame = spark.read.format("jdbc") .option("url", "jdbc:mysql://hadoop130:3306/spark-sql") .option("driver", "com.mysql.jdbc.Driver") .option("user", "root") .option("password", "123456") .option("dbtable", "user") .load() frame.write.format("jdbc") .option("url", "jdbc:mysql://hadoop130:3306/spark-sql") .option("driver", "com.mysql.jdbc.Driver") .option("user", "root") .option("password", "123456") .option("dbtable", "user1") .mode(SaveMode.Append) //可以选定模式在原表中追加 .save() spark.stop() } }

    Hive

    添加依赖

    <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-hive_2.12</artifactId> <version>2.4.5</version> </dependency> <dependency> <groupId>org.apache.hive</groupId> <artifactId>hive-exec</artifactId> <version>3.1.2</version> </dependency>

    本地hive操作

    spark内嵌hive

    package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql11_Load_Hive { def main(args: Array[String]): Unit = { //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //builder构建,创建 //默认情况下SparkSQL支持本地Hive操作的,执行前需要启用Hive的支持 //调用enableHiveSupport方法 val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate() //可以使用基本的sql访问hive中的内容 spark.sql("create table aa(id int)") spark.sql("show tables").show() spark.sql("load data local inpath'input/id.txt' into table aa") spark.sql("select * from aa").show spark.stop() } }

    hive外连接

    添加resource文件

    hive-site.xml

    注意取消tez的配置 与spark冲突

    <?xml version="1.0"?> <?xml-stylesheet type="text/xsl" href="configuration.xsl"?> <configuration> <!-- jdbc连接的URL --> <property> <name>javax.jdo.option.ConnectionURL</name> <value>jdbc:mysql://hadoop130:3306/metastore?useSSL=false</value> </property> <!-- jdbc连接的Driver--> <property> <name>javax.jdo.option.ConnectionDriverName</name> <value>com.mysql.jdbc.Driver</value> </property> <!-- jdbc连接的username--> <property> <name>javax.jdo.option.ConnectionUserName</name> <value>root</value> </property> <property> <name>javax.jdo.option.ConnectionPassword</name> <value>123456</value> </property> <!-- Hive默认在HDFS的工作目录 --> <property> <name>hive.metastore.warehouse.dir</name> <value>/user/hive/warehouse</value> </property> <!-- Hive元数据存储版本的验证 --> <property> <name>hive.metastore.schema.verification</name> <value>false</value> </property> <!-- 指定存储元数据要连接的地址 --> <property> <name>hive.metastore.uris</name> <value>thrift://hadoop130:9083</value> </property> <!-- 指定hiveserver2连接的端口号 --> <property> <name>hive.server2.thrift.port</name> <value>10000</value> </property> <!-- 指定hiveserver2连接的host --> <property> <name>hive.server2.thrift.bind.host</name> <value>hadoop130</value> </property> <!-- 元数据存储授权 --> <property> <name>hive.metastore.event.db.notification.api.auth</name> <value>false</value> </property> <property> <name>hive.cli.print.header</name> <value>true</value> <description>Whether to print the names of the columns in query output.</description> </property> <property> <name>hive.cli.print.current.db</name> <value>true</value> <description>Whether to include the current database in the Hive prompt.</description> </property> </configuration> package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql12_Load_Hive { def main(args: Array[String]): Unit = { //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //builder构建,创建 //访问外置的hive val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate() //可以使用基本的sql访问hive中的内容 spark.sql("show databases").show() spark.stop() } }

    案例:

    地区商品名称点击次数城市备注华北商品A100000北京21.2%,天津13.2%,其他65.6%华北商品P80200北京63.0%,太原10%,其他27.0%华北商品M40000北京63.0%,太原10%,其他27.0%东北商品J92000大连28%,辽宁17.0%,其他 55.0%
    造表 导入数据
    package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql13_Req_Mock { def main(args: Array[String]): Unit = { System.setProperty("HADOOP_USER_NAME", "vanas") //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //访问外置的Hive val spark = SparkSession.builder() .enableHiveSupport() .config(sparkConf).getOrCreate() spark.sql("use bigdata0213") spark.sql( """ |CREATE TABLE `user_visit_action`( | `date` string, | `user_id` bigint, | `session_id` string, | `page_id` bigint, | `action_time` string, | `search_keyword` string, | `click_category_id` bigint, | `click_product_id` bigint, | `order_category_ids` string, | `order_product_ids` string, | `pay_category_ids` string, | `pay_product_ids` string, | `city_id` bigint) |row format delimited fields terminated by '\t' |""".stripMargin) spark.sql( """ |load data local inpath 'input1/user_visit_action.txt' into table bigdata0213.user_visit_action |""".stripMargin) spark.sql( """ |CREATE TABLE `product_info`( | `product_id` bigint, | `product_name` string, | `extend_info` string) |row format delimited fields terminated by '\t' |""".stripMargin).show spark.sql( """ |load data local inpath 'input1/product_info.txt' into table bigdata0213.product_info |""".stripMargin) spark.sql( """ |CREATE TABLE `city_info`( | `city_id` bigint, | `city_name` string, | `area` string) |row format delimited fields terminated by '\t' |""".stripMargin) spark.sql( """ |load data local inpath 'input1/city_info.txt' into table bigdata0213.city_info |""".stripMargin) spark.sql( """ |select * from city_info |""".stripMargin).show(10) spark.stop() } }
    需求:各区域热门商品 Top3
    package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql14_Req { def main(args: Array[String]): Unit = { System.setProperty("HADOOP_USER_NAME", "vanas") //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //访问外置的Hive val spark = SparkSession.builder() .enableHiveSupport() .config(sparkConf).getOrCreate() spark.sql("use bigdata0213") // spark.sql( // """ // |select // | * // |from ( // | select // | *, // | rank() over( partition by area order by clickCount desc ) as rank // | from ( // | select // | area, // | product_name, // | count(*) as clickCount // | from ( // | select // | a.*, // | c.area, // | p.product_name // | from user_visit_action a // | join city_info c on c.city_id = a.city_id // | join product_info p on p.product_id = a.click_product_id // | where a.click_product_id > -1 // | ) t1 group by area, product_name // | ) t2 // |) t3 // |where rank <= 3 // """.stripMargin).show spark.sql( """ |select * |from( |select *, |rank() over(distribute by area order by sum_click desc) rank |from( |select area ,product_name,count(click_product_id) sum_click |from user_visit_action a |join city_info c on a.city_id = c.city_id |join product_info p on p.product_id = a.click_product_id |where click_product_id > -1 |group by area ,product_name |)t1 |)t2 |where rank <=3 |""".stripMargin).show() spark.stop() } }

    这里的热门商品是从点击量的维度来看的,计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示

    package com.vanas.bigdata.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ import org.apache.spark.sql.{Row, SparkSession} /** * @author Vanas * @create 2020-06-10 4:32 下午 */ object SparkSql15_Req { def main(args: Array[String]): Unit = { System.setProperty("HADOOP_USER_NAME", "vanas") //创建环境对象 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL") //访问外置的Hive val spark = SparkSession.builder() .enableHiveSupport() .config(sparkConf).getOrCreate() spark.sql("use bigdata0213") //创建自定义聚合函数 val udaf = new CityRemarkUDAF //注册聚合函数 spark.udf.register("cityReamark", udaf) //从hive表中/获取满足条件的数据 //将数据根据区域进行分组,统计商品点击的数量 spark.sql( """ |select area ,product_name,count(click_product_id) sum_click,cityReamark(city_name) |from user_visit_action a |join city_info c on a.city_id = c.city_id |join product_info p on p.product_id = a.click_product_id |where click_product_id > -1 |group by area ,product_name |""".stripMargin).createOrReplaceTempView("t1") //将统计结果数量进行排序(降序) spark.sql( """ |select *, |rank() over(distribute by area order by sum_click desc) rank |from t1 |""".stripMargin).createOrReplaceTempView("t2") //将组内排序后的结果取前三名 spark.sql( """ |select * |from t2 |where rank <=3 |""".stripMargin).show() spark.stop() } //北京,上海,北京,深圳 //in:cityname:String //out:remark:String //buffer :2结构,(total,map) //(商品点击总和,每个城市点击总和) //(商品点击总和,Map(城市,点击sum)) //城市点击sum/商品点击总和% //自定义城市备注聚合函数 class CityRemarkUDAF extends UserDefinedAggregateFunction { //输入的数据其实就是城市名称 override def inputSchema: StructType = { StructType(Array(StructField("cityName", StringType))) } //缓冲区中的数据应该为:totalcnt,Map[cityname,cnt] override def bufferSchema: StructType = { StructType(Array( StructField("cityName", LongType), StructField("cityMap", MapType(StringType, LongType)) )) } //返回城市备注的字符串 override def dataType: DataType = StringType override def deterministic: Boolean = true //缓冲区的初始化 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0L //buffer.update(0,0L) buffer(1) = Map[String, Long]() } //更新缓冲区 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { val cityName: String = input.getString(0) //点击总和需要增加 buffer(0) = buffer.getLong(0) + 1 //城市点击增加 val cityMap: Map[String, Long] = buffer.getAs[Map[String, Long]](1) val newClickCount = cityMap.getOrElse(cityName, 0L) + 1 buffer(1) = cityMap.updated(cityName, newClickCount) } //合并缓冲区 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { //合并点击数量总和 buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) //合并城市点击map val map1 = buffer1.getAs[Map[String, Long]](1) val map2 = buffer2.getAs[Map[String, Long]](1) buffer1(1) = map1.foldLeft(map2) { case (map, (k, v)) => { map.updated(k, map.getOrElse(k, 0L) + v) } } } //对缓冲区进行计算并返回备注信息 override def evaluate(buffer: Row): Any = { val totalcnt: Long = buffer.getLong(0) val citymap: collection.Map[String, Long] = buffer.getMap[String, Long](1) val cityToCountList: List[(String, Long)] = citymap.toList.sortWith( (left, right) => left._2 > right._2 ).take(2) //val hasRest = citymap.size > 2 var rest = 0L val s = new StringBuilder cityToCountList.foreach { case (city, cnt) => { val r = (cnt * 100 / totalcnt) s.append(city + " " + r + "%,") rest = rest + r } } s.toString() + "其他" + (100 - rest) + "%" // if (hasRest) { // s.toString() + "其他" + (100 - rest) + "%" // } else { // toString // } } } }
    Processed: 0.017, SQL: 9