基于 BGE 系列的文本嵌入模型,支持稠密向量生成
FP16 量化与 GPU 并行计算输入列名 | 说明 |
|---|---|
texts | 包含待处理文本的数组,元素类型为str。 |
处理后的数组,包含每个文本对应的稠密嵌入向量
如参数没有默认值,则为必填参数
参数名称 | 类型 | 默认值 | 描述 |
|---|---|---|---|
dtype | str | float32 | 模型精度,支持 float32 和 float16 可选值:["float32", "float16"] 默认值:"float32" |
batch_size | int | 512 | 模型推理时的批处理大小 默认值:512 |
model_path | str | /opt/las/models | 模型文件所在的路径 默认值:"/opt/las/models" |
model_name | str | BAAI/bge-m3 | 模型名称 可选值:["BAAI/bge-m3", "BAAI/bge-large-zh-v1.5", "BAAI/bge-large-en-v1.5", "BAAI/bge-multilingual-gemma2"] 默认值:"BAAI/bge-m3" |
rank | int or None | GPU 编号 默认值:None |
下面的代码展示了如何使用 daft 运行算子基于 bge-m3 模型计算文本embedding。
from __future__ import annotations import logging import os import ray import daft from daft import col from daft.las.functions.text.embedding.bge_embedding import BgeEmbedding from daft.las.functions.udf import las_udf if __name__ == "__main__": def configure_logging(): logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S.%s".format(), ) logging.getLogger("tracing.span").setLevel(logging.WARNING) logging.getLogger("daft_io.stats").setLevel(logging.WARNING) logging.getLogger("DaftStatisticsManager").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaScheduler").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaDispatcher").setLevel(logging.WARNING) ray.init(dashboard_host="0.0.0.0", runtime_env={"worker_process_setup_hook": configure_logging}) daft.context.set_runner_ray() daft.set_execution_config(actor_udf_ready_timeout=600) daft.set_execution_config(min_cpu_per_task=0) samples = {"text": ["Hello World!", None]} dtype = "float16" batch_size = 512 model_path = os.getenv("MODEL_PATH", "/opt/las/models") model_name = "BAAI/bge-m3" rank = 0 ds = daft.from_pydict(samples) ds = ds.with_column( "embeddings", las_udf( BgeEmbedding, construct_args={ "dtype": dtype, "batch_size": batch_size, "model_path": model_path, "model_name": model_name, "rank": rank, }, num_gpus=1, batch_size=1, concurrency=1, )(col("text")), ) df = ds.to_pandas() ds.show() # ╭──────────────┬────────────────────────────────╮ # │ text ┆ embeddings │ # │ --- ┆ --- │ # │ Utf8 ┆ List[Float32] │ # ╞══════════════╪════════════════════════════════╡ # │ Hello World! ┆ [-0.042053223, 0.02178955, -0… │ # ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ # │ None ┆ None │ # ╰──────────────┴────────────────────────────────╯