You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何将基于DataFrame的分组滚动聚合函数改写为Polars表达式实现

如何将基于DataFrame的分组滚动聚合函数改写为Polars表达式实现

我完全理解你的困惑——当前基于DataFrame的实现虽然能跑通,但列的处理不够灵活,还会不小心处理无关列,换成表达式式的写法确实会更简洁,还能完美融入Polars的表达式生态。咱们一步步拆解怎么把现有逻辑转换成纯表达式实现。

先明确你的核心诉求:把当前接收DataFrame、依赖列名的滚动聚合逻辑,改成接收Polars表达式(分组列、值列都用pl.Expr)、返回表达式的形式,同时保留自定义滚动函数、按N行采样(gather_every)、分组这些核心功能。

你当前的DataFrame版本代码(方便对比)

这是你现在能正常运行的实现,里面的自定义滚动函数和DataFrame处理逻辑是核心:

from typing import Callable, Sequence
import numpy as np
import polars as pl
from numba import guvectorize

@guvectorize(['(float64[:], int64, float64[:])'], '(n),()->(n)')
def rolling_func(input_array, window_size, out):
    """Example for a custom rolling function with a specified window size."""
    n = len(input_array)
    for i in range(n):
        start = max(i - window_size + 1, 0)
        out[i] = np.mean(input_array[start:i+1])


def apply_rolling_gathered_agg(
        df,
        func: Callable,
        window_size: int,
        *func_args,
        group_col: str | list[str] | None = None,
        value_col: str | None = None,
        result_col: str = 'result',
        every_nth: int = 1,
        window_buffer: int = 0,
        return_dtype: pl.DataType = pl.Float64) -> pl.DataFrame:
    """
    Apply a custom rolling aggregation function to a DataFrame, with grouping and every nth value selection.

    This function performs a rolling aggregation on a specified value column in a Polars DataFrame. It allows
    grouping by one or more columns, gathering every nth value, and applying a custom aggregation function
    (e.g., `rolling_func`) with a specified window size and optional buffer.

    Args:
        df (pl.DataFrame): The DataFrame to operate on.
        func (Callable): The aggregation function to apply to each rolling window.
        window_size (int): The size of the window over which to apply the aggregation function.
        *func_args: Additional arguments to pass to the custom function.
        group_col (str | list[str] | None, optional): The column(s) to group by. If `None`, the first column is used.
        value_col (str | None, optional): The column to apply the rolling function to. If `None`, the last column is used.
        result_col (str, optional): The name of the result column in the output DataFrame. Default is 'result'.
        every_nth (int, optional): The step size for gathering values within each group. Default is 1.
        window_buffer (int, optional): A buffer to add around the rolling window, extending the window on both ends. Default is 0.
        return_dtype (pl.DataType, optional): The desired data type for the result column. Default is `pl.Float64`.

    Returns:
        pl.DataFrame: A DataFrame containing the results of the rolling aggregation, with one row per group.


    Example:
        # Create a sample DataFrame with two groups 'A' and 'B', and values from 0 to 99
        df = pl.DataFrame({
            'group': np.repeat(['A', 'B'], 100),  # Repeat 'A' and 'B' for each group
            'value': np.tile(np.arange(100), 2)   # Tile the values 0 to 99 for each group
        })
        func_args = []
        res = apply_rolling_gathered_agg(
            df,
            func=rolling_func,
            window_size=3,
            *func_args,
            group_col='group',
            value_col='value',
            every_nth=10,
            window_buffer=0,
            return_dtype=pl.Float64,
        )
        print(res)
        res_pd = res.to_pandas()
    """
    # Handle cases where group_col or value_col might not be passed
    cols = df.columns
    group_col = group_col or cols[0]
    value_col = value_col or cols[-1]

    # If group_col is a list, ensure it is processed correctly
    if isinstance(group_col, list):
        group_by = group_col
    else:
        group_by = [group_col]

    # Temporary index column for rolling aggregation
    index_col = '_index'

    # Calculate the total window size
    total_window = every_nth * (window_size + window_buffer)
    period = f'{total_window}i'

    # Apply rolling aggregation
    result = (
        df
        .with_row_index(name=index_col)
        .rolling(index_column=index_col, period=period, group_by=group_by)
        .agg(
            pl.all().last(),  # pass the last element of all present columns
            pl.col(value_col)
            .reverse().gather_every(every_nth).reverse()
            .map_batches(lambda batch: func(batch, window_size, *func_args), return_dtype=return_dtype)
            .last().alias(result_col))  # This is the desired expression
        .drop(index_col)
    )
    return result

你期望的表达式式函数原型

你想要改成类似这样的形式,完全基于Polars表达式来传递参数和返回结果:

def expr_apply_rolling_gathered_agg(
        group_expr: pl.Expr | Sequence[pl.Expr],  # Single or list of group column expressions
        value_expr: pl.Expr,  # Expression for the value column (series/column)
        func: Callable,       # The rolling aggregation function
        window_size: int,     # Size of the rolling window
        *func_args,           # Additional arguments for the rolling function
        every_nth: int = 1,   # Step size for gathering values
        window_buffer: int = 0,  # Buffer size around the window
        return_dtype: pl.DataType = pl.Float64  # Output data type
) -> pl.Expr:
    pass

