You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

使用Polars读取Parquet文件时多次读取值不一致,但Pandas无此问题,同配置仅单工作站出现异常

Polars读取Parquet文件时多次读取值不一致,但Pandas无此问题,同配置仅单工作站出现异常

我多次用Polars(Rust引擎和PyArrow后端)以及Pandas(PyArrow后端,不用fastparquet因为太慢)读取同一批Parquet文件,相关代码附在下方。

所有Parquet文件都包含一列名为backscat,每一行都是包含150万个浮点数的时间序列列表。

我的脚本逻辑是:如果两次读取的DataFrame哈希值不同,就检查backscat列的数值差异,把超出rtol = 1e-7, atol = 1e-10的差异记录为"violations"。

我在两台配置完全一致的工作站上运行脚本:相同CPU、相同Poetry虚拟环境、Python 3.12.8、Polars 1.20.0、Ubuntu 22.04,内存配置也完全相同。

异常现象

  • 工作站A上,Polars读取同一文件时出现了显著差异,比如Polars Rust引擎的最大差异示例:
name_testdata_1data_2abs_diffabs_relative_diff
polars_rust-3.041666-38.66665635.6249911.712328
polars_rust-38.666656-3.04166635.624990.921336
polars_rust-2.927914-27.42331524.4954018.36616
polars_rust-27.423315-2.92791424.4954010.893233
polars_rust-2.927876-27.4230124.4951348.366178

可以看到差异非常大,比如一次读出来是-3.041666,另一次是-38.666656!但同一台工作站A上用Pandas读取完全没有异常。

  • 工作站B上,不管是Polars还是Pandas,都没有出现任何异常。

后续排查情况

  • 我在工作站A上用Docker容器运行(Polars 1.24.0),结果还是一样有异常。
  • 我发现如果同事同时运行占满所有处理器的模拟任务时,Polars就不会出现异常!看起来当Polars无法并行使用多进程时,错误就不会发生。
  • 我目前怀疑是硬件问题(比如工作站A的RAM损坏),但疑惑为什么只有Polars会出现这个问题,Pandas却没问题。

读取和检测差异的代码

# CODE TO READ PARQUET FILES AND DETECT DIFFERENCES
import hashlib
from pathlib import Path
from typing import Any, Callable

import numpy as np
import pandas as pd
import polars as pl
from tqdm import tqdm


def hash_dataframe_polars(df):
    return hashlib.sha256(df.select("backscat").explode("backscat").write_csv(None).encode()).hexdigest()


def hash_dataframe_pandas(df):
    return hashlib.sha256(df["backscat"].to_csv().encode()).hexdigest()


