如何仅使用CPU运行Embedding层以规避TensorFlow GPU内存初始化错误
解决TensorFlow中Embedding层GPU内存不足的问题
Hey there! 我完全懂你遇到的显存瓶颈糟心事儿——RTX 3060的显存扛不住把整个Embedding层加载到GPU,全模型跑CPU又慢到没法接受。你已经尝试自定义CPUEmbedding类了,这个方向特别对,咱们再把细节捋顺,确保只有Embedding层在CPU运算,其他层正常用GPU加速。
核心解决方案:精准绑定Embedding层到CPU
你之前的自定义层只把权重放在了CPU,但前向/反向传播的运算逻辑还是默认跑GPU,这就导致TensorFlow尝试把CPU上的张量拷贝到GPU,触发了你遇到的Dst tensor is not initialized错误(本质是显存不够初始化目标张量)。咱们调整一下自定义层的实现,把整个Embedding的运算都锁在CPU:
from tensorflow.keras.layers import Embedding from tensorflow.keras import backend as ops from tensorflow.keras import utils as tf_utils class CPUEmbedding(Embedding): def call(self, inputs): # 强制Embedding的前向计算在CPU执行 with ops.device('cpu:0'): return super().call(inputs) @tf_utils.shape_type_conversion def build(self, input_shape): # 把Embedding权重也放在CPU with ops.device('cpu:0'): super().build(input_shape) print("✅ Embedding layer is running on CPU")
这个修改的关键是重写call方法,让整个Embedding的运算逻辑都在CPU完成,结果再传到GPU给后续的LSTM、Dense层处理,彻底避免了不必要的跨设备张量拷贝。
验证设备分配是否生效
你可以在模型构建后,打印每一层的设备信息确认:
for layer in model.layers: if isinstance(layer, CPUEmbedding): print(f"Embedding weights device: {layer.embeddings.device}") print(f"{layer.name} layer default device: GPU (auto-assigned)")
正常情况下,Embedding的权重会显示在/job:localhost/replica:0/task:0/device:CPU:0,而LSTM、Dense层会自动分配到GPU。
额外的显存优化小技巧
除了Embedding层的设备控制,还有几个小技巧能帮你进一步节省GPU显存,顺利用上全部20000条训练数据:
- 调整batch size:如果还是显存紧张,把
batch_size从128降到64或者32,这是最直接的显存减压方式 - 开启混合精度训练:在不损失精度的前提下大幅减少显存占用,只需要添加两行代码:
注意最后一层Dense要保持from tensorflow.keras.mixed_precision import set_global_policy set_global_policy('mixed_float16')float32精度,避免softmax运算出现精度问题:model.add(Dense(total_words, activation='softmax', dtype='float32')) - 启用显存按需分配:让TensorFlow根据需要占用显存,而不是一开始就占满所有显存:
import tensorflow as tf tf.keras.backend.clear_session() gpus = tf.config.list_physical_devices('GPU') if gpus: tf.config.experimental.set_memory_growth(gpus[0], True)
修改后的完整可运行代码
把所有优化整合到你的代码里,最终版本如下:
from keras.preprocessing.sequence import pad_sequences from keras.layers import Embedding, LSTM, Dense, Dropout from keras.preprocessing.text import Tokenizer from keras.models import Sequential from tensorflow.keras.optimizers import Adam from tensorflow.keras.utils import to_categorical from tensorflow.keras import backend as ops from tensorflow.keras import utils as tf_utils import numpy as np import tensorflow as tf # 开启显存按需分配,避免初始占满显存 tf.keras.backend.clear_session() gpus = tf.config.list_physical_devices('GPU') if gpus: try: tf.config.experimental.set_memory_growth(gpus[0], True) except RuntimeError as e: print(e) # 自定义CPU绑定的Embedding层 class CPUEmbedding(Embedding): def call(self, inputs): with ops.device('cpu:0'): return super().call(inputs) @tf_utils.shape_type_conversion def build(self, input_shape): with ops.device('cpu:0'): super().build(input_shape) print("✅ Embedding layer is running on CPU") # 数据预处理 tokenizer = Tokenizer() with open('text.txt', encoding='utf-8') as f: data = f.read().replace('\ufeff', '') data_list = data.lower().split("\n") tokenizer.fit_on_texts(data_list) total_words = len(tokenizer.word_index) + 1 print('Total words:', total_words) input_sequences = [] for line in data_list: token_list = tokenizer.texts_to_sequences([line])[0] for i in range(1, len(token_list)): n_gram_sequence = token_list[:i + 1] input_sequences.append(n_gram_sequence) max_sequence_len = max([len(x) for x in input_sequences]) input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre')) X, labels = input_sequences[:, :-1], input_sequences[:, -1] Y = to_categorical(labels, num_classes=total_words) # 构建模型 model = Sequential() model.add(CPUEmbedding(total_words, 256, input_length=max_sequence_len - 1)) model.add(LSTM(256, return_sequences=True)) model.add(Dropout(0.2)) model.add(LSTM(512)) model.add(Dropout(0.2)) # 最后一层保持float32精度,避免softmax运算精度损失 model.add(Dense(total_words, activation='softmax', dtype='float32')) # 编译模型 adam = Adam(learning_rate=0.001) model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['acc']) # 打印模型结构和设备验证信息 model.summary() for layer in model.layers: if isinstance(layer, CPUEmbedding): print(f"Embedding weights device: {layer.embeddings.device}") # 用全部20000条数据训练 history = model.fit(x=X, y=Y, batch_size=64, epochs=1000)
内容的提问来源于stack exchange,提问作者Ne1zvestnyj




