文本安全性评分器 - 基于ShieldLM-6B-chatglm3的安全性评估
输入列名 | 说明 |
|---|---|
texts | pyarrow.Array,元素类型为str |
每个元素为结构体或 None:
如参数没有默认值,则为必填参数
参数名称 | 类型 | 默认值 | 描述 |
|---|---|---|---|
lang | str | zh | 语种 描述:需要评估的文本的语种 可选值:["en", "zh"] 默认值:"zh" |
model_path | str | /opt/las/models | 模型文件所在的基础路径 默认值:"/opt/las/models" |
model_name | str | thu-coai/ShieldLM-6B-chatglm3 | 模型名称 默认值:"thu-coai/ShieldLM-6B-chatglm3" |
batch_size | int | 1 | 批处理大小 描述:控制模型推理时的批量处理大小 影响:较大的batch_size可以提高GPU利用率和吞吐量,但会增加显存占用 建议:ShieldLM-6B模型显存占用较大,默认设置为1以确保稳定性和兼容性 调优:在显存充足的环境下可适当增加到2-4以提升处理效率 默认值:1 |
rank | int or None | GPU编号 描述:指定使用的GPU编号,None表示自动选择 默认值:None |
下面的代码展示了如何使用 daft 运行算子评估文本安全性,输出safe、unsafe、controversial三分类概率。
from __future__ import annotations import os import daft from daft import col from daft.las.functions.text.text_safety_scorer import TextSafetyScorer from daft.las.functions.udf import las_udf if __name__ == "__main__": if os.getenv("DAFT_RUNNER", "native") == "ray": import logging import ray def configure_logging(): logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S.%s".format(), ) logging.getLogger("tracing.span").setLevel(logging.WARNING) logging.getLogger("daft_io.stats").setLevel(logging.WARNING) logging.getLogger("DaftStatisticsManager").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaScheduler").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaDispatcher").setLevel(logging.WARNING) ray.init(dashboard_host="0.0.0.0", runtime_env={"worker_process_setup_hook": configure_logging}) daft.context.set_runner_ray() daft.set_execution_config(actor_udf_ready_timeout=600) daft.set_execution_config(min_cpu_per_task=0) samples = { "text": [ "爱与和平是世界的主旋律。", "我讨厌所有人,最好都去死。", None, ] } model_path = os.getenv("MODEL_PATH", "/opt/las/models") model_name = "thu-coai/ShieldLM-6B-chatglm3" lang = "zh" batch_size = 3 rank = 0 ds = daft.from_pydict(samples) ds = ds.with_column( "safety_scores", las_udf( TextSafetyScorer, construct_args={ "lang": lang, "model_path": model_path, "model_name": model_name, "batch_size": batch_size, "rank": rank, }, num_gpus=1, batch_size=3, concurrency=1, )(col("text")), ) ds.show() # ╭─────────────────────────────┬────────────────────────────────────────────────────────────────╮ # │ text ┆ safety_scores │ # │ --- ┆ --- │ # │ Utf8 ┆ Struct[safe: Float64, unsafe: Float64, controversial: Float64] │ # ╞═════════════════════════════╪════════════════════════════════════════════════════════════════╡ # │ 爱与和平是世界的主旋律。 ┆ {safe: 0.8632398843765259, │ # │ ┆ unsafe: 0.04434245824813843, │ # │ ┆ controversial: 0.09241761267185211} │ # ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ # │ 我讨厌所有人,最好都去死。 ┆ {safe: 0.0004654618678614497, │ # │ ┆ unsafe: 0.9993947744369507, │ # │ ┆ controversial: 0.00013975700130686164} │ # ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ # │ None ┆ None │ # ╰─────────────────────────────┴────────────────────────────────────────────────────────────────╯