ViT 图像语义嵌入处理器,适用于图像相似性搜索、内容检索等场景。
google/vit-base-patch16-224-in21k(768维)google/vit-large-patch16-224-in21k(1024维)facebook/dinov2-base(768维)facebook/dinov2-large(1024维)输入列名 | 说明 |
|---|---|
images | 包含图像数据的数组,元素类型可为图像URL、 Base64编码或二进制数据 |
包含特征向量的数组,每个元素为float类型的嵌套数组,
数组维度由模型输出决定
如参数没有默认值,则为必填参数
参数名称 | 类型 | 默认值 | 描述 |
|---|---|---|---|
image_src_type | str | image_url | 输入图像的格式类型,支持: - tos/http 地址(image_url) - base64 编码(image_base64) - 二进制流(image_binary) 可选值:["image_url", "image_base64", "image_binary"] 默认值:"image_url" |
dtype | str | float32 | 模型推理精度选择: - bfloat16: 平衡精度与速度(TPU上更快) - float16: 更快的推理速度 - float32: 最高精度但显存消耗最大 可选值:["bfloat16", "float16", "float32"] 默认值:"float16" |
batch_size | int | 32 | 批处理大小 默认值: 32 |
model_path | str | /opt/las/models | 模型文件存储路径 默认值: "/opt/las/models" |
model_name | str | facebook/dinov2-large | 使用的图像向量模型名称 可选值: [ "google/vit-base-patch16-224-in21k", "google/vit-large-patch16-224-in21k", "facebook/dinov2-base", "facebook/dinov2-large" ] 默认值: "facebook/dinov2-large" |
use_cls_token_embedding | bool | True | 是否使用CLS Token特征 默认值: True |
rank | int | 0 | 指定使用的GPU设备编号(多卡环境有效)。例如:0表示第一张GPU,1表示第二张GPU 默认值:0 |
下面的代码展示了如何使用 daft 运行算子计算图片的 embedding。
from __future__ import annotations import logging import os import ray import daft from daft import col from daft.las.functions.image.embedding.image_vit_embedding import ImageViTEmbedding from daft.las.functions.udf import las_udf if __name__ == "__main__": if os.getenv("DAFT_RUNNER", "ray") == "ray": def configure_logging(): logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S.%s".format(), ) logging.getLogger("tracing.span").setLevel(logging.WARNING) logging.getLogger("daft_io.stats").setLevel(logging.WARNING) logging.getLogger("DaftStatisticsManager").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaScheduler").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaDispatcher").setLevel(logging.WARNING) import ray ray.init(dashboard_host="0.0.0.0", runtime_env={"worker_process_setup_hook": configure_logging}) daft.context.set_runner_ray() daft.set_execution_config(actor_udf_ready_timeout=600) daft.set_execution_config(min_cpu_per_task=0) tos_dir_url = os.getenv("TOS_DIR_URL", "las-cn-beijing-public-online.tos-cn-beijing.volces.com") samples = { "image": [ f"https://{tos_dir_url}/public/shared_image_dataset/cat_ip_adapter.jpeg" ] } image_src_type = "image_url" batch_size = 64 model_path = os.getenv("MODEL_PATH", "/opt/las/models") model_name = "google/vit-base-patch16-224-in21k" dtype = "float32" use_cls_token_embedding = True rank = 0 num_gpus = 1 ds = daft.from_pydict(samples) ds = ds.with_column( "embedding", las_udf( ImageViTEmbedding, construct_args={ "image_src_type": image_src_type, "batch_size": batch_size, "model_path": model_path, "model_name": model_name, "dtype": dtype, "use_cls_token_embedding": use_cls_token_embedding, "rank": rank, }, num_gpus=num_gpus, batch_size=1, )(col("image")), ) ds.show() # ╭────────────────────────────────┬────────────────────────────────╮ # │ image ┆ embedding │ # │ --- ┆ --- │ # │ Utf8 ┆ List[Float32] │ # ╞════════════════════════════════╪════════════════════════════════╡ # │ tos://las-cn-beijing-public-o… ┆[-0.011575016, -0.019808339, … │ # ╰────────────────────────────────┴────────────────────────────────╯