You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

导出Coqui-TTS训练的VITS模型为ONNX时的固定输入长度问题及动态推理实现方法问询

导出Coqui-TTS训练的VITS模型为ONNX时的固定输入长度问题及动态推理实现方法问询

问题描述(来自Darko)

I am trying to export a VITS trained with coqui-tts to .onnx, but I have the following problem:
Model exports, but I have fixed input length.
when I try the model using the onnx runtime in Python, my audio is only 2 seconds long.
When I increase dummy_input_length from 100 to a larger number, the audio is longer but if text is short I am hearing noise in the rest of the audio.

我的导出代码(基于Coqui Vits.py修改):

def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True):
    """Export model to ONNX format for inference

    Args:
        output_path (str): Path to save the exported model.
        verbose (bool): Print verbose information. Defaults to True.
    """

    # rollback values
    _forward = self.forward
    disc = None
    if hasattr(self, "disc"):
        disc = self.disc
    training = self.training

    # set export mode
    self.disc = None
    self.eval()

    def onnx_inference(text, text_lengths, scales, sid=None, langid=None):
        noise_scale = scales[0]
        length_scale = scales[1]
        noise_scale_dp = scales[2]
        self.noise_scale = noise_scale
        self.length_scale = length_scale
        self.noise_scale_dp = noise_scale_dp
        return self.inference(
            text,
            aux_input={
                "x_lengths": text_lengths,
                "d_vectors": None,
                "speaker_ids": sid,
                "language_ids": langid,
                "durations": None,
            },
        )["model_outputs"]

    self.forward = onnx_inference

    # set dummy inputs
    dummy_input_length = 100
    sequences = torch.randint(low=0, high=2, size=(1, dummy_input_length), dtype=torch.long)
    sequence_lengths = torch.LongTensor([sequences.size(1)])
    scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
    dummy_input = (sequences, sequence_lengths, scales)
    input_names = ["input", "input_lengths", "scales"]

    if self.num_speakers > 0:
        speaker_id = torch.LongTensor([0])
        dummy_input += (speaker_id,)
        input_names.append("sid")

    if hasattr(self, "num_languages") and self.num_languages > 0 and self.embedded_language_dim > 0:
        language_id = torch.LongTensor([0])
        dummy_input += (language_id,)
        input_names.append("langid")

    # export to ONNX
    torch.onnx.export(
        model=self,
        args=dummy_input,
        opset_version=18,
        f=output_path,
        verbose=verbose,
        input_names=input_names,
        output_names=["output"],
        dynamic_axes={
            "input": {0: "batch_size", 1: "phonemes"},
            "input_lengths": {0: "batch_size"},
            "output": {0: "batch_size", 1: "time1", 2: "time2"},
        },
        dynamo=False,
    )

    # rollback
    self.forward = _forward
    if training:
        self.train()
    if disc is not None:
        self.disc = disc

我的推理测试代码:

import onnxruntime as ort
import numpy as np
import soundfile as sf


# =========================
# 1. Učitavanje tokena
# =========================
def load_tokens(path):
    token_map = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.rstrip("\n").split(" ")
            if len(parts) >= 2:
                char = " ".join(parts[:-1])  # zbog space karaktera
                idx = int(parts[-1])
                token_map[char] = idx
    return token_map


# =========================
# 2. Tekst -> ID-jevi
# =========================
def text_to_ids(text, token_map, blank_id):
    ids = []

    for ch in text:
        if ch not in token_map:
            print(f"[WARNING] Nepoznat karakter: '{ch}'")
            continue

        ids.append(token_map[ch])
        ids.append(blank_id)  # blank posle svakog karaktera

    return ids


# =========================
# 3. Main
# =========================
def main():
    # PODESI OVE PUTANJE
    model_path = "model.onnx"
    tokens_path = "tokens.txt"
    output_wav = "output.wav"

    # Tekst za test
    text = "Тврд је орах воћка чудновата, не сломи га ал зубе поломи."

    # učitaj tokene
    tokens = load_tokens(tokens_path)

    if "<BLNK>" not in tokens:
        raise Exception("Nema <BLNK> tokena u tokens.txt!")

    blank_id = tokens["<BLNK>"]

    print(f"[INFO] BLANK ID: {blank_id}")

    # tekst -> id-jevi
    ids = text_to_ids(text, tokens, blank_id)

    print(f"[INFO] ID sekvenca: {ids}")

    # =========================
    # =========================
    # ONNX inference
    # =========================
    session = ort.InferenceSession(model_path)

    input_ids = np.array([ids], dtype=np.int64)
    input_lengths = np.array([len(ids)], dtype=np.int64)
    scales = np.array([0.667, 3.0, 0.8], dtype=np.float32)
    sid = np.array([0], dtype=np.int64)
    langid = np.array([0], dtype=np.int64)

    inputs = {
        "input": input_ids,
        "input_lengths": input_lengths,
        "scales": scales,
        "sid": sid,
        "langid": langid
    }

    print("[INFO] Pokrećem inferencu...")

    outputs = session.run(None, inputs)    
    # =========================
    # 5. Audio output
    # =========================
    audio = outputs[0]

    audio = np.squeeze(audio)
    audio = audio.astype(np.float32)

    # zaštita od NaN / Inf
    audio = np.nan_to_num(audio)

    # normalizacija
    if np.max(np.abs(audio)) > 0:
        audio = audio / np.max(np.abs(audio))

    sf.write("output.wav", audio, 22050, format='WAV')
    print(f"[OK] Snimljeno u {output_wav}")


if __name__ == "__main__":
    main()

