语音分类模块 - 基于 BEATs 模型的多语言语音分类解决方案
AudioBeatsClassifier 是一个基于 BEATs 模型的音频分类算子,用于识别音频中的主要声音事件,并返回概率最高的 Top K 个分类标签。
基于 BEATs 模型实施语音分类,你可以从 microsoft/unilm 仓库下载对应模型文件。
输入列名 | 说明 |
|---|---|
audios | 包含音频数据的数组,支持以下格式: - audio_url: 音频文件 URL 路径(支持 HTTP/TOS/S3 等协议 URL,以及本地文件路径); - audio_binary: 原始音频字节数据 |
算子采用 JSON 数组对分类结果进行组织(如下所示,每个输入音频对应一个 JSON 数组对象),数组中的每个元素代表一个识别出的声音分类,包含两个字段:
label:分类标签的唯一标识符,遵循 Google AudioSet 的标签体系(例如 "/m/04rlf" 代表音乐)。probability:该分类的置信度分数,取值范围在 0.0 到 1.0 之间。[ {"label": "/m/04rlf", "probability": 0.85}, {"label": "/m/09x0r", "probability": 0.39}, {"label": "/m/03qc9zr", "probability": 0.33}, {"label": "/m/07sr1lc", "probability": 0.27}, {"label": "/m/07s2xch", "probability": 0.15} ]
如参数没有默认值,则为必填参数
参数名称 | 类型 | 默认值 | 描述 |
|---|---|---|---|
model_path | str | /opt/las/models | 模型存储路径 |
model_name | str | BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt | BEATs 模型名称 |
top_k | int | 5 | 控制返回的置信得分最高的 K 个分类标签数,默认为 5 |
precision | int | None | 控制置信分数保留的小数位数,默认保留全部小数 |
下面的代码展示了如何使用 Daft 运行算子对语音实施分类:
from __future__ import annotations import logging import os import daft from daft import col from daft.las.functions.audio import AudioBeatsClassifier from daft.las.functions.udf import las_udf if __name__ == "__main__": if os.getenv("DAFT_RUNNER", "ray") == "ray": def configure_logging(): logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S.%s".format(), ) logging.getLogger("tracing.span").setLevel(logging.WARNING) logging.getLogger("daft_io.stats").setLevel(logging.WARNING) logging.getLogger("DaftStatisticsManager").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaScheduler").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaDispatcher").setLevel(logging.WARNING) import ray ray.init(dashboard_host="0.0.0.0", runtime_env={"worker_process_setup_hook": configure_logging}) daft.set_runner_ray() daft.set_execution_config(actor_udf_ready_timeout=600) daft.set_execution_config(min_cpu_per_task=0) tos_dir_url = os.getenv("TOS_DIR_URL", "las-cn-beijing-public-online.tos-cn-beijing.volces.com") samples = {"audio_path": [f"https://{tos_dir_url}/public/shared_audio_dataset/参观八达岭长城。.wav"]} df = daft.from_pydict(samples) df = df.with_column( "classify_result", las_udf( AudioBeatsClassifier, construct_args={ "model_path": "/opt/las/models", "model_name": "BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt", "top_k": 5, "precision": 2, }, num_gpus=0.25, batch_size=8, concurrency=4, )(col("audio_path")), ) df.show()