如何用Scala在Apache Spark Dataset中按ID计算平均向量?
计算Spark分组后向量的逐元素平均值
针对你给出的数据集,要按id分组计算对应向量的逐元素平均值,这里有两种实用的方法,你可以根据自己的场景选择:
方法一:固定维度下的列展开法(性能更优)
如果你的向量维度是固定且已知的(比如示例里的4维),这种方法用Spark内置函数处理,性能会更好:
Scala 代码示例
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ val spark = SparkSession.builder().appName("VectorAverage").master("local[*]").getOrCreate() // 构建示例数据集 val schema = StructType(Array( StructField("id", IntegerType, nullable = false), StructField("vec", ArrayType(DoubleType), nullable = false) )) val data = Seq( (0, Array(1.0, 2.0, 3.0, 4.0)), (0, Array(2.0, 3.0, 4.0, 5.0)), (0, Array(6.0, 7.0, 8.0, 9.0)), (1, Array(1.0, 2.0, 3.0, 4.0)), (1, Array(5.0, 6.0, 7.0, 8.0)) ) val df = spark.createDataFrame(data).toDF("id", "vec") // 1. 把向量的每个元素拆成单独列 val dfWithElements = df.select( col("id"), col("vec")(0).alias("v0"), col("vec")(1).alias("v1"), col("vec")(2).alias("v2"), col("vec")(3).alias("v3") ) // 2. 按id分组,计算每个元素的平均值 val avgDf = dfWithElements.groupBy("id") .agg( avg("v0").alias("avg_v0"), avg("v1").alias("avg_v1"), avg("v2").alias("avg_v2"), avg("v3").alias("avg_v3") ) // 3. 把平均后的元素合并回向量列 val resultDf = avgDf.select( col("id"), array(col("avg_v0"), col("avg_v1"), col("avg_v2"), col("avg_v3")).alias("vec") ) resultDf.show()
Python 代码示例
from pyspark.sql import SparkSession from pyspark.sql.functions import col, avg spark = SparkSession.builder.appName("VectorAverage").master("local[*]").getOrCreate() # 构建示例数据集 data = [ (0, [1.0, 2.0, 3.0, 4.0]), (0, [2.0, 3.0, 4.0, 5.0]), (0, [6.0, 7.0, 8.0, 9.0]), (1, [1.0, 2.0, 3.0, 4.0]), (1, [5.0, 6.0, 7.0, 8.0]) ] df = spark.createDataFrame(data, ["id", "vec"]) # 1. 拆分向量元素为单独列 df_with_elements = df.select( col("id"), col("vec")[0].alias("v0"), col("vec")[1].alias("v1"), col("vec")[2].alias("v2"), col("vec")[3].alias("v3") ) # 2. 分组求平均 avg_df = df_with_elements.groupBy("id").agg( avg("v0").alias("avg_v0"), avg("v1").alias("avg_v1"), avg("v2").alias("avg_v2"), avg("v3").alias("avg_v3") ) # 3. 合并回向量 result_df = avg_df.select( col("id"), [col("avg_v0"), col("avg_v1"), col("avg_v2"), col("avg_v3")].alias("vec") ) result_df.show()
方法二:动态维度的UDF法(更灵活)
如果你的向量维度不固定,或者不想硬编码维度,可以用自定义UDF来处理,这种方法适配任意维度的向量:
Scala 代码示例
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ val spark = SparkSession.builder().appName("VectorAverage").master("local[*]").getOrCreate() // 构建示例数据集(同方法一) val data = Seq( (0, Array(1.0, 2.0, 3.0, 4.0)), (0, Array(2.0, 3.0, 4.0, 5.0)), (0, Array(6.0, 7.0, 8.0, 9.0)), (1, Array(1.0, 2.0, 3.0, 4.0)), (1, Array(5.0, 6.0, 7.0, 8.0)) ) val df = spark.createDataFrame(data).toDF("id", "vec") // 定义UDF:接收一组向量,返回逐元素平均后的向量 val vectorAverageUdf = udf((vectors: Seq[Seq[Double]]) => { if (vectors.isEmpty) null else { val dim = vectors.head.length (0 until dim).map(i => vectors.map(_(i)).sum / vectors.size).toArray } }) // 分组收集所有向量,再用UDF计算平均 val resultDf = df.groupBy("id") .agg(collect_list("vec").alias("all_vecs")) .select( col("id"), vectorAverageUdf(col("all_vecs")).alias("vec") ) resultDf.show()
Python 代码示例
from pyspark.sql import SparkSession from pyspark.sql.functions import col, collect_list, udf from pyspark.sql.types import ArrayType, DoubleType spark = SparkSession.builder.appName("VectorAverage").master("local[*]").getOrCreate() # 构建示例数据集(同方法一) data = [ (0, [1.0, 2.0, 3.0, 4.0]), (0, [2.0, 3.0, 4.0, 5.0]), (0, [6.0, 7.0, 8.0, 9.0]), (1, [1.0, 2.0, 3.0, 4.0]), (1, [5.0, 6.0, 7.0, 8.0]) ] df = spark.createDataFrame(data, ["id", "vec"]) # 定义UDF:计算一组向量的逐元素平均 def vector_average(vectors): if not vectors: return None dim = len(vectors[0]) return [sum(vec[i] for vec in vectors) / len(vectors) for i in range(dim)] vector_average_udf = udf(vector_average, ArrayType(DoubleType())) # 分组收集向量并计算平均 result_df = df.groupBy("id")\ .agg(collect_list("vec").alias("all_vecs"))\ .select( col("id"), vector_average_udf(col("all_vecs")).alias("vec") ) result_df.show()
两种方法对比
- 方法一:依赖Spark内置聚合函数,优化程度高,适合向量维度固定的场景,大数据量下性能更好。
- 方法二:无需关注向量维度,适配性强,但UDF属于自定义代码,Spark无法对其进行全量优化,超大数据量下可能不如方法一高效。
内容的提问来源于stack exchange,提问作者Bjartur Sigurbergsson




