音频质量评分模块 - 使用 DNSMOS 模型评估音频质量
输入列名 | 说明 |
|---|---|
audio_col | 包含音频二进制数据的数组,每个元素应为一段完整的音频内容 |
结构化结果数组,其中每个元素包含以下字段:
处理失败的音频返回包含 null 值的结构
如参数没有默认值,则为必填参数
参数名称 | 类型 | 默认值 | 描述 |
|---|---|---|---|
model_path | str | /opt/las/models | 本地 DNSMOS 模型文件所在目录 默认值:"/opt/las/models" |
device | str | cpu | 运行模型的设备('cuda' 或 'cpu') 默认值:"cpu" |
is_personalized_mos | bool | False | 是否使用个性化 MOS 评分 默认值:False |
下面的代码展示了如何使用 daft 运行算子对音频进行质量评分。
from __future__ import annotations import os import daft from daft import col from daft.las.functions.audio import AudioQualityScore from daft.las.functions.udf import las_udf if __name__ == "__main__": TOS_TEST_DIR_URL = os.getenv("TOS_TEST_DIR_URL", "las-cn-beijing-public-online.tos-cn-beijing.volces.com") model_path = os.getenv("MODEL_PATH", "/opt/las/models") if os.getenv("DAFT_RUNNER", "native") == "ray": import logging import ray def configure_logging(): logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) logging.getLogger("tracing.span").setLevel(logging.WARNING) logging.getLogger("daft_io.stats").setLevel(logging.WARNING) logging.getLogger("DaftStatisticsManager").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaScheduler").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaDispatcher").setLevel(logging.WARNING) ray.init(dashboard_host="0.0.0.0", runtime_env={"worker_process_setup_hook": configure_logging}) daft.context.set_runner_ray() daft.set_execution_config(actor_udf_ready_timeout=600) daft.set_execution_config(min_cpu_per_task=0) samples = {"audio_path": [f"https://{TOS_TEST_DIR_URL}/public/archive/audio_quality_score/test_audio.wav"]} df = daft.from_pydict(samples) df = df.with_column( "audio_quality_score", las_udf( AudioQualityScore, construct_args={"model_path": model_path, "device": "cuda"}, num_gpus=1, batch_size=8, concurrency=1, )(col("audio_path")), ) df.show() # ╭────────────────────────────────┬───────────────────────────────────────────────────╮ # │ audio_path ┆ audio_quality_score │ # │ --- ┆ --- │ # │ String ┆ Struct[ovrl: Float64, sig: Float64, bak: Float64] │ # ╞════════════════════════════════╪═══════════════════════════════════════════════════╡ # │ https://las-public-data-qa.to… ┆ {ovrl: 1.7469981067293439, │ # │ ┆ si… │ # ╰────────────────────────────────┴───────────────────────────────────────────────────╯