用户自定义聚合函数

聚合

内置的 DataFrame 函数提供常见的聚合,如 count()countDistinct()avg()max()min() 等。虽然这些函数是为 DataFrame 设计的,但 Spark SQL 也有类型安全的版本其中一些在 Scala 和 Java 中使用强类型数据集。此外,用户不限于预定义的聚合函数,并且可以创建自己的聚合函数。

无类型用户自定义函数

用户必须扩展 UserDefinedAggregateFunction 抽象类以实现自定义无类型聚合函数。例如,用户定义的平均值可能看起来像这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._

object MyAverage extends UserDefinedAggregateFunction {

// 该聚合函数输入参数的数据类型
def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)

// 聚合缓冲区中值的数据类型
def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}

// 返回值的数据类型
def dataType: DataType = DoubleType

// 对于相同输入该函数是否始终返回相同输出
def deterministic: Boolean = true

// 初始化给定的聚合缓冲区。缓冲区本身是一个 "Row",除了诸如在索引处检索值(例如,get(),getBoolean()) 标准方法之外,提供了更新其值的机会。
// 请注意,缓冲区内的数组和 map 映射仍然是不可变的
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L // sum 值
buffer(1) = 0L // 元素个数
}

// 使用来自 `input` 的新输入数据更新给定的聚合缓冲区 `buffer`
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
buffer(0) = buffer.getLong(0) + input.getLong(0) // 更新 sum 值
buffer(1) = buffer.getLong(1) + 1 // 更新元素个数
}
}

// 合并两个聚合缓冲区并将更新后的缓冲区值存储回 `buffer1`
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}

// 计算最终结果
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1) // 计算出平均值: sum/count
}

// 注册该函数以访问它
spark.udf.register("myAverage", MyAverage)

val df = spark.read.json("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+

val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+

完整的例子代码, 请参见 Spark 仓库中的

examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala。

类型安全的用户自定义函数

强类型数据集的用户定义聚合围绕 Aggregator 抽象类。例如,类型安全的用户定义平均值可能如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator

case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)

object MyAverage extends Aggregator[Employee, Average, Double] {

// 该聚合的零值。 它应该满足对于任何 b 值: b + 0 = b
def zero: Average = Average(0L, 0L)

// 合并两个值以生成新值。为了提高性能,该函数可以修改 `buffer` 并返回它而不是构造一个新对象
def reduce(buffer: Average, employee: Employee): Average = {
buffer.sum += employee.salary
buffer.count += 1
buffer
}

// 合并两个中间值
def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}

// 变换归约后的输出
def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count

// 指定中间值类型的编码器
def bufferEncoder: Encoder[Average] = Encoders.product

// 为最终输出值类型指定编码器
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
ds.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+

// 将函数转换为 `TypedColumn` 并为其命名
val averageSalary = MyAverage.toColumn.name("average_salary")
val result = ds.select(averageSalary)
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+

完整的例子代码, 请参见 Spark 仓库中的

examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala。