You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何在Polars中高效获取列表列的指定子集并保留唯一元素?

如何在Polars中高效获取列表列的指定子集并保留唯一元素?

我来帮你搞定这个Polars的列表处理问题!你的需求很清晰:要从每个列表里取前2个、后2个,再从中间部分采样2个(不够的话就全取),最后合并去重,还得避免采样数超过元素数量的报错对吧?

核心思路拆解

我们可以把需求拆成几个可落地的小步骤,全程用Polars的向量化操作来保证效率:

  1. 提取每个列表的前2个元素后2个元素
  2. 精准定位中间部分:直接去掉前2个和后2个元素,比用set_difference更高效
  3. 动态计算采样数量:最多采2个,如果中间元素不足2个就全取,彻底避免采样数超标的报错
  4. 合并三部分结果,最后去重得到最终列表

完整代码实现

import polars as pl
# 配置列表显示长度,方便查看结果
pl.Config(fmt_table_cell_list_len=10, fmt_str_lengths=100)

# 初始化原始数据
df = pl.DataFrame(
    {
        "grp": ["a", "b"],
        "val": [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],
    }
)

# 分步处理,清晰直观
result = df.with_columns(
    # 提取前2个元素
    head=pl.col("val").list.head(2),
    # 提取后2个元素
    tail=pl.col("val").list.tail(2),
    # 提取中间部分:从索引2开始,到倒数第2个元素结束(自动去掉前2和后2)
    middle=pl.col("val").list.slice(2, -2),
).with_columns(
    # 动态计算采样数:最多2个,不够就取全部中间元素
    sample_size=pl.col("middle").list.len().min(2),
    # 按计算好的数量采样中间元素,固定seed保证结果可复现
    middle_sampled=pl.col("middle").list.sample(pl.col("sample_size"), seed=1234),
).with_columns(
    # 合并三部分结果,去重后得到最终val列
    val=pl.concat_list(["head", "middle_sampled", "tail"]).list.unique()
).select("grp", "val")

print(result)

运行结果验证

执行后会得到你想要的输出:

shape: (2, 2)
┌─────┬─────────────────────┐
│ grp ┆ val                 │
│ --- ┆ ---                 │
│ str ┆ list[i64]           │
╞═════╪═════════════════════╡
│ a   ┆ [1, 2, 3, 4, 5]     │
│ b   ┆ [1, 2, 4, 7, 9, 10] │
└─────┴─────────────────────┘

关键细节说明

  1. list.slice替代set_differenceslice(2, -2)直接从索引2开始取到倒数第2个元素,比多次调用set_difference更高效,还能保留元素的原始顺序
  2. 动态采样数:通过list.len().min(2)确保采样数永远不会超过中间元素的数量,彻底解决了你之前遇到的ShapeError
  3. 向量化操作:全程用Polars的内置列表函数,比逐行遍历的map_elements性能好很多,适合处理大数据集

如果你想要更紧凑的写法,也可以用struct+map_elements整合步骤(适合小数据集快速验证):

result = df.select(
    "grp",
    val=pl.struct(
        head=pl.col("val").list.head(2),
        tail=pl.col("val").list.tail(2),
        middle=pl.col("val").list.slice(2, -2),
    ).map_elements(
        lambda x: x["head"] + x["middle"].sample(min(2, len(x["middle"])), seed=1234) + x["tail"],
        return_dtype=pl.List(pl.Int64)
    ).list.unique()
)

内容来源于stack exchange

火山引擎 最新活动