You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何基于rank_group列按70%/20%/10%比例拆分数据集为训练集、验证集和测试集

如何基于rank_group列按70%/20%/10%比例拆分数据集为训练集、验证集和测试集

问题背景

你手头有一份带rank_group分组标记的数据集,希望按70%训练集、20%验证集、10%测试集的比例拆分,且要求同一个rank_group下的所有行必须完整落在同一个数据集中(比如你的示例里,训练集要包含rank_group=1.0~5.0的所有行,验证集包含6.0~7.0,测试集包含8.0)。

先看一下你提供的原始数据集:

date      user   f1     f2       rank   rank_group  counts
   0  09/09/2021  USER100  59.0  3599.9    1         1.0       3
   1  10/09/2021  USER100  75.29 80790.0   2         1.0       3
   2  11/09/2021  USER100  75.29 80790.0   3         1.0       3
   1  10/09/2021  USER100  75.29 80790.0   2         2.0       3
   2  11/09/2021  USER100  75.29 80790.0   3         2.0       3
   3  12/09/2021  USER100  75.29 80790.0   4         2.0       3
   2  11/09/2021  USER100  75.29 80790.0   3         3.0       3
   3  12/09/2021  USER100  75.29 80790.0   4         3.0       3
   4  13/09/2021  USER100  75.29 80790.0   5         3.0       3
   3  12/09/2021  USER100  75.29 80790.0   4         4.0       3
   4  13/09/2021  USER100  75.29 80790.0   5         4.0       3
   5  14/09/2021  USER100  75.29 80790.0   6         4.0       3
   4  13/09/2021  USER100  75.29 80790.0   5         5.0       3
   5  14/09/2021  USER100  75.29 80790.0   6         5.0       3
   6  15/09/2021  USER100  71.24 28809.9   7         5.0       3
   5  14/09/2021  USER100  75.29 80790.0   6         6.0       3
   6  15/09/2021  USER100  71.24 28809.9   7         6.0       3
   7  16/09/2021  USER100  71.31 79209.9   8         6.0       3
   6  15/09/2021  USER100  71.24 28809.9   7         7.0       3
   7  16/09/2021  USER100  71.31 79209.9   8         7.0       3
   8  17/09/2021  USER100  70.43 82809.9   9         7.0       3
   7  16/09/2021  USER100  71.31 79209.9   8         8.0       3
   8  17/09/2021  USER100  70.43 82809.9   9         8.0       3
   9  18/09/2021  USER100  68.65 82809.9   10        8.0       3

现有方法的问题

先说说你尝试的两种方法存在的问题:

  • 方法I(np.split):这个函数是按行的索引位置直接拆分,完全不考虑rank_group分组。会导致同一个rank_group的行被拆到训练、验证或测试集里,完全不符合你“同组行必须在同一个集合”的要求。
  • 方法II(自定义逻辑):逻辑过于复杂,而且计算和范围判断容易出错(比如validation_number的计算、range的起止值),后期维护起来很麻烦。

简洁可靠的解决方案

其实我们可以换个思路:先把rank_group的唯一值按顺序划分好,再根据分组筛选数据,逻辑清晰且不易出错。以下是基于Pandas的实现(因为你的数据是表格格式,Pandas处理最方便):

import pandas as pd

# 1. 读取或构造你的数据集(这里假设已经是DataFrame格式)
# user_dataset = pd.read_csv("your_data.csv")  # 实际场景可以用这个读取数据

# 2. 获取排序后的唯一rank_group列表
unique_groups = sorted(user_dataset['rank_group'].unique())
total_groups = len(unique_groups)
print(f"所有rank_group:{unique_groups},共{total_groups}个")

# 3. 按比例计算各集合对应的group数量(这里和你的期望完全匹配)
train_group_count = 5  # 8*0.7=5.6,取整为5,对应group1-5
val_group_count = 2    # 8*0.2=1.6,取整为2,对应group6-7
test_group_count = 1   # 剩下的1个,对应group8

# 如果你想通用化(适配不同数量的group),可以用下面的计算方式:
# train_group_count = int(total_groups * 0.7)
# val_group_count = int(total_groups * 0.2)
# test_group_count = total_groups - train_group_count - val_group_count

# 4. 划分各集合对应的group列表
train_groups = unique_groups[:train_group_count]
val_groups = unique_groups[train_group_count:train_group_count+val_group_count]
test_groups = unique_groups[train_group_count+val_group_count:]

# 5. 根据group列表筛选数据
train_set = user_dataset[user_dataset['rank_group'].isin(train_groups)]
validation_set = user_dataset[user_dataset['rank_group'].isin(val_groups)]
test_set = user_dataset[user_dataset['rank_group'].isin(test_groups)]

# 验证结果
print(f"训练集包含rank_group:{train_groups},共{len(train_set)}行")
print(f"验证集包含rank_group:{val_groups},共{len(validation_set)}行")
print(f"测试集包含rank_group:{test_groups},共{len(test_set)}行")

代码说明

  • 先把所有rank_group去重并排序,保证分组顺序正确;
  • 按比例计算每个集合要包含的rank_group数量,你可以根据实际需求调整取整方式(比如四舍五入、向上/向下取整);
  • isin()方法筛选数据,确保同一个rank_group的所有行都被分到同一个集合里;
  • 逻辑简单直观,后期修改比例或者调整分组数量都很方便。

备注:内容来源于stack exchange,提问作者Carlo Allocca

火山引擎 最新活动