You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何从TensorFlow Dataset中获取批量大小?

嘿,针对你想从TensorFlow Dataset或者它生成的迭代器里获取批量大小的需求,我结合你给出的输入流水线代码,整理了几个实用的方案:

首先先确认下你的基础流水线代码是这样的:

import tensorflow as tf

# 指定数据集
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
# 打乱数据集
dataset = dataset.shuffle(buffer_size=1e5)
# 指定批量大小
dataset = dataset.batch(128)
# 创建迭代器
iterator = dataset.make_one_shot_iterator()
# 获取下一批数据
next_batch = iterator.get_next()

一、从Dataset对象直接获取设置的批量大小

方法1:提前存变量(最推荐!)

其实最简单的方式就是把batch size定义成一个变量,后续不管哪里需要直接调用就行,完全不用额外解析:

# 提前定义好批量大小
BATCH_SIZE = 128
# 用这个变量来设置batch
dataset = dataset.batch(BATCH_SIZE)

# 之后需要用的时候直接拿这个变量
print(f"设置的批量大小是:{BATCH_SIZE}")

这种方式的好处是后续修改批量大小的时候只改一处就行,避免遗漏,也不会出现解析错误的情况。

方法2:从Dataset的element_spec解析

如果是接手别人的代码,或者不想额外定义变量,也可以通过Dataset的element_spec属性来提取你设置的固定批量大小:

# 从特征张量的形状第0维获取设置的batch size
batch_size_from_dataset = dataset.element_spec[0].shape[0]
print(f"从Dataset解析出的批量大小:{batch_size_from_dataset}")

⚠️ 注意:这种方法拿到的是你设置的固定值,如果数据集最后一批样本数量不足设置的大小,这个值不会反映真实的批量数。

二、从迭代器/批次张量中获取真实批量大小

如果想拿到每一批实际的样本数量(比如最后一批可能比设置的batch size小),可以用tf.shape()来动态获取:

# 先把批次数据拆分成特征和标签
features_batch, labels_batch = next_batch

# 动态获取当前批次的真实大小
actual_batch_size = tf.shape(features_batch)[0]

# 在会话中运行就能拿到具体数值
with tf.Session() as sess:
    try:
        while True:
            current_batch_size = sess.run(actual_batch_size)
            print(f"当前批次的真实样本数:{current_batch_size}")
    except tf.errors.OutOfRangeError:
        print("数据集已经遍历完啦")

另外补充下TensorFlow 2.x的情况,TF2.x里迭代器的用法更简洁,直接遍历Dataset就行,这时候获取真实批量大小更简单:

# TF2.x 环境下的写法
for features_batch, labels_batch in dataset:
    # 直接转成numpy数值拿到真实批量大小
    actual_batch_size = tf.shape(features_batch)[0].numpy()
    print(f"当前批次真实样本数:{actual_batch_size}")

内容的提问来源于stack exchange,提问作者Miladiouss

火山引擎 最新活动