2016-05-24 3 views
0

Я использую конвейер для группировки текстовых документов. Последний этап в конвейере - ml.clustering.KMeans, который дает мне DataFrame со столбцом предсказаний кластера. Я хотел бы добавить кластерные центры как столбцы. Я понимаю, что я могу выполнить Vector[] clusterCenters = kmeansModel.clusterCenters();, а затем преобразовать результаты в DataFrame и присоединиться указанных результатов к другому DataFrame однако я надеялся найти способ сделать это аналогично коду Kmeans ниже:искровой трубопровод KMeansModel clusterCenters

KMeans kMeans = new KMeans() 
       .setFeaturesCol("pca") 
       .setPredictionCol("kmeansclusterprediction") 
       .setK(5) 
       .setInitMode("random") 
       .setSeed(43L) 
       .setInitSteps(3) 
       .setMaxIter(15); 

pipeline.setStages(... 

Я смог расширить KMeans и вызвать метод fit через конвейер, но мне не удастся расширить KMeansModel ... для конструктора требуется String uid и KMeansModel, но я не знаю, как пройти в модели при определении этапы и вызов метода setStages.

Я также изучил расширение KMeans.scala, однако, как разработчик Java, я понимаю только половину кода таким образом, я надеюсь, что у кого-то может быть более легкое решение, прежде чем я решу это. В конце концов, я хотел бы закончить с DataFrame следующим образом:

+--------------------+-----------------------+--------------------+ 
|    docid|kmeansclusterprediction|kmeansclustercenters| 
+--------------------+-----------------------+--------------------+ 
|2bcbcd54-c11a-48c...|      2|  [-0.04, -7.72]| 
|0e644620-f5ff-40f...|      3|  [0.23, 1.08]| 
|665c1c2b-3065-4e8...|      3|  [0.23, 1.08]| 
|598c6268-e4b9-4c9...|      0|  [-15.81, 0.01]| 
+--------------------+-----------------------+--------------------+ 

Любая помощь или советы очень ценится. Спасибо

ответ

0

Отвечая на мой собственный вопрос ... это было на самом деле легко ... Я расширил KMeans и KMeansModel ... расширенный метод Kmeans fit должен вернуть расширенный KMeansModel. Например:

public class AnalyticsKMeansModel extends KMeansModel ... 


public class AnalyticsKMeans extends org.apache.spark.ml.clustering.KMeans { ... 

public AnalyticsKMeansModel fit(DataFrame dataset) { 

    JavaRDD<Vector> javaRDD = dataset.select(this.getFeaturesCol()).toJavaRDD().map(new Function<Row, Vector>(){ 
     private static final long serialVersionUID = -4588981547209486909L; 

     @Override 
     public Vector call(Row row) throws Exception { 
      Object point = row.getAs("pca"); 
      Vector vector = (Vector)point; 
      return vector; 
     } 

    }); 

    RDD<Vector> rdd = JavaRDD.toRDD(javaRDD); 
    org.apache.spark.mllib.clustering.KMeans algo = new org.apache.spark.mllib.clustering.KMeans().setK(BoxesRunTime.unboxToInt(this.$((Param<?>)this.k()))).setInitializationMode((String)this.$(this.initMode())).setInitializationSteps(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.initSteps()))).setMaxIterations(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.maxIter()))).setSeed(BoxesRunTime.unboxToLong((Object)this.$((Param<?>)this.seed()))).setEpsilon(BoxesRunTime.unboxToDouble((Object)this.$((Param<?>)this.tol()))); 
    org.apache.spark.mllib.clustering.KMeansModel parentModel = algo.run(rdd); 
    AnalyticsKMeansModel model = new AnalyticsKMeansModel(this.uid(), parentModel); 
    return (AnalyticsKMeansModel) this.copyValues((Params)model, this.copyValues$default$2()); 
} 

Как только я изменил метод подгонки, чтобы вернуть мой расширенный класс KMeansModel, все работало, как ожидалось.

+0

Эй Я пытаюсь сделать то же самое, но не могу понять ваш код, вы можете отправить полный код. – DnA

+0

