基于 BGE-M3 的文本嵌入模型,支持稀疏/稠密/token 三级向量生成
FP16 量化与 GPU 并行计算输入列名 | 说明 |
|---|---|
texts | 包含待处理文本的数组,元素类型为str。 |
处理后的数组,包含以下字段:
如参数没有默认值,则为必填参数
参数名称 | 类型 | 默认值 | 描述 |
|---|---|---|---|
is_output_token_vec | bool | False | 是否输出 token 向量 默认值:False |
dtype | str | float32 | 模型精度,支持 float32 和 float16 可选值:["float32", "float16"] 默认值:"float32" |
batch_size | int | 512 | 模型推理时的批处理大小 默认值:512 |
model_path | str | /opt/las/models | 模型文件所在的路径 默认值:"/opt/las/models" |
model_name | str | BAAI/bge-m3 | 模型名称 可选值:["BAAI/bge-m3"] 默认值:"BAAI/bge-m3" |
rank | int or None | GPU 编号 默认值:None |
下面的代码展示了如何使用 daft 运行算子基于bge-m3模型计算文本dense embedding、sparse embedding以及token embedding。
from __future__ import annotations import os import daft from daft import col from daft.las.functions.text.embedding.bge_sparse_dense_embedding import BgeSparseDenseEmbedding from daft.las.functions.udf import las_udf if __name__ == "__main__": if os.getenv("DAFT_RUNNER", "native") == "ray": import logging import ray def configure_logging(): logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S.%s".format(), ) logging.getLogger("tracing.span").setLevel(logging.WARNING) logging.getLogger("daft_io.stats").setLevel(logging.WARNING) logging.getLogger("DaftStatisticsManager").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaScheduler").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaDispatcher").setLevel(logging.WARNING) ray.init(dashboard_host="0.0.0.0", runtime_env={"worker_process_setup_hook": configure_logging}) daft.context.set_runner_ray() daft.set_execution_config(actor_udf_ready_timeout=600) daft.set_execution_config(min_cpu_per_task=0) samples = {"text": ["Hello World!", None]} is_output_token_vec = True dtype = "float16" batch_size = 512 model_path = os.getenv("MODEL_PATH", "/opt/las/models") model_name = "BAAI/bge-m3" rank = 0 ds = daft.from_pydict(samples) ds = ds.with_column( "embeddings", las_udf( BgeSparseDenseEmbedding, construct_args={ "is_output_token_vec": is_output_token_vec, "dtype": dtype, "batch_size": batch_size, "model_path": model_path, "model_name": model_name, "rank": rank, }, num_gpus=1, batch_size=1, concurrency=1, )(col("text")), ) ds.show() # ╭──────────────┬───────────────────────────────────────────────────────────────────────────────────────────────╮ # │ text ┆ embeddings │ # │ --- ┆ --- │ # │ Utf8 ┆ Struct[dense_embedding: List[Float32], sparse_embedding: Map[Utf8: Float32], token_embedding: │ # │ ┆ List[List[Float32]]] │ # ╞══════════════╪═══════════════════════════════════════════════════════════════════════════════════════════════╡ # │ Hello World! ┆ {dense_embedding: [-0.0420532… │ # ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ # │ None ┆ {dense_embedding: None, │ # │ ┆ spars… │ # ╰──────────────┴───────────────────────────────────────────────────────────────────────────────────────────────╯