Polars滚动聚合性能疑惑:元素级变换内联与预计算的运行时为何近乎一致?
Polars滚动聚合性能疑惑:元素级变换内联与预计算的运行时为何近乎一致?
嘿,这个问题我太有共鸣了!当初我也直觉认为预计算sin/cos再做滚动聚合肯定更快,结果实测下来和内联写法耗时几乎一致,甚至数据量大了还反超,后来翻了Polars的优化逻辑才搞明白,给你拆解一下:
核心原因:Polars查询优化器在「偷偷做优化」
Polars的查询优化器会自动分析你的表达式逻辑,把元素级变换操作(比如sin、cos)从滚动聚合的窗口逻辑中提前提取出来。也就是说,你写的第一种写法:
X = frame.rolling(index_column="date", group_by="group", period="360d").agg( pl.col("value").sin().sum().alias("sin(value)"), pl.col("value").cos().sum().alias("cos(value)"), pl.col("value").sum() )
优化器会自动将其重写成和第二种写法几乎完全一致的执行逻辑:先对整个value列计算sin和cos,再在滚动窗口内做sum聚合。两种写法本质上跑的是同一段优化后的执行计划,自然耗时相差无几!
为什么大数据量下第二种写法反而更慢?
当数据量(尤其是日期数)变大时,第二种写法会通过with_columns先创建两个额外的列sin(value)和cos(value),这会占用更多内存空间。当内存压力上升时,数据在内存与缓存/磁盘之间的交换开销(或缓存命中率下降)会抵消掉「避免重复计算」带来的收益,甚至反过来拖慢整体运行速度——这就是你观察到的:随着日期数增加,第二种版本耗时反超的原因。
验证优化器作用的小技巧
你可以用.explain()方法把两种写法的执行计划打印出来对比,会发现它们优化后的计划几乎完全一致,Polars已经帮你把重复计算的问题解决了,不用手动预计算~
附上实验代码方便复现
import datetime import itertools import time import numpy as np import polars as pl import polars.testing def run_experiment(): start = datetime.date.fromisoformat("1991-01-01") result = {"num_dates": [], "num_groups": [], "version1": [], "version2": [], } for n_dates in [1000, 2000, 5000, 10000]: end = start + datetime.timedelta(days=(n_dates - 1)) dates = pl.date_range(start, end, eager=True) for m_groups in [10, 20, 50, 100, 200, 500, 1000]: groups = [f"g_{i + 1}" for i in range(m_groups)] groups_, dates_ = list(zip(*itertools.product(groups, dates))) frame = pl.from_dict({"group": groups_, "date": dates_, "value": np.random.rand(n_dates * m_groups)}) t0 = time.time() X = frame.rolling(index_column="date", group_by="group", period="360d").agg( pl.col("value").sin().sum().alias("sin(value)"), pl.col("value").cos().sum().alias("cos(value)"), pl.col("value").sum() ) t1 = time.time() - t0 t0 = time.time() Y = frame.with_columns( pl.col("value").sin().alias("sin(value)"), pl.col("value").cos().alias("cos(value)") ).rolling(index_column="date", group_by="group", period="360d").agg( pl.col("sin(value)").sum(), pl.col("cos(value)").sum(), pl.col("value").sum() ) t2 = time.time() - t0 polars.testing.assert_frame_equal(X, Y) result["num_dates"].append(n_dates) result["num_groups"].append(m_groups) result["version1"].append(t1) result["version2"].append(t2) return pl.from_dict(result)
备注:内容来源于stack exchange,提问作者Benjamin Trendelkamp-Schroer




