导出Coqui-TTS训练的VITS模型为ONNX时的固定输入长度问题及动态推理实现方法问询
导出Coqui-TTS训练的VITS模型为ONNX时的固定输入长度问题及动态推理实现方法问询
问题描述(来自Darko)
I am trying to export a VITS trained with coqui-tts to .onnx, but I have the following problem:
Model exports, but I have fixed input length.
when I try the model using the onnx runtime in Python, my audio is only 2 seconds long.
When I increase dummy_input_length from 100 to a larger number, the audio is longer but if text is short I am hearing noise in the rest of the audio.
我的导出代码(基于Coqui Vits.py修改):
def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True): """Export model to ONNX format for inference Args: output_path (str): Path to save the exported model. verbose (bool): Print verbose information. Defaults to True. """ # rollback values _forward = self.forward disc = None if hasattr(self, "disc"): disc = self.disc training = self.training # set export mode self.disc = None self.eval() def onnx_inference(text, text_lengths, scales, sid=None, langid=None): noise_scale = scales[0] length_scale = scales[1] noise_scale_dp = scales[2] self.noise_scale = noise_scale self.length_scale = length_scale self.noise_scale_dp = noise_scale_dp return self.inference( text, aux_input={ "x_lengths": text_lengths, "d_vectors": None, "speaker_ids": sid, "language_ids": langid, "durations": None, }, )["model_outputs"] self.forward = onnx_inference # set dummy inputs dummy_input_length = 100 sequences = torch.randint(low=0, high=2, size=(1, dummy_input_length), dtype=torch.long) sequence_lengths = torch.LongTensor([sequences.size(1)]) scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp]) dummy_input = (sequences, sequence_lengths, scales) input_names = ["input", "input_lengths", "scales"] if self.num_speakers > 0: speaker_id = torch.LongTensor([0]) dummy_input += (speaker_id,) input_names.append("sid") if hasattr(self, "num_languages") and self.num_languages > 0 and self.embedded_language_dim > 0: language_id = torch.LongTensor([0]) dummy_input += (language_id,) input_names.append("langid") # export to ONNX torch.onnx.export( model=self, args=dummy_input, opset_version=18, f=output_path, verbose=verbose, input_names=input_names, output_names=["output"], dynamic_axes={ "input": {0: "batch_size", 1: "phonemes"}, "input_lengths": {0: "batch_size"}, "output": {0: "batch_size", 1: "time1", 2: "time2"}, }, dynamo=False, ) # rollback self.forward = _forward if training: self.train() if disc is not None: self.disc = disc
我的推理测试代码:
import onnxruntime as ort import numpy as np import soundfile as sf # ========================= # 1. Učitavanje tokena # ========================= def load_tokens(path): token_map = {} with open(path, "r", encoding="utf-8") as f: for line in f: parts = line.rstrip("\n").split(" ") if len(parts) >= 2: char = " ".join(parts[:-1]) # zbog space karaktera idx = int(parts[-1]) token_map[char] = idx return token_map # ========================= # 2. Tekst -> ID-jevi # ========================= def text_to_ids(text, token_map, blank_id): ids = [] for ch in text: if ch not in token_map: print(f"[WARNING] Nepoznat karakter: '{ch}'") continue ids.append(token_map[ch]) ids.append(blank_id) # blank posle svakog karaktera return ids # ========================= # 3. Main # ========================= def main(): # PODESI OVE PUTANJE model_path = "model.onnx" tokens_path = "tokens.txt" output_wav = "output.wav" # Tekst za test text = "Тврд је орах воћка чудновата, не сломи га ал зубе поломи." # učitaj tokene tokens = load_tokens(tokens_path) if "<BLNK>" not in tokens: raise Exception("Nema <BLNK> tokena u tokens.txt!") blank_id = tokens["<BLNK>"] print(f"[INFO] BLANK ID: {blank_id}") # tekst -> id-jevi ids = text_to_ids(text, tokens, blank_id) print(f"[INFO] ID sekvenca: {ids}") # ========================= # ========================= # ONNX inference # ========================= session = ort.InferenceSession(model_path) input_ids = np.array([ids], dtype=np.int64) input_lengths = np.array([len(ids)], dtype=np.int64) scales = np.array([0.667, 3.0, 0.8], dtype=np.float32) sid = np.array([0], dtype=np.int64) langid = np.array([0], dtype=np.int64) inputs = { "input": input_ids, "input_lengths": input_lengths, "scales": scales, "sid": sid, "langid": langid } print("[INFO] Pokrećem inferencu...") outputs = session.run(None, inputs) # ========================= # 5. Audio output # ========================= audio = outputs[0] audio = np.squeeze(audio) audio = audio.astype(np.float32) # zaštita od NaN / Inf audio = np.nan_to_num(audio) # normalizacija if np.max(np.abs(audio)) > 0: audio = audio / np.max(np.abs(audio)) sf.write("output.wav", audio, 22050, format='WAV') print(f"[OK] Snimljeno u {output_wav}") if __name__ == "__main__": main()
模型为塞尔维亚语训练,测试文本为塞尔维亚语。我的需求是导出支持动态输入的ONNX模型,根据给定文本生成对应长度的音频,无多余噪音。
专家解答
Hey Darko,这个VITS导出ONNX的动态长度坑我之前也踩过,核心是要让ONNX完全放开序列维度的限制,同时避免导出时固化中间层的维度。下面是具体的解决步骤:
一、问题根源分析
你的当前导出脚本有两个核心问题:
- 固定dummy输入导致维度固化:虽然设置了
dynamic_axes,但dummy_input_length=100会让ONNX导出器默认把部分中间张量的长度固化,导致推理时即使输入短文本,模型也会生成对应100长度的输出(填充噪音) - 参数传递方式错误:把
noise_scale等参数设为模型属性,ONNX无法正确识别这些参数对序列长度的动态影响,反而会干扰时长预测的逻辑
二、修改ONNX导出脚本(支持完全动态长度)
替换你的export_onnx方法为以下代码:
def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True): """Export model to ONNX format with full dynamic sequence length support""" # Save original state _forward = self.forward disc = self.disc if hasattr(self, "disc") else None training_mode = self.training # Prepare for export if disc is not None: self.disc = None self.eval() # Redefine forward for ONNX: pass scales directly, no model attribute modification def onnx_dynamic_inference(text, text_lengths, noise_scale, length_scale, noise_scale_dp, sid=None, langid=None): return self.inference( text, aux_input={ "x_lengths": text_lengths, "d_vectors": None, "speaker_ids": sid, "language_ids": langid, "durations": None, }, noise_scale=noise_scale, length_scale=length_scale, noise_scale_dp=noise_scale_dp, )["model_outputs"] self.forward = onnx_dynamic_inference # Use minimal dummy input to force ONNX to recognize dynamic dimensions dummy_text = torch.randint(low=0, high=self.num_chars, size=(1, 1), dtype=torch.long) dummy_text_lens = torch.LongTensor([dummy_text.size(1)]) dummy_noise_scale = torch.tensor([self.inference_noise_scale], dtype=torch.float32) dummy_length_scale = torch.tensor([self.length_scale], dtype=torch.float32) dummy_noise_scale_dp = torch.tensor([self.inference_noise_scale_dp], dtype=torch.float32) # Build dummy inputs dummy_inputs = (dummy_text, dummy_text_lens, dummy_noise_scale, dummy_length_scale, dummy_noise_scale_dp) input_names = ["input", "input_lengths", "noise_scale", "length_scale", "noise_scale_dp"] # Add speaker/language inputs if needed if self.num_speakers > 0: dummy_sid = torch.LongTensor([0]) dummy_inputs += (dummy_sid,) input_names.append("sid") if hasattr(self, "num_languages") and self.num_languages > 0 and self.embedded_language_dim > 0: dummy_langid = torch.LongTensor([0]) dummy_inputs += (dummy_langid,) input_names.append("langid") # Export with explicit dynamic axes definitions torch.onnx.export( model=self, args=dummy_inputs, opset_version=18, f=output_path, verbose=verbose, input_names=input_names, output_names=["output"], dynamic_axes={ "input": {0: "batch_size", 1: "phoneme_seq_len"}, "input_lengths": {0: "batch_size"}, "output": {0: "batch_size", 1: "audio_seq_len"}, "sid": {0: "batch_size"} if self.num_speakers > 0 else None, "langid": {0: "batch_size"} if hasattr(self, "num_languages") and self.num_languages > 0 else None }, dynamo=False, do_constant_folding=True, keep_initializers_as_inputs=False ) # Restore original model state self.forward = _forward if disc is not None: self.disc = disc if training_mode: self.train()
三、调整推理脚本适配新的输入格式
修改你的推理代码中ONNX输入部分:
# 替换原scales张量为三个单独的输入 noise_scale = np.array([0.667], dtype=np.float32) length_scale = np.array([3.0], dtype=np.float32) noise_scale_dp = np.array([0.8], dtype=np.float32) inputs = { "input": input_ids, "input_lengths": input_lengths, "noise_scale": noise_scale, "length_scale": length_scale, "noise_scale_dp": noise_scale_dp, "sid": sid, "langid": langid } outputs = session.run(None, inputs) audio = outputs[0].squeeze()
四、额外优化建议
- 用官方命令导出(推荐):如果你用的是Coqui-TTS >=1.4.0,直接用官方导出命令,自动支持动态长度:
tts --model_path path/to/your/checkpoint --export_format onnx --dynamic
- 避免手动添加多余
:确保你的 text_to_ids函数不会在文本末尾添加额外的,否则可能会导致音频末尾出现短暂噪音 - 检查ONNX导出日志:如果导出时出现
WARNING: The shape of tensor X is fixed的提示,说明某个维度还是被固化了,可以检查对应层的代码,或者升级PyTorch到>=2.0.0
按照这个方案修改后,你的ONNX模型就能根据输入文本的实际长度动态生成对应时长的音频,不会有多余的噪音了。如果还有问题,告诉我你的Coqui-TTS和PyTorch版本,我再帮你排查~




