irpas技术客

【回顾】SparkSQL 之 用户自定义函数_Riding the snail chase missiles ~

irpas 3544


文章目录 1、UDF1) 创建 DataFrame2) 注册 UDF3) 创建临时表4) 应用 UDF 2、UDAF1) 实现方式 - RDD2) 实现方式 - 累加器3) 实现方式 - UDAF - 弱类型4) 实现方式 - UDAF - 强类型Spark 3.0早期版本


1、UDF

UDF(User Defined Function):spark SQL中用户自定义函数,用法和spark SQL中的内置函数类似;是saprk SQL中内置函数无法满足要求,用户根据业务需求自定义的函数

基本使用步骤如下:


1) 创建 DataFrame scala> val df = spark.read.json("/home/data/spark/user.json") df: org.apache.spark.sql.DataFrame = [age: bigint, username: string]
2) 注册 UDF // 自定义udf函数:添加说明词,并注册 scala> spark.udf.register("addName",(x:String)=> "Name:"+x) res9: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))
3) 创建临时表 scala> df.createOrReplaceTempView("people")
4) 应用 UDF // 使用的时候直接调用注册名传入参数即可! scala> spark.sql("Select addName(username),age from people").show() +---------------------+---+ |UDF:addName(username)|age| +---------------------+---+ | Namezhangsan| 20| | Namelisi| 30| | Namewangwu| 40| +---------------------+---+ scala> spark.sql("select addName(username) as newName,age from people").show +------------+---+ | newName|age| +------------+---+ |Namezhangsan| 20| | Namelisi| 30| | Namewangwu| 40| +------------+---+

注意:当spark-shell重新的启动的时候需要重新注册UDF函数,因为此时的SparkSession重新创建了,是新的入口。org.apache.spark.sql.AnalysisException: Undefined function: 'addName'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.

返回顶部


2、UDAF

强类型的 Dataset 和 弱类型的 DataFrame 都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。

除此之外,用户可以设定自己的自定义聚合函数。

通过继承 UserDefinedAggregateFunction 来实现用户自定义弱类型聚合函数。从 Spark3.0 版本后,UserDefinedAggregateFunction 已经不推荐使用了,可以统一采用强类型聚合函数 Aggregator。

需求:计算平均工资


1) 实现方式 - RDD val conf: SparkConf = new SparkConf().setAppName("app").setMaster("local[*]") val sc: SparkContext = new SparkContext(conf) val res: (Int, Int) = sc.makeRDD(List(("zhangsan", 20), ("lisi", 30),("wangw", 40))) .map { case (name, age) => { (age, 1) } } .reduce { (t1, t2) => { (t1._1 + t2._1, t1._2 + t2._2) } } println(res._1/res._2) // 关闭连接 sc.stop() 30

返回顶部


2) 实现方式 - 累加器 class MyAC extends AccumulatorV2[Int,Int]{ var sum:Int = 0 var count:Int = 0 override def isZero: Boolean = { return sum ==0 && count == 0 } override def copy(): AccumulatorV2[Int, Int] = { val newMyAc = new MyAC newMyAc.sum = this.sum newMyAc.count = this.count newMyAc } override def reset(): Unit = { sum =0 count = 0 } // 求和、计数 override def add(v: Int): Unit = { sum += v count += 1 } // 聚合 override def merge(other: AccumulatorV2[Int, Int]): Unit = { other match { case o:MyAC => { sum += o.sum count += o.count } case _ => {} } } // 计算结果 override def value: Int = sum/count }

返回顶部


3) 实现方式 - UDAF - 弱类型

自定义avgUDF函数类,继承UserDefinedAggregateFunction ,并重写方法

// 输入的数据的结构 override def inputSchema: StructType // 缓冲区的数据结构 override def bufferSchema: StructType // 函数计算结果的数据类型 override def dataType: DataType = LongType // 函数的稳定性,传入传出数据的类型保持一致 override def deterministic: Boolean = true // 缓冲区初始化 override def initialize(buffer: MutableAggregationBuffer): Unit // 根据输入的值跟新缓冲区 override def update(buffer: MutableAggregationBuffer, input: Row): Unit // 缓冲区数据合并 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit // 计算平均值 override def evaluate(buffer: Row): Any // 自定义聚合函数类计算年龄平均值 class avgUDF extends UserDefinedAggregateFunction { // 输入的数据的结构 override def inputSchema: StructType = { StructType(Array(StructField("age",LongType))) } // 缓冲区的数据结构 override def bufferSchema: StructType = { StructType( Array( StructField("total",LongType), StructField("count",LongType) ) ) } // 函数计算结果的数据类型 override def dataType: DataType = LongType // 函数的稳定性,传入传出数据的类型保持一致 override def deterministic: Boolean = true // 缓冲区初始化 override def initialize(buffer: MutableAggregationBuffer): Unit = { /*buffer(0) = 0L buffer(1) = 0L*/ buffer.update(0,0L) buffer.update(1,0L) } // 根据输入的值跟新缓冲区 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { // 缓冲区例的 total + 传进来的 age buffer.update(0,buffer.getLong(0)+input.getLong(0)) // count + 1 buffer.update(1,buffer.getLong(1)+1) } // 缓冲区数据合并 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0)) buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1)) } // 计算平均值 override def evaluate(buffer: Row): Any = { buffer.getLong(0)/buffer.getLong(1) } } // 创建视图 df.createOrReplaceTempView("people") // 调用udf函数 spark.sql("select avgAge(age) from people").show() +-----------+ |avgudf(age)| +-----------+ | 23| +-----------+

