You need to enable JavaScript to run this app.
导航
模型打标
最近更新时间:2025.04.22 20:22:41首次发布时间:2025.04.22 20:22:41
我的收藏
有用
有用
无用
无用

场景描述

在 AI数据湖场景中,Schema Evolution(模式演进)是一项至关重要的能力,它能够有效地处理和维护数据结构的变化,且无需停机或重建现有数据集。随着业务需求的日益复杂和数据量的持续增长,数据结构的变化成为了一种常态。这可能表现为添加新的字段以满足新的业务需求,例如记录更多的客户信息或产品特征;甚至可能是删除不再需要的字段,以优化数据存储和处理效率。
Schema Evolution 的强大之处在于,它使得数据架构能够灵活地随时调整和扩展,从而更好地适应业务的发展变化。与此同时,它还能保持对历史数据的访问和兼容性,确保过往的数据分析和处理工作不受影响,为企业的数据管理和利用提供了极大的便利和保障。
在以图片为主的场景中,我们经常会使用一些 AI 模型,对图片做一轮打标的操作。例如使用一些美学分的模型,对数据集中的图片进行一个美学分的判定,判断图像的质量。对于这种场景,我们需要数据集能通加列的能力。这是因为在图片分析和处理的过程中,可能会产生新的特征或属性需要记录,而通过加列的方式可以方便地将这些新信息整合到数据集中,为后续的分析和应用提供更全面和准确的数据支持。

当前方案

当前主流AI场景使用的是WebDataset, WebDataset通过将数据集打包成 tar 文件,并使用简单的 URL 进行访问,使得数据集的管理和使用变得更加高效和灵活。
Lance 是一个高效的列式存储格式,基于 Apache Arrow,旨在提供快速的数据存储和检索。它特别适用于大规模数据分析和机器学习任务。Lance 通过利用列式存储的优势,能够在处理大规模数据时提供高效的读写性能和压缩效果。
而WebDataset因是固定的压缩格式,无法直接将webdataset中包含的json信息等打平,无法快速的对数据集进行数据集筛选、过滤、查询等。
因此本文可以通过Ray读取WebDataset,写入lance的范式。并通过Ray对lance的数据集进行操作
Image
使用 Lance 格式后,进行模型打标会非常便捷。
Image
图像数据一般以Webdataset方式存储,计算采用Ray方式读取,并处理。

  1. 启动Ray任务会,会一行一行的读取WebDataset的数据。
  2. 加载模型,并对图片列进行推理,得到推理的标签值。
  3. 将标签值,插入到Label的json中。
  4. 将整个数据完整的写入到一个新的数据集中。

Lance方案

Image
数据采用Lance格式存储,计算使用Ray或者Spark引擎。
直接在Lance数据集中抽取Image字段,并产生新的Label的Lance File。
最后提交新生成的Lance File,打标就完成了。

方案对比

维度

当前方案

Lance方案

数据读取

  • 需要读取全部的WebDataset数据,IO数据放到十分严重
  • 如果是过滤读取的话,IO放大更加严重
  • 列式存储只读取了图片列。
  • 如果过滤读取,能够跳过刷选列。

数据写入

  • 全量写入,尤其图片字段也需要重新写入 。
  • 只写入新的Label列,字段小。

代码样例

这段代码实现了对Lance数据集中存储的图像进行分布式美学评分计算,并将结果写入数据集。主要包含以下核心功能:

  1. 图像美学评分计算:使用预训练的视觉模型对图像进行美学评分
  2. 分布式处理:利用Ray框架实现多节点并行计算
  3. 数据版本管理:通过Lance数据集的事务机制保证数据一致性
import lance
import pyarrow as pa
import pandas as pd
import numpy as np
import torch
import ray
import time
from PIL import Image
import random
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
from typing import List, Dict, Any
from io import BytesIO
# TOS配置
ENV_AK = ''
ENV_SK = ''
storage_options = {}
lance_path = ""
REF_COLUMN_SCHEMA = pa.schema([
    pa.field("jpg", pa.binary())
])
SCORE_SCHEMA = pa.schema([
    pa.field("AESTHETIC_SCORE", pa.float32()),
    pa.field("AESTHETIC_TAG", pa.string())
])


