如何从TensorFlow Dataset中获取批量大小?
嘿,针对你想从TensorFlow Dataset或者它生成的迭代器里获取批量大小的需求,我结合你给出的输入流水线代码,整理了几个实用的方案:
首先先确认下你的基础流水线代码是这样的:
import tensorflow as tf # 指定数据集 dataset = tf.data.Dataset.from_tensor_slices((features, labels)) # 打乱数据集 dataset = dataset.shuffle(buffer_size=1e5) # 指定批量大小 dataset = dataset.batch(128) # 创建迭代器 iterator = dataset.make_one_shot_iterator() # 获取下一批数据 next_batch = iterator.get_next()
一、从Dataset对象直接获取设置的批量大小
方法1:提前存变量(最推荐!)
其实最简单的方式就是把batch size定义成一个变量,后续不管哪里需要直接调用就行,完全不用额外解析:
# 提前定义好批量大小 BATCH_SIZE = 128 # 用这个变量来设置batch dataset = dataset.batch(BATCH_SIZE) # 之后需要用的时候直接拿这个变量 print(f"设置的批量大小是:{BATCH_SIZE}")
这种方式的好处是后续修改批量大小的时候只改一处就行,避免遗漏,也不会出现解析错误的情况。
方法2:从Dataset的element_spec解析
如果是接手别人的代码,或者不想额外定义变量,也可以通过Dataset的element_spec属性来提取你设置的固定批量大小:
# 从特征张量的形状第0维获取设置的batch size batch_size_from_dataset = dataset.element_spec[0].shape[0] print(f"从Dataset解析出的批量大小:{batch_size_from_dataset}")
⚠️ 注意:这种方法拿到的是你设置的固定值,如果数据集最后一批样本数量不足设置的大小,这个值不会反映真实的批量数。
二、从迭代器/批次张量中获取真实批量大小
如果想拿到每一批实际的样本数量(比如最后一批可能比设置的batch size小),可以用tf.shape()来动态获取:
# 先把批次数据拆分成特征和标签 features_batch, labels_batch = next_batch # 动态获取当前批次的真实大小 actual_batch_size = tf.shape(features_batch)[0] # 在会话中运行就能拿到具体数值 with tf.Session() as sess: try: while True: current_batch_size = sess.run(actual_batch_size) print(f"当前批次的真实样本数:{current_batch_size}") except tf.errors.OutOfRangeError: print("数据集已经遍历完啦")
另外补充下TensorFlow 2.x的情况,TF2.x里迭代器的用法更简洁,直接遍历Dataset就行,这时候获取真实批量大小更简单:
# TF2.x 环境下的写法 for features_batch, labels_batch in dataset: # 直接转成numpy数值拿到真实批量大小 actual_batch_size = tf.shape(features_batch)[0].numpy() print(f"当前批次真实样本数:{actual_batch_size}")
内容的提问来源于stack exchange,提问作者Miladiouss




