2015-02-20 4 views
1

Я пытаюсь выполнить поиск сетки в RF-классификаторе, где функция подсчета является precision_score из модуля sklearn.metrics. Это код.проблема с scorer в GridSearchCV в sklearn

from sklearn.metrics import precision_score 

param_grid = {'n_estimators': [51, 101, 201, 301, 501], 
       'max_depth': [3, 5, 10, None], 
       'min_samples_split': [2, 5, 10], 
       'criterion': ['gini', 'entropy'], 
       'bootstrap': [True, False]} 

def fit_gridCV_RFclassifier(param_grid): 
    from sklearn.ensemble import RandomForestClassifier 
    rf = RandomForestClassifier() 
    clf = GridSearchCV(estimator=rf, param_grid=param_grid, 
         cv=5, scoring=precision_score, 
         refit=True) 
    clf.fit(train_X, train_y) 
    return clf 

gridsearch_rf = fit_gridCV_RFclassifier(param_grid) 

Во время работы функции, я получаю следующую ошибку

--------------------------------------------------------------------------- 
ValueError        Traceback (most recent call last) 
<ipython-input-34-6f91362a017c> in <module>() 
----> 1 gridsearch_rf = fit_gridCV_RFclassifier(param_grid) 

<ipython-input-33-974d026d5dc8> in fit_gridCV_RFclassifier(param_grid) 
    11      scoring=precision_score, 
    12      cv=5, refit=True) 
---> 13  clf.fit(train_X, train_y) 
    14  return clf 

/anaconda/lib/python2.7/site-packages/sklearn/grid_search.pyc in fit(self, X, y) 
    594 
    595   """ 
--> 596   return self._fit(X, y, ParameterGrid(self.param_grid)) 
    597 
    598 

/anaconda/lib/python2.7/site-packages/sklearn/grid_search.pyc in _fit(self, X, y, parameter_iterable) 
    376          train, test, self.verbose, parameters, 
    377          self.fit_params, return_parameters=True) 
--> 378    for parameters in parameter_iterable 
    379    for train, test in cv) 
    380 

/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in __call__(self, iterable) 
    651    self._iterating = True 
    652    for function, args, kwargs in iterable: 
--> 653     self.dispatch(function, args, kwargs) 
    654 
    655    if pre_dispatch == "all" or n_jobs == 1: 

/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in dispatch(self, func, args, kwargs) 
    398   """ 
    399   if self._pool is None: 
--> 400    job = ImmediateApply(func, args, kwargs) 
    401    index = len(self._jobs) 
    402    if not _verbosity_filter(index, self.verbose): 

/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in __init__(self, func, args, kwargs) 
    136   # Don't delay the application, to avoid keeping the input 
    137   # arguments in memory 
--> 138   self.results = func(*args, **kwargs) 
    139 
    140  def get(self): 

/anaconda/lib/python2.7/site-packages/sklearn/cross_validation.pyc in _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters) 
    1238  else: 
    1239   estimator.fit(X_train, y_train, **fit_params) 
-> 1240  test_score = _score(estimator, X_test, y_test, scorer) 
    1241  if return_train_score: 
    1242   train_score = _score(estimator, X_train, y_train, scorer) 

/anaconda/lib/python2.7/site-packages/sklearn/cross_validation.pyc in _score(estimator, X_test, y_test, scorer) 
    1294   score = scorer(estimator, X_test) 
    1295  else: 
-> 1296   score = scorer(estimator, X_test, y_test) 
    1297  if not isinstance(score, numbers.Number): 
    1298   raise ValueError("scoring must return a number, got %s (%s) instead." 

/anaconda/lib/python2.7/site-packages/sklearn/metrics/metrics.pyc in precision_score(y_true, y_pred, labels, pos_label, average, sample_weight) 
    1883             average=average, 
    1884             warn_for=('precision',), 
-> 1885             sample_weight=sample_weight) 
    1886  return p 
    1887 

/anaconda/lib/python2.7/site-packages/sklearn/metrics/metrics.pyc in precision_recall_fscore_support(y_true, y_pred, beta, labels, pos_label, average, warn_for, sample_weight) 
    1667   raise ValueError("beta should be >0 in the F-beta score") 
    1668 
-> 1669  y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred) 
    1670 
    1671  label_order = labels # save this for later 

/anaconda/lib/python2.7/site-packages/sklearn/metrics/metrics.pyc in _check_clf_targets(y_true, y_pred) 
    107  y_pred : array or indicator matrix 
    108  """ 
--> 109  y_true, y_pred = check_arrays(y_true, y_pred, allow_lists=True) 
    110  type_true = type_of_target(y_true) 
    111  type_pred = type_of_target(y_pred) 

/anaconda/lib/python2.7/site-packages/sklearn/utils/validation.pyc in check_arrays(*arrays, **options) 
    252   if size != n_samples: 
    253    raise ValueError("Found array with dim %d. Expected %d" 
--> 254        % (size, n_samples)) 
    255 
    256   if not allow_lists or hasattr(array, "shape"): 

ValueError: Found array with dim 317760. Expected 51 

Похоже, что ошибка исходит от функции выигрыша. Любая помощь будет оценена по достоинству. Благодарю.

Мой scikit-узнать версию: 0.15.2

+0

вы можете разместить свой код с игрушечным набором данных? Трудно сказать, что происходит в данный момент. – JAB

ответ

4

Параметр "скоринг" принимает (docs)

скоринга: строка, вызываемая или None, необязательный, по умолчанию: Нет

A string (see model evaluation documentation) or a scorer callable object/function with signature scorer(estimator, X, y). 

Функция «precision_score» имеет другую подпись. То, что вы должны сделать, это просто дать строку, как «точность» является одной из сборки в метриках (docs):

clf = GridSearchCV(estimator=rf, param_grid=param_grid, 
        cv=5, scoring="precision", 
        refit=True) 
+0

Ах .. Спасибо, что сработало. – Nitin

Смежные вопросы