You need to enable JavaScript to run this app.
AI 数据湖服务

AI 数据湖服务

复制全文
音频分类
多语言语音分类
复制全文
多语言语音分类

算子介绍

描述

语音分类模块 - 基于 BEATs 模型的多语言语音分类解决方案
AudioBeatsClassifier 是一个基于 BEATs 模型的音频分类算子,用于识别音频中的主要声音事件,并返回概率最高的 Top K 个分类标签。

核心功能

  • 支持多类型音频:能够自动处理多种类型的音频,识别出来自 Google AudioSet 定义的 527 类声音,例如“音乐”、“语音”、“警报声”或“动物叫声”等。
  • 支持多渠道输入:无缝处理来自本地文件路径、HTTP URL、TOS/S3 对象存储或原始字节流的音频数据。
  • 自动化预处理:内置音频解码和预处理能力,自动将输入音频重采样为 16kHz 采样率的单声道格式,简化了调用流程。

支持模型

基于 BEATs 模型实施语音分类,你可以从 microsoft/unilm 仓库下载对应模型文件。

Daft 调用

算子参数

输入

输入列名

说明

audios

包含音频数据的数组,支持以下格式: - audio_url: 音频文件 URL 路径(支持 HTTP/TOS/S3 等协议 URL,以及本地文件路径); - audio_binary: 原始音频字节数据

输出

算子采用 JSON 数组对分类结果进行组织(如下所示,每个输入音频对应一个 JSON 数组对象),数组中的每个元素代表一个识别出的声音分类,包含两个字段:

  • label:分类标签的唯一标识符,遵循 Google AudioSet 的标签体系(例如 "/m/04rlf" 代表音乐)。
  • probability:该分类的置信度分数,取值范围在 0.0 到 1.0 之间。
[
    {"label": "/m/04rlf", "probability": 0.85},
    {"label": "/m/09x0r", "probability": 0.39},
    {"label": "/m/03qc9zr", "probability": 0.33},
    {"label": "/m/07sr1lc", "probability": 0.27},
    {"label": "/m/07s2xch", "probability": 0.15}
]

参数

如参数没有默认值,则为必填参数

参数名称

类型

默认值

描述

model_path

str

/opt/las/models

模型存储路径

model_name

str

BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt

BEATs 模型名称

top_k

int

5

控制返回的置信得分最高的 K 个分类标签数,默认为 5

precision

int

None

控制置信分数保留的小数位数,默认保留全部小数

调用示例

下面的代码展示了如何使用 Daft 运行算子对语音实施分类:

from __future__ import annotations

import logging
import os

import daft
from daft import col
from daft.las.functions.audio import AudioBeatsClassifier
from daft.las.functions.udf import las_udf

if __name__ == "__main__":
    if os.getenv("DAFT_RUNNER", "ray") == "ray":

        def configure_logging():
            logging.basicConfig(
                level=logging.INFO,
                format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
                datefmt="%Y-%m-%d %H:%M:%S.%s".format(),
            )
            logging.getLogger("tracing.span").setLevel(logging.WARNING)
            logging.getLogger("daft_io.stats").setLevel(logging.WARNING)
            logging.getLogger("DaftStatisticsManager").setLevel(logging.WARNING)
            logging.getLogger("DaftFlotillaScheduler").setLevel(logging.WARNING)
            logging.getLogger("DaftFlotillaDispatcher").setLevel(logging.WARNING)

        import ray

        ray.init(dashboard_host="0.0.0.0", runtime_env={"worker_process_setup_hook": configure_logging})
        daft.set_runner_ray()

    daft.set_execution_config(actor_udf_ready_timeout=600)
    daft.set_execution_config(min_cpu_per_task=0)

    tos_dir_url = os.getenv("TOS_DIR_URL", "las-cn-beijing-public-online.tos-cn-beijing.volces.com")
    samples = {"audio_path": [f"https://{tos_dir_url}/public/shared_audio_dataset/参观八达岭长城。.wav"]}

    df = daft.from_pydict(samples)
    df = df.with_column(
        "classify_result",
        las_udf(
            AudioBeatsClassifier,
            construct_args={
                "model_path": "/opt/las/models",
                "model_name": "BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt",
                "top_k": 5,
                "precision": 2,
            },
            num_gpus=0.25,
            batch_size=8,
            concurrency=4,
        )(col("audio_path")),
    )
    df.show()
最近更新时间:2026.01.08 19:15:09
这个页面对您有帮助吗?
有用
有用
无用
无用