You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何仅使用CPU运行Embedding层以规避TensorFlow GPU内存初始化错误

解决TensorFlow中Embedding层GPU内存不足的问题

Hey there! 我完全懂你遇到的显存瓶颈糟心事儿——RTX 3060的显存扛不住把整个Embedding层加载到GPU,全模型跑CPU又慢到没法接受。你已经尝试自定义CPUEmbedding类了,这个方向特别对,咱们再把细节捋顺,确保只有Embedding层在CPU运算,其他层正常用GPU加速。

核心解决方案:精准绑定Embedding层到CPU

你之前的自定义层只把权重放在了CPU,但前向/反向传播的运算逻辑还是默认跑GPU,这就导致TensorFlow尝试把CPU上的张量拷贝到GPU,触发了你遇到的Dst tensor is not initialized错误(本质是显存不够初始化目标张量)。咱们调整一下自定义层的实现,把整个Embedding的运算都锁在CPU:

from tensorflow.keras.layers import Embedding
from tensorflow.keras import backend as ops
from tensorflow.keras import utils as tf_utils

class CPUEmbedding(Embedding):
    def call(self, inputs):
        # 强制Embedding的前向计算在CPU执行
        with ops.device('cpu:0'):
            return super().call(inputs)
    
    @tf_utils.shape_type_conversion
    def build(self, input_shape):
        # 把Embedding权重也放在CPU
        with ops.device('cpu:0'):
            super().build(input_shape)
        print("✅ Embedding layer is running on CPU")

这个修改的关键是重写call方法,让整个Embedding的运算逻辑都在CPU完成,结果再传到GPU给后续的LSTM、Dense层处理,彻底避免了不必要的跨设备张量拷贝。

验证设备分配是否生效

你可以在模型构建后,打印每一层的设备信息确认:

for layer in model.layers:
    if isinstance(layer, CPUEmbedding):
        print(f"Embedding weights device: {layer.embeddings.device}")
    print(f"{layer.name} layer default device: GPU (auto-assigned)")

正常情况下,Embedding的权重会显示在/job:localhost/replica:0/task:0/device:CPU:0,而LSTM、Dense层会自动分配到GPU。

额外的显存优化小技巧

除了Embedding层的设备控制,还有几个小技巧能帮你进一步节省GPU显存,顺利用上全部20000条训练数据:

  • 调整batch size:如果还是显存紧张,把batch_size从128降到64或者32,这是最直接的显存减压方式
  • 开启混合精度训练:在不损失精度的前提下大幅减少显存占用,只需要添加两行代码:
    from tensorflow.keras.mixed_precision import set_global_policy
    set_global_policy('mixed_float16')
    
    注意最后一层Dense要保持float32精度,避免softmax运算出现精度问题:
    model.add(Dense(total_words, activation='softmax', dtype='float32'))
    
  • 启用显存按需分配:让TensorFlow根据需要占用显存,而不是一开始就占满所有显存:
    import tensorflow as tf
    tf.keras.backend.clear_session()
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    

修改后的完整可运行代码

把所有优化整合到你的代码里,最终版本如下:

from keras.preprocessing.sequence import pad_sequences
from keras.layers import Embedding, LSTM, Dense, Dropout
from keras.preprocessing.text import Tokenizer
from keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import backend as ops
from tensorflow.keras import utils as tf_utils
import numpy as np
import tensorflow as tf

# 开启显存按需分配,避免初始占满显存
tf.keras.backend.clear_session()
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)

# 自定义CPU绑定的Embedding层
class CPUEmbedding(Embedding):
    def call(self, inputs):
        with ops.device('cpu:0'):
            return super().call(inputs)
    
    @tf_utils.shape_type_conversion
    def build(self, input_shape):
        with ops.device('cpu:0'):
            super().build(input_shape)
        print("✅ Embedding layer is running on CPU")

# 数据预处理
tokenizer = Tokenizer()
with open('text.txt', encoding='utf-8') as f:
    data = f.read().replace('\ufeff', '')
data_list = data.lower().split("\n")

tokenizer.fit_on_texts(data_list)
total_words = len(tokenizer.word_index) + 1
print('Total words:', total_words)

input_sequences = []
for line in data_list:
    token_list = tokenizer.texts_to_sequences([line])[0]
    for i in range(1, len(token_list)):
        n_gram_sequence = token_list[:i + 1]
        input_sequences.append(n_gram_sequence)

max_sequence_len = max([len(x) for x in input_sequences])
input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre'))

X, labels = input_sequences[:, :-1], input_sequences[:, -1]
Y = to_categorical(labels, num_classes=total_words)

# 构建模型
model = Sequential()
model.add(CPUEmbedding(total_words, 256, input_length=max_sequence_len - 1))
model.add(LSTM(256, return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(512))
model.add(Dropout(0.2))
# 最后一层保持float32精度,避免softmax运算精度损失
model.add(Dense(total_words, activation='softmax', dtype='float32'))

# 编译模型
adam = Adam(learning_rate=0.001)
model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['acc'])

# 打印模型结构和设备验证信息
model.summary()
for layer in model.layers:
    if isinstance(layer, CPUEmbedding):
        print(f"Embedding weights device: {layer.embeddings.device}")

# 用全部20000条数据训练
history = model.fit(x=X, y=Y, batch_size=64, epochs=1000)

内容的提问来源于stack exchange,提问作者Ne1zvestnyj

火山引擎 最新活动