You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

基于Keras协同过滤模型,如何利用物品因子权重查找相似游戏?

嘿,我完全懂你的痛点——不想重新写一遍整个协同过滤模型的代码,只想快速把训练好的items_factors嵌入层权重捞出来,用来计算游戏相似度对吧?其实有两个挺优雅的方案,不用碰整个大模型,给你细说:

方案1:单独加载目标嵌入层(最轻量化)

直接复刻训练时items_factors层的结构,然后通过层名匹配加载权重,完全不用管其他层:

from tensorflow.keras.layers import Embedding

# 注意:这里的参数必须和你训练时的设置完全一致!
num_items = 你的游戏总数  # 比如训练时用了多少个游戏就填多少
num_item_features = 50   # 和你函数里的num_item_features参数保持一致

# 创建和训练时一模一样的items_factors嵌入层
items_embedding = Embedding(
    input_dim=num_items,
    output_dim=num_item_features,
    name='items_factors'  # 层名必须完全匹配!权重文件是按层名存储的
)

# 先构建层(Embedding层需要先初始化权重结构才能加载外部权重)
items_embedding.build(input_shape=(None, 1))
# 加载权重,by_name=True会自动找到同名层的权重
items_embedding.load_weights('weights_633-1.79.hdf5', by_name=True)

# 现在直接拿到所有游戏的嵌入向量矩阵
# shape是(num_items, num_item_features)
item_vectors = items_embedding.get_weights()[0]
方案2:构建极简模型提取嵌入层

如果你之后可能还要用这个模型做一些物品嵌入的推理(比如输入游戏ID直接拿向量),可以建一个只包含item_initems_factors的极简模型,同样通过层名加载权重:

from tensorflow.keras.layers import Input, Embedding
from tensorflow.keras.models import Model

num_items = 你的游戏总数
num_item_features = 50

# 只定义获取物品嵌入的极简模型
item_input = Input(shape=(1,), dtype='int64', name='item_in')
item_factors = Embedding(
    input_dim=num_items,
    output_dim=num_item_features,
    name='items_factors'
)(item_input)
item_embedding_model = Model(inputs=item_input, outputs=item_factors)

# 加载权重,by_name=True会精准匹配items_factors层的权重
item_embedding_model.load_weights('weights_633-1.79.hdf5', by_name=True)

# 两种方式拿嵌入向量:
# 1. 直接获取整个嵌入矩阵
item_vectors = item_embedding_model.get_layer('items_factors').get_weights()[0]
# 2. 输入单个/多个游戏ID获取对应向量
# single_item_vector = item_embedding_model.predict([[10]])
接下来计算游戏相似度

拿到item_vectors之后,用余弦相似度计算相似游戏就很简单了:

from sklearn.metrics.pairwise import cosine_similarity

# 计算所有游戏之间的余弦相似度矩阵
similarity_matrix = cosine_similarity(item_vectors)

# 示例:找游戏ID为10的Top10相似游戏(排除自己)
target_game_id = 10
# argsort倒序取前11个,然后去掉第一个(自己)
similar_game_ids = similarity_matrix[target_game_id].argsort()[::-1][1:11]
print(f"游戏ID {target_game_id} 的Top10相似游戏ID:{similar_game_ids}")

内容的提问来源于stack exchange,提问作者julka

火山引擎 最新活动