Разделение данных на 3 выборки
Подскажите, пожалуйста, как разбить данные на обучающую, валидационную и тестовую выборки в пропорции 60:20:20, сейчас получается так:
df_train, df_test = train_test_split(df, test_size=0.2, random_state=12345)
features_val, features_test, target_val, target_test = train_test_split(features, target, test_size=0.25, random_state=12345)
Ответы (2 шт):
Ну, например, вот так:
df=pd.DataFrame({'A':[0,1,2,3,4,5,6,7,8,9],'B':[11,23,45,67,89,32,43,54,65,76]})
df_train, df_test1 =train_test_split(df, test_size=0.4, random_state=42)
df_test, df_val=train_test_split(df_test1, test_size=0.5, random_state=42)
print(df_train)
print(df_test)
print(df_val)
Результат:
A B
7 7 54
2 2 45
9 9 76
4 4 89
3 3 67
6 6 43
A B
8 8 65
5 5 32
A B
1 1 23
0 0 11
Как видите - пропорция 60:20:20, как просили.
Задача для произвольных размеров выборок
Заданы размеры в процентах для тренировочной и валидационной выборки, например:
train_size = 60
valid_size = 20
Тестовую выборку, в таком случае, мы получим из того куска датасета, который останется от тренировочной и валидационной: test_size = 100 - train_size - valid_size.
Решение
Используем метод numpy.split(). Он принимает на вход датафрейм и пороги срезов. Возвращает разрезанные куски датафрейма. Предварительно, данные в датафрейме требуется перемешать с помощью pandas.DataFrame.sample()
Итак:
import numpy as np
import pandas as pd
def triple_split(data, t_size, v_size):
return np.split(
data.sample(frac=1, random_state=123),
[int(len(data)*t_size/100),
int(len(data)*(t_size+v_size)/100)]
)
dic = {
'A': np.random.randint(10, 99, 20),
'B': np.random.randint(10, 99, 20),
'C': np.random.randint(10,99, 20)
}
df = pd.DataFrame.from_dict(dic)
print(df)
train_size = 60
val_size = 20
train, valid, test = triple_split(df, train_size, val_size)
print(train)
print(valid)
print(test)
Важно
Метод train_test_split() опционально поддерживает параметр стратификации, который отвечает за распределение выборки по классам. Если он будет необходим - тогда останется только применять train_test_split() два раза, как указано в соседнем ответе.