Spark 中的 UDAF 函数
— 焉知非鱼上午介绍了 Spark 的用户自定义聚合函数, 其实是官网的例子, 它使用 Spark SQL。下午我们来看一个稍微复杂点的例子, 这次我们在 agg 中使用自定义的聚合函数。
我们的程序读取 JSON数据, 其中 vin 是车架号; timestamp 是信号发生时间, 信号每隔 1 秒发生一次; PtcConsmp 是一个 Long 类型的值。
{"vin":"KLSAP5S09K1214720","timestamp":1543766400000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766401000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766402000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766403000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766405000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766406000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766407000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766408000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766409000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766410000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766411000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766412000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766424000,"PtcConsmp":1650}
{"vin":"KLSAP5S09K1214720","timestamp":1543766425000,"PtcConsmp":1650}
{"vin":"KLSAP5S09K1214720","timestamp":1543766426000,"PtcConsmp":1650}
...
我们规定:
- 如果一段连续的时间内, PtcConsmp 的值发生变化, 则记为一次 PTC 切换。
- 如果发生一次时间缺失, 也记为一次 PTC 切换。
计算每个 vin 的 PtcConsmp 切换次数。
我们借助 Spark 的 UDAF 函数来计算这个需求。首先我们定义一个 PTCPwrAgg
对象, 它扩展自 UserDefinedAggregateFunction
。所以 PTCPwrAgg
里必须实现如下几个函数:
- inputSchema 输入数据
- bufferSchema 中间状态
- dataType 返回值类型
- deterministic true
- initialize 初始化缓冲区
- update 更新缓冲区
- merge 合并缓冲区
- evaluate 计算最终结果
package ohmysummer
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._
object PTCPwrAgg extends UserDefinedAggregateFunction {
// 输入参数为信号时间戳和 PTC 消耗值
def inputSchema: StructType = StructType(StructField("timestamp", LongType) :: StructField("PtcConsmp", LongType) :: Nil)
// 聚合缓冲区值的类型, sum 字段为 ptc 切换次数, lastTime 为上次时间, lastPtc 为上次 PTC 值
def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("lastTime", LongType) :: StructField("lastPtc", LongType) :: Nil)
}
// 返回值类型
def dataType: DataType = LongType
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L // PTC 初始切换次数计为 0
buffer(1) = 0L // 上次时间
buffer(2) = 0L // 上次 PTC 值
}
// Updates the given aggregation buffer `buffer` with new input data from `input`
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val newTime = input.getLong(0) // 时间戳
val newPtcConsmp = input.getLong(1) // PTC 值
if (!input.isNullAt(0)) {
if (newTime - buffer.getLong(1) > 1000) {
buffer(0) = buffer.getLong(0) + 1 // PTC 切换一次
} else {
if (newPtcConsmp != buffer.getLong(2)) {
buffer(0) = buffer.getLong(0) + 1 // PTC 切换一次
}
}
buffer(1) = newTime // 更新 buffer 中的时间
buffer(2) = newPtcConsmp // 更新 buffer 中的 ptc 值
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) // ptc 切换次数
}
// Calculates the final result
def evaluate(buffer: Row): Long = buffer.getLong(0)
}
在 PTCStartCount.scala
中, 我们先把 JSON 文件读取为 DataFrame, 然后根据 vin 对该 DataFrame 分组, 因为聚合函数只能用在分组后的数据上。
package ohmysummer
import org.apache.spark.sql.SparkSession
object PTCStartCount {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.master("local[*]")
.appName("Spark SQL user-defined DataFrames aggregation example")
.getOrCreate()
import spark.implicits._
val df = spark.read.json("/Users/ohmycloud/work/cihon/3to2.json")
.groupBy("vin")
.agg(PTCPwrAgg($"timestamp", $"PtcConsmp") as "ptcCount")
df.show()
// +-----------------+--------+
// | vin|ptcCount|
// +-----------------+--------+
// |KLSAP5S09K1214720| 4 |
// +-----------------+--------+
}
}
如果不是很理解这个过程, 可以在 update
和 merge
函数中添加打印语句, 看一下中间输出的东西, 有助于理解这个过程。