2016-11-26 1 views
4

Я использовал StratifiedKFold из scikit-learn, но теперь мне нужно также посмотреть «группы». Есть хорошая функция GroupKFold, но мои данные очень зависят от времени. Точно так же, как и в справке, т.е. количество недель - это индекс группировки. Но каждую неделю нужно только в один раз.Scikit-learn, GroupKFold с перетасовкой групп?

Предположим, мне нужно 10 складок. Мне нужно сначала перетасовать данные, прежде чем я смогу использовать GroupKFold.

Перемешивание в групповом режиме - так что все группы должны перемещаться между собой.

Есть ли способ сделать это с помощью scikit-learn elegant как-то? Кажется, GroupKFold надежно перенести данные сначала.

Если нет способа сделать это с помощью scikit, может ли кто-нибудь написать эффективный код этого? У меня большие наборы данных.

матрица, этикетки, группы в качестве входов

ответ

4

Я думаю, используя sklearn.utils.shuffle элегантное решение!

Для данных в X, Y и группы:

from sklearn.utils import shuffle 
X_shuffled, y_shuffled, groups_shuffled = shuffle(X, y, groups, random_state=0) 

Затем используйте X_shuffled, y_shuffled и groups_shuffled с GroupKFold:

from sklearn.model_selection import GroupKFold 
group_k_fold = GroupKFold(n_splits=10) 
splits = group_k_fold.split(X_shuffled, y_shuffled, groups_shuffled) 

Конечно, вы, вероятно, хотите, чтобы перетасовать несколько раз и сделать перекрестная проверка с каждой перетасовкой. Вы можете положить все это в цикл - вот полный пример с 5 перетасовками (и только 3 раза вместо ваших 10):

X = np.arange(20).reshape((10, 2)) 
y = np.arange(10) 
groups = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7] 

n_shuffles = 5 
group_k_fold = GroupKFold(n_splits=3) 

for i in range(n_shuffles): 
    X_shuffled, y_shuffled, groups_shuffled = shuffle(X, y, groups, random_state=i) 
    splits = group_k_fold.split(X_shuffled, y_shuffled, groups_shuffled) 
    # do something with splits here, I'm just printing them out 
    print 'Shuffle', i 
    print 'groups_shuffled:', groups_shuffled 
    for train_idx, val_idx in splits: 
     print 'Train:', train_idx 
     print 'Val:', val_idx 
Смежные вопросы