加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Java > 正文

K-means算法(Spark Demo)

发布时间:2020-12-15 00:28:06 所属栏目:Java 来源:网络整理
导读:今天PHP站长网 52php.cn把收集自互联网的代码分享给大家,仅供参考。 import java.util.Randomimport spark.SparkContextimport spark.SparkContext._import spark.examples.Vector._ object SparkKMeans { /** * line -

以下代码由PHP站长网 52php.cn收集自互联网

现在PHP站长网小编把它分享给大家,仅供参考

import java.util.Random
import spark.SparkContext
import spark.SparkContext._
import spark.examples.Vector._
 
object SparkKMeans {
    /**
     * line -> vector
     */
def parseVector (line: String) : Vector = {
        return new Vector (line.split (' ').map (_.toDouble) )
    }
 
    /**
     * 计算该节点的最近中心节点
     */
def closestCenter (p: Vector,centers: Array[Vector]) : Int = {
        var bestIndex = 0
        var bestDist = p.squaredDist (centers (0) ) //差平方之和
        for (i < - 1 until centers.length) {
            val dist = p.squaredDist (centers (i) )
            if (dist < bestDist) {
                bestDist = dist
                bestIndex = i
            }
        }
        return bestIndex
    }
 
def main (args: Array[String]) {
        if (args.length < 3) {
            System.err.println ("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>")
            System.exit (1)
        }
        val sc = new SparkContext (args (0),"SparkKMeans")
        val lines = sc.textFile (args (1),args (5).toInt)
                    val points = lines.map (parseVector (_) ).cache() //文本中每行为一个节点,再将每个节点转换成Vector
                                 val dimensions = args (2).toInt //节点的维度
                                         val k = args (3).toInt //聚类个数
                                                 val iterations = args (4).toInt //迭代次数
 
                                                         // 随机初始化k个中心节点
                                                         val rand = new Random (42)
        var centers = new Array[Vector] (k)
        for (i < - 0 until k)
            centers (i) = Vector (dimensions,_ => 2 * rand.nextDouble - 1)
                          println ("Initial centers: " + centers.mkString (",") )
                          val time1 = System.currentTimeMillis()
            for (i < - 1 to iterations) {
                println ("On iteration " + i)
 
                // Map each point to the index of its closest center and a (point,1) pair
                // that we will use to compute an average later
                val mappedPoints = points.map { p => (closestCenter (p,centers),(p,1) ) }
 
                val newCenters = mappedPoints.reduceByKey {
                case ( (sum1,count1),(sum2,count2) ) => (sum1 + sum2,count1 + count2) //(向量相加,计数器相加)
                    } .map {
                case (id,(sum,count) ) => (id,sum / count) //根据前面的聚类,重新计算中心节点的位置
                    } .collect
 
                // 更新中心节点
                for ( (id,value) < - newCenters) {
                    centers (id) = value
                }
            }
                       val time2 = System.currentTimeMillis()
                                   println ("Final centers: " + centers.mkString (",") + ",time: " + (time2 - time1) )
    }
}

以上内容由PHP站长网【52php.cn】收集整理供大家参考研究

如果以上内容对您有帮助,欢迎收藏、点赞、推荐、分享。

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读