@DnA, я добавил три класса выше ... Надеюсь, это поможет. Я не дотронулся до кода за долгое время, но считаю, что он был в рабочем состоянии, когда последний раз касался. Код использовался для учебных целей и никогда не использовался в производственной среде. – tullm

0
 import java.util.ArrayList; 
     import java.util.Arrays; 
     import java.util.List; 

     import org.apache.spark.api.java.JavaRDD; 
     import org.apache.spark.api.java.JavaSparkContext; 
     import org.apache.spark.api.java.function.Function; 
     import org.apache.spark.ml.clustering.KMeansModel; 
     import org.apache.spark.mllib.linalg.Vector; 
     import org.apache.spark.sql.DataFrame; 
     import org.apache.spark.sql.Row; 
     import org.apache.spark.sql.RowFactory; 
     import org.apache.spark.sql.types.DataTypes; 
     import org.apache.spark.sql.types.StructField; 
     import org.apache.spark.sql.types.StructType; 

     import AnalyticsCluster; 

     public class AnalyticsKMeansModel extends KMeansModel { 
      private static final long serialVersionUID = -8893355418042946358L; 

      public AnalyticsKMeansModel(String uid, org.apache.spark.mllib.clustering.KMeansModel parentModel) { 
       super(uid, parentModel); 
      } 

      public DataFrame transform(DataFrame dataset) { 

       Vector[] clusterCenters = super.clusterCenters(); 

       List<AnalyticsCluster> analyticsClusters = new ArrayList<AnalyticsCluster>(); 

       for (int i=0; i<clusterCenters.length;i++){ 
        Integer clusterId = super.predict(clusterCenters[i]); 
        Vector vector = clusterCenters[i]; 
        double[] point = vector.toArray(); 
        AnalyticsCluster analyticsCluster = new AnalyticsCluster(clusterId, point, 0L); 
        analyticsClusters.add(analyticsCluster); 
       } 

       JavaSparkContext jsc = JavaSparkContext.fromSparkContext(dataset.sqlContext().sparkContext()); 

       JavaRDD<AnalyticsCluster> javaRDD = jsc.parallelize(analyticsClusters); 

       JavaRDD<Row> javaRDDRow = javaRDD.map(new Function<AnalyticsCluster, Row>() { 
        private static final long serialVersionUID = -2677295862916670965L; 

        @Override 
        public Row call(AnalyticsCluster cluster) throws Exception { 
         Row row = RowFactory.create(
          String.valueOf(cluster.getID()), 
          String.valueOf(Arrays.toString(cluster.getCenter())) 
         ); 
         return row; 
        } 

       }); 

       List<StructField> schemaColumns = new ArrayList<StructField>(); 
       schemaColumns.add(DataTypes.createStructField(this.getPredictionCol(), DataTypes.StringType, false)); 
       schemaColumns.add(DataTypes.createStructField("clusterpoint", DataTypes.StringType, false)); 

       StructType dataFrameSchema = DataTypes.createStructType(schemaColumns); 

       DataFrame clusterPointsDF = dataset.sqlContext().createDataFrame(javaRDDRow, dataFrameSchema); 

       //SOMETIMES "K" IS SET TO A VALUE GREATER THAN THE NUMBER OF ACTUAL ROWS OF DATA ... GET DISTINCT VALUES 
       clusterPointsDF.registerTempTable("clusterPoints"); 
       DataFrame clustersDF = clusterPointsDF.sqlContext().sql("select distinct " + this.getPredictionCol()+ ", clusterpoint from clusterPoints"); 
       clustersDF.cache(); 
       clusterPointsDF.sqlContext().dropTempTable("clusterPoints"); 

       DataFrame transformedDF = super.transform(dataset); 
       transformedDF.cache(); 

       DataFrame df = transformedDF.join(clustersDF, 
         transformedDF.col(this.getPredictionCol()).equalTo(clustersDF.col(this.getPredictionCol())), "inner") 
          .drop(clustersDF.col(this.getPredictionCol())); 

       return df; 
      } 
     } 





    import org.apache.spark.api.java.JavaRDD; 
    import org.apache.spark.api.java.function.Function; 
    import org.apache.spark.ml.param.Param; 
    import org.apache.spark.ml.param.Params; 
    import org.apache.spark.mllib.linalg.Vector; 
    import org.apache.spark.rdd.RDD; 
    import org.apache.spark.sql.DataFrame; 
    import org.apache.spark.sql.Row; 

    import scala.runtime.BoxesRunTime; 

    public class AnalyticsKMeans extends org.apache.spark.ml.clustering.KMeans { 
     private static final long serialVersionUID = 8943702485821267996L; 
     private static String uid = null; 

     public AnalyticsKMeans(String uid){ 
      AnalyticsKMeans.uid= uid; 
     } 


     public AnalyticsKMeansModel fit(DataFrame dataset) { 

      JavaRDD<Vector> javaRDD = dataset.select(this.getFeaturesCol()).toJavaRDD().map(new Function<Row, Vector>(){ 
       private static final long serialVersionUID = -4588981547209486909L; 

       @Override 
       public Vector call(Row row) throws Exception { 
        Object point = row.getAs("pca"); 
        Vector vector = (Vector)point; 
        return vector; 
       } 

      }); 

      RDD<Vector> rdd = JavaRDD.toRDD(javaRDD); 
      org.apache.spark.mllib.clustering.KMeans algo = new org.apache.spark.mllib.clustering.KMeans().setK(BoxesRunTime.unboxToInt(this.$((Param<?>)this.k()))).setInitializationMode((String)this.$(this.initMode())).setInitializationSteps(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.initSteps()))).setMaxIterations(BoxesRunTime.unboxToInt((Object)this.$((Param<?>)this.maxIter()))).setSeed(BoxesRunTime.unboxToLong((Object)this.$((Param<?>)this.seed()))).setEpsilon(BoxesRunTime.unboxToDouble((Object)this.$((Param<?>)this.tol()))); 
      org.apache.spark.mllib.clustering.KMeansModel parentModel = algo.run(rdd); 
      AnalyticsKMeansModel model = new AnalyticsKMeansModel(this.uid(), parentModel); 
      return (AnalyticsKMeansModel) this.copyValues((Params)model, this.copyValues$default$2()); 
     } 

    } 




