Wait the light to fall

用户定义函数

焉知非鱼

User Defined Functions

用户自定义函数

用户自定义函数(UDFs)是扩展点,用于调用常用的逻辑或自定义逻辑,这些逻辑无法在查询中以其他方式表达。

用户定义函数可以用 JVM 语言(如 Java 或 Scala)或 Python 实现。实现者可以在 UDF 中使用任意的第三方库。本页将重点介绍基于 JVM 的语言。

概述 #

目前,Flink 区分了以下几种函数。

  • 标量函数将标量值映射到一个新的标量值。
  • 表函数将标量值映射到新的行(row)。
  • 聚合函数将多行的标量值映射到新的标量值。
  • 表聚合函数将多行的标量值映射到新的行上。
  • 异步表函数是针对 table source 执行查找的特殊函数。

注意: 标量函数和表函数已经更新为基于数据类型的新类型系统。聚合函数仍然使用基于 TypeInformation 的旧类型系统。

下面的示例展示了如何创建一个简单的标量函数,以及如何在表 API 和 SQL 中调用该函数。

对于 SQL 查询,一个函数必须始终以一个名字注册。对于 Table API,函数可以被注册,也可以直接内联使用。

import org.apache.flink.table.api._
import org.apache.flink.table.functions.ScalarFunction

// define function logic
class SubstringFunction extends ScalarFunction {
  def eval(s: String, begin: Integer, end: Integer): String = {
    s.substring(begin, end)
  }
}

val env = TableEnvironment.create(...)

// call function "inline" without registration in Table API
env.from("MyTable").select(call(classOf[SubstringFunction], $"myField", 5, 12))

// register function
env.createTemporarySystemFunction("SubstringFunction", classOf[SubstringFunction])

// call registered function in Table API
env.from("MyTable").select(call("SubstringFunction", $"myField", 5, 12))

// call registered function in SQL
env.sqlQuery("SELECT SubstringFunction(myField, 5, 12) FROM MyTable")

对于交互式会话,也可以在使用或注册函数之前对其进行参数化。在这种情况下,可以使用函数实例代替函数类作为临时函数。

它要求参数是可序列化的,以便将函数实例运送到集群。

import org.apache.flink.table.api._
import org.apache.flink.table.functions.ScalarFunction

// define parameterizable function logic
class SubstringFunction(val endInclusive) extends ScalarFunction {
  def eval(s: String, begin: Integer, end: Integer): String = {
    s.substring(endInclusive ? end + 1 : end)
  }
}

val env = TableEnvironment.create(...)

// call function "inline" without registration in Table API
env.from("MyTable").select(call(new SubstringFunction(true), $"myField", 5, 12))

// register function
env.createTemporarySystemFunction("SubstringFunction", new SubstringFunction(true))

实现指南 #

注意:本节目前只适用于标量函数和表函数;在集合函数更新到新的类型系统之前,本节只适用于标量函数。

无论函数的种类如何,所有用户定义的函数都遵循一些基本的实现原则。

函数类 #

一个实现类必须从一个可用的基类(例如 org.apache.flink.table.function.ScalarFunction)中扩展出来。

这个类必须被声明为 public,而不是 abstract,并且应该是全局访问的。因此,不允许使用非静态的内部类或匿名类。

对于在持久化目录中存储用户定义的函数,该类必须有一个默认的构造函数,并且在运行时必须是可实例化的。

评估方法 #

基类提供了一组可以重写的方法,如 open()close()isDeterministic()

然而,除了这些声明的方法外,应用于每个传入记录的主要运行时逻辑必须通过专门的评估方法来实现。

根据函数种类的不同,评价方法如 eval()accumulate()retract() 会在运行时被代码生成的操作符调用。

这些方法必须声明为 public,并接受一组定义明确的参数。

常规的 JVM 方法调用语义适用。因此,可以

  • 实现重载方法,如 eval(Integer)eval(LocalDateTime)
  • 使用 var-args,如 eval(Integer...)
  • 使用对象继承,如 eval(Object),它同时接受 LocalDateTimeInteger
  • 以及上述函数的组合,如 eval(Object...),它可以接受所有类型的参数。

如果你打算在 Scala 中实现函数,请在使用变量参数时添加 scala.annotation.varargs 注解。此外,建议使用盒状基元(如用 java.lang.Integer 代替 Int)来支持 NULL。

下面的代码段显示了一个重载函数的示例。

import org.apache.flink.table.functions.ScalarFunction
import java.lang.Integer
import java.lang.Double
import scala.annotation.varargs

// function with overloaded evaluation methods
class SumFunction extends ScalarFunction {

  def eval(a: Integer, b: Integer): Integer = {
    a + b
  }

  def eval(a: String, b: String): Integer = {
    Integer.valueOf(a) + Integer.valueOf(b)
  }

  @varargs // generate var-args like Java
  def eval(d: Double*): Integer = {
    d.sum.toInt
  }
}

类型推断 #

表生态系统(类似于 SQL 标准)是一个强类型的 API。因此,函数参数和返回类型都必须映射到数据类型

