Keras增大Batch Size时GPU显存占用未增加的问题及优化咨询
首先得明确一个核心点:模型参数占用的显存是固定的(你看到的1017MiB就是模型权重、偏置等参数的显存占用),而预测阶段的额外显存占用来自输入张量、中间计算的激活值等。你遇到的「batch size增大但显存占用不变」,主要是因为TensorFlow动态图模式下的显存分配逻辑,加上allow_growth=True的特性——它会按需分配显存,处理完一个batch后可能释放临时显存(或放入内部缓存,但nvidia-smi显示的是进程总占用,缓存未被复用的话就看不出变化)。
下面是几个优雅的方案来利用闲置显存,同时提升预测效率:
1. 用静态图包装预测逻辑(推荐)
TensorFlow默认的动态图模式在预测时会逐batch动态分配显存,容易出现碎片化和显存复用率低的问题。用tf.function包装预测函数,让TensorFlow构建静态计算图,它会预先优化显存分配,处理大batch时会持续占用足够的显存:
import tensorflow as tf from keras.models import load_model # 加载模型 model = load_model(model_h5_path) # 用tf.function包装,指定输入形状帮助TensorFlow优化 @tf.function(input_signature=[tf.TensorSpec(shape=(None, *model.input_shape[1:]), dtype=tf.float32)]) def optimized_predict(input_batch): return model(input_batch, training=False) # training=False确保用推理模式 # 测试大batch输入 # 假设你的模型输入是(224,224,3)的图像,生成1000张测试图 large_test_batch = tf.random.normal((1000, *model.input_shape[1:])) predictions = optimized_predict(large_test_batch) # 查看实际显存使用(比nvidia-smi更准确) mem_info = tf.config.experimental.get_memory_info('GPU:0') print(f"当前已用显存: {round(mem_info['used'] / 1024**2, 2)} MiB") print(f"峰值显存: {round(mem_info['peak'] / 1024**2, 2)} MiB")
静态图模式下,TensorFlow会为大batch预分配合适的显存空间,你会看到显存占用明显提升,同时预测速度也会更快(减少了动态图的开销)。
2. 构建高效的GPU数据管道
用tf.data.Dataset构建数据管道,让数据预加载并缓存到GPU,这样GPU在处理当前batch时,下一个batch已经准备好,显存会持续占用更大的空间(因为数据管道会缓存更多数据在GPU显存/内存中):
import tensorflow as tf from keras.models import load_model model = load_model(model_h5_path) # 假设你的图像路径存在file_paths列表中 file_paths = ["path/to/img1.jpg", "path/to/img2.jpg", ...] # 定义图像加载预处理函数 def load_and_preprocess_image(img_path): img = tf.io.read_file(img_path) img = tf.image.decode_jpeg(img, channels=3) img = tf.image.resize(img, model.input_shape[1:3]) # 匹配模型输入尺寸 # 替换成你模型对应的预处理函数 img = tf.keras.applications.resnet50.preprocess_input(img) return img # 构建数据管道 dataset = tf.data.Dataset.from_tensor_slices(file_paths) # 并行加载预处理 dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE) # 设置大batch size dataset = dataset.batch(1000) # 预加载数据到GPU,让GPU不等待数据 dataset = dataset.prefetch(tf.data.AUTOTUNE) # 执行批量预测 for batch in dataset: preds = model.predict(batch, verbose=0)
这种方式不仅能提升显存利用率,还能大幅减少数据加载的等待时间,让GPU持续处于工作状态。
3. 调整显存分配策略(可选)
如果你确定要固定占用部分显存,而不是动态增长,可以关闭allow_growth,设置固定比例的显存预分配:
import tensorflow as tf from keras.backend.tensorflow_backend import set_session config = tf.ConfigProto() # 预分配90%的GPU显存(根据你的需求调整比例) config.gpu_options.per_process_gpu_memory_fraction = 0.9 set_session(tf.Session(config=config)) # 后续加载模型和预测 model = load_model(model_h5_path)
这种方式下,TensorFlow会一开始就占用指定比例的显存,处理大batch时直接使用预分配的空间,避免动态分配的开销,但缺点是其他进程无法使用这部分显存。
为什么之前的显存没变化?
allow_growth=True时,TensorFlow只会在需要时分配显存,处理完一个batch后,临时显存(输入、中间激活)可能被释放,nvidia-smi显示的是进程总占用,所以看起来显存没涨,但实际计算时的峰值显存是随batch size增大而增加的(你可以用tf.config.experimental.get_memory_info查看峰值)。- 模型参数的显存是固定的,这部分不会随batch size变化,所以基础占用始终是1017MiB。
内容的提问来源于stack exchange,提问作者nick