import java.io.Serializable; 
import java.util.Arrays; 

public class AnalyticsCluster implements Serializable { 
    private static final long serialVersionUID = 6535671221958712594L; 

    private final int id; 
    private volatile double[] center; 
    private volatile long count; 

    public AnalyticsCluster(int id, double[] center, long initialCount) { 
    //  Preconditions.checkArgument(center.length > 0); 
    //  Preconditions.checkArgument(initialCount >= 1); 
     this.id = id; 
     this.center = center; 
     this.count = initialCount; 
    } 

    public int getID() { 
     return id; 
    } 

    public double[] getCenter() { 
     return center; 
    } 

    public long getCount() { 
     return count; 
    } 

    public synchronized void update(double[] newPoint, long newCount) { 
     int length = center.length; 
    //  Preconditions.checkArgument(length == newPoint.length); 
     double[] newCenter = new double[length]; 
     long newTotalCount = newCount + count; 
     double newToTotal = (double) newCount/newTotalCount; 
     for (int i = 0; i < length; i++) { 
      double centerI = center[i]; 
      newCenter[i] = centerI + newToTotal * (newPoint[i] - centerI); 
     } 
     center = newCenter; 
     count = newTotalCount; 
    } 

    @Override 
    public synchronized String toString() { 
     return id + " " + Arrays.toString(center) + " " + count; 
    } 

// public static void main(String[] args) { 
//  double[] point = new double[2]; 
//  point[0] = 0.10150532938119154; 
//  point[1] = -0.23734759238651829; 
//  
//  Cluster cluster = new Cluster(1,point, 10L); 
//  System.out.println("cluster: " + cluster.toString()); 
// } 

} 
+0

Добавлен код выше (3 класса). Я больше не касался кода более года, и это была моя первая попытка изучить искру. Я считаю, что использовал искру 1.6 – tullm

Смежные вопросы