You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

Spark/PySpark中ML预处理数据缓存及自动失效实现问询

如何在Spark机器学习流水线中缓存转换后的数据并自动失效?

我期望搭建如下机器学习(ML)流水线:

raw_data = spark.read....()
data = time_consuming_data_transformation(raw_data, preprocessing_params)
model = fit_model(data)
evaluate(model, data)

请问是否可在步骤2完成后对data进行缓存/持久化,以避免再次运行Spark应用时重复执行耗时的数据转换操作?理想状态下,当原始数据或转换代码(计算图、preprocessing_params)变更时,缓存能自动失效。

当然可以!缓存转换后的DataFrame绝对能帮你省下重复执行耗时数据转换的开销,不过Spark默认的缓存机制没法自动感知原始数据、参数或者转换代码的变化,得咱们自己加一点逻辑来实现自动失效。下面给你拆解具体的实现思路和操作:

一、基础缓存操作

先来说最直接的缓存方式,在步骤2之后加上缓存调用就行:

from pyspark.storagelevel import StorageLevel

data = time_consuming_data_transformation(raw_data, preprocessing_params)
# 选择合适的存储级别:数据量小用MEMORY_ONLY,大的话优先选带磁盘的级别避免OOM
data.persist(StorageLevel.MEMORY_AND_DISK)
# 触发实际计算(Spark缓存是懒加载的,不触发的话第一次用data还是会重新跑转换)
data.count()

cache()其实是persist(StorageLevel.MEMORY_ONLY)的简写,用persist()能更灵活地选择存储策略。

二、实现自动失效的核心思路

Spark原生缓存只会在数据血缘失效、手动调用unpersist()或者集群资源不足时清除,没法自动感知原始数据、参数或代码的变化,所以咱们得自己做版本控制和校验:

1. 针对原始数据变更的失效逻辑

你可以监控原始数据的源信息来判断是否需要失效缓存:

  • 如果是读取文件(Parquet/CSV等):检查文件的修改时间或文件列表的哈希值;
  • 如果是读取Hive表:检查表的分区更新时间或元数据版本。

每次运行流水线前先做检查,发现原始数据变化就清除旧缓存:

# 示例:检查HDFS原始文件的修改时间(伪代码,需适配你的存储系统)
from hdfs import InsecureClient

hdfs_client = InsecureClient("http://your-nn-host:50070")
raw_data_path = "/path/to/raw_data"
# 获取HDFS文件的最新修改时间
file_stats = hdfs_client.list(raw_data_path, status=True)
last_modified = max(stat["modificationTime"] for path, stat in file_stats)

# 假设我们把上次的修改时间存在本地文件或数据库中
def get_saved_last_modified():
    try:
        with open("/tmp/raw_data_last_modified.txt", "r") as f:
            return int(f.read())
    except FileNotFoundError:
        return 0

def save_last_modified(timestamp):
    with open("/tmp/raw_data_last_modified.txt", "w") as f:
        f.write(str(timestamp))

if last_modified > get_saved_last_modified():
    # 清除旧缓存(如果存在)
    if data.is_cached:
        data.unpersist()
    # 更新记录的修改时间
    save_last_modified(last_modified)
    # 重新执行转换并缓存
    data = time_consuming_data_transformation(raw_data, preprocessing_params)
    data.persist(StorageLevel.MEMORY_AND_DISK)
    data.count()

2. 针对参数/转换代码变更的失效逻辑

这里的核心是给缓存加一个版本标识,版本由preprocessing_params和转换代码的哈希值共同决定:

  • preprocessing_params做哈希:把参数转成有序JSON字符串,再计算MD5/SHA哈希(取前几位做版本号即可);
  • 对转换代码做哈希:读取time_consuming_data_transformation函数的源码字符串再计算哈希。

然后用版本标识管理缓存,比如创建带版本号的临时表:

import hashlib
import json
import inspect

def get_version_hash(params, func):
    # 计算参数的哈希(保证有序,避免参数顺序不同导致哈希变化)
    params_str = json.dumps(params, sort_keys=True)
    params_hash = hashlib.md5(params_str.encode()).hexdigest()[:8]
    # 计算函数代码的哈希
    func_code = inspect.getsource(func)
    func_hash = hashlib.md5(func_code.encode()).hexdigest()[:8]
    return f"{params_hash}_{func_hash}"

# 获取当前版本哈希
current_version = get_version_hash(preprocessing_params, time_consuming_data_transformation)
cache_table_name = f"transformed_data_{current_version}"

# 检查是否存在当前版本的缓存表
if spark.catalog.tableExists(cache_table_name):
    # 直接读取缓存
    data = spark.table(cache_table_name)
else:
    # 清除所有旧版本的缓存表(可选,避免冗余占用资源)
    for table in spark.catalog.listTables():
        if table.name.startswith("transformed_data_") and table.name != cache_table_name:
            spark.catalog.uncacheTable(table.name)
    # 执行转换并缓存
    data = time_consuming_data_transformation(raw_data, preprocessing_params)
    data.createOrReplaceTempView(cache_table_name)
    spark.catalog.cacheTable(cache_table_name)
    data.count()

三、注意事项

  • 懒加载问题:一定要调用count()show()等触发计算的操作,不然缓存只是标记了要缓存,第一次使用data时还是会重新执行转换;
  • 存储级别选择:根据数据量和集群资源调整,大优先选DISK_ONLYMEMORY_AND_DISK
  • 资源清理:长期运行的流水线要定期清理旧缓存,避免占用过多存储资源;
  • 分布式适配:如果在分布式集群上运行,监控原始数据的逻辑要适配对应的存储系统(比如用HDFS API代替本地文件操作)。

内容的提问来源于stack exchange,提问作者Lefty

火山引擎 最新活动