You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

基于分类数据的预测问题:编码后训练数据与预测结果不一致

搞定cat.codes编码导致训练/预测结果不匹配的坑

老哥,我太懂你这个问题了——用cat.codes直接处理分类数据时,很容易踩「编码不统一」的坑!本质原因是:cat.codes的编码规则完全依赖于当前数据集的类别集合和顺序。如果你的测试集(或者要预测的新数据)里的分类变量和训练集的类别不完全一致,哪怕只是顺序变了,编码都会错位,模型输出自然就不对了。比如你提到的user_agent字段,假设训练集里它的类别是['Chrome', 'Firefox', 'Safari'],编码是0、1、2,但测试集突然冒出来个Edge,或者类别顺序变成了['Firefox', 'Chrome', 'Safari'],那编码直接就乱套了。

给你几个实用的解决办法,按稳妥程度排序:

1. 用训练集的类别锁死编码规则(最稳妥)

训练的时候先把分类列的类别固定下来,预测时强制用训练集的类别来转换,从根源上避免编码差异:

# 训练阶段:处理分类列并保存训练集的类别
train_df['user_agent'] = train_df['user_agent'].astype('category')
# 把训练集的类别存下来,相当于一个"编码字典"
user_agent_train_cats = train_df['user_agent'].cat.categories

# 预测阶段:用训练集的类别来标准化测试数据
test_df['user_agent'] = test_df['user_agent'].astype('category').cat.set_categories(user_agent_train_cats)
# 再生成编码,此时测试集里没见过的类别会被设为NaN,你可以后续填充或单独处理
test_df['user_agent_code'] = test_df['user_agent'].cat.codes

你可以用dataset.groupby(['user_agent']).size()对比训练集和测试集的类别分布,看看是不是有训练集没见过的类别,这是排查问题的好办法。

2. 用Scikit-learn的LabelEncoder自动管理映射

LabelEncoder会自动记住训练时的类别-编码映射,预测时直接复用,比手动处理更省心:

from sklearn.preprocessing import LabelEncoder

# 训练阶段:拟合训练集的类别并生成编码
le = LabelEncoder()
train_df['user_agent_code'] = le.fit_transform(train_df['user_agent'])

# 预测阶段:先处理训练集没见过的类别(可选,不然会报错)
test_df['user_agent'] = test_df['user_agent'].apply(lambda x: x if x in le.classes_ else 'unknown')
# 新版本Scikit-learn支持handle_unknown参数,可以直接忽略未知类别:
# le = LabelEncoder(handle_unknown='ignore')
test_df['user_agent_code'] = le.transform(test_df['user_agent'])

3. 别用cat.codes了!换更合适的编码方式(进阶优化)

如果你的分类变量是无序的(比如不同的user_agent之间没有优先级/顺序关系),cat.codes生成的数值会给模型传递错误的“顺序信号”,反而影响效果。这时候更推荐:

  • 独热编码:把每个类别变成一个二进制特征,适合类别数量不多的情况,用pd.get_dummies()OneHotEncoder实现:
    from sklearn.preprocessing import OneHotEncoder
    # drop='first'避免多重共线性问题
    ohe = OneHotEncoder(sparse_output=False, drop='first')
    train_ohe_features = ohe.fit_transform(train_df[['user_agent']])
    test_ohe_features = ohe.transform(test_df[['user_agent']])
    
  • 目标编码:用类别对应的目标变量均值来编码,适合类别数量多的场景,但要注意用交叉验证防止过拟合。

内容的提问来源于stack exchange,提问作者gogasca

火山引擎 最新活动