使用Polars读取Parquet文件时多次读取值不一致,但Pandas无此问题,同配置仅单工作站出现异常
Polars读取Parquet文件时多次读取值不一致,但Pandas无此问题,同配置仅单工作站出现异常
我多次用Polars(Rust引擎和PyArrow后端)以及Pandas(PyArrow后端,不用fastparquet因为太慢)读取同一批Parquet文件,相关代码附在下方。
所有Parquet文件都包含一列名为backscat,每一行都是包含150万个浮点数的时间序列列表。
我的脚本逻辑是:如果两次读取的DataFrame哈希值不同,就检查backscat列的数值差异,把超出rtol = 1e-7, atol = 1e-10的差异记录为"violations"。
我在两台配置完全一致的工作站上运行脚本:相同CPU、相同Poetry虚拟环境、Python 3.12.8、Polars 1.20.0、Ubuntu 22.04,内存配置也完全相同。
异常现象
- 工作站A上,Polars读取同一文件时出现了显著差异,比如Polars Rust引擎的最大差异示例:
| name_test | data_1 | data_2 | abs_diff | abs_relative_diff |
|---|---|---|---|---|
| polars_rust | -3.041666 | -38.666656 | 35.62499 | 11.712328 |
| polars_rust | -38.666656 | -3.041666 | 35.62499 | 0.921336 |
| polars_rust | -2.927914 | -27.423315 | 24.495401 | 8.36616 |
| polars_rust | -27.423315 | -2.927914 | 24.495401 | 0.893233 |
| polars_rust | -2.927876 | -27.42301 | 24.495134 | 8.366178 |
可以看到差异非常大,比如一次读出来是-3.041666,另一次是-38.666656!但同一台工作站A上用Pandas读取完全没有异常。
- 工作站B上,不管是Polars还是Pandas,都没有出现任何异常。
后续排查情况
- 我在工作站A上用Docker容器运行(Polars 1.24.0),结果还是一样有异常。
- 我发现如果同事同时运行占满所有处理器的模拟任务时,Polars就不会出现异常!看起来当Polars无法并行使用多进程时,错误就不会发生。
- 我目前怀疑是硬件问题(比如工作站A的RAM损坏),但疑惑为什么只有Polars会出现这个问题,Pandas却没问题。
读取和检测差异的代码
# CODE TO READ PARQUET FILES AND DETECT DIFFERENCES import hashlib from pathlib import Path from typing import Any, Callable import numpy as np import pandas as pd import polars as pl from tqdm import tqdm def hash_dataframe_polars(df): return hashlib.sha256(df.select("backscat").explode("backscat").write_csv(None).encode()).hexdigest() def hash_dataframe_pandas(df): return hashlib.sha256(df["backscat"].to_csv().encode()).hexdigest() def find_precision_violations( paths_parquet_files: list[Path], name_test: str, rtol: float, atol: float, hash_fun: Callable[[Any], str], read_parquet_file: Callable[[Path], Any], num_reps_per_file: int, output_dir: Path = Path("./results_check_ram_or_cpu_catch"), # Default output directory ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: precision_violated = [] errors_reading_files = [] # Ensure output directory exists output_dir.mkdir(parents=True, exist_ok=True) # Use tqdm to monitor file processing for file_index, path_file in enumerate(tqdm(paths_parquet_files, desc=f"Processing {name_test} Files")): previous_hash = None previous_df = None # Use tqdm to monitor repetitions per file for rep in tqdm(range(num_reps_per_file), desc=f"Reps for File {file_index}", leave=False): try: df = read_parquet_file(path_file) except Exception as e: errors_reading_files.append( {"name_test": name_test, "file_index": file_index, "rep": rep, "error": str(e)} ) continue hash = hash_fun(df) if hash != previous_hash and previous_hash is not None: tqdm.write(f"{name_test}, hashes differ for file {path_file}") for i in range(df.shape[0]): data_1 = np.asarray(df["backscat"][i]) data_2 = np.asarray(previous_df["backscat"][i]) # type: ignore mask = ~np.isclose(data_1, data_2, rtol=rtol, atol=atol) if np.any(mask): print(f"Precision violation found for {sum(mask)} elements") precision_violated.append( { "name_test": name_test, "data_1": data_1[mask].tolist(), "data_2": data_2[mask].tolist(), "i": i, "rep": rep, "file_index": file_index, } ) previous_hash = hash previous_df = df # Convert results to Polars DataFrame and save as Parquet if precision_violated: df_precision_violated = pl.DataFrame(precision_violated) df_precision_violated.write_parquet(output_dir / f"precision_violations_{name_test}.parquet") if errors_reading_files: df_errors = pl.DataFrame(errors_reading_files) df_errors.write_parquet(output_dir / f"errors_reading_files_{name_test}.parquet") return precision_violated, errors_reading_files list_set_up = [ { "name_test": "polars_rust", "read_parquet_file": lambda path_file: pl.read_parquet( path_file, columns=["date", "backscat"], use_pyarrow=False ), "hash_fun": hash_dataframe_polars, }, { "name_test": "pandas_pyarrow", "read_parquet_file": lambda path_file: pd.read_parquet( path_file, columns=["date", "backscat"], engine="pyarrow" ), "hash_fun": hash_dataframe_pandas, }, # { # "name_test": "pandas_fastparquet", # "read_parquet_file": lambda path_file: pd.read_parquet( # path_file, columns=["date", "backscat"], engine="fastparquet" # ), # "hash_fun": hash_dataframe_pandas, # }, { "name_test": "polars_pyarrow", "read_parquet_file": lambda path_file: pl.read_parquet( path_file, columns=["date", "backscat"], use_pyarrow=True ), "hash_fun": hash_dataframe_polars, }, ] dir_parquet_files = Path("cache/212_artificial_data_set/parquet_files") paths_parquet_files = list(dir_parquet_files.glob("*.parquet")) rtol = 1e-7 atol = 1e-10 num_reps_per_file = 10 all_violations = [] all_errors_reading_files = [] for set_up in list_set_up: violations, errors_reading_files = find_precision_violations( paths_parquet_files=paths_parquet_files, name_test=set_up["name_test"], rtol=rtol, atol=atol, hash_fun=set_up["hash_fun"], read_parquet_file=set_up["read_parquet_file"], num_reps_per_file=num_reps_per_file, ) all_violations += violations all_errors_reading_files += errors_reading_files # %% save violations to csv save errors reading files to csv violations_df = pd.DataFrame(all_violations) print(violations_df) violations_df.to_csv("all_violations.csv") errors_reading_files_df = pd.DataFrame(all_errors_reading_files) print(errors_reading_files_df) errors_reading_files_df.to_csv("all_errors_reading_files.csv")
分析脚本代码
# ANALYSIS SCRIPT import ast import matplotlib.pyplot as plt import polars as pl violations_df = pl.read_csv("all_violations.csv") errors_reading_files_df = pl.read_csv("all_errors_reading_files.csv") violations_df = pl.read_csv("all_violations.csv") violations_df = violations_df.with_columns( pl.col("data_1").map_elements(lambda x: ast.literal_eval(x), return_dtype=pl.List(pl.Float64)), pl.col("data_2").map_elements(lambda x: ast.literal_eval(x), return_dtype=pl.List(pl.Float64)), ) violations_df = violations_df.with_columns( pl.col("data_1").len().alias("len_data_1"), pl.col("data_2").len().alias("len_data_2"), ) if violations_df.filter(pl.col("len_data_1") != pl.col("len_data_2")).height > 0: print("Lengths differ") violations_df = violations_df.drop("len_data_1").rename({"len_data_2": "num_violations"}) tot_violations = ( violations_df.group_by("name_test") .agg(pl.col("num_violations").sum().alias("total_violations")) .sort("total_violations") ) print(tot_violations) # %% for name_test in violations_df["name_test"].unique(): # 原代码此处内容不完整,保留原样
备注:内容来源于stack exchange,提问作者newandlost