class FragmentScoreColumns:
    def __init__(self, storage_options, lance_path):
        self.ds = lance.dataset(uri=lance_path,
                                storage_options=storage_options)
        self.model, self.preprocessor = convert_v2_5_from_siglip(
            predictor_name_or_path="aesthetic_predictor_v2_5.pth",
            encoder_model_name="siglip-so400m-patch14-384",
            low_cpu_mem_usage=True,
            trust_remote_code=True
        )
        self.model = self.model.to(torch.float32)

    def get_score_tag(self, score: float) -> str:
        if score > 7.1:
            return 'very aesthetic'
        elif 5.5 < score <= 7.1:
            return 'aesthetic'
        elif 3.7 < score <= 5.5:
            return 'displeasing'
        else:
            return 'very displeasing'

    def process_image(self, image_bytes: bytes) -> tuple[float, str]:
        try:
            # 从字节流创建图像
            image = Image.open(BytesIO(image_bytes)).convert("RGB")
            pixel_values = self.preprocessor(
                images=image,
                return_tensors="pt"
            ).pixel_values
            with torch.inference_mode():
                score = float(self.model(
                    pixel_values).logits.squeeze().numpy())
            return score, self.get_score_tag(score)
        except Exception as e:
            print(f"Error processing image: {str(e)}")
            return None, "error"
        finally:
            if 'image' in locals():
                image.close()

    def __call__(self, batch):
        fragment_id = batch["item"]
        fragment = self.ds.get_fragment(fragment_id)
        new_fragment, new_schema = fragment.merge_columns(
            value_func=self.generate_aesthetic_scores,
            columns=REF_COLUMN_SCHEMA.names
        )
        return {"fragments": new_fragment, "schema": new_schema}

    ## 根据jpg的列,应用模型,计算得到score和tag列
    def generate_aesthetic_scores(self, batch: pa.RecordBatch) -> pa.RecordBatch:
        i = random.randint(0, 10)
        # 获取图像数据列
        image_bytes = batch.column('jpg').to_pylist()
        scores = []
        tags = []
        for img in image_bytes:
            score, tag = self.process_image(img)
            scores.append(score if score is not None else float('nan'))
            tags.append(tag)
        # 创建DataFrame并转换为RecordBatch
        df = pd.DataFrame({
            'AESTHETIC_SCORE': scores,
            'AESTHETIC_TAG': tags
        })
        return pa.RecordBatch.from_pandas(df, schema=SCORE_SCHEMA)


def process_lance_dataset(num_workers: int = 10):
    """处理Lance数据集并添加美学评分"""
    print("Loading Lance dataset...")
    lance_ds = lance.dataset(lance_path, storage_options=storage_options)
    print(f"Original schema: {lance_ds.schema}")
    # 获取所有fragment IDs
    fragment_ids = [f.fragment_id for f in lance_ds.get_fragments()]
    print(f"Processing {len(fragment_ids)} fragments...")
    # 使用Ray进行分布式处理
    ray_ds = ray.data.from_items(fragment_ids).map(
        FragmentScoreColumns,
        fn_constructor_args=(storage_options, lance_path),
        concurrency=num_workers
    ).take_all()
    print("Committing results to Lance dataset...")
    # 提交到Lance
    merged_fragments = [item["fragments"] for item in ray_ds]
    schema = ray_ds[0]["schema"] if ray_ds else None
    operation = lance.LanceOperation.Merge(merged_fragments, schema)
    dataset = lance.LanceDataset.commit(
        lance_path,
        operation,
        read_version=lance_ds.version,
        storage_options=storage_options
    )
    print("Processing completed!")
    return dataset


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(
        description="Process Lance dataset for aesthetic scoring")
    parser.add_argument("--num-workers", type=int,
                        default=10, help="Number of workers")
    args = parser.parse_args()
    start_time = time.time()
    ray.init()
    process_lance_dataset(args.num_workers)

    print(f"Total processing time: {time.time() - start_time:.2f} seconds")