Spark 中的 UDAF 函数

上午介绍了 Spark 的用户自定义聚合函数, 其实是官网的例子, 它使用 Spark SQL。下午我们来看一个稍微复杂点的例子, 这次我们在 agg 中使用自定义的聚合函数。

我们的程序读取 JSON数据, 其中 vin 是车架号; timestamp 是信号发生时间, 信号每隔 1 秒发生一次; PtcConsmp 是一个 Long 类型的值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
{"vin":"KLSAP5S09K1214720","timestamp":1543766400000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766401000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766402000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766403000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766405000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766406000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766407000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766408000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766409000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766410000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766411000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766412000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766424000,"PtcConsmp":1650}
{"vin":"KLSAP5S09K1214720","timestamp":1543766425000,"PtcConsmp":1650}
{"vin":"KLSAP5S09K1214720","timestamp":1543766426000,"PtcConsmp":1650}
...

我们规定:

  • 如果一段连续的时间内, PtcConsmp 的值发生变化, 则记为一次 PTC 切换。
  • 如果发生一次时间缺失, 也记为一次 PTC 切换。

计算每个 vin 的 PtcConsmp 切换次数。

我们借助 Spark 的 UDAF 函数来计算这个需求。首先我们定义一个 PTCPwrAgg 对象, 它扩展自 UserDefinedAggregateFunction。所以 PTCPwrAgg 里必须实现如下几个函数:

  • inputSchema 输入数据
  • bufferSchema 中间状态
  • dataType 返回值类型
  • deterministic true
  • initialize 初始化缓冲区
  • update 更新缓冲区
  • merge 合并缓冲区
  • evaluate 计算最终结果
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
package ohmysummer

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._

object PTCPwrAgg extends UserDefinedAggregateFunction {

// 输入参数为信号时间戳和 PTC 消耗值
def inputSchema: StructType = StructType(StructField("timestamp", LongType) :: StructField("PtcConsmp", LongType) :: Nil)

// 聚合缓冲区值的类型, sum 字段为 ptc 切换次数, lastTime 为上次时间, lastPtc 为上次 PTC 值
def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("lastTime", LongType) :: StructField("lastPtc", LongType) :: Nil)
}

// 返回值类型
def dataType: DataType = LongType

def deterministic: Boolean = true

def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L // PTC 初始切换次数计为 0
buffer(1) = 0L // 上次时间
buffer(2) = 0L // 上次 PTC 值
}

// Updates the given aggregation buffer `buffer` with new input data from `input`
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val newTime = input.getLong(0) // 时间戳
val newPtcConsmp = input.getLong(1) // PTC 值

if (!input.isNullAt(0)) {

if (newTime - buffer.getLong(1) > 1000) {
buffer(0) = buffer.getLong(0) + 1 // PTC 切换一次
} else {
if (newPtcConsmp != buffer.getLong(2)) {
buffer(0) = buffer.getLong(0) + 1 // PTC 切换一次
}
}
buffer(1) = newTime // 更新 buffer 中的时间
buffer(2) = newPtcConsmp // 更新 buffer 中的 ptc 值
}
}

def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) // ptc 切换次数
}

// Calculates the final result
def evaluate(buffer: Row): Long = buffer.getLong(0)
}

PTCStartCount.scala 中, 我们先把 JSON 文件读取为 DataFrame, 然后根据 vin 对该 DataFrame 分组, 因为聚合函数只能用在分组后的数据上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
package ohmysummer

import org.apache.spark.sql.SparkSession

object PTCStartCount {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.master("local[*]")
.appName("Spark SQL user-defined DataFrames aggregation example")
.getOrCreate()

import spark.implicits._
val df = spark.read.json("/Users/ohmycloud/work/cihon/3to2.json")
.groupBy("vin")
.agg(PTCPwrAgg($"timestamp", $"PtcConsmp") as "ptcCount")

df.show()
// +-----------------+--------+
// | vin|ptcCount|
// +-----------------+--------+
// |KLSAP5S09K1214720| 4 |
// +-----------------+--------+
}
}

如果不是很理解这个过程, 可以在 updatemerge 函数中添加打印语句, 看一下中间输出的东西, 有助于理解这个过程。