2016-03-14 5 views
1

У меня есть RDD в следующем формате и хотели бы, чтобы преобразовать его в LabeledPoint РДУ, чтобы обработать его в mllib:Преобразование [(Int, Seq [Double])] РДД к LabeledPoint

Test: RDD[(Int, Seq[Double])] = Array((1,List(1.0,3.0,8.0),(2,List(3.0, 3.0,8.0),(1,List(2.0,3.0,7.0),(1,List(5.0,5.0,9.0)) 

Я попытался с картой

import org.apache.spark.mllib.linalg.{Vector, Vectors} 
import org.apache.spark.mllib.regression.LabeledPoint 
Test.map(x=> LabeledPoint(x._1, Vectors.sparse(x._2))) 

, но я получаю эту ошибку

mllib.linalg.Vector cannot be applied to (Seq[scala.Double]) 

Так предположительно элемент Seq должен быть преобразован пихта но я не знаю, в чем.

ответ

1

Есть несколько проблем здесь:

  • этикетка должна быть Double не Int
  • SparseVector требует количества элементов, индексы и значение
  • ни один из векторных конструкторов не принимает список Double
  • ваши данные выглядят плотно не редкими

Одно из возможных решений:

val rdd = sc.parallelize(Array(
    (1, List(1.0,3.0,8.0)), 
    (2, List(3.0, 3.0,8.0)), 
    (1, List(2.0,3.0,7.0)), 
    (1, List(5.0,5.0,9.0)))) 

rdd.map { case (k, vs) => 
    LabeledPoint(k.toDouble, Vectors.dense(vs.toArray)) 
} 

и другое:

rdd.collect { case (k, v::vs) => 
    LabeledPoint(k.toDouble, Vectors.dense(v, vs: _*)) } 
1

Как вы можете заметить в LabeledPoint's documentation его конструктор получает Double в качестве метки и Vector как особенности (DenseVector или SparseVector). Однако, если вы посмотрите на оба конструктора наследуемых классов, они получат Array, поэтому вам необходимо преобразовать Seq в Array.

import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseVector} 
import org.apache.spark.mllib.regression.LabeledPoint 

val rdd = sc.parallelize(Array((1, Seq(1.0,3.0,8.0)), 
           (2, Seq(3.0, 3.0,8.0)), 
           (1, Seq(2.0,3.0, 7.0)), 
           (1, Seq(5.0, 5.0, 9.0)))) 
val x = rdd.map{ 
    case (a: Int, b:Seq[Double]) => LabeledPoint(a, new DenseVector(b.toArray)) 
} 

x.take(2).foreach(println) 

//(1.0,[1.0,3.0,8.0]) 
//(2.0,[3.0,3.0,8.0]) 
Смежные вопросы