You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

为何在caret包中使用不同交叉验证配置训练lm模型会得到相同结果?

为何在caret包中使用不同交叉验证配置训练lm模型会得到相同结果?

你发现的这个现象完全是正常的,这其实是caret包的核心设计逻辑导致的,我来给你掰扯清楚:

首先要明确:交叉验证在caret::train()里的核心作用是评估模型性能,或者帮你选择最优的超参数,而不是用来拟合最终交付使用的模型

对于像lm(线性回归)这种没有可调超参数的模型来说,train()函数不管你设置了多少折交叉验证,最终都会用全部的训练数据来拟合最终的finalModel——交叉验证只是跑一遍,帮你算出模型在不同数据子集上的表现(比如RMSE、R²这些指标),但根本不会改变最终模型的拟合数据来源。

回到你的例子里:

  • 你设置了2折CV和10折CV,这俩配置只会影响交叉验证过程中计算的性能指标(你可以看看model_2$resultsmodel_10$results,里面的评估指标会有细微差别),但最终的model_2$finalModelmodel_10$finalModel都是用你生成的全部300行数据拟合的线性回归模型,所以它们的系数完全一致,identical()返回TRUE也就不足为奇了。

如果换成有超参数的模型(比如随机森林、支持向量机SVM),情况就不一样了:这时候交叉验证会用来挑选最优的超参数(比如树的数量、核函数类型),不同的CV配置可能会选出不同的超参数,最终的模型结果也就会有差异。但lm没有需要调整的超参数,所以CV配置对最终模型的系数完全没有影响。

你可以自己验证一下:看看model_2$finalModel$callmodel_10$finalModel$call,会发现它们都是用整个数据集来拟合的,和CV的折数完全无关。

附上你的验证代码(方便对照):

library(caret)
#> Loading required package: ggplot2
#> Loading required package: lattice

{
set.seed(123)
Xs <- matrix(rnorm(300*20),nrow = 300)
Y <- rnorm(300)
data <- cbind(Xs,Y) |> as.data.frame()
}

ctrlspecs_2 <- trainControl(method="cv", number=2)
ctrlspecs_10 <- trainControl(method="cv", number=10)

set.seed(123)
model_2 <- train(Y~.,
                 data = data,
                 method = "lm",
                 trControl = ctrlspecs_2)

set.seed(123)
model_10 <- train(Y~.,
                 data = data,
                 method = "lm",
                 trControl = ctrlspecs_10)

summary(model_2)
#> 
#> Call:
#> lm(formula = .outcome ~ ., data = dat)
#> 
#> Residuals:
#>     Min      1Q  Median      3Q     Max 
#> -3.5934 -0.6277 -0.0082  0.7448  2.2594 
#> 
#> Coefficients:
#>              Estimate Std. Error t value Pr(>|t|)   
#> (Intercept) -0.044073   0.060499  -0.728  0.46692   
#> V1          -0.129567   0.065772  -1.970  0.04984 * 
#> V2          -0.002505   0.061859  -0.040  0.96773   
#> V3          -0.046897   0.059486  -0.788  0.43115   
#> V4           0.044195   0.061427   0.719  0.47245   
#> V5           0.086981   0.064085   1.357  0.17579   
#> V6           0.014166   0.061001   0.232  0.81653   
#> V7          -0.077959   0.060911  -1.280  0.20165   
#> V8           0.017661   0.065486   0.270  0.78759   
#> V9          -0.096562   0.060567  -1.594  0.11200   
#> V10          0.164024   0.060858   2.695  0.00746 **
#> V11         -0.028008   0.060869  -0.460  0.64577   
#> V12          0.034027   0.062118   0.548  0.58428   
#> V13         -0.066028   0.066681  -0.990  0.32294   
#> V14          0.142444   0.061319   2.323  0.02090 * 
#> V15         -0.129046   0.060109  -2.147  0.03267 * 
#> V16         -0.020873   0.061512  -0.339  0.73462   
#> V17          0.046835   0.063381   0.739  0.46056   
#> V18          0.035570   0.066567   0.534  0.59353   
#> V19         -0.016253   0.060039  -0.271  0.78682   
#> V20         -0.082083   0.060843  -1.349  0.17840   
#> ---
#> Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#> 
#> Residual standard error: 1.033 on 279 degrees of freedom
#> Multiple R-squared:  0.1041, Adjusted R-squared:  0.03986 
#> F-statistic: 1.621 on 20 and 279 DF,  p-value: 0.04731
summary(model_10)
#> 
#> Call:
#> lm(formula = .outcome ~ ., data = dat)
#> 
#> Residuals:
#>     Min      1Q  Median      3Q     Max 
#> -3.5934 -0.6277 -0.0082  0.7448  2.2594 
#> 
#> Coefficients:
#>              Estimate Std. Error t value Pr(>|t|)   
#> (Intercept) -0.044073   0.060499  -0.728  0.46692   
#> V1          -0.129567   0.065772  -1.970  0.04984 * 
#> V2          -0.002505   0.061859  -0.040  0.96773   
#> V3          -0.046897   0.059486  -0.788  0.43115   
#> V4           0.044195   0.061427   0.719  0.47245   
#> V5           0.086981   0.064085   1.357  0.17579   
#> V6           0.014166   0.061001   0.232  0.81653   
#> V7          -0.077959   0.060911  -1.280  0.20165   
#> V8           0.017661   0.065486   0.270  0.78759   
#> V9          -0.096562   0.060567  -1.594  0.11200   
#> V10          0.164024   0.060858   2.695  0.00746 **
#> V11         -0.028008   0.060869  -0.460  0.64577   
#> V12          0.034027   0.062118   0.548  0.58428   
#> V13         -0.066028   0.066681  -0.990  0.32294   
#> V14          0.142444   0.061319   2.323  0.02090 * 
#> V15         -0.129046   0.060109  -2.147  0.03267 * 
#> V16         -0.020873   0.061512  -0.339  0.73462   
#> V17          0.046835   0.063381   0.739  0.46056   
#> V18          0.035570   0.066567   0.534  0.59353   
#> V19         -0.016253   0.060039  -0.271  0.78682   
#> V20         -0.082083   0.060843  -1.349  0.17840   
#> ---
#> Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#> 
#> Residual standard error: 1.033 on 279 degrees of freedom
#> Multiple R-squared:  0.1041, Adjusted R-squared:  0.03986 
#> F-statistic: 1.621 on 20 and 279 DF,  p-value: 0.04731

identical(model_2$finalModel$coefficients,model_10$finalModel$coefficients)
#> [1] TRUE

备注:内容来源于stack exchange,提问作者Juan P FZ

火山引擎 最新活动