2016-12-23 1 views
1

Я работаю над проблемой анализа настроений для твитов, используя Spark с Scala. У меня есть рабочий вариант с использованием модели логистической регрессии следующим образом:Почему Naive Bayes не работает в Spark MLlib Pipeline, как Logistic Regression?

import org.apache.spark.mllib.regression.LabeledPoint 
import org.apache.spark.mllib.feature.HashingTF 
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD 
import org.apache.spark.sql.SQLContext 
import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerType}; 
import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} 
import org.apache.spark.mllib.util.MLUtils 
import org.apache.spark.ml.feature.{CountVectorizer, RegexTokenizer, StopWordsRemover} 
import org.apache.spark.sql.functions._ 
import org.apache.spark.ml.classification.LogisticRegression 
import org.apache.spark.ml.Pipeline 
import org.apache.spark.ml.feature.Word2Vec 
import org.apache.spark.mllib.evaluation.RegressionMetrics 
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics 

val sqlContext = new SQLContext(sc) 

// Sentiment140 training corpus 
val trainFile = "s3://someBucket/training.1600000.processed.noemoticon.csv" 
val swFile = "s3://someBucket/stopwords.txt" 
val tr = sc.textFile(trainFile) 
val stopwords: Array[String] = sc.textFile(swFile).flatMap(_.stripMargin.split("\\s+")).collect ++ Array("rt") 

val parsed = tr.filter(_.contains("\",\"")).map(_.split("\",\"").map(_.replace("\"", ""))).filter(row => row.forall(_.nonEmpty)).map(row => (row(0).toDouble, row(5))).filter(row => row._1 != 2).map(row => (row._1/4, row._2)) 
val pDF = parsed.toDF("label","tweet") 
val tokenizer = new RegexTokenizer().setGaps(false).setPattern("\\p{L}+").setInputCol("tweet").setOutputCol("words") 
val filterer = new StopWordsRemover().setStopWords(stopwords).setCaseSensitive(false).setInputCol("words").setOutputCol("filtered") 
val countVectorizer = new CountVectorizer().setInputCol("filtered").setOutputCol("features") 

val lr = new LogisticRegression().setMaxIter(50).setRegParam(0.2).setElasticNetParam(0.0) 
val pipeline = new Pipeline().setStages(Array(tokenizer, filterer, countVectorizer, lr)) 

val lrModel = pipeline.fit(pDF) 

// Now model is made. Lets get some test data... 

val testFile = "s3://someBucket/testdata.manual.2009.06.14.csv" 
val te = sc.textFile(testFile) 
val teparsed = te.filter(_.contains("\",\"")).map(_.split("\",\"").map(_.replace("\"", ""))).filter(row => row.forall(_.nonEmpty)).map(row => (row(0).toDouble, row(5))).filter(row => row._1 != 2).map(row => (row._1/4, row._2)) 
val teDF = teparsed.toDF("label","tweet") 

val res = lrModel.transform(teDF) 
val restup = res.select("label","prediction").rdd.map(r => (r(1).asInstanceOf[Double], r(0).asInstanceOf[Double])) 
val metrics = new BinaryClassificationMetrics(restup) 

metrics.areaUnderROC() 

Используя логистическую регрессию, это возвращает совершенно нормальное АУК. Однако при переходе от логистической регрессии к Валу нб = новые NaiveBayes(), я получаю следующее сообщение об ошибке:

found : org.apache.spark.mllib.classification.NaiveBayes 
required: org.apache.spark.ml.PipelineStage 
    val pipeline = new Pipeline().setStages(Array(tokenizer, filterer, countVectorizer, nb)) 

В консультации с API документации на MLlib PipelineStage списках как логистической регрессии и наивного байесовского оба перечисленных как подклассы. Итак, почему LR работает, а не NB?

ответ

2

Это не работает, потому что вы используете неправильный класс. С Трубопроводы используют:

org.apache.spark.ml.NaiveBayes 

и консультации the documentation для правильного синтаксиса.

+0

Ах. Конвейеры не работают для более старых пакетов .mllib (которые я использовал с NB из некоторого унаследованного кода), но работают для пакетов .ml (которые я использовал для модели LR). Одна небольшая коррекция выше ... это org.apache.spark.ml.classification.NaiveBayes. –

+0

Интересно, почему это займет мое неряшливое использование пакетов с HashingTF, также выходящее из mllib вместо ml? Ах хорошо. :) –

Смежные вопросы