模型为塞尔维亚语训练,测试文本为塞尔维亚语。我的需求是导出支持动态输入的ONNX模型,根据给定文本生成对应长度的音频,无多余噪音。


专家解答

Hey Darko,这个VITS导出ONNX的动态长度坑我之前也踩过,核心是要让ONNX完全放开序列维度的限制,同时避免导出时固化中间层的维度。下面是具体的解决步骤:

一、问题根源分析

你的当前导出脚本有两个核心问题:

  1. 固定dummy输入导致维度固化:虽然设置了dynamic_axes,但dummy_input_length=100会让ONNX导出器默认把部分中间张量的长度固化,导致推理时即使输入短文本,模型也会生成对应100长度的输出(填充噪音)
  2. 参数传递方式错误:把noise_scale等参数设为模型属性,ONNX无法正确识别这些参数对序列长度的动态影响,反而会干扰时长预测的逻辑

二、修改ONNX导出脚本(支持完全动态长度)

替换你的export_onnx方法为以下代码:

def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True):
    """Export model to ONNX format with full dynamic sequence length support"""

    # Save original state
    _forward = self.forward
    disc = self.disc if hasattr(self, "disc") else None
    training_mode = self.training

    # Prepare for export
    if disc is not None:
        self.disc = None
    self.eval()

    # Redefine forward for ONNX: pass scales directly, no model attribute modification
    def onnx_dynamic_inference(text, text_lengths, noise_scale, length_scale, noise_scale_dp, sid=None, langid=None):
        return self.inference(
            text,
            aux_input={
                "x_lengths": text_lengths,
                "d_vectors": None,
                "speaker_ids": sid,
                "language_ids": langid,
                "durations": None,
            },
            noise_scale=noise_scale,
            length_scale=length_scale,
            noise_scale_dp=noise_scale_dp,
        )["model_outputs"]

    self.forward = onnx_dynamic_inference

    # Use minimal dummy input to force ONNX to recognize dynamic dimensions
    dummy_text = torch.randint(low=0, high=self.num_chars, size=(1, 1), dtype=torch.long)
    dummy_text_lens = torch.LongTensor([dummy_text.size(1)])
    dummy_noise_scale = torch.tensor([self.inference_noise_scale], dtype=torch.float32)
    dummy_length_scale = torch.tensor([self.length_scale], dtype=torch.float32)
    dummy_noise_scale_dp = torch.tensor([self.inference_noise_scale_dp], dtype=torch.float32)

    # Build dummy inputs
    dummy_inputs = (dummy_text, dummy_text_lens, dummy_noise_scale, dummy_length_scale, dummy_noise_scale_dp)
    input_names = ["input", "input_lengths", "noise_scale", "length_scale", "noise_scale_dp"]

    # Add speaker/language inputs if needed
    if self.num_speakers > 0:
        dummy_sid = torch.LongTensor([0])
        dummy_inputs += (dummy_sid,)
        input_names.append("sid")
    if hasattr(self, "num_languages") and self.num_languages > 0 and self.embedded_language_dim > 0:
        dummy_langid = torch.LongTensor([0])
        dummy_inputs += (dummy_langid,)
        input_names.append("langid")

    # Export with explicit dynamic axes definitions
    torch.onnx.export(
        model=self,
        args=dummy_inputs,
        opset_version=18,
        f=output_path,
        verbose=verbose,
        input_names=input_names,
        output_names=["output"],
        dynamic_axes={
            "input": {0: "batch_size", 1: "phoneme_seq_len"},
            "input_lengths": {0: "batch_size"},
            "output": {0: "batch_size", 1: "audio_seq_len"},
            "sid": {0: "batch_size"} if self.num_speakers > 0 else None,
            "langid": {0: "batch_size"} if hasattr(self, "num_languages") and self.num_languages > 0 else None
        },
        dynamo=False,
        do_constant_folding=True,
        keep_initializers_as_inputs=False
    )

    # Restore original model state
    self.forward = _forward
    if disc is not None:
        self.disc = disc
    if training_mode:
        self.train()

三、调整推理脚本适配新的输入格式

修改你的推理代码中ONNX输入部分:

# 替换原scales张量为三个单独的输入
noise_scale = np.array([0.667], dtype=np.float32)
length_scale = np.array([3.0], dtype=np.float32)
noise_scale_dp = np.array([0.8], dtype=np.float32)

inputs = {
    "input": input_ids,
    "input_lengths": input_lengths,
    "noise_scale": noise_scale,
    "length_scale": length_scale,
    "noise_scale_dp": noise_scale_dp,
    "sid": sid,
    "langid": langid
}

outputs = session.run(None, inputs)
audio = outputs[0].squeeze()

四、额外优化建议

  1. 用官方命令导出(推荐):如果你用的是Coqui-TTS >=1.4.0,直接用官方导出命令,自动支持动态长度:
tts --model_path path/to/your/checkpoint --export_format onnx --dynamic
  1. 避免手动添加多余:确保你的text_to_ids函数不会在文本末尾添加额外的,否则可能会导致音频末尾出现短暂噪音
  2. 检查ONNX导出日志:如果导出时出现WARNING: The shape of tensor X is fixed的提示,说明某个维度还是被固化了,可以检查对应层的代码,或者升级PyTorch到>=2.0.0

按照这个方案修改后,你的ONNX模型就能根据输入文本的实际长度动态生成对应时长的音频,不会有多余的噪音了。如果还有问题,告诉我你的Coqui-TTS和PyTorch版本,我再帮你排查~

火山引擎 最新活动