def find_precision_violations(
    paths_parquet_files: list[Path],
    name_test: str,
    rtol: float,
    atol: float,
    hash_fun: Callable[[Any], str],
    read_parquet_file: Callable[[Path], Any],
    num_reps_per_file: int,
    output_dir: Path = Path("./results_check_ram_or_cpu_catch"),  # Default output directory
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
    precision_violated = []
    errors_reading_files = []

    # Ensure output directory exists
    output_dir.mkdir(parents=True, exist_ok=True)

    # Use tqdm to monitor file processing
    for file_index, path_file in enumerate(tqdm(paths_parquet_files, desc=f"Processing {name_test} Files")):
        previous_hash = None
        previous_df = None

        # Use tqdm to monitor repetitions per file
        for rep in tqdm(range(num_reps_per_file), desc=f"Reps for File {file_index}", leave=False):
            try:
                df = read_parquet_file(path_file)
            except Exception as e:
                errors_reading_files.append(
                    {"name_test": name_test, "file_index": file_index, "rep": rep, "error": str(e)}
                )
                continue

            hash = hash_fun(df)
            if hash != previous_hash and previous_hash is not None:
                tqdm.write(f"{name_test}, hashes differ for file {path_file}")

                for i in range(df.shape[0]):
                    data_1 = np.asarray(df["backscat"][i])
                    data_2 = np.asarray(previous_df["backscat"][i])  # type: ignore
                    mask = ~np.isclose(data_1, data_2, rtol=rtol, atol=atol)

                    if np.any(mask):
                        print(f"Precision violation found for {sum(mask)} elements")

                        precision_violated.append(
                            {
                                "name_test": name_test,
                                "data_1": data_1[mask].tolist(),
                                "data_2": data_2[mask].tolist(),
                                "i": i,
                                "rep": rep,
                                "file_index": file_index,
                            }
                        )

            previous_hash = hash
            previous_df = df

    # Convert results to Polars DataFrame and save as Parquet
    if precision_violated:
        df_precision_violated = pl.DataFrame(precision_violated)
        df_precision_violated.write_parquet(output_dir / f"precision_violations_{name_test}.parquet")

    if errors_reading_files:
        df_errors = pl.DataFrame(errors_reading_files)
        df_errors.write_parquet(output_dir / f"errors_reading_files_{name_test}.parquet")

    return precision_violated, errors_reading_files


list_set_up = [
    {
        "name_test": "polars_rust",
        "read_parquet_file": lambda path_file: pl.read_parquet(
            path_file, columns=["date", "backscat"], use_pyarrow=False
        ),
        "hash_fun": hash_dataframe_polars,
    },
    {
        "name_test": "pandas_pyarrow",
        "read_parquet_file": lambda path_file: pd.read_parquet(
            path_file, columns=["date", "backscat"], engine="pyarrow"
        ),
        "hash_fun": hash_dataframe_pandas,
    },
    # {
    #     "name_test": "pandas_fastparquet",
    #     "read_parquet_file": lambda path_file: pd.read_parquet(
    #         path_file, columns=["date", "backscat"], engine="fastparquet"
    #     ),
    #     "hash_fun": hash_dataframe_pandas,
    # },
    {
        "name_test": "polars_pyarrow",
        "read_parquet_file": lambda path_file: pl.read_parquet(
            path_file, columns=["date", "backscat"], use_pyarrow=True
        ),
        "hash_fun": hash_dataframe_polars,
    },
]

dir_parquet_files = Path("cache/212_artificial_data_set/parquet_files")
paths_parquet_files = list(dir_parquet_files.glob("*.parquet"))
rtol = 1e-7
atol = 1e-10
num_reps_per_file = 10
all_violations = []
all_errors_reading_files = []
for set_up in list_set_up:
    violations, errors_reading_files = find_precision_violations(
        paths_parquet_files=paths_parquet_files,
        name_test=set_up["name_test"],
        rtol=rtol,
        atol=atol,
        hash_fun=set_up["hash_fun"],
        read_parquet_file=set_up["read_parquet_file"],
        num_reps_per_file=num_reps_per_file,
    )

    all_violations += violations
    all_errors_reading_files += errors_reading_files


# %% save violations to csv save errors reading files to csv
violations_df = pd.DataFrame(all_violations)
print(violations_df)
violations_df.to_csv("all_violations.csv")

errors_reading_files_df = pd.DataFrame(all_errors_reading_files)
print(errors_reading_files_df)
errors_reading_files_df.to_csv("all_errors_reading_files.csv")

分析脚本代码

# ANALYSIS SCRIPT
import ast

import matplotlib.pyplot as plt
import polars as pl

violations_df = pl.read_csv("all_violations.csv")
errors_reading_files_df = pl.read_csv("all_errors_reading_files.csv")

violations_df = pl.read_csv("all_violations.csv")

violations_df = violations_df.with_columns(
    pl.col("data_1").map_elements(lambda x: ast.literal_eval(x), return_dtype=pl.List(pl.Float64)),
    pl.col("data_2").map_elements(lambda x: ast.literal_eval(x), return_dtype=pl.List(pl.Float64)),
)

violations_df = violations_df.with_columns(
    pl.col("data_1").len().alias("len_data_1"),
    pl.col("data_2").len().alias("len_data_2"),
)

if violations_df.filter(pl.col("len_data_1") != pl.col("len_data_2")).height > 0:
    print("Lengths differ")

violations_df = violations_df.drop("len_data_1").rename({"len_data_2": "num_violations"})

tot_violations = (
    violations_df.group_by("name_test")
    .agg(pl.col("num_violations").sum().alias("total_violations"))
    .sort("total_violations")
)

print(tot_violations)

# %%
for name_test in violations_df["name_test"].unique():
    # 原代码此处内容不完整,保留原样

备注:内容来源于stack exchange,提问作者newandlost

火山引擎 最新活动