You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

新版Torchtext中TranslationDataset、Field、BucketIterator的替代实现方案咨询

新版Torchtext中TranslationDataset、Field、BucketIterator的替代实现方案咨询

我完全懂你的困扰!新版Torchtext确实做了大刀阔斧的重构,把旧版的TranslationDatasetFieldBucketIterator这些核心组件都移除了,不想退回到legacy版本的话,咱们可以用新版的原生PyTorch兼容组件来替代,下面是具体的实现思路和代码示例:

一、替代旧版Field:用Vocab+Transform Pipeline

旧版Field负责分词、词汇表构建、文本转张量等全流程,新版把这些拆分成了更灵活的组件:

  • get_tokenizer获取分词器(支持spaCy、BPE等多种方式)
  • build_vocab_from_iterator从数据迭代器构建词汇表
  • SequentialTransforms把分词、转索引、添加特殊符号、转张量、补全这些步骤串成流水线

示例代码:

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.transforms import SequentialTransforms, VocabTransform, ToTensorTransform, PadTransform

# 1. 定义源语言和目标语言的分词器
src_tokenizer = get_tokenizer("spacy", language="en_core_web_sm")  # 英文分词
tgt_tokenizer = get_tokenizer("spacy", language="de_core_news_sm")  # 德文分词

# 2. 定义生成词汇表的辅助函数
def yield_tokens(data_iter, tokenizer, text_idx):
    # data_iter的每个元素是(源文本, 目标文本),text_idx取0是源文本,1是目标文本
    for sample in data_iter:
        yield tokenizer(sample[text_idx])

# 3. 从训练数据构建词汇表(假设train_iter是你的训练数据迭代器)
src_vocab = build_vocab_from_iterator(
    yield_tokens(train_iter, src_tokenizer, 0),
    specials=["<unk>", "<pad>", "<bos>", "<eos>"]  # 特殊符号:未知、填充、句首、句尾
)
tgt_vocab = build_vocab_from_iterator(
    yield_tokens(train_iter, tgt_tokenizer, 1),
    specials=["<unk>", "<pad>", "<bos>", "<eos>"]
)
# 设置未知词的默认索引
src_vocab.set_default_index(src_vocab["<unk>"])
tgt_vocab.set_default_index(tgt_vocab["<unk>"])

# 4. 构建文本转张量的流水线
MAX_SRC_LEN = 50  # 源文本最大长度
MAX_TGT_LEN = 60  # 目标文本最大长度

src_transform = SequentialTransforms(
    src_tokenizer,
    VocabTransform(src_vocab),
    lambda tokens: [src_vocab["<bos>"]] + tokens + [src_vocab["<eos>"]],  # 添加句首句尾符号
    ToTensorTransform(),
    PadTransform(max_length=MAX_SRC_LEN, pad_value=src_vocab["<pad>"])  # 补全到最大长度
)

tgt_transform = SequentialTransforms(
    tgt_tokenizer,
    VocabTransform(tgt_vocab),
    lambda tokens: [tgt_vocab["<bos>"]] + tokens + [tgt_vocab["<eos>"]],
    ToTensorTransform(),
    PadTransform(max_length=MAX_TGT_LEN, pad_value=tgt_vocab["<pad>"])
)

二、替代旧版TranslationDataset:自定义Dataset或用内置数据集加载器

新版Torchtext不再提供专门的TranslationDataset,而是推荐用PyTorch原生的Dataset自定义,或者直接用内置的数据集加载函数:

1. 自定义翻译数据集

如果你的数据是本地的文本文件(每行一条样本),可以这样写:

from torch.utils.data import Dataset

class CustomTranslationDataset(Dataset):
    def __init__(self, src_file_path, tgt_file_path, src_transform, tgt_transform):
        # 加载源语言和目标语言数据
        self.src_texts = [line.strip() for line in open(src_file_path, encoding="utf-8") if line.strip()]
        self.tgt_texts = [line.strip() for line in open(tgt_file_path, encoding="utf-8") if line.strip()]
        self.src_transform = src_transform
        self.tgt_transform = tgt_transform

    def __len__(self):
        return len(self.src_texts)

    def __getitem__(self, idx):
        # 对单条样本应用转换流水线
        src_tensor = self.src_transform(self.src_texts[idx])
        tgt_tensor = self.tgt_transform(self.tgt_texts[idx])
        return src_tensor, tgt_tensor

2. 用内置翻译数据集

如果用的是公开数据集(比如IWSLT、WMT),新版Torchtext提供了直接的加载函数:

from torchtext.datasets import IWSLT2016
# 加载英德翻译的训练集
train_iter = IWSLT2016(split='train', language_pair=('en', 'de'))

这个迭代器可以直接用来构建词汇表,或者转换成Dataset后使用。

三、替代旧版BucketIterator:BucketSampler+DataLoader

旧版BucketIterator的核心是按样本长度分组批量,新版可以用BucketSampler配合PyTorch原生的DataLoader实现:

from torch.utils.data import DataLoader
from torchtext.data.utils import BucketSampler

# 定义排序键函数:返回源文本的分词长度,用于BucketSampler分组
def sort_key(sample):
    # 这里假设sample是(源文本, 目标文本)的原始样本,如果是已经转换后的张量,可以取size(0)
    src_text = sample[0]
    return len(src_tokenizer(src_text))

# 创建BucketSampler
bucket_sampler = BucketSampler(
    dataset=your_custom_dataset,  # 这里填你的自定义Dataset
    batch_size=32,  # 批量大小
    sort_key=sort_key,
    shuffle=True  # 每个epoch打乱分组
)

# 构建DataLoader,替代BucketIterator
train_dataloader = DataLoader(
    dataset=your_custom_dataset,
    batch_sampler=bucket_sampler,
    # 自定义collate_fn,把batch里的样本堆叠成张量
    collate_fn=lambda batch: (
        torch.stack([item[0] for item in batch]),
        torch.stack([item[1] for item in batch])
    )
)

一些额外注意点

  • 新版Torchtext更强调和PyTorch原生生态的兼容,灵活性更高,但需要自己拼接各个组件
  • 构建词汇表时,可以通过min_freq参数过滤低频词,比如build_vocab_from_iterator(..., min_freq=2)
  • 如果需要动态调整序列长度,可以不用固定max_length,而是在collate_fn里动态补全到当前batch的最大长度

备注:内容来源于stack exchange,提问作者виктор ивнов

火山引擎 最新活动