图像安全性(NSFW)检测器——支持多源输入与批量推理
输入列名 | 说明 |
|---|---|
images | 包含输入图像的数组,支持 URL、Base64 或二进制格式。 |
返回一个包含检测结果的数组,其中每个元素是对应图像的 NSFW(Not Safe for Work)置信度分数(浮点类型)。如果检测失败,则该元素为 None。
如参数没有默认值,则为必填参数
参数名称 | 类型 | 默认值 | 描述 |
|---|---|---|---|
image_src_type | str | "image_url" | 输入图像的格式类型。可选值为 ["image_url", "image_base64", "image_binary"]。 |
model_path | str | "/home/ray/workdir/models" | 预训练模型在本地的根目录路径。 |
model_name | str | "Falconsai/nsfw_image_detection" | 在 model_path 下的具体模型目录名称。可选值为 ["Falconsai/nsfw_image_detection"]。 |
dtype | str | "float16" | 模型推理精度选择。float16 速度更快,float32 精度更高但显存占用也更大。可选值为 ["float16", "float32"]。 |
batch_size | int | 16 | 每次送入模型进行推理的图片数量。批量越大吞吐越高,但显存占用也更高。 |
rank | int | 0 | 推理所使用的 GPU 编号。当使用 CPU 进行推理时,此参数可保持为 0。 |
下面的代码展示了如何使用 Daft 和 LAS UDF 对图像进行 NSFW 检测。
from __future__ import annotations import os import daft from daft import col from daft.las.functions.image.image_nsfw_detect import ImageNsfwDetect from daft.las.functions.udf import las_udf if __name__ == "__main__": if os.getenv("DAFT_RUNNER", "native") == "ray": import logging import ray def configure_logging(): logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S.%s".format(), ) logging.getLogger("tracing.span").setLevel(logging.WARNING) logging.getLogger("daft_io.stats").setLevel(logging.WARNING) logging.getLogger("DaftStatisticsManager").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaScheduler").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaDispatcher").setLevel(logging.WARNING) ray.init(dashboard_host="0.0.0.0", runtime_env={"worker_process_setup_hook": configure_logging}) daft.context.set_runner_ray() daft.set_execution_config(actor_udf_ready_timeout=600) daft.set_execution_config(min_cpu_per_task=0) tos_dir_url = os.getenv("TOS_DIR_URL", "las-cn-beijing-public-online.tos-cn-beijing.volces.com") samples = { "image": [ f"https://{tos_dir_url}/public/shared_image_dataset/cat_ip_adapter.jpeg" ] } image_src_type = "image_url" model_path = os.getenv("MODEL_PATH", "/opt/las/models") model_name = "Falconsai/nsfw_image_detection" rank = 0 num_gpus = 0 batch_size = 1 ds = daft.from_pydict(samples) ds = ds.with_column( "nsfw_detect", las_udf( ImageNsfwDetect, construct_args={ "image_src_type": image_src_type, "batch_size": batch_size, "model_path": model_path, "model_name": model_name, "rank": rank, }, num_gpus=num_gpus, batch_size=1, )(col("image")), ) ds.show() # ╭────────────────────────────────┬────────────────────────╮ # │ image ┆ nsfw_detect │ # │ --- ┆ --- │ # │ Utf8 ┆ Float64 │ # ╞════════════════════════════════╪════════════════════════╡ # │ https://las-cn-beijing-public… ┆ 0.000114 │ # ╰────────────────────────────────┴────────────────────────╯