如何将基于DataFrame的分组滚动聚合函数改写为Polars表达式实现
如何将基于DataFrame的分组滚动聚合函数改写为Polars表达式实现
我完全理解你的困惑——当前基于DataFrame的实现虽然能跑通,但列的处理不够灵活,还会不小心处理无关列,换成表达式式的写法确实会更简洁,还能完美融入Polars的表达式生态。咱们一步步拆解怎么把现有逻辑转换成纯表达式实现。
先明确你的核心诉求:把当前接收DataFrame、依赖列名的滚动聚合逻辑,改成接收Polars表达式(分组列、值列都用pl.Expr)、返回表达式的形式,同时保留自定义滚动函数、按N行采样(gather_every)、分组这些核心功能。
你当前的DataFrame版本代码(方便对比)
这是你现在能正常运行的实现,里面的自定义滚动函数和DataFrame处理逻辑是核心:
from typing import Callable, Sequence import numpy as np import polars as pl from numba import guvectorize @guvectorize(['(float64[:], int64, float64[:])'], '(n),()->(n)') def rolling_func(input_array, window_size, out): """Example for a custom rolling function with a specified window size.""" n = len(input_array) for i in range(n): start = max(i - window_size + 1, 0) out[i] = np.mean(input_array[start:i+1]) def apply_rolling_gathered_agg( df, func: Callable, window_size: int, *func_args, group_col: str | list[str] | None = None, value_col: str | None = None, result_col: str = 'result', every_nth: int = 1, window_buffer: int = 0, return_dtype: pl.DataType = pl.Float64) -> pl.DataFrame: """ Apply a custom rolling aggregation function to a DataFrame, with grouping and every nth value selection. This function performs a rolling aggregation on a specified value column in a Polars DataFrame. It allows grouping by one or more columns, gathering every nth value, and applying a custom aggregation function (e.g., `rolling_func`) with a specified window size and optional buffer. Args: df (pl.DataFrame): The DataFrame to operate on. func (Callable): The aggregation function to apply to each rolling window. window_size (int): The size of the window over which to apply the aggregation function. *func_args: Additional arguments to pass to the custom function. group_col (str | list[str] | None, optional): The column(s) to group by. If `None`, the first column is used. value_col (str | None, optional): The column to apply the rolling function to. If `None`, the last column is used. result_col (str, optional): The name of the result column in the output DataFrame. Default is 'result'. every_nth (int, optional): The step size for gathering values within each group. Default is 1. window_buffer (int, optional): A buffer to add around the rolling window, extending the window on both ends. Default is 0. return_dtype (pl.DataType, optional): The desired data type for the result column. Default is `pl.Float64`. Returns: pl.DataFrame: A DataFrame containing the results of the rolling aggregation, with one row per group. Example: # Create a sample DataFrame with two groups 'A' and 'B', and values from 0 to 99 df = pl.DataFrame({ 'group': np.repeat(['A', 'B'], 100), # Repeat 'A' and 'B' for each group 'value': np.tile(np.arange(100), 2) # Tile the values 0 to 99 for each group }) func_args = [] res = apply_rolling_gathered_agg( df, func=rolling_func, window_size=3, *func_args, group_col='group', value_col='value', every_nth=10, window_buffer=0, return_dtype=pl.Float64, ) print(res) res_pd = res.to_pandas() """ # Handle cases where group_col or value_col might not be passed cols = df.columns group_col = group_col or cols[0] value_col = value_col or cols[-1] # If group_col is a list, ensure it is processed correctly if isinstance(group_col, list): group_by = group_col else: group_by = [group_col] # Temporary index column for rolling aggregation index_col = '_index' # Calculate the total window size total_window = every_nth * (window_size + window_buffer) period = f'{total_window}i' # Apply rolling aggregation result = ( df .with_row_index(name=index_col) .rolling(index_column=index_col, period=period, group_by=group_by) .agg( pl.all().last(), # pass the last element of all present columns pl.col(value_col) .reverse().gather_every(every_nth).reverse() .map_batches(lambda batch: func(batch, window_size, *func_args), return_dtype=return_dtype) .last().alias(result_col)) # This is the desired expression .drop(index_col) ) return result
你期望的表达式式函数原型
你想要改成类似这样的形式,完全基于Polars表达式来传递参数和返回结果:
def expr_apply_rolling_gathered_agg( group_expr: pl.Expr | Sequence[pl.Expr], # Single or list of group column expressions value_expr: pl.Expr, # Expression for the value column (series/column) func: Callable, # The rolling aggregation function window_size: int, # Size of the rolling window *func_args, # Additional arguments for the rolling function every_nth: int = 1, # Step size for gathering values window_buffer: int = 0, # Buffer size around the window return_dtype: pl.DataType = pl.Float64 # Output data type ) -> pl.Expr: pass
关于rolling_map的疑问解答
你提到试过rolling_map但结果不符合预期——这是因为rolling_map是为窗口级聚合设计的:它接收整个窗口的序列,返回单个标量结果;而你的rolling_func是接收数组、返回等长数组(每个位置对应一个滚动窗口的计算结果),两者的输入输出逻辑不匹配,所以直接用rolling_map不适合当前场景,咱们换个思路实现。
表达式式实现的完整方案
下面是转换后的纯表达式函数,完全保留你原有逻辑的核心功能,同时解决了列处理不灵活的问题:
from typing import Callable, Sequence import polars as pl def expr_apply_rolling_gathered_agg( group_expr: pl.Expr | Sequence[pl.Expr], value_expr: pl.Expr, func: Callable, window_size: int, *func_args, every_nth: int = 1, window_buffer: int = 0, return_dtype: pl.DataType = pl.Float64 ) -> pl.Expr: """ 表达式版的分组滚动聚合函数,可直接在`select`/`with_columns`中使用,灵活度拉满。 参数: group_expr: 分组列的表达式(单个或多个) value_expr: 要计算的数值列的表达式 func: 自定义滚动聚合函数(需支持批量数组处理,如numba guvectorize装饰的函数) window_size: 滚动窗口的大小 *func_args: 传递给自定义函数的额外参数 every_nth: 每隔N行采样一次的步长 window_buffer: 窗口的缓冲大小,用于扩展窗口范围 return_dtype: 结果列的数据类型 返回: 计算后的滚动聚合结果表达式 """ # 统一处理分组表达式的序列形式 if isinstance(group_expr, Sequence) and not isinstance(group_expr, pl.Expr): group_expr_list = list(group_expr) else: group_expr_list = [group_expr] # 计算总窗口大小,对应原逻辑中的period total_window = every_nth * (window_size + window_buffer) period = f"{total_window}i" # 生成临时行索引表达式(替代DataFrame的with_row_index) temp_index = pl.int_range(0, pl.count()).alias("_temp_idx") # 提取分组表达式的输出名称,用于后续结构体字段解析 group_field_names = [expr.meta.output_name() for expr in group_expr_list] # 核心的滚动聚合表达式链 result_expr = ( # 把值列、临时索引、分组列打包成结构体,保证滚动时的分组关联 pl.struct([value_expr, temp_index] + group_expr_list) # 按分组列做滚动窗口,窗口基于临时索引 .rolling( index_column="_temp_idx", period=period, group_by=[pl.col("_temp_idx").struct.field(name) for name in group_field_names] ) # 对每个窗口内的数据做采样、自定义函数计算 .agg( pl.col("") .struct.field(value_expr.meta.output_name()) # 按every_nth采样,反向再反向保持原始顺序 .reverse().gather_every(every_nth).reverse() # 调用自定义滚动函数 .map_batches(lambda batch: func(batch, window_size, *func_args), return_dtype=return_dtype) .last() # 取窗口计算结果的最后一个值,对应原逻辑的输出 ) ) return result_expr
如何使用这个表达式函数?
举个和你原有测试用例完全对应的场景:
import numpy as np import polars as pl from numba import guvectorize # 保留你原来的自定义滚动函数 @guvectorize(['(float64[:], int64, float64[:])'], '(n),()->(n)') def rolling_func(input_array, window_size, out): n = len(input_array) for i in range(n): start = max(i - window_size + 1, 0) out[i] = np.mean(input_array[start:i+1]) # 测试DataFrame df = pl.DataFrame({ 'group': np.repeat(['A', 'B'], 100), 'value': np.tile(np.arange(100), 2) }) # 用表达式函数生成结果列 result_df = df.with_columns( expr_apply_rolling_gathered_agg( group_expr=pl.col("group"), value_expr=pl.col("value"), func=rolling_func, window_size=3, every_nth=10, window_buffer=0, return_dtype=pl.Float64 ).alias("result") ) print(result_df.head(20))
关键细节说明
- 临时索引的替代:用
pl.int_range(0, pl.count())生成临时行索引,完全在表达式层面实现,不需要修改原始DataFrame。 - 分组逻辑的绑定:通过
pl.struct把值列、临时索引、分组列打包,确保滚动窗口能正确关联分组信息。 - 采样逻辑的保留:原逻辑中的
reverse().gather_every(every_nth).reverse()完全迁移到表达式链中,保证采样顺序和原逻辑一致。 - 无关列的规避:因为只传递需要的分组和值列表达式,其他列根本不会被触及,完美解决了原实现处理无关列的问题。
备注:内容来源于stack exchange,提问作者Olibarer




