这是一个智驾客户的训练场景,客户想解决的核心问题是:训练数据集在超大规模场景下的横向扩展问题。
客户的数据集非常大,每条数据称为“一帧”(frame),一帧有几百列,通常包含摄像头(图片)、激光雷达和传感器上报的数据,一帧数据的大小约20MB,通常一秒有若干帧。客户当前将数据存储在火山的vePFS,并且已经遇到瓶颈,vePFS经常需要横向扩容来满足不断增长的数据存储需求。未来客户规划存储 20 亿frame的数据量,存储规模会逼近 40PB,vePFS 要容纳如此规模的数据,成本将会非常高。
客户当前采用的是基于LMDB + pickle的存储方案。通过pickle将上面提到的每一帧的Python对象数据序列化为一个键值对,然后存储到LMDB中,同时将LMDB的数据库文件存储在vePFS上。当需要访问的时候,先通过LMDB打开对应的数据库文件,然后基于特定的key读取出帧的序列化后的数据,最后通过pickle再反序列化为Python 对象。
这种方案选型,在小规模的数据量场景是一个比较成熟而且轻量的选型。
火山EMR引擎提出了基于Lance Format的多模数据湖方案。该方案相比客户当前的LMDB的嵌入式存储方案,可以在支撑超大规模存储的前提下,以较好的性能满足客户随机访问和全量扫描的的多种数据访问场景。同时,Lance数据的存储跟底层的存储服务耦合度也非常低。用户可以按场景自行选择:
同时,Lance 支持ZSTD压缩编码,对二进制数据支持较高的压缩比,这个特性可以更进一步压缩存储空间占用同时还能够降低网络带宽。
维度 | 当前方案 | Lance方案 |
---|---|---|
存储方式 |
|
|
Schema变更 |
|
|
分布式计算友好 |
|
|
压缩 |
|
|
管理成本 |
|
|
维度 | 当前方案 | Lance方案 |
---|---|---|
数据shuffle |
|
|
稳定性 |
|
|
客户的 schema 示例如下:
schema = pa.schema([ pa.field("dat_name", pa.string()), pa.field("frame_key", pa.string()), pa.field("label", pa.binary(), metadata={"lance-encoding:compression": "zstd"}), pa.field("sensor", pa.binary(), metadata={"lance-encoding:compression": "zstd"}), pa.field("src_label", pa.binary(), metadata={"lance-encoding:compression": "zstd"}), ])
LMDB转Lance:
import lmdb import os import lance import pyarrow as pa import ray from tosfs.core import TosFileSystem ENV_AK = "" ENV_SK = "" fs = TosFileSystem( endpoint_url="https://tos-cn-beijing.ivolces.com", key=ENV_AK, secret=ENV_SK, region="cn-beijing" ) root_path = "" sensor_path = root_path + 'sensor/' label_path = root_path + 'label/' local_tmp_path = "/tmp/" LANCE_PATH="" dat_names = [] for path in fs.ls(label_path): dat_names.append(path.split("/")[-1].replace("_label.lmdb", "")) schema = pa.schema([ pa.field("label", pa.large_binary()), pa.field("sensor", pa.large_binary()), ]) LANCE_AK = "" LANCE_SK = "" storage_options = { "access_key_id": LANCE_AK, "secret_access_key": LANCE_SK, "aws_region": "cn-beijing", "aws_endpoint": "https://emrtest.tos-s3-cn-beijing.ivolces.com", "virtual_hosted_style_request": "true", "timeout": "600s" } ds = ray.data.from_items(dat_names, override_num_blocks=100) def gen_pa_table_by_dat_name(parent_dir:str, dat_name: str): label_lmdb = parent_dir + "/" + dat_name + '_label.lmdb' sensor_lmdb = parent_dir + "/" + dat_name + '.lmdb' label_env = lmdb.open(label_lmdb, readonly=True, lock=False) label_ctx = label_env.begin(write=False) sensor_env = lmdb.open(sensor_lmdb, readonly=True, lock=False) sensor_ctx = sensor_env.begin(write=False) frame_keys = [] for key, _ in label_ctx.cursor(): if key.startswith("CD".encode()): frame_keys.append(key) schema = pa.schema([ pa.field("label", pa.large_binary()), pa.field("sensor", pa.large_binary()), ]) data = { "label": [], "sensor": [], } for frame_key in frame_keys: label_key = frame_key.decode('utf-8') sensor_key = frame_key.decode('utf-8') src_label_key = "src_" + label_key label_data = label_ctx.get(frame_key) sensor_data = sensor_ctx.get(frame_key) src_label_data = label_ctx.get(src_label_key.encode()) data['label'].append(label_data) data['sensor'].append(sensor_data) table = pa.Table.from_arrays( [ pa.array(data['label'], type=pa.large_binary()), pa.array(data['sensor'], type=pa.large_binary()), ], schema=schema ) label_env.close() sensor_env.close() return table def duplicate_row(row: dict[str, any]): dat_name = row["item"] import os local_dat_dir = local_tmp_path + dat_name local_label_dat_dir = local_dat_dir + "/" + dat_name + "_label.lmdb" local_sensor_dat_dir = local_dat_dir + "/" + dat_name + ".lmdb" label_tos_lmdb = label_path + dat_name + "_label.lmdb/" sensor_tos_lmdb = sensor_path + dat_name + "/" + dat_name+ ".lmdb/" print(f'local label = {local_label_dat_dir}') print(f'local sensor = {local_sensor_dat_dir}') print(f'tos label = {label_tos_lmdb}') print(f'tos sensor = {sensor_tos_lmdb}') os.makedirs(os.path.dirname(local_dat_dir), exist_ok=True) os.makedirs(os.path.dirname(local_label_dat_dir), exist_ok=True) os.makedirs(os.path.dirname(local_label_dat_dir), exist_ok=True) from tosfs.core import TosFileSystem fs = TosFileSystem( endpoint_url="https://tos-cn-beijing.ivolces.com", key=ENV_AK, secret=ENV_SK, region="cn-beijing" ) fs.get(label_tos_lmdb, local_label_dat_dir, recursive=True) print(f"Downloaded {label_tos_lmdb} from tos to {local_label_dat_dir}") fs.get(sensor_tos_lmdb, local_sensor_dat_dir, recursive=True) print(f"Downloaded {sensor_tos_lmdb} from tos to {local_sensor_dat_dir}") table = gen_pa_table_by_dat_name(local_dat_dir, dat_name) fragment = lance.fragment.LanceFragment.create(LANCE_PATH, table, storage_options=storage_options) import shutil if os.path.exists(local_dat_dir): try: # 递归删除文件夹,处理只读权限 print(f"准备删除文件夹 {local_dat_dir}。") shutil.rmtree(local_dat_dir) print(f"文件夹 {local_dat_dir} 已成功删除。") except Exception as e: print(f"删除文件夹 {local_dat_dir} 时出错: {e}") else: print(f"文件夹 {local_dat_dir} 不存在。") return {"fragment": fragment} fragments = ds.map(duplicate_row).take_all() operation = lance.LanceOperation.Append([fragment['fragment'] for fragment in fragments]) dataset = lance.dataset(LANCE_PATH, storage_options=storage_options) dataset = lance.LanceDataset.commit(LANCE_PATH, operation, read_version=dataset.latest_version, storage_options=storage_options)
训练时数据采用、shuffle:
columns=["dat_name", "frame_key"] rowids = [] index = 0 for fragment in ds.get_fragments(): rowids.append(index) index += fragment.count_rows() ds.take(rowids, columns=columns)
columns=["dat_name", "frame_key"] rowids = [] fragids = [] for fragment in ds.get_fragments(): fragids.append(fragment.fragment_id) rowids.append(fragment.fragment_id << 32) table = ds._take_rows(rowids, columns=columns) # 将Fragment id写入到新列去 new_column_data = pa.array(fragids) new_table = table.append_column('frag_id', new_column_data) import pyarrow as pa import pyarrow.compute as pc filter_values = ["tag13_seq29_CD701_LS6C3E191PA250083_2024-12-06_10-22-32","tag13_seq30_CD701_LS6C3E191PA250083_2024-12-06_10-22-32","tag13_seq31_CD701_LS6C3E191PA250083_2024-12-06_10-22-32"] # 按照dat_name过滤 filter_fn = pc.is_in(new_table.column('dat_name'), pa.array(filter_values)) filtered_table = pc.filter(new_table, filter_fn) print(filtered_table.to_pandas())