Wait the light to fall

Spark 中的 UDAF 函数

焉知非鱼

上午介绍了 Spark 的用户自定义聚合函数, 其实是官网的例子, 它使用 Spark SQL。下午我们来看一个稍微复杂点的例子, 这次我们在 agg 中使用自定义的聚合函数。

我们的程序读取 JSON数据, 其中 vin 是车架号; timestamp 是信号发生时间, 信号每隔 1 秒发生一次; PtcConsmp 是一个 Long 类型的值。

{"vin":"KLSAP5S09K1214720","timestamp":1543766400000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766401000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766402000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766403000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766405000,"PtcConsmp":5000}
{"vin":"KLSAP5S09K1214720","timestamp":1543766406000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766407000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766408000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766409000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766410000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766411000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766412000,"PtcConsmp":850}
{"vin":"KLSAP5S09K1214720","timestamp":1543766424000,"PtcConsmp":1650}
{"vin":"KLSAP5S09K1214720","timestamp":1543766425000,"PtcConsmp":1650}
{"vin":"KLSAP5S09K1214720","timestamp":1543766426000,"PtcConsmp":1650}
...

我们规定:

  • 如果一段连续的时间内, PtcConsmp 的值发生变化, 则记为一次 PTC 切换。
  • 如果发生一次时间缺失, 也记为一次 PTC 切换。

计算每个 vin 的 PtcConsmp 切换次数。

我们借助 Spark 的 UDAF 函数来计算这个需求。首先我们定义一个 PTCPwrAgg 对象, 它扩展自 UserDefinedAggregateFunction。所以 PTCPwrAgg 里必须实现如下几个函数:

  • inputSchema 输入数据
  • bufferSchema 中间状态
  • dataType 返回值类型
  • deterministic true
  • initialize 初始化缓冲区
  • update 更新缓冲区
  • merge 合并缓冲区
  • evaluate 计算最终结果
package ohmysummer

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._

object PTCPwrAgg extends UserDefinedAggregateFunction {

  // 输入参数为信号时间戳和 PTC 消耗值
  def inputSchema: StructType = StructType(StructField("timestamp", LongType) :: StructField("PtcConsmp", LongType) :: Nil)

  // 聚合缓冲区值的类型, sum 字段为 ptc 切换次数,  lastTime 为上次时间, lastPtc 为上次 PTC 值
  def bufferSchema: StructType = {
    StructType(StructField("sum", LongType) :: StructField("lastTime", LongType)  :: StructField("lastPtc", LongType)  :: Nil)
  }

  // 返回值类型
  def dataType: DataType = LongType

  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L     // PTC 初始切换次数计为 0
    buffer(1) = 0L     // 上次时间
    buffer(2) = 0L     // 上次 PTC 值
  }

  // Updates the given aggregation buffer `buffer` with new input data from `input`
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val newTime      = input.getLong(0) // 时间戳
    val newPtcConsmp = input.getLong(1) // PTC 值

    if (!input.isNullAt(0)) {

      if (newTime - buffer.getLong(1) > 1000) {
        buffer(0) = buffer.getLong(0) + 1         // PTC 切换一次
      } else {
        if (newPtcConsmp != buffer.getLong(2)) {
          buffer(0) = buffer.getLong(0) + 1       // PTC 切换一次
        }
      }
      buffer(1) = newTime      // 更新 buffer 中的时间
      buffer(2) = newPtcConsmp // 更新 buffer 中的 ptc 值
    }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) // ptc 切换次数
  }

  // Calculates the final result
  def evaluate(buffer: Row): Long = buffer.getLong(0)
}

PTCStartCount.scala 中, 我们先把 JSON 文件读取为 DataFrame, 然后根据 vin 对该 DataFrame 分组, 因为聚合函数只能用在分组后的数据上。

package ohmysummer

import org.apache.spark.sql.SparkSession

object PTCStartCount {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .master("local[*]")
      .appName("Spark SQL user-defined DataFrames aggregation example")
      .getOrCreate()

    import spark.implicits._
    val df = spark.read.json("/Users/ohmycloud/work/cihon/3to2.json")
     .groupBy("vin")
     .agg(PTCPwrAgg($"timestamp", $"PtcConsmp") as "ptcCount")

    df.show()
    // +-----------------+--------+
    // |              vin|ptcCount|
    // +-----------------+--------+
    // |KLSAP5S09K1214720|      4 |
    // +-----------------+--------+
  }
}

如果不是很理解这个过程, 可以在 updatemerge 函数中添加打印语句, 看一下中间输出的东西, 有助于理解这个过程。