关于rolling_map的疑问解答

你提到试过rolling_map但结果不符合预期——这是因为rolling_map是为窗口级聚合设计的:它接收整个窗口的序列,返回单个标量结果;而你的rolling_func是接收数组、返回等长数组(每个位置对应一个滚动窗口的计算结果),两者的输入输出逻辑不匹配,所以直接用rolling_map不适合当前场景,咱们换个思路实现。


表达式式实现的完整方案

下面是转换后的纯表达式函数,完全保留你原有逻辑的核心功能,同时解决了列处理不灵活的问题:

from typing import Callable, Sequence
import polars as pl

def expr_apply_rolling_gathered_agg(
        group_expr: pl.Expr | Sequence[pl.Expr],
        value_expr: pl.Expr,
        func: Callable,
        window_size: int,
        *func_args,
        every_nth: int = 1,
        window_buffer: int = 0,
        return_dtype: pl.DataType = pl.Float64
) -> pl.Expr:
    """
    表达式版的分组滚动聚合函数,可直接在`select`/`with_columns`中使用,灵活度拉满。
    
    参数:
        group_expr: 分组列的表达式(单个或多个)
        value_expr: 要计算的数值列的表达式
        func: 自定义滚动聚合函数(需支持批量数组处理,如numba guvectorize装饰的函数)
        window_size: 滚动窗口的大小
        *func_args: 传递给自定义函数的额外参数
        every_nth: 每隔N行采样一次的步长
        window_buffer: 窗口的缓冲大小,用于扩展窗口范围
        return_dtype: 结果列的数据类型
    
    返回:
        计算后的滚动聚合结果表达式
    """
    # 统一处理分组表达式的序列形式
    if isinstance(group_expr, Sequence) and not isinstance(group_expr, pl.Expr):
        group_expr_list = list(group_expr)
    else:
        group_expr_list = [group_expr]
    
    # 计算总窗口大小,对应原逻辑中的period
    total_window = every_nth * (window_size + window_buffer)
    period = f"{total_window}i"
    
    # 生成临时行索引表达式(替代DataFrame的with_row_index)
    temp_index = pl.int_range(0, pl.count()).alias("_temp_idx")
    
    # 提取分组表达式的输出名称,用于后续结构体字段解析
    group_field_names = [expr.meta.output_name() for expr in group_expr_list]
    
    # 核心的滚动聚合表达式链
    result_expr = (
        # 把值列、临时索引、分组列打包成结构体,保证滚动时的分组关联
        pl.struct([value_expr, temp_index] + group_expr_list)
        # 按分组列做滚动窗口,窗口基于临时索引
        .rolling(
            index_column="_temp_idx",
            period=period,
            group_by=[pl.col("_temp_idx").struct.field(name) for name in group_field_names]
        )
        # 对每个窗口内的数据做采样、自定义函数计算
        .agg(
            pl.col("")
            .struct.field(value_expr.meta.output_name())
            # 按every_nth采样,反向再反向保持原始顺序
            .reverse().gather_every(every_nth).reverse()
            # 调用自定义滚动函数
            .map_batches(lambda batch: func(batch, window_size, *func_args), return_dtype=return_dtype)
            .last()  # 取窗口计算结果的最后一个值,对应原逻辑的输出
        )
    )
    
    return result_expr

如何使用这个表达式函数?

举个和你原有测试用例完全对应的场景:

import numpy as np
import polars as pl
from numba import guvectorize

# 保留你原来的自定义滚动函数
@guvectorize(['(float64[:], int64, float64[:])'], '(n),()->(n)')
def rolling_func(input_array, window_size, out):
    n = len(input_array)
    for i in range(n):
        start = max(i - window_size + 1, 0)
        out[i] = np.mean(input_array[start:i+1])

# 测试DataFrame
df = pl.DataFrame({
    'group': np.repeat(['A', 'B'], 100),
    'value': np.tile(np.arange(100), 2)
})

# 用表达式函数生成结果列
result_df = df.with_columns(
    expr_apply_rolling_gathered_agg(
        group_expr=pl.col("group"),
        value_expr=pl.col("value"),
        func=rolling_func,
        window_size=3,
        every_nth=10,
        window_buffer=0,
        return_dtype=pl.Float64
    ).alias("result")
)

print(result_df.head(20))

关键细节说明

  1. 临时索引的替代:用pl.int_range(0, pl.count())生成临时行索引,完全在表达式层面实现,不需要修改原始DataFrame。
  2. 分组逻辑的绑定:通过pl.struct把值列、临时索引、分组列打包,确保滚动窗口能正确关联分组信息。
  3. 采样逻辑的保留:原逻辑中的reverse().gather_every(every_nth).reverse()完全迁移到表达式链中,保证采样顺序和原逻辑一致。
  4. 无关列的规避:因为只传递需要的分组和值列表达式,其他列根本不会被触及,完美解决了原实现处理无关列的问题。

备注:内容来源于stack exchange,提问作者Olibarer

火山引擎 最新活动