基于Keras协同过滤模型,如何利用物品因子权重查找相似游戏?
嘿,我完全懂你的痛点——不想重新写一遍整个协同过滤模型的代码,只想快速把训练好的items_factors嵌入层权重捞出来,用来计算游戏相似度对吧?其实有两个挺优雅的方案,不用碰整个大模型,给你细说:
方案1:单独加载目标嵌入层(最轻量化)
直接复刻训练时items_factors层的结构,然后通过层名匹配加载权重,完全不用管其他层:
from tensorflow.keras.layers import Embedding # 注意:这里的参数必须和你训练时的设置完全一致! num_items = 你的游戏总数 # 比如训练时用了多少个游戏就填多少 num_item_features = 50 # 和你函数里的num_item_features参数保持一致 # 创建和训练时一模一样的items_factors嵌入层 items_embedding = Embedding( input_dim=num_items, output_dim=num_item_features, name='items_factors' # 层名必须完全匹配!权重文件是按层名存储的 ) # 先构建层(Embedding层需要先初始化权重结构才能加载外部权重) items_embedding.build(input_shape=(None, 1)) # 加载权重,by_name=True会自动找到同名层的权重 items_embedding.load_weights('weights_633-1.79.hdf5', by_name=True) # 现在直接拿到所有游戏的嵌入向量矩阵 # shape是(num_items, num_item_features) item_vectors = items_embedding.get_weights()[0]
方案2:构建极简模型提取嵌入层
如果你之后可能还要用这个模型做一些物品嵌入的推理(比如输入游戏ID直接拿向量),可以建一个只包含item_in和items_factors的极简模型,同样通过层名加载权重:
from tensorflow.keras.layers import Input, Embedding from tensorflow.keras.models import Model num_items = 你的游戏总数 num_item_features = 50 # 只定义获取物品嵌入的极简模型 item_input = Input(shape=(1,), dtype='int64', name='item_in') item_factors = Embedding( input_dim=num_items, output_dim=num_item_features, name='items_factors' )(item_input) item_embedding_model = Model(inputs=item_input, outputs=item_factors) # 加载权重,by_name=True会精准匹配items_factors层的权重 item_embedding_model.load_weights('weights_633-1.79.hdf5', by_name=True) # 两种方式拿嵌入向量: # 1. 直接获取整个嵌入矩阵 item_vectors = item_embedding_model.get_layer('items_factors').get_weights()[0] # 2. 输入单个/多个游戏ID获取对应向量 # single_item_vector = item_embedding_model.predict([[10]])
接下来计算游戏相似度
拿到item_vectors之后,用余弦相似度计算相似游戏就很简单了:
from sklearn.metrics.pairwise import cosine_similarity # 计算所有游戏之间的余弦相似度矩阵 similarity_matrix = cosine_similarity(item_vectors) # 示例:找游戏ID为10的Top10相似游戏(排除自己) target_game_id = 10 # argsort倒序取前11个,然后去掉第一个(自己) similar_game_ids = similarity_matrix[target_game_id].argsort()[::-1][1:11] print(f"游戏ID {target_game_id} 的Top10相似游戏ID:{similar_game_ids}")
内容的提问来源于stack exchange,提问作者julka