从逻辑的角度来看,规划师需要关于预期类型、精度和规模的信息。从 JVM 的角度来看,规划师需要了解当调用用户定义的函数时,内部数据结构如何被表示为 JVM 对象。

验证输入参数和推导出函数的参数和结果的数据类型的逻辑被总结在类型推理这个术语下。

Flink 的用户定义函数实现了自动类型推理提取,通过反射从函数的类和它的评估方法中导出数据类型。如果这种隐式反射提取方法不成功,可以通过用 @DataTypeHint@FunctionHint 注释受影响的参数、类或方法来支持提取过程。更多关于如何注释函数的例子如下所示。

如果需要更高级的类型推理逻辑,实现者可以在每个用户定义的函数中显式覆盖 getTypeInference() 方法。然而,推荐使用注释方法,因为它将自定义类型推理逻辑保持在受影响的位置附近,并回落到其余实现的默认行为。

自动类型推断 #

自动类型推理检查函数的类和评估方法,从而得出函数的参数和结果的数据类型。@DataTypeHint@FunctionHint 注解支持自动提取。

关于可以隐式映射到数据类型的类的完整列表,请参阅数据类型提取部分。

@DataTypeHint #

在很多情况下,需要支持对函数的参数和返回类型进行在线自动提取。

下面的示例展示了如何使用数据类型提示。更多信息可以在注解类的文档中找到。

import org.apache.flink.table.annotation.DataTypeHint
import org.apache.flink.table.annotation.InputGroup
import org.apache.flink.table.functions.ScalarFunction
import org.apache.flink.types.Row
import scala.annotation.varargs

// function with overloaded evaluation methods
class OverloadedFunction extends ScalarFunction {

  // no hint required
  def eval(a: Long, b: Long): Long = {
    a + b
  }

  // define the precision and scale of a decimal
  @DataTypeHint("DECIMAL(12, 3)")
  def eval(double a, double b): BigDecimal = {
    java.lang.BigDecimal.valueOf(a + b)
  }

  // define a nested data type
  @DataTypeHint("ROW<s STRING, t TIMESTAMP(3) WITH LOCAL TIME ZONE>")
  def eval(Int i): Row = {
    Row.of(java.lang.String.valueOf(i), java.time.Instant.ofEpochSecond(i))
  }

  // allow wildcard input and customly serialized output
  @DataTypeHint(value = "RAW", bridgedTo = classOf[java.nio.ByteBuffer])
  def eval(@DataTypeHint(inputGroup = InputGroup.ANY) Object o): java.nio.ByteBuffer = {
    MyUtils.serializeToByteBuffer(o)
  }
}
@FunctionHint #

在某些场景下,一个评估方法同时处理多种不同的数据类型是可取的。此外,在某些场景中,重载的评估方法有一个共同的结果类型,应该只声明一次。

@FunctionHint 注解可以提供从参数数据类型到结果数据类型的映射。它可以为输入、累加器和结果数据类型注释整个函数类或评估方法。一个或多个注解可以在一个类的顶部声明,也可以为每个评估方法单独声明,以便重载函数签名。所有的提示参数都是可选的。如果没有定义参数,则使用默认的基于反射的提取方式。在函数类之上定义的提示参数会被所有的评估方法继承。

下面的例子展示了如何使用函数提示。更多信息可以在注解类的文档中找到。

import org.apache.flink.table.annotation.DataTypeHint
import org.apache.flink.table.annotation.FunctionHint
import org.apache.flink.table.functions.TableFunction
import org.apache.flink.types.Row

// function with overloaded evaluation methods
// but globally defined output type
@FunctionHint(output = new DataTypeHint("ROW<s STRING, i INT>"))
class OverloadedFunction extends TableFunction[Row] {

  def eval(a: Int, b: Int): Unit = {
    collect(Row.of("Sum", Int.box(a + b)))
  }

  // overloading of arguments is still possible
  def eval(): Unit = {
    collect(Row.of("Empty args", Int.box(-1)))
  }
}

// decouples the type inference from evaluation methods,
// the type inference is entirely determined by the function hints
@FunctionHint(
  input = Array(new DataTypeHint("INT"), new DataTypeHint("INT")),
  output = new DataTypeHint("INT")
)
@FunctionHint(
  input = Array(new DataTypeHint("BIGINT"), new DataTypeHint("BIGINT")),
  output = new DataTypeHint("BIGINT")
)
@FunctionHint(
  input = Array(),
  output = new DataTypeHint("BOOLEAN")
)
class OverloadedFunction extends TableFunction[AnyRef] {

  // an implementer just needs to make sure that a method exists
  // that can be called by the JVM
  @varargs
  def eval(o: AnyRef*) = {
    if (o.length == 0) {
      collect(Boolean.box(false))
    }
    collect(o(0))
  }
}

自定义类型推断 #

对于大多数情况下,@DataTypeHint@FunctionHint 应该足以为用户定义的函数建模。然而,通过覆盖 getTypeInference() 中定义的自动类型推理,实现者可以创建任意的函数,这些函数的行为就像内置的系统函数一样。

