如何基于rank_group列按70%/20%/10%比例拆分数据集为训练集、验证集和测试集
如何基于rank_group列按70%/20%/10%比例拆分数据集为训练集、验证集和测试集
问题背景
你手头有一份带rank_group分组标记的数据集,希望按70%训练集、20%验证集、10%测试集的比例拆分,且要求同一个rank_group下的所有行必须完整落在同一个数据集中(比如你的示例里,训练集要包含rank_group=1.0~5.0的所有行,验证集包含6.0~7.0,测试集包含8.0)。
先看一下你提供的原始数据集:
date user f1 f2 rank rank_group counts 0 09/09/2021 USER100 59.0 3599.9 1 1.0 3 1 10/09/2021 USER100 75.29 80790.0 2 1.0 3 2 11/09/2021 USER100 75.29 80790.0 3 1.0 3 1 10/09/2021 USER100 75.29 80790.0 2 2.0 3 2 11/09/2021 USER100 75.29 80790.0 3 2.0 3 3 12/09/2021 USER100 75.29 80790.0 4 2.0 3 2 11/09/2021 USER100 75.29 80790.0 3 3.0 3 3 12/09/2021 USER100 75.29 80790.0 4 3.0 3 4 13/09/2021 USER100 75.29 80790.0 5 3.0 3 3 12/09/2021 USER100 75.29 80790.0 4 4.0 3 4 13/09/2021 USER100 75.29 80790.0 5 4.0 3 5 14/09/2021 USER100 75.29 80790.0 6 4.0 3 4 13/09/2021 USER100 75.29 80790.0 5 5.0 3 5 14/09/2021 USER100 75.29 80790.0 6 5.0 3 6 15/09/2021 USER100 71.24 28809.9 7 5.0 3 5 14/09/2021 USER100 75.29 80790.0 6 6.0 3 6 15/09/2021 USER100 71.24 28809.9 7 6.0 3 7 16/09/2021 USER100 71.31 79209.9 8 6.0 3 6 15/09/2021 USER100 71.24 28809.9 7 7.0 3 7 16/09/2021 USER100 71.31 79209.9 8 7.0 3 8 17/09/2021 USER100 70.43 82809.9 9 7.0 3 7 16/09/2021 USER100 71.31 79209.9 8 8.0 3 8 17/09/2021 USER100 70.43 82809.9 9 8.0 3 9 18/09/2021 USER100 68.65 82809.9 10 8.0 3
现有方法的问题
先说说你尝试的两种方法存在的问题:
- 方法I(np.split):这个函数是按行的索引位置直接拆分,完全不考虑
rank_group分组。会导致同一个rank_group的行被拆到训练、验证或测试集里,完全不符合你“同组行必须在同一个集合”的要求。 - 方法II(自定义逻辑):逻辑过于复杂,而且计算和范围判断容易出错(比如
validation_number的计算、range的起止值),后期维护起来很麻烦。
简洁可靠的解决方案
其实我们可以换个思路:先把rank_group的唯一值按顺序划分好,再根据分组筛选数据,逻辑清晰且不易出错。以下是基于Pandas的实现(因为你的数据是表格格式,Pandas处理最方便):
import pandas as pd # 1. 读取或构造你的数据集(这里假设已经是DataFrame格式) # user_dataset = pd.read_csv("your_data.csv") # 实际场景可以用这个读取数据 # 2. 获取排序后的唯一rank_group列表 unique_groups = sorted(user_dataset['rank_group'].unique()) total_groups = len(unique_groups) print(f"所有rank_group:{unique_groups},共{total_groups}个") # 3. 按比例计算各集合对应的group数量(这里和你的期望完全匹配) train_group_count = 5 # 8*0.7=5.6,取整为5,对应group1-5 val_group_count = 2 # 8*0.2=1.6,取整为2,对应group6-7 test_group_count = 1 # 剩下的1个,对应group8 # 如果你想通用化(适配不同数量的group),可以用下面的计算方式: # train_group_count = int(total_groups * 0.7) # val_group_count = int(total_groups * 0.2) # test_group_count = total_groups - train_group_count - val_group_count # 4. 划分各集合对应的group列表 train_groups = unique_groups[:train_group_count] val_groups = unique_groups[train_group_count:train_group_count+val_group_count] test_groups = unique_groups[train_group_count+val_group_count:] # 5. 根据group列表筛选数据 train_set = user_dataset[user_dataset['rank_group'].isin(train_groups)] validation_set = user_dataset[user_dataset['rank_group'].isin(val_groups)] test_set = user_dataset[user_dataset['rank_group'].isin(test_groups)] # 验证结果 print(f"训练集包含rank_group:{train_groups},共{len(train_set)}行") print(f"验证集包含rank_group:{val_groups},共{len(validation_set)}行") print(f"测试集包含rank_group:{test_groups},共{len(test_set)}行")
代码说明
- 先把所有
rank_group去重并排序,保证分组顺序正确; - 按比例计算每个集合要包含的
rank_group数量,你可以根据实际需求调整取整方式(比如四舍五入、向上/向下取整); - 用
isin()方法筛选数据,确保同一个rank_group的所有行都被分到同一个集合里; - 逻辑简单直观,后期修改比例或者调整分组数量都很方便。
备注:内容来源于stack exchange,提问作者Carlo Allocca




