如何使用已构建的Python高斯朴素贝叶斯模型预测用户输入的新数据结果
解决Gaussian Naive Bayes模型的新数据预测问题
嘿,我看你已经把Gaussian Naive Bayes模型搭建完成了,要实现新数据的预测其实核心就一点:新数据的预处理逻辑必须和训练模型时完全对齐,不然模型根本没法正确理解输入。咱们一步步来搞定这个事儿:
第一步:修正训练代码的小问题并保存关键参数
你原来的代码里有个容易踩坑的地方:用同一个LabelEncoder处理所有类别特征。这会导致不同特征的编码规则混乱(比如workclass的"Federal-gov"和education的"Some-college"可能被编出相同的数字)。咱们先把这部分改了,同时把训练时用到的编码器、标准化参数和训练好的模型都保存下来,方便后续复用。
修正后的训练代码片段
# 替换原来的LabelEncoder部分,每个类别特征单独创建编码器 le_workclass = preprocessing.LabelEncoder() workclass_cat = le_workclass.fit_transform(adult_df.workclass) le_education = preprocessing.LabelEncoder() education_cat = le_education.fit_transform(adult_df.education) le_marital = preprocessing.LabelEncoder() marital_cat = le_marital.fit_transform(adult_df.marital_status) le_occupation = preprocessing.LabelEncoder() occupation_cat = le_occupation.fit_transform(adult_df.occupation) le_relationship = preprocessing.LabelEncoder() relationship_cat = le_relationship.fit_transform(adult_df.relationship) le_race = preprocessing.LabelEncoder() race_cat = le_race.fit_transform(adult_df.race) le_sex = preprocessing.LabelEncoder() sex_cat = le_sex.fit_transform(adult_df.sex) le_native_country = preprocessing.LabelEncoder() native_country_cat = le_native_country.fit_transform(adult_df.native_country) # 后续的列添加、删除、标准化逻辑和你原来的代码一致,这里省略... # 训练模型 clf = GaussianNB() clf.fit(features_train, target_train) # 保存编码器、标准化参数和模型(用pickle序列化) import pickle # 保存所有类别特征的编码器 encoders = { 'workclass': le_workclass, 'education': le_education, 'marital_status': le_marital, 'occupation': le_occupation, 'relationship': le_relationship, 'race': le_race, 'sex': le_sex, 'native_country': le_native_country } with open('encoders.pkl', 'wb') as f: pickle.dump(encoders, f) # 保存标准化用的均值和标准差 with open('scaled_features.pkl', 'wb') as f: pickle.dump(scaled_features, f) # 保存训练好的GNB模型 with open('gnb_model.pkl', 'wb') as f: pickle.dump(clf, f)
第二步:编写新数据预处理函数
新数据必须经过和训练数据一模一样的处理:填充缺失值(用训练时的众数)、类别特征编码、数值特征标准化,还要保证特征顺序和训练时一致。咱们写个函数来封装这些逻辑:
import pandas as pd import pickle def preprocess_new_data(new_sample): # 加载之前保存的编码器和标准化参数 with open('encoders.pkl', 'rb') as f: encoders = pickle.load(f) with open('scaled_features.pkl', 'rb') as f: scaled_features = pickle.load(f) # 把单条新数据转成DataFrame方便处理 df = pd.DataFrame([new_sample]) # 处理新数据中的'?'(用训练时的众数填充) mode_values = { 'workclass': adult_df.describe(include='all')['workclass'][2], 'education': adult_df.describe(include='all')['education'][2], 'marital_status': adult_df.describe(include='all')['marital_status'][2], 'occupation': adult_df.describe(include='all')['occupation'][2], 'relationship': adult_df.describe(include='all')['relationship'][2], 'race': adult_df.describe(include='all')['race'][2], 'sex': adult_df.describe(include='all')['sex'][2], 'native_country': adult_df.describe(include='all')['native_country'][2] } for col in mode_values.keys(): df[col] = df[col].replace('?', mode_values[col]) # 对类别特征进行编码(用训练时的编码器) df['workclass_cat'] = encoders['workclass'].transform(df['workclass']) df['education_cat'] = encoders['education'].transform(df['education']) df['marital_cat'] = encoders['marital_status'].transform(df['marital_status']) df['occupation_cat'] = encoders['occupation'].transform(df['occupation']) df['relationship_cat'] = encoders['relationship'].transform(df['relationship']) df['race_cat'] = encoders['race'].transform(df['race']) df['sex_cat'] = encoders['sex'].transform(df['sex']) df['native_country_cat'] = encoders['native_country'].transform(df['native_country']) # 删除原始类别列 dummy_fields = ['workclass', 'education', 'marital_status', 'occupation', 'relationship', 'race', 'sex', 'native_country'] df = df.drop(dummy_fields, axis=1) # 标准化数值特征(用训练时的均值和标准差) num_features = ['age', 'workclass_cat', 'fnlwgt', 'education_cat', 'education_num', 'marital_cat', 'occupation_cat', 'relationship_cat', 'race_cat', 'sex_cat', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country_cat'] for feature in num_features: mean, std = scaled_features[feature] df[feature] = (df[feature] - mean) / std # 确保特征顺序和训练时完全一致 required_columns = ['age', 'workclass_cat', 'fnlwgt', 'education_cat', 'education_num', 'marital_cat', 'occupation_cat', 'relationship_cat', 'race_cat', 'sex_cat', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country_cat'] df = df.reindex(columns=required_columns) # 返回模型需要的数组格式 return df.values
第三步:加载模型并预测新数据
现在就可以用预处理好的新数据喂给模型,得到预测结果了:
# 加载训练好的GNB模型 with open('gnb_model.pkl', 'rb') as f: trained_clf = pickle.load(f) # 示例新数据(和你的数据集格式一致) new_data = { 'age': 37, 'workclass': 'Federal-gov', 'fnlwgt': 29054, 'education': 'Some-college', 'education_num': 10, 'marital_status': 'Married-civ-spouse', 'occupation': 'Adm-clerical', 'relationship': 'Husband', 'race': 'White', 'sex': 'Male', 'capital_gain': 0, 'capital_loss': 0, 'hours_per_week': 42, 'native_country': 'United-States' } # 预处理新数据 processed_data = preprocess_new_data(new_data) # 获取预测结果 prediction = trained_clf.predict(processed_data) print(f"该样本的收入预测结果:{prediction[0]}")
额外提示
- 如果你要批量处理多条新数据,只需要把
new_sample改成包含多个字典的列表,然后在预处理函数里直接用pd.DataFrame(new_samples)即可。 - 原代码中的
reindex_axis在新版本pandas中已经被废弃,建议用reindex(columns=...)替代。 - 一定要保证新数据的特征和训练数据完全一致,比如不能少了
education_num这种数值特征,不然预处理会报错。
内容的提问来源于stack exchange,提问作者AbhiExtreme