下面这个用 Java 实现的例子说明了自定义类型推理逻辑的潜力。它使用一个字符串文字参数来确定一个函数的结果类型。该函数需要两个字符串参数:第一个参数代表要解析的字符串,第二个参数代表目标类型。

import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.types.Row;

public static class LiteralFunction extends ScalarFunction {
  public Object eval(String s, String type) {
    switch (type) {
      case "INT":
        return Integer.valueOf(s);
      case "DOUBLE":
        return Double.valueOf(s);
      case "STRING":
      default:
        return s;
    }
  }

  // the automatic, reflection-based type inference is disabled and
  // replaced by the following logic
  @Override
  public TypeInference getTypeInference(DataTypeFactory typeFactory) {
    return TypeInference.newBuilder()
      // specify typed arguments
      // parameters will be casted implicitly to those types if necessary
      .typedArguments(DataTypes.STRING(), DataTypes.STRING())
      // specify a strategy for the result data type of the function
      .outputTypeStrategy(callContext -> {
        if (!callContext.isArgumentLiteral(1) || callContext.isArgumentNull(1)) {
          throw callContext.newValidationError("Literal expected for second argument.");
        }
        // return a data type based on a literal
        final String literal = callContext.getArgumentValue(1, String.class).orElse("STRING");
        switch (literal) {
          case "INT":
            return Optional.of(DataTypes.INT().notNull());
          case "DOUBLE":
            return Optional.of(DataTypes.DOUBLE().notNull());
          case "STRING":
          default:
            return Optional.of(DataTypes.STRING());
        }
      })
      .build();
  }
}

运行时集成 #

有时可能需要用户自定义函数在实际工作前获取全局运行时信息或做一些设置/清理工作。用户自定义函数提供了 open()close() 方法,这些方法可以被重写,并提供与 DataStream API 的 RichFunction 中的方法类似的功能。

open() 方法在评估方法之前被调用一次。close() 方法在最后一次调用评估方法后调用。

open() 方法提供了一个 FunctionContext,该 FunctionContext 包含了用户定义函数执行的上下文信息,如度量组、分布式缓存文件或全局作业参数。

通过调用 FunctionContext 的相应方法,可以获得以下信息。

方法 描述
getMetricGroup() 该并行子任务的度量组。
getCachedFile(name) 分布式缓存文件的本地临时文件副本。
getJobParameter(name, defaultValue) 与给定键相关联的全局作业参数值。

下面的示例片段展示了如何在标量函数中使用 FunctionContext 来访问全局工作参数。

import org.apache.flink.table.api._
import org.apache.flink.table.functions.FunctionContext
import org.apache.flink.table.functions.ScalarFunction

class HashCodeFunction extends ScalarFunction {

  private var factor: Int = 0

  override def open(context: FunctionContext): Unit = {
    // access the global "hashcode_factor" parameter
    // "12" would be the default value if the parameter does not exist
    factor = context.getJobParameter("hashcode_factor", "12").toInt
  }

  def eval(s: String): Int = {
    s.hashCode * factor
  }
}

val env = TableEnvironment.create(...)

// add job parameter
env.getConfig.addJobParameter("hashcode_factor", "31")

// register the function
env.createTemporarySystemFunction("hashCode", classOf[HashCodeFunction])

// use the function
env.sqlQuery("SELECT myField, hashCode(myField) FROM MyTable")

标量函数 #

用户定义的标量函数可以将零、一或多个标量值映射到一个新的标量值。数据类型一节中列出的任何数据类型都可以作为一个评估方法的参数或返回类型。

为了定义一个标量函数,必须扩展 org.apache.flink.table.function 中的基类 ScalarFunction,并实现一个或多个名为 eval(...) 的评估方法。

下面的例子展示了如何定义自己的哈希码函数并在查询中调用它。更多细节请参见实施指南。

import org.apache.flink.table.annotation.InputGroup
import org.apache.flink.table.api._
import org.apache.flink.table.functions.ScalarFunction

class HashFunction extends ScalarFunction {

  // take any data type and return INT
  def eval(@DataTypeHint(inputGroup = InputGroup.ANY) o: AnyRef): Int {
    return o.hashCode();
  }
}

val env = TableEnvironment.create(...)

// call function "inline" without registration in Table API
env.from("MyTable").select(call(classOf[HashFunction], $"myField"))

// register function
env.createTemporarySystemFunction("HashFunction", classOf[HashFunction])

// call registered function in Table API
env.from("MyTable").select(call("HashFunction", $"myField"))

// call registered function in SQL
env.sqlQuery("SELECT HashFunction(myField) FROM MyTable")

如果你打算用 Python 实现或调用函数,请参考 Python Scalar Functions 文档了解更多细节。

表函数 #

与用户定义的标量函数类似,用户定义的表函数将零、一个或多个标量值作为输入参数。然而,与标量函数不同的是,它可以返回任意数量的行(或结构化类型)作为输出,而不是单个值。返回的记录可能由一个或多个字段组成。如果一条输出记录只由一个字段组成,则可以省略结构化记录,并发出一个标量值。它将被运行时包装成一个隐式行。

