2016-11-01 2 views
2

Я пытаюсь реализовать схему перекрестной проверки по сгруппированным данным. Я надеялся использовать метод GroupKFold, но я все время получаю сообщение об ошибке. Что я делаю не так? Код (немного отличается от той, которую я использовал - я имел различные данные, так что я имел большие n_splits, но everythign еще одно и то же)Sklearn: перекрестная проверка для сгруппированных данных

from sklearn import metrics 
import matplotlib.pyplot as plt 
import numpy as np 
from sklearn.model_selection import GroupKFold 
from sklearn.grid_search import GridSearchCV 
from xgboost import XGBRegressor 
#generate data 
x=np.array([0,1,2,3,4,5,6,7,8,9,10,11,12,13]) 
y= np.array([1,2,3,4,5,6,7,1,2,3,4,5,6,7]) 
group=np.array([1,0,1,1,2,2,2,1,1,1,2,0,0,2)] 
#grid search 
gkf = GroupKFold(n_splits=3).split(x,y,group) 
subsample = np.arange(0.3,0.5,0.1) 
param_grid = dict(subsample=subsample) 
rgr_xgb = XGBRegressor(n_estimators=50) 
grid_search = GridSearchCV(rgr_xgb, param_grid, cv=gkf, n_jobs=-1) 
result = grid_search.fit(x, y) 

ошибка:

Traceback (most recent call last): 

File "<ipython-input-143-11d785056a08>", line 8, in <module> 
result = grid_search.fit(x, y) 

File "/home/student/anaconda/lib/python3.5/site-packages/sklearn/grid_search.py", line 813, in fit 
return self._fit(X, y, ParameterGrid(self.param_grid)) 

File "/home/student/anaconda/lib/python3.5/site-packages/sklearn/grid_search.py", line 566, in _fit 
n_folds = len(cv) 

TypeError: object of type 'generator' has no len() 

изменяя линия

gkf = GroupKFold(n_splits=3).split(x,y,group) 

в

gkf = GroupKFold(n_splits=3) 

также не работает. Сообщение об ошибке затем:

'GroupKFold' object is not iterable 
+1

Какую версию 'sklearn' у вас есть? Параметр 'cv'' GridSearchCV' обычно должен принимать генератор. –

ответ

11

split функция GroupKFoldурожайности индексы подготовки и тестирования одну пару за один раз. Вы должны позвонить list по значению раздельным, чтобы получить их в виде списка, так что длина может быть вычислена:

gkf = list(GroupKFold(n_splits=3).split(x,y,group)) 
+1

Я искал возрастов для этого ответа, спасибо! – Archie

Смежные вопросы