算子ID:daft.las.functions.audio.audio_ctc_aligner.AudioCTCAligner
多语言音频 CTC 对齐算子(AudioCTCAligner)是一个基于 MMS Forced Aligner 模型的多语言音频 CTC 对齐算子(CTC,Connectionist Temporal Classification),旨在将音频内容与对应的文本脚本在时间维度上精确对齐。
基于 MMS Forced Aligner 模型实施音频 CTC 对齐,你可以点此下载对应模型文件。
输入列名 | 说明 |
|---|---|
audios | 包含音频数据的数组,支持以下格式:
|
text | 与音频内容对应的文本列表 |
lang | 音频文本对应的语言,目前仅支持 "en" (英文) 和 "zh" (中文) |
算子采用 JSON 数组对分类结果进行组织(如下所示,每个输入音频对应一个 JSON 数组对象),数组中的每个元素包含 word(单词)、score(置信度)、start(开始时间戳)和 end(结束时间戳),其中时间戳以毫秒为单位。
[ {"word": "i", "score": 1.0, "start": 644, "end": 664}, {"word": "had", "score": 0.98, "start": 704, "end": 845}, {"word": "that", "score": 1.0, "start": 885, "end": 1026}, {"word": "curiosity", "score": 1.0, "start": 1086, "end": 1790}, {"word": "beside", "score": 0.97, "start": 1871, "end": 2314}, {"word": "me", "score": 1.0, "start": 2334, "end": 2414}, {"word": "at", "score": 1.0, "start": 2495, "end": 2575}, {"word": "this", "score": 1.0, "start": 2595, "end": 2756}, {"word": "moment", "score": 1.0, "start": 2837, "end": 3138}, ]
如参数没有默认值,则为必填参数。
参数名称 | 类型 | 默认值 | 描述 |
|---|---|---|---|
model_path | str | /opt/las/models | 模型存储路径 |
model_name | str | MMS/ctc_alignment_mling_uroman_model.pt | 应用于 CTC 对齐的模型名称 |
下面的代码展示了如何使用 Daft 运行算子对语音和文本实施 CTC 对齐。
from __future__ import annotations import logging import os import daft from daft import col from daft.las.functions.audio.audio_ctc_aligner import AudioCTCAligner from daft.las.functions.udf import las_udf if __name__ == "__main__": if os.getenv("DAFT_RUNNER", "ray") == "ray": def configure_logging(): logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S.%s".format(), ) logging.getLogger("tracing.span").setLevel(logging.WARNING) logging.getLogger("daft_io.stats").setLevel(logging.WARNING) logging.getLogger("DaftStatisticsManager").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaScheduler").setLevel(logging.WARNING) logging.getLogger("DaftFlotillaDispatcher").setLevel(logging.WARNING) import ray ray.init(dashboard_host="0.0.0.0", runtime_env={"worker_process_setup_hook": configure_logging}) daft.set_runner_ray() daft.set_execution_config(actor_udf_ready_timeout=600) daft.set_execution_config(min_cpu_per_task=0) tos_dir_url = os.getenv("TOS_DIR_URL", "las-cn-beijing-public-online.tos-cn-beijing.volces.com") samples = { "audio_path": [f"https://{tos_dir_url}/public/shared_audio_dataset/参观八达岭长城。.wav"], "text": ["参观八达岭长城"], "lang": ["zh"], } df = daft.from_pydict(samples) df = df.with_column( "ctc_result", las_udf( AudioCTCAligner, construct_args={ "model_path": "/opt/las/models", "model_name": "MMS/ctc_alignment_mling_uroman_model.pt", }, num_gpus=0.1, batch_size=16, concurrency=8, )(col("audio_path"), col("text"), col("lang")), ) df.show()