返回顶部


4) 实现方式 - UDAF - 强类型 Spark 3.0

Spark 3.0 中 使用Aggregator替代了原来的 UserDefinedAggregateFunction,具体使用如下:

自定义聚合类继承 Aggregator => 定义泛型 // 样例类 case class Buff( var sum:Long, var cnt:Long ) // 自定义聚合类 class MyAvgAgeUDAF extends Aggregator[Long, Buff, Double]{ // 初始值 override def zero: Buff = Buff(0,0) // 根据输入数据跟新缓冲区的数据 override def reduce(b: Buff, a: Long): Buff = { b.sum += a b.cnt += 1 b } // 聚合:合并缓冲区 override def merge(b1: Buff, b2: Buff): Buff = { b1.sum += b2.sum b1.cnt += b2.cnt b1 } // 计算结果 override def finish(reduction: Buff): Double = { reduction.sum.toDouble/reduction.cnt } // 网络传输缓冲区编码 自定义的类型就选product override def bufferEncoder: Encoder[Buff] = Encoders.product // 网络传输缓冲区解码 spark原有的就选相应的 override def outputEncoder: Encoder[Double] = Encoders.scalaDouble } // TODO 创建 UDAF 函数 val udaf = new MyAvgAgeUDAF // TODO 注册到 SparkSQL 中 spark.udf.register("avgAge", functions.udaf(new MyAvgAgeUDAF())) // TODO 在 SQL 中使用聚合函数 spark.sql("select avgAge(age) from people").show() +-----------+ |avgudf(age)| +-----------+ | 23| +-----------+

那么如果在早期版本中使用强类型的UDAF,该怎样使用呢?

返回顶部


早期版本

早期版本在使用 Aggregator 的时候基本步骤不变定义泛型的时候指定为User类型。

// 样例类 case class User(username:String, age:Long) case class Buff( var total:Long, var count:Long ) class MyAvgAgeUDAF extends Aggregator[User, Buff, Long]{ // z & zero : 初始值或零值 // 缓冲区的初始化 override def zero: Buff = { Buff(0L,0L) } // 根据输入的数据更新缓冲区的数据 override def reduce(buff: Buff, in: User): Buff = { buff.total = buff.total + in.age buff.count = buff.count + 1 buff } // 合并缓冲区 override def merge(buff1: Buff, buff2: Buff): Buff = { buff1.total = buff1.total + buff2.total buff1.count = buff1.count + buff2.count buff1 } //计算结果 override def finish(buff: Buff): Long = { buff.total / buff.count } // 缓冲区的编码操作 override def bufferEncoder: Encoder[Buff] = Encoders.product // 输出的编码操作 override def outputEncoder: Encoder[Long] = Encoders.scalaLong }

使用UDFA函数的时候操作有所不同,需要使用DSL语法操作

将udfa函数转为查询列的对象,进行查询

// TODO:早期版本的强类型 UDAF 使用DSL语法操作 // 读取数据 val df = spark.read.json("data/user.json") val ds = df.as[User] // 将udfa函数转换为查询的列对象 val udafColumn = new MyAvgAgeUDAF().toColumn ds.select(udafColumn).show() +------------------------------------------------------+ |MyAvgAgeUDAF(test02_UDF.Spark02_sql_UDF_avgAge04$User)| +------------------------------------------------------+ | 23| +------------------------------------------------------+

返回顶部



1.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源;2.本站的原创文章,会注明原创字样,如未注明都非原创,如有侵权请联系删除!;3.作者投稿可能会经我们编辑修改或补充;4.本站不提供任何储存功能只提供收集或者投稿人的网盘链接。

标签: #回顾SparkSQL # #用户自定义函数 #文章目录1UDF1 #创建 #DataFrame2 #注册 #UDF3