2016-01-08 2 views
8

Для того, чтобы построить MultiClass классификатор NaiveBayes, я использую CrossValidator, чтобы выбрать лучшие параметры в моем трубопроводе:СПАРК, ML, Тюнинг, CrossValidator: доступ метрики

val cv = new CrossValidator() 
     .setEstimator(pipeline) 
     .setEstimatorParamMaps(paramGrid) 
     .setEvaluator(new MulticlassClassificationEvaluator) 
     .setNumFolds(10) 

val cvModel = cv.fit(trainingSet) 

Трубопровод содержит обычные трансформаторы и оценки в следующем порядке: Tokenizer, StopWordsRemover, HashingTF, IDF и, наконец, NaiveBayes.

Возможно ли получить доступ к метрикам, рассчитанным для лучшей модели?

В идеале я хотел бы получить доступ к метрикам всех моделей, чтобы увидеть, как изменение параметров изменяет качество классификации. Но на данный момент лучшая модель достаточно хороша.

FYI, я использую Спарк 1.6.0

ответ

6

Вот как я это делаю:

val pipeline = new Pipeline() 
    .setStages(Array(tokenizer, stopWordsFilter, tf, idf, word2Vec, featureVectorAssembler, categoryIndexerModel, classifier, categoryReverseIndexer)) 

... 

val paramGrid = new ParamGridBuilder() 
    .addGrid(tf.numFeatures, Array(10, 100)) 
    .addGrid(idf.minDocFreq, Array(1, 10)) 
    .addGrid(word2Vec.vectorSize, Array(200, 300)) 
    .addGrid(classifier.maxDepth, Array(3, 5)) 
    .build() 

paramGrid.size // 16 entries 

... 

// Print the average metrics per ParamGrid entry 
val avgMetricsParamGrid = crossValidatorModel.avgMetrics 

// Combine with paramGrid to see how they affect the overall metrics 
val combined = paramGrid.zip(avgMetricsParamGrid) 

... 

val bestModel = crossValidatorModel.bestModel.asInstanceOf[PipelineModel] 

// Explain params for each stage 
val bestHashingTFNumFeatures = bestModel.stages(2).asInstanceOf[HashingTF].explainParams 
val bestIDFMinDocFrequency = bestModel.stages(3).asInstanceOf[IDFModel].explainParams 
val bestWord2VecVectorSize = bestModel.stages(4).asInstanceOf[Word2VecModel].explainParams 
val bestDecisionTreeDepth = bestModel.stages(7).asInstanceOf[DecisionTreeClassificationModel].explainParams 
+1

почтовый индекс работает, но я действительно не так как он предполагает внутренние знания о том, как работает CrossValidator. Они могли бы изменить способ создания массива показателей, чтобы он был в другом порядке для следующей версии, и вы закрыты, но не знаете, что использовали, потому что ваш код все еще работает. Я хотел бы иметь параметры для модели, возвращенной с ее метрикой. Я также хотел бы видеть сводную статистику, а не просто среднюю. Насколько полезно среднее значение без стандартного отклонения? – Turbo

0
cvModel.avgMetrics 

работа в pyspark 2.2.0

Смежные вопросы