You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何使用fitctree在不平衡数据集训练更敏感的MATLAB决策树?

解决MATLAB不平衡数据集决策树全判负类的方法

嗨,这个问题太典型了——当数据集严重不平衡时,模型很容易“躺平”直接预测多数类来刷高准确率,完全忽略了你的正类样本。别担心,咱们从几个方向来调整MATLAB的决策树,让它对正类更敏感:

  • 调整损失矩阵,惩罚误判正类
    核心思路是给“把正类误判为负类”设置更高的代价,让模型不敢随便忽略正类。在fitctree里可以通过Cost参数实现:

    % 假设正类标签为1,负类为0
    % 损失矩阵规则:行=真实类别,列=预测类别
    cost_matrix = [0 10; 1 0]; % 正类误判为负类的代价设为10,远高于负类误判的代价
    tree = fitctree(features, labels, 'Cost', cost_matrix);
    

    你可以根据业务需求调整代价数值——如果正类漏判的后果越严重,就把对应的代价调得越高。

  • 给少数类加权重,提升其优先级
    通过ClassWeights参数给正类(少数类)赋予更高的权重,让模型在训练时更重视它们的样本:

    num_pos = sum(labels == 1);
    num_neg = sum(labels == 0);
    % 按样本数反比设置权重:正类权重 = 负类样本数/正类样本数
    class_weights = [num_neg/num_pos, 1];
    tree = fitctree(features, labels, 'ClassWeights', class_weights);
    

    你也可以用Prior参数调整先验概率,强制模型认为正类出现的概率更高,效果类似:

    prior_probs = [num_pos/(num_pos+num_neg), num_neg/(num_pos+num_neg)];
    tree = fitctree(features, labels, 'Prior', prior_probs);
    
  • 数据集层面的平衡处理
    如果模型层面的调整效果不够,还可以直接调整数据集的分布:

    • 过采样正类:复制正类样本,让其数量接近负类(注意可能导致过拟合)
      pos_idx = labels == 1;
      neg_idx = labels == 0;
      % 复制正类样本14次(350000/25000≈14)
      oversampled_pos = repmat(features(pos_idx,:), 14, 1);
      oversampled_labels = repmat(labels(pos_idx), 14, 1);
      % 合并新数据集
      new_features = [oversampled_pos; features(neg_idx,:)];
      new_labels = [oversampled_labels; labels(neg_idx)];
      tree = fitctree(new_features, new_labels);
      
    • 欠采样负类:随机删除部分负类样本,只保留和正类数量相当的样本(注意可能丢失负类信息)
      neg_idx = labels == 0;
      % 随机选择和正类数量相同的负类样本
      sampled_neg_idx = neg_idx(randperm(sum(neg_idx), sum(labels==1)));
      new_features = [features(labels==1,:); features(sampled_neg_idx,:)];
      new_labels = [labels(labels==1); labels(sampled_neg_idx)];
      tree = fitctree(new_features, new_labels);
      
  • 放弃准确率,改用合适的评估指标
    准确率在不平衡数据集里毫无意义!你应该关注召回率(敏感度)、F1分数、ROC-AUC这些指标,它们能更准确反映模型对正类的识别能力:

    [pred_labels, scores] = predict(tree, test_features);
    % 计算混淆矩阵
    conf_mat = confusionmat(test_labels, pred_labels);
    % 正类的召回率(抓到的正类占所有真实正类的比例)
    recall = conf_mat(1,1)/(conf_mat(1,1)+conf_mat(1,2));
    % 计算ROC曲线和AUC
    [roc_obj, auc_score] = rocmetrics(test_labels, scores(:,2), 1);
    plot(roc_obj)
    
  • 调整决策树的分裂准则
    默认的基尼指数对不平衡数据不够友好,你可以尝试改用增益率作为分裂准则,配合前面的方法一起使用:

    tree = fitctree(features, labels, 'SplitCriterion', 'gainratio');
    

最后提醒一下:优先尝试损失矩阵和类权重这两个模型层面的调整,它们不需要修改原始数据,风险更低。如果效果还是不理想,再考虑数据集采样的方法。

内容的提问来源于stack exchange,提问作者user3470496

火山引擎 最新活动