为了定义一个表函数,必须扩展 org.apache.flink.table.function 中的基类 TableFunction,并实现一个或多个名为 eval(...) 的评估方法。与其他函数类似,输入和输出数据类型也是使用反射自动提取的。这包括类的通用参数 T,用于确定输出数据类型。与标量函数不同的是,评价方法本身不能有返回类型,相反,表函数提供了一个 collect(T) 方法,可以在每个评价方法内调用,用于发出零、一条或多条记录。

在表 API 中,表函数的使用方法是 .joinLateral(...).leftOuterJoinLateral(...)。joinLateral 运算符(cross)将外表(运算符左边的表)的每条记录与表值函数产生的所有记录(表值函数在运算符的右边)连接起来。leftOuterJoinLateral 操作符将外表(操作符左边的表)的每一条记录与表值函数产生的所有记录(它在操作符的右边)连接起来,并且保留那些表函数返回空表的外表。

在 SQL 中,使用 LATERAL TABLE(<TableFunction>) 与 JOIN 或 LEFT JOIN 与 ON TRUE 连接条件。

下面的示例展示了如何定义自己的拆分函数并在查询中调用它。更多细节请参见《实现指南》。

import org.apache.flink.table.annotation.DataTypeHint
import org.apache.flink.table.annotation.FunctionHint
import org.apache.flink.table.api._
import org.apache.flink.table.functions.TableFunction
import org.apache.flink.types.Row

@FunctionHint(output = new DataTypeHint("ROW<word STRING, length INT>"))
class SplitFunction extends TableFunction[Row] {

  def eval(str: String): Unit = {
    // use collect(...) to emit a row
    str.split(" ").foreach(s => collect(Row.of(s, Int.box(s.length))))
  }
}

val env = TableEnvironment.create(...)

