You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch Tabular结合Sklearn K折交叉验证报错的技术求助

解决PyTorch Tabular与scikit-learn cross_val_score兼容问题

问题根源分析

你遇到的两个错误本质上是同一个核心问题:PyTorch Tabular的TabularModel并不兼容scikit-learn的Estimator接口

  1. 第一个错误里,tabular_model.fit()是原地训练模型的方法,返回值为None,所以你把mymodel传入cross_val_score时,相当于传了一个空对象,自然触发类型错误。
  2. 第二个错误里,scikit-learn的交叉验证工具要求Estimator必须实现get_params()set_params()方法来克隆模型,但TabularModel并没有提供这些接口,导致无法完成模型克隆操作。

另外要注意一个逻辑矛盾:你配置模型时用了task="classification",但交叉验证却使用了回归任务的r2评分指标,这会导致后续得分计算逻辑错误,需要根据实际任务调整(比如分类任务改用accuracy/f1,回归任务把task改成"regression")。

解决方案

下面提供两种可行的解决方法:

方法1:手动实现K折交叉验证(推荐,更直观)

既然TabularModel不兼容scikit-learn接口,我们可以手动循环每个fold完成交叉验证,步骤清晰且易调试:

from sklearn.model_selection import KFold
from sklearn.metrics import r2_score  # 分类任务请替换为accuracy_score/f1_score等
import numpy as np

# 准备包含特征和目标的完整训练数据集
full_train_data = train_data.copy()

# 初始化K折拆分器,设置随机种子保证结果可复现
kf = KFold(n_splits=10, shuffle=True, random_state=8)
fold_scores = []

for train_idx, val_idx in kf.split(full_train_data):
    # 拆分当前fold的训练集和验证集
    fold_train = full_train_data.iloc[train_idx]
    fold_val = full_train_data.iloc[val_idx]
    
    # 每个fold都重新初始化模型,避免上一轮训练的参数干扰当前fold
    tabular_model = TabularModel(
        data_config=data_config,
        model_config=model_config,
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
    )
    
    # 训练当前fold的模型
    tabular_model.fit(train=fold_train, validation=fold_val)
    
    # 在验证集上预测并计算得分
    val_preds = tabular_model.predict(fold_val)
    # 回归任务取target列预测值,分类任务取prediction列的类别值
    score = r2_score(fold_val['target'], val_preds['target'])
    fold_scores.append(score)

# 输出交叉验证结果
print(f"10折交叉验证得分列表:{fold_scores}")
print(f"平均得分:{np.mean(fold_scores):.4f} ± {np.std(fold_scores):.4f}")

关键注意点:

  • 每个fold必须重新初始化TabularModel,因为它不支持scikit-learn的克隆机制,避免之前的训练状态污染当前fold。
  • 务必匹配任务类型与评分指标:如果是分类任务,修改model_configtask"classification",并改用分类任务对应的指标计算得分。

方法2:包装成scikit-learn兼容的Estimator

如果你想复用scikit-learn的全量工具(比如GridSearchCV调参),可以把TabularModel包装成符合scikit-learn接口的类:

from sklearn.base import BaseEstimator, RegressorMixin  # 分类任务请替换为ClassifierMixin

class SKLearnTabularModel(BaseEstimator, RegressorMixin):
    def __init__(self, data_config, model_config, optimizer_config, trainer_config):
        self.data_config = data_config
        self.model_config = model_config
        self.optimizer_config = optimizer_config
        self.trainer_config = trainer_config
        self._inner_model = None
    
    def fit(self, X, y):
        # 将特征与目标合并为PyTorch Tabular要求的格式
        train_data = X.copy()
        train_data['target'] = y
        # 初始化并训练内部模型
        self._inner_model = TabularModel(
            data_config=self.data_config,
            model_config=self.model_config,
            optimizer_config=self.optimizer_config,
            trainer_config=self.trainer_config,
        )
        self._inner_model.fit(train=train_data)
        return self  # 必须返回self,符合scikit-learn规范
    
    def predict(self, X):
        # 预测并返回numpy数组(scikit-learn预期的输出格式)
        preds = self._inner_model.predict(X)
        return preds['target'].values
    
    # 实现get_params和set_params,让scikit-learn可以克隆模型
    def get_params(self, deep=True):
        return {
            "data_config": self.data_config,
            "model_config": self.model_config,
            "optimizer_config": self.optimizer_config,
            "trainer_config": self.trainer_config
        }
    
    def set_params(self, **params):
        for key, value in params.items():
            setattr(self, key, value)
        return self

使用包装后的模型:

# 初始化兼容scikit-learn的模型
sk_compatible_model = SKLearnTabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)

# 准备特征与目标(分开传入,符合scikit-learn要求)
X = full_train_data.drop('target', axis=1)
y = full_train_data['target']

# 现在可以直接使用cross_val_score了
scores = cross_val_score(sk_compatible_model, X, y, scoring='r2', cv=10)
print(f"10折交叉验证得分:{scores}")
print(f"平均得分:{np.mean(scores):.4f} ± {np.std(scores):.4f}")

总结

  • 手动实现K折交叉验证更简单直接,适合快速验证模型效果,无需额外封装代码。
  • 包装成scikit-learn兼容的Estimator可以解锁更多scikit-learn工具的使用,适合复杂调参或大规模实验场景。
  • 记得修正任务类型与评分指标的匹配问题,避免后续出现逻辑错误。

内容的提问来源于stack exchange,提问作者zizi

火山引擎 最新活动