You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

scikit-learn中如何正确将TimeSeriesSplit传入cross_val_score?

解决TimeSeriesSplit传入cross_val_score时的TypeError问题

这个错误的核心原因主要有两个:模块导入错误版本兼容性问题,咱们一步步来解决:

1. 先修正模块导入

你错误栈里显示用的是sklearn.cross_validation模块,这个模块在scikit-learn 0.20版本之后就被完全废弃了,所有交叉验证相关的工具都迁移到了sklearn.model_selection模块下。首先把导入语句改成正确的:

from sklearn.model_selection import TimeSeriesSplit, cross_val_score
from sklearn.neighbors import KNeighborsClassifier

2. 两种正确传入cv参数的方式

方式一:直接传入TimeSeriesSplit对象(推荐,适用于scikit-learn 0.19+)

从scikit-learn 0.19版本开始,cross_val_score已经支持直接传入交叉验证生成器实例(比如TimeSeriesSplit对象),不需要手动生成迭代器。修正后的完整代码如下:

# 初始化时间序列交叉验证生成器
tss = TimeSeriesSplit(max_train_size=None, n_splits=10)
accuracy_scores = []
neighbor_values = [1,3,5,7,9,11,13,12,23,19,18]

for k in neighbor_values:
    knn = KNeighborsClassifier(n_neighbors=k, algorithm='brute')
    # 直接将tss传入cv参数
    fold_scores = cross_val_score(knn, X1, y1, cv=tss, scoring='accuracy')
    accuracy_scores.append(fold_scores.mean())

方式二:手动生成索引迭代器(兼容旧版本)

如果你使用的是0.19之前的旧版本scikit-learn,需要手动调用TimeSeriesSplitsplit()方法生成训练/测试索引对的迭代器,再传给cv参数:

tss = TimeSeriesSplit(max_train_size=None, n_splits=10)
# 基于你的特征集X1生成交叉验证索引
cv_splits = tss.split(X1)
accuracy_scores = []
neighbor_values = [1,3,5,7,9,11,13,12,23,19,18]

for k in neighbor_values:
    knn = KNeighborsClassifier(n_neighbors=k, algorithm='brute')
    fold_scores = cross_val_score(knn, X1, y1, cv=cv_splits, scoring='accuracy')
    accuracy_scores.append(fold_scores.mean())

为什么之前会报错?

旧版本的cross_val_score(来自废弃的sklearn.cross_validation模块)无法直接识别TimeSeriesSplit对象,它需要的是一个可迭代的索引对集合,而新版本的sklearn.model_selection.cross_val_score已经做了适配,可以直接接受交叉验证生成器实例。

内容的提问来源于stack exchange,提问作者Dhruv Bhardwaj

火山引擎 最新活动