// call function "inline" without registration in Table API
env
  .from("MyTable")
  .joinLateral(call(classOf[SplitFunction], $"myField")
  .select($"myField", $"word", $"length")
env
  .from("MyTable")
  .leftOuterJoinLateral(call(classOf[SplitFunction], $"myField"))
  .select($"myField", $"word", $"length")

// rename fields of the function in Table API
env
  .from("MyTable")
  .leftOuterJoinLateral(call(classOf[SplitFunction], $"myField").as("newWord", "newLength"))
  .select($"myField", $"newWord", $"newLength")

// register function
env.createTemporarySystemFunction("SplitFunction", classOf[SplitFunction])

// call registered function in Table API
env
  .from("MyTable")
  .joinLateral(call("SplitFunction", $"myField"))
  .select($"myField", $"word", $"length")
env
  .from("MyTable")
  .leftOuterJoinLateral(call("SplitFunction", $"myField"))
  .select($"myField", $"word", $"length")

// call registered function in SQL
env.sqlQuery(
  "SELECT myField, word, length " +
  "FROM MyTable, LATERAL TABLE(SplitFunction(myField))");
env.sqlQuery(
  "SELECT myField, word, length " +
  "FROM MyTable " +
  "LEFT JOIN LATERAL TABLE(SplitFunction(myField)) ON TRUE")

// rename fields of the function in SQL
env.sqlQuery(
  "SELECT myField, newWord, newLength " +
  "FROM MyTable " +
  "LEFT JOIN LATERAL TABLE(SplitFunction(myField)) AS T(newWord, newLength) ON TRUE")

如果你打算在 Scala 中实现函数,不要将表函数实现为 Scala 对象。Scala 对象是单子,会导致并发问题。

如果你打算用 Python 实现或调用函数,请参考 Python 表函数文档了解更多细节。

聚合函数 #

用户自定义聚合函数(UDAGG)将一个表(一个或多个具有一个或多个属性的行)聚合成一个标量值。

img

上图显示了一个聚合的例子。假设你有一个包含饮料数据的表。该表由 id、名称和价格三列和 5 行组成。想象一下,你需要找到表中所有饮料的最高价格,即执行 max() 聚合。你需要对 5 行中的每一行进行检查,结果将是一个单一的数值。

用户定义的聚合函数是通过扩展 AggregateFunction 类来实现的。AggregateFunction 的工作原理如下。首先,它需要一个累加器,它是存放聚合中间结果的数据结构。通过调用 AggregateFunction 的 createAccumulator() 方法创建一个空的累加器。随后,函数的 accumulate() 方法对每一条输入行进行调用,以更新累加器。一旦所有的行都被处理完毕,函数的 getValue() 方法就会被调用来计算并返回最终结果。

以下方法是每个 AggregateFunction 必须使用的。

  • createAccumulator()
  • accumulate()
  • getValue()

Flink 的类型提取设施可能无法识别复杂的数据类型,例如,如果它们不是基本类型或简单的 POJOs。所以与 ScalarFunction 和 TableFunction 类似,AggregateFunction 提供了指定结果类型(通过 AggregateFunction#getResultType())和累加器类型(通过 AggregateFunction#getAccumulatorType())的方法。

除了上述方法外,还有一些签约方法可以选择实现。这些方法中的一些方法可以让系统更高效地执行查询,而另一些方法则是某些用例所必须的。例如,如果聚合函数应该在会话组窗口的上下文中应用,那么 merge() 方法是强制性的(当观察到有一行 “连接 “它们时,需要将两个会话窗口的累加器连接起来)。

AggregateFunction 的以下方法是根据用例需要的。

  • retract() 对于有界 OVER 窗口上的聚合是需要的。
  • merge() 是许多批次聚合和会话窗口聚合所需要的。
  • resetAccumulator() 是许多批处理聚合所需要的。

AggregateFunction 的所有方法都必须声明为 public,而不是 static,并且命名与上述名称完全一致。方法 createAccumulator、getValue、getResultType 和 getAccumulatorType 是在 AggregateFunction 抽象类中定义的,而其他方法则是合同方法。为了定义一个聚合函数,必须扩展基类 org.apache.flink.table.function.AggregateFunction,并实现一个(或多个)accumulate 方法。方法 accumulate 可以用不同的参数类型重载,并支持变量参数。

下面给出了 AggregateFunction 所有方法的详细文档。

/**
  * Base class for user-defined aggregates and table aggregates.
  *
  * @tparam T   the type of the aggregation result.
  * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  *             aggregated values which are needed to compute an aggregation result.
  */
abstract class UserDefinedAggregateFunction[T, ACC] extends UserDefinedFunction {

  /**
    * Creates and init the Accumulator for this (table)aggregate function.
    *
    * @return the accumulator with the initial value
    */
  def createAccumulator(): ACC // MANDATORY

  /**
    * Returns the TypeInformation of the (table)aggregate function's result.
    *
    * @return The TypeInformation of the (table)aggregate function's result or null if the result
    *         type should be automatically inferred.
    */
  def getResultType: TypeInformation[T] = null // PRE-DEFINED

  /**
    * Returns the TypeInformation of the (table)aggregate function's accumulator.
    *
    * @return The TypeInformation of the (table)aggregate function's accumulator or null if the
    *         accumulator type should be automatically inferred.
    */
  def getAccumulatorType: TypeInformation[ACC] = null // PRE-DEFINED
}

/**
  * Base class for aggregation functions. 
  *
  * @tparam T   the type of the aggregation result
  * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  *             aggregated values which are needed to compute an aggregation result.
  *             AggregateFunction represents its state using accumulator, thereby the state of the
  *             AggregateFunction must be put into the accumulator.
  */
abstract class AggregateFunction[T, ACC] extends UserDefinedAggregateFunction[T, ACC] {

  /**
    * Processes the input values and update the provided accumulator instance. The method
    * accumulate can be overloaded with different custom types and arguments. An AggregateFunction
    * requires at least one accumulate() method.
    *
    * @param accumulator           the accumulator which contains the current aggregated results
    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
    */
  def accumulate(accumulator: ACC, [user defined inputs]): Unit // MANDATORY

  /**
    * Retracts the input values from the accumulator instance. The current design assumes the
    * inputs are the values that have been previously accumulated. The method retract can be
    * overloaded with different custom types and arguments. This function must be implemented for
    * datastream bounded over aggregate.
    *
    * @param accumulator           the accumulator which contains the current aggregated results
    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
    */
  def retract(accumulator: ACC, [user defined inputs]): Unit // OPTIONAL

  /**
    * Merges a group of accumulator instances into one accumulator instance. This function must be
    * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
    *
    * @param accumulator  the accumulator which will keep the merged aggregate results. It should
    *                     be noted that the accumulator may contain the previous aggregated
    *                     results. Therefore user should not replace or clean this instance in the
    *                     custom merge method.
    * @param its          an [[java.lang.Iterable]] pointed to a group of accumulators that will be
    *                     merged.
    */
  def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit // OPTIONAL
  
  /**
    * Called every time when an aggregation result should be materialized.
    * The returned value could be either an early and incomplete result
    * (periodically emitted as data arrive) or the final result of the
    * aggregation.
    *
    * @param accumulator the accumulator which contains the current
    *                    aggregated results
    * @return the aggregation result
    */
  def getValue(accumulator: ACC): T // MANDATORY

  /**
    * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for
    * dataset grouping aggregate.
    *
    * @param accumulator  the accumulator which needs to be reset
    */
  def resetAccumulator(accumulator: ACC): Unit // OPTIONAL

  /**
    * Returns true if this AggregateFunction can only be applied in an OVER window.
    *
    * @return true if the AggregateFunction requires an OVER window, false otherwise.
    */
  def requiresOver: Boolean = false // PRE-DEFINED
}

下面的例子说明了如何进行

  • 定义一个 AggregateFunction,用于计算给定列的加权平均值。
  • 在 TableEnvironment 中注册该函数,并且
  • 在查询中使用该函数。

为了计算加权平均值,累加器需要存储所有已积累的数据的加权和和计数。在我们的例子中,我们定义了一个类 WeightedAvgAccum 作为累加器。累积器由 Flink 的检查点机制自动备份,并在故障时恢复,以保证精确的唯一性语义。

我们 WeightedAvg AggregateFunction 的 accumulate() 方法有三个输入。第一个是 WeightedAvgAccum 累加器,另外两个是用户自定义的输入:输入值 ivalue 和输入的权重 iweight。虽然 retract()merge()resetAccumulator() 方法对于大多数聚合类型来说并不是强制性的,但我们在下面提供它们作为例子。请注意,我们在 Scala 示例中使用了 Java 基元类型,并定义了 getResultType()getAccumulatorType() 方法,因为 Flink 类型提取对于 Scala 类型并不十分有效。

import java.lang.{Long => JLong, Integer => JInteger}
import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1}
import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.api.Types
import org.apache.flink.table.functions.AggregateFunction

/**
 * Accumulator for WeightedAvg.
 */
class WeightedAvgAccum extends JTuple1[JLong, JInteger] {
  sum = 0L
  count = 0
}

/**
 * Weighted Average user-defined aggregate function.
 */
class WeightedAvg extends AggregateFunction[JLong, CountAccumulator] {

  override def createAccumulator(): WeightedAvgAccum = {
    new WeightedAvgAccum
  }
  
  override def getValue(acc: WeightedAvgAccum): JLong = {
    if (acc.count == 0) {
        null
    } else {
        acc.sum / acc.count
    }
  }
  
  def accumulate(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
    acc.sum += iValue * iWeight
    acc.count += iWeight
  }

  def retract(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
    acc.sum -= iValue * iWeight
    acc.count -= iWeight
  }
    
  def merge(acc: WeightedAvgAccum, it: java.lang.Iterable[WeightedAvgAccum]): Unit = {
    val iter = it.iterator()
    while (iter.hasNext) {
      val a = iter.next()
      acc.count += a.count
      acc.sum += a.sum
    }
  }

  def resetAccumulator(acc: WeightedAvgAccum): Unit = {
    acc.count = 0
    acc.sum = 0L
  }

  override def getAccumulatorType: TypeInformation[WeightedAvgAccum] = {
    new TupleTypeInfo(classOf[WeightedAvgAccum], Types.LONG, Types.INT)
  }

  override def getResultType: TypeInformation[JLong] = Types.LONG
}

// register function
val tEnv: StreamTableEnvironment = ???
tEnv.registerFunction("wAvg", new WeightedAvg())

// use function
tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user")

表聚合函数 #

用户定义表聚合函数(UDTAGGs)将一个表(具有一个或多个属性的一行或多行)聚合到一个具有多行和多列的结果表。

img

上图显示了一个表聚合的例子。假设你有一个包含饮料数据的表。该表由 id、名称和价格三列和 5 行组成。设想你需要找到表中所有饮料中价格最高的前 2 名,即执行 top2() 表聚合。你需要对 5 行中的每一行进行检查,结果将是一个具有前 2 个值的表。

用户定义的表聚合函数是通过扩展 TableAggregateFunction 类来实现的。TableAggregateFunction 的工作原理如下。首先,它需要一个累加器,它是存放聚合中间结果的数据结构。通过调用 TableAggregateFunction 的 createAccumulator() 方法创建一个空的累加器。随后,对每一条输入行调用函数的 accumulate() 方法来更新累加器。一旦所有的行都被处理完毕,函数的 emitValue() 方法就会被调用来计算并返回最终结果。

以下方法是每个 TableAggregateFunction 必须使用的。

  • createAccumulator()
  • accumulate()

Flink 的类型提取设施可能无法识别复杂的数据类型,例如,如果它们不是基本类型或简单的 POJOs。因此,与 ScalarFunction 和 TableFunction 类似,TableAggregateFunction 提供了指定结果类型(通过 TableAggregateFunction#getResultType())和累积器类型(通过 TableAggregateFunction#getAccumulatorType())的方法。

除了上述方法外,还有一些签约方法可以选择实现。这些方法中的一些方法可以让系统更高效地执行查询,而另一些方法则是某些用例所必须的。例如,如果聚合函数应该在会话组窗口的上下文中应用,那么 merge() 方法是强制性的(当观察到有一条记录"连接"它们时,需要将两个会话窗口的累加器连接起来)。

TableAggregateFunction 的以下方法是需要的,这取决于用例。

  • retract() 对于有界 OVER 窗口上的聚合是需要的。
  • merge() 是许多批次聚合和会话窗口聚合所需要的。
  • resetAccumulator() 是许多批处理聚合所需要的。
  • emitValue() 是批处理和窗口聚合所需要的。

TableAggregateFunction 的以下方法用于提高流作业的性能。

  • emitUpdateWithRetract() 用于发射在伸缩模式下更新的值。

对于 emitValue 方法,则是根据累加器来发射完整的数据。以 TopN 为例,emitValue 每次都会发射所有前 n 个值。这可能会给流式作业带来性能问题。为了提高性能,用户也可以实现 emitUpdateWithRetract 方法来提高性能。该方法以回缩模式增量输出数据,即一旦有更新,我们必须在发送新的更新记录之前回缩旧记录。如果在表聚合函数中都定义了该方法,那么该方法将优先于 emitValue 方法使用,因为 emitUpdateWithRetract 被认为比 emitValue 更有效率,因为它可以增量输出值。

TableAggregateFunction 的所有方法都必须声明为 public,而不是 static,并完全按照上面提到的名字命名。方法 createAccumulator、getResultType 和 getAccumulatorType 是在 TableAggregateFunction 的父抽象类中定义的,而其他方法则是收缩的方法。为了定义一个表聚合函数,必须扩展基类 org.apache.flink.table.function.TableAggregateFunction,并实现一个(或多个)accumulate 方法。积累方法可以用不同的参数类型重载,并支持变量参数。

下面给出了 TableAggregateFunction 所有方法的详细文档。

/**
  * Base class for user-defined aggregates and table aggregates.
  *
  * @tparam T   the type of the aggregation result.
  * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  *             aggregated values which are needed to compute an aggregation result.
  */
abstract class UserDefinedAggregateFunction[T, ACC] extends UserDefinedFunction {

  /**
    * Creates and init the Accumulator for this (table)aggregate function.
    *
    * @return the accumulator with the initial value
    */
  def createAccumulator(): ACC // MANDATORY

  /**
    * Returns the TypeInformation of the (table)aggregate function's result.
    *
    * @return The TypeInformation of the (table)aggregate function's result or null if the result
    *         type should be automatically inferred.
    */
  def getResultType: TypeInformation[T] = null // PRE-DEFINED

  /**
    * Returns the TypeInformation of the (table)aggregate function's accumulator.
    *
    * @return The TypeInformation of the (table)aggregate function's accumulator or null if the
    *         accumulator type should be automatically inferred.
    */
  def getAccumulatorType: TypeInformation[ACC] = null // PRE-DEFINED
}

/**
  * Base class for table aggregation functions. 
  *
  * @tparam T   the type of the aggregation result
  * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  *             aggregated values which are needed to compute an aggregation result.
  *             TableAggregateFunction represents its state using accumulator, thereby the state of
  *             the TableAggregateFunction must be put into the accumulator.
  */
abstract class TableAggregateFunction[T, ACC] extends UserDefinedAggregateFunction[T, ACC] {

  /**
    * Processes the input values and update the provided accumulator instance. The method
    * accumulate can be overloaded with different custom types and arguments. A TableAggregateFunction
    * requires at least one accumulate() method.
    *
    * @param accumulator           the accumulator which contains the current aggregated results
    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
    */
  def accumulate(accumulator: ACC, [user defined inputs]): Unit // MANDATORY

  /**
    * Retracts the input values from the accumulator instance. The current design assumes the
    * inputs are the values that have been previously accumulated. The method retract can be
    * overloaded with different custom types and arguments. This function must be implemented for
    * datastream bounded over aggregate.
    *
    * @param accumulator           the accumulator which contains the current aggregated results
    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
    */
  def retract(accumulator: ACC, [user defined inputs]): Unit // OPTIONAL

  /**
    * Merges a group of accumulator instances into one accumulator instance. This function must be
    * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
    *
    * @param accumulator  the accumulator which will keep the merged aggregate results. It should
    *                     be noted that the accumulator may contain the previous aggregated
    *                     results. Therefore user should not replace or clean this instance in the
    *                     custom merge method.
    * @param its          an [[java.lang.Iterable]] pointed to a group of accumulators that will be
    *                     merged.
    */
  def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit // OPTIONAL
  
  /**
    * Called every time when an aggregation result should be materialized. The returned value
    * could be either an early and incomplete result  (periodically emitted as data arrive) or
    * the final result of the  aggregation.
    *
    * @param accumulator the accumulator which contains the current
    *                    aggregated results
    * @param out         the collector used to output data
    */
  def emitValue(accumulator: ACC, out: Collector[T]): Unit // OPTIONAL

  /**
    * Called every time when an aggregation result should be materialized. The returned value
    * could be either an early and incomplete result (periodically emitted as data arrive) or
    * the final result of the aggregation.
    *
    * Different from emitValue, emitUpdateWithRetract is used to emit values that have been updated.
    * This method outputs data incrementally in retract mode, i.e., once there is an update, we
    * have to retract old records before sending new updated ones. The emitUpdateWithRetract
    * method will be used in preference to the emitValue method if both methods are defined in the
    * table aggregate function, because the method is treated to be more efficient than emitValue
    * as it can outputvalues incrementally.
    *
    * @param accumulator the accumulator which contains the current
    *                    aggregated results
    * @param out         the retractable collector used to output data. Use collect method
    *                    to output(add) records and use retract method to retract(delete)
    *                    records.
    */
  def emitUpdateWithRetract(accumulator: ACC, out: RetractableCollector[T]): Unit // OPTIONAL
 
  /**
    * Collects a record and forwards it. The collector can output retract messages with the retract
    * method. Note: only use it in `emitRetractValueIncrementally`.
    */
  trait RetractableCollector[T] extends Collector[T] {
    
    /**
      * Retract a record.
      *
      * @param record The record to retract.
      */
    def retract(record: T): Unit
  }
}

下面的例子说明了如何进行

  • 定义一个 TableAggregateFunction,用于计算给定列上的前 2 个值。
  • 在 TableEnvironment 中注册该函数,并且
  • 在 Table API 查询中使用该函数(TableAggregateFunction 仅由 Table API 支持)。

为了计算前 2 名的值,累加器需要存储所有已积累的数据中最大的 2 个值。在我们的例子中,我们定义了一个类 Top2Accum 作为累加器。累积器会被 Flink 的检查点机制自动备份,并在故障时恢复,以保证精确的 once 语义。

我们 Top2 TableAggregateFunction 的 accumulate() 方法有两个输入。第一个是 Top2Accum 累加器,另一个是用户定义的输入:输入值 v,虽然 merge() 方法对于大多数表聚合类型来说不是强制性的,但我们在下面提供它作为例子。请注意,我们在 Scala 示例中使用了 Java 基元类型,并定义了 getResultType()getAccumulatorType() 方法,因为 Flink 类型提取对 Scala 类型的效果并不好。

import java.lang.{Integer => JInteger}
import org.apache.flink.table.api.Types
import org.apache.flink.table.functions.TableAggregateFunction

/**
 * Accumulator for top2.
 */
class Top2Accum {
  var first: JInteger = _
  var second: JInteger = _
}

/**
 * The top2 user-defined table aggregate function.
 */
class Top2 extends TableAggregateFunction[JTuple2[JInteger, JInteger], Top2Accum] {

  override def createAccumulator(): Top2Accum = {
    val acc = new Top2Accum
    acc.first = Int.MinValue
    acc.second = Int.MinValue
    acc
  }

  def accumulate(acc: Top2Accum, v: Int) {
    if (v > acc.first) {
      acc.second = acc.first
      acc.first = v
    } else if (v > acc.second) {
      acc.second = v
    }
  }

  def merge(acc: Top2Accum, its: JIterable[Top2Accum]): Unit = {
    val iter = its.iterator()
    while (iter.hasNext) {
      val top2 = iter.next()
      accumulate(acc, top2.first)
      accumulate(acc, top2.second)
    }
  }

  def emitValue(acc: Top2Accum, out: Collector[JTuple2[JInteger, JInteger]]): Unit = {
    // emit the value and rank
    if (acc.first != Int.MinValue) {
      out.collect(JTuple2.of(acc.first, 1))
    }
    if (acc.second != Int.MinValue) {
      out.collect(JTuple2.of(acc.second, 2))
    }
  }
}

// init table
val tab = ...

// use function
tab
  .groupBy('key)
  .flatAggregate(top2('a) as ('v, 'rank))
  .select('key, 'v, 'rank)

下面的例子展示了如何使用 emitUpdateWithRetract 方法来只发送更新。在我们的例子中,为了只发出更新,累加器同时保留新旧 top2 的值。注意:如果 topN 的 N 很大,那么同时保留新旧值的效率可能很低。解决这种情况的方法之一是在累加方法中把输入的记录存储到累加器中,然后在 emitUpdateWithRetract 中进行计算。

import java.lang.{Integer => JInteger}
import org.apache.flink.table.api.Types
import org.apache.flink.table.functions.TableAggregateFunction

/**
 * Accumulator for top2.
 */
class Top2Accum {
  var first: JInteger = _
  var second: JInteger = _
  var oldFirst: JInteger = _
  var oldSecond: JInteger = _
}

/**
 * The top2 user-defined table aggregate function.
 */
class Top2 extends TableAggregateFunction[JTuple2[JInteger, JInteger], Top2Accum] {

  override def createAccumulator(): Top2Accum = {
    val acc = new Top2Accum
    acc.first = Int.MinValue
    acc.second = Int.MinValue
    acc.oldFirst = Int.MinValue
    acc.oldSecond = Int.MinValue
    acc
  }

  def accumulate(acc: Top2Accum, v: Int) {
    if (v > acc.first) {
      acc.second = acc.first
      acc.first = v
    } else if (v > acc.second) {
      acc.second = v
    }
  }

  def emitUpdateWithRetract(
    acc: Top2Accum,
    out: RetractableCollector[JTuple2[JInteger, JInteger]])
  : Unit = {
    if (acc.first != acc.oldFirst) {
      // if there is an update, retract old value then emit new value.
      if (acc.oldFirst != Int.MinValue) {
        out.retract(JTuple2.of(acc.oldFirst, 1))
      }
      out.collect(JTuple2.of(acc.first, 1))
      acc.oldFirst = acc.first
    }
    if (acc.second != acc.oldSecond) {
      // if there is an update, retract old value then emit new value.
      if (acc.oldSecond != Int.MinValue) {
        out.retract(JTuple2.of(acc.oldSecond, 2))
      }
      out.collect(JTuple2.of(acc.second, 2))
      acc.oldSecond = acc.second
    }
  }
}

// init table
val tab = ...

// use function
tab
  .groupBy('key)
  .flatAggregate(top2('a) as ('v, 'rank))
  .select('key, 'v, 'rank)

原文链接: https://ci.apache.org/projects/flink/flink-docs-release-1.11/dev/table/functions/udfs.html