You need to enable JavaScript to run this app.
AI 数据湖服务

AI 数据湖服务

复制全文
文本处理
文本安全性评分器
复制全文
文本安全性评分器

算子介绍

描述

文本安全性评分器 - 基于ShieldLM-6B-chatglm3的安全性评估

核心功能

  • 多语言支持:支持中文和英文文本安全性评估
  • 三分类评估:输出safe、unsafe、controversial三类概率
  • 批量处理:支持批量文本安全性评估,提升处理效率

技术实现

  • 模型核心:基于ShieldLM-6B-chatglm3
  • 推理优化:支持GPU加速和批量推理

应用场景

  • 内容安全审核
  • 文本风险评估
  • 多语言安全过滤

Daft 调用

算子参数

输入

输入列名

说明

texts

pyarrow.Array,元素类型为str

输出

每个元素为结构体或 None:

  • 如果对应输入为 None,输出为 None;
  • 否则输出 Struct,包括以下字段:
  • safe: Float64,模型预测文本为“安全”的概率
  • unsafe: Float64,模型预测文本为“不安全”的概率
  • controversial: Float64,模型预测文本为“有争议”的概率

参数

如参数没有默认值,则为必填参数

参数名称

类型

默认值

描述

lang

str

zh

语种 描述:需要评估的文本的语种 可选值:["en", "zh"] 默认值:"zh"

model_path

str

/opt/las/models

模型文件所在的基础路径 默认值:"/opt/las/models"

model_name

str

thu-coai/ShieldLM-6B-chatglm3

模型名称 默认值:"thu-coai/ShieldLM-6B-chatglm3"

batch_size

int

1

批处理大小 描述:控制模型推理时的批量处理大小 影响:较大的batch_size可以提高GPU利用率和吞吐量,但会增加显存占用 建议:ShieldLM-6B模型显存占用较大,默认设置为1以确保稳定性和兼容性 调优:在显存充足的环境下可适当增加到2-4以提升处理效率 默认值:1

rank

int or None

GPU编号 描述:指定使用的GPU编号,None表示自动选择 默认值:None

调用示例

下面的代码展示了如何使用 daft 运行算子评估文本安全性,输出safe、unsafe、controversial三分类概率。

from __future__ import annotations

import os

import daft
from daft import col
from daft.las.functions.text.text_safety_scorer import TextSafetyScorer
from daft.las.functions.udf import las_udf

if __name__ == "__main__":

    if os.getenv("DAFT_RUNNER", "native") == "ray":
        import logging

        import ray

        def configure_logging():
            logging.basicConfig(
                level=logging.INFO,
                format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
                datefmt="%Y-%m-%d %H:%M:%S.%s".format(),
            )
            logging.getLogger("tracing.span").setLevel(logging.WARNING)
            logging.getLogger("daft_io.stats").setLevel(logging.WARNING)
            logging.getLogger("DaftStatisticsManager").setLevel(logging.WARNING)
            logging.getLogger("DaftFlotillaScheduler").setLevel(logging.WARNING)
            logging.getLogger("DaftFlotillaDispatcher").setLevel(logging.WARNING)

        ray.init(dashboard_host="0.0.0.0", runtime_env={"worker_process_setup_hook": configure_logging})
        daft.context.set_runner_ray()
    daft.set_execution_config(actor_udf_ready_timeout=600)
    daft.set_execution_config(min_cpu_per_task=0)

    samples = {
        "text": [
            "爱与和平是世界的主旋律。",
            "我讨厌所有人,最好都去死。",
            None,
        ]
    }

    model_path = os.getenv("MODEL_PATH", "/opt/las/models")
    model_name = "thu-coai/ShieldLM-6B-chatglm3"
    lang = "zh"
    batch_size = 3
    rank = 0

    ds = daft.from_pydict(samples)
    ds = ds.with_column(
        "safety_scores",
        las_udf(
            TextSafetyScorer,
            construct_args={
                "lang": lang,
                "model_path": model_path,
                "model_name": model_name,
                "batch_size": batch_size,
                "rank": rank,
            },
            num_gpus=1,
            batch_size=3,
            concurrency=1,
        )(col("text")),
    )

    ds.show()
    # ╭─────────────────────────────┬────────────────────────────────────────────────────────────────╮
    # │ text                        ┆ safety_scores                                                  │
    # │ ---                         ┆ ---                                                            │
    # │ Utf8                        ┆ Struct[safe: Float64, unsafe: Float64, controversial: Float64] │
    # ╞═════════════════════════════╪════════════════════════════════════════════════════════════════╡
    # │ 爱与和平是世界的主旋律。        ┆ {safe: 0.8632398843765259,                                     │
    # │                             ┆ unsafe: 0.04434245824813843,                                   │
    # │                             ┆ controversial: 0.09241761267185211}                            │
    # ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
    # │ 我讨厌所有人,最好都去死。      ┆ {safe: 0.0004654618678614497,                                  │
    # │                             ┆ unsafe: 0.9993947744369507,                                    │
    # │                             ┆ controversial: 0.00013975700130686164}                         │
    # ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
    # │ None                        ┆ None                                                           │
    # ╰─────────────────────────────┴────────────────────────────────────────────────────────────────╯
最近更新时间:2026.01.08 19:14:22
这个页面对您有帮助吗?
有用
有用
无用
无用