Implementation of k-means clustering algorithm

2019-04-02 12:08发布

In my program, i'm taking k=2 for k-mean algorithm i.e i want only 2 clusters. I have implemented in a very simple and straightforward way, still i'm unable to understand why my program is getting into infinite loop. can anyone please guide me where i'm making a mistake..?

for simplicity, i hav taken the input in the program code itself. here is my code :

import java.io.*;
import java.lang.*;
class Kmean
{
public static void main(String args[])
{
int N=9;
int arr[]={2,4,10,12,3,20,30,11,25};    // initial data
int i,m1,m2,a,b,n=0;
boolean flag=true;
float sum1=0,sum2=0;
a=arr[0];b=arr[1];
m1=a; m2=b;
int cluster1[]=new int[9],cluster2[]=new int[9];
for(i=0;i<9;i++)
    System.out.print(arr[i]+ "\t");
System.out.println();

do
{
 n++;
 int k=0,j=0;
 for(i=0;i<9;i++)
 {
    if(Math.abs(arr[i]-m1)<=Math.abs(arr[i]-m2))
    {   cluster1[k]=arr[i];
        k++;
    }
    else
    {   cluster2[j]=arr[i];
        j++;
    }
 }
    System.out.println();
    for(i=0;i<9;i++)
        sum1=sum1+cluster1[i];
    for(i=0;i<9;i++)
        sum2=sum1+cluster2[i];
    a=m1;
    b=m2;
    m1=Math.round(sum1/k);
    m2=Math.round(sum2/j);
    if(m1==a && m2==b)
        flag=false;
    else
        flag=true;

    System.out.println("After iteration "+ n +" , cluster 1 :\n");    //printing the clusters of each iteration
    for(i=0;i<9;i++)
        System.out.print(cluster1[i]+ "\t");

    System.out.println("\n");
    System.out.println("After iteration "+ n +" , cluster 2 :\n");
    for(i=0;i<9;i++)
        System.out.print(cluster2[i]+ "\t");

}while(flag);

    System.out.println("Final cluster 1 :\n");            // final clusters
    for(i=0;i<9;i++)
        System.out.print(cluster1[i]+ "\t");

    System.out.println();
    System.out.println("Final cluster 2 :\n");
    for(i=0;i<9;i++)
        System.out.print(cluster2[i]+ "\t");
 }
}

4条回答
孤傲高冷的网名
2楼-- · 2019-04-02 12:43
public class KMeansClustering {

public static void main(String args[]) {
    int arr[] = {2, 4, 10, 12, 3, 20, 30, 11, 25};    // initial data
    int i, m1, m2, a, b, n = 0;
    boolean flag;
    float sum1, sum2;
    a = arr[0];
    b = arr[1];
    m1 = a;
    m2 = b;
    int cluster1[] = new int[arr.length], cluster2[] = new int[arr.length];
    do {
        sum1 = 0;
        sum2 = 0;
        cluster1 = new int[arr.length];
        cluster2 = new int[arr.length];
        n++;
        int k = 0, j = 0;
        for (i = 0; i < arr.length; i++) {
            if (Math.abs(arr[i] - m1) <= Math.abs(arr[i] - m2)) {
                cluster1[k] = arr[i];
                k++;
            } else {
                cluster2[j] = arr[i];
                j++;
            }
        }
        System.out.println();
        for (i = 0; i < k; i++) {
            sum1 = sum1 + cluster1[i];
        }
        for (i = 0; i < j; i++) {
            sum2 = sum2 + cluster2[i];
        }
        //printing Centroids/Means\
        System.out.println("m1=" + m1 + "   m2=" + m2);
        a = m1;
        b = m2;
        m1 = Math.round(sum1 / k);
        m2 = Math.round(sum2 / j);
        flag = !(m1 == a && m2 == b);

        System.out.println("After iteration " + n + " , cluster 1 :\n");    //printing the clusters of each iteration
        for (i = 0; i < cluster1.length; i++) {
            System.out.print(cluster1[i] + "\t");
        }

        System.out.println("\n");
        System.out.println("After iteration " + n + " , cluster 2 :\n");
        for (i = 0; i < cluster2.length; i++) {
            System.out.print(cluster2[i] + "\t");
        }

    } while (flag);

    System.out.println("Final cluster 1 :\n");            // final clusters
    for (i = 0; i < cluster1.length; i++) {
        System.out.print(cluster1[i] + "\t");
    }

    System.out.println();
    System.out.println("Final cluster 2 :\n");
    for (i = 0; i < cluster2.length; i++) {
        System.out.print(cluster2[i] + "\t");
    }
}

}

This is working code.

查看更多
来,给爷笑一个
3楼-- · 2019-04-02 12:48

The only possible infinite loop is the do-while.

if(m1==a && m2==b)
    flag=false;
else
    flag=true;

You only exit the loop if flag is true. Breakpoint the if statement here and have a look to see why it is never getting set to false. Maybe add some debug print statements as well.

查看更多
Ridiculous、
4楼-- · 2019-04-02 12:51
package k;

/**
 *
 * @author Anooj.k.varghese
 */

import java.io.FileNotFoundException;
import java.io.File;
import java.util.Scanner;
public class K {


    /**
     * @param args the command line arguments
     */
    //GLOBAL VARIABLES
    //data_set[][] -------------datast is stored in the data_set[][] array
    //initial_centroid[][]------according to k'th value we select initaly k centroid.stored in the initial_centroid[][] 
    //                          value is assigned in the  'first_itration()' function
    private static double[][] arr;
    static int num = 0;
    static Double data_set[][]=new Double[20000][100];
    static Double diff[][]=new Double[20000][100];
    static Double intial_centroid[][]=new Double[300][400];
    static Double center_mean[][]=new Double[20000][100];
    static Double total_mean[]=new Double[200000];
    static int cnum;
    static int it=1;
    static int checker=1;
    static int row=4;//rows in Your DataSet here i use iris dataset 
     /////////////////////////////////reading the file/////////////////////////////////////
     // discriptin readFile readthe txt file
    private static void readFile() throws FileNotFoundException
        {
        Scanner scanner = new Scanner(new File("E:/aa.txt"));//Dataset path
        scanner.useDelimiter(System.getProperty("line.separator"));
        int lineNo = 0;
            while (scanner.hasNext())
             {
                parseLine(scanner.next(),lineNo);
                lineNo++;
                System.out.println();
             }
             // System.out.println("total"+num); PRINT THE TOTAL
     scanner.close();
        }
    //read file is copey to the data_set
    public static void parseLine(String line,int lineNo)
      { 
        Scanner lineScanner = new Scanner(line);
        lineScanner.useDelimiter(",");
          for(int col=0;col<row;col++)
              {
                  Double arry=lineScanner.nextDouble();
                  data_set[num][col]=arry;                          ///here read  data set is assign the variable data_set
               }
         num++;

        }
      public static void first_itration()
    {   double a = 0;
         System.out.println("ENTER CLUSTER NUMBER");
         Scanner sc=new Scanner(System.in);      
         cnum=sc.nextInt();   //enter the number of cenroid

         int result[]=new int[cnum];
        double re=0;

         System.out.println("centroid");
         for(int i=0;i<cnum;i++)
         {
            for(int j=0;j<row;j++)
                {
                    intial_centroid[i][j]=data_set[i][j];                  //// CENTROID ARE STORED IN AN intial_centroid variable
                    System.out.print(intial_centroid[i][j]);      
                }
            System.out.println();
         }
       System.out.println("------------");

       int counter1=0;
       for(int i=0;i<num;i++)
       {
            for(int j=0;j<row;j++)
                {
                      //System.out.println("hii");
                 System.out.print(data_set[i][j]);

                 }
       counter1++;
       System.out.println();
       }
           System.out.println("total="+counter1);                             //print the total number of data
           //----------------------------------

           ///////////////////EUCLIDEAN DISTANCE////////////////////////////////////
                                                                                /// find the Euclidean Distance
        for(int i=0;i<num;i++)
        {
                for(int j=0;j<cnum;j++)       
                {
                    re=0;
                     for(int k=0;k<row;k++)
                     {
                            a= (intial_centroid[j][k]-data_set[i][k]);
                            //System.out.println(a);
                             a=a*a;
                             re=re+a;                                                 // store the row sum

                        }

                         diff[i][j]= Math.sqrt(re);// find the squre root

        }
        }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

///////////////////////////////////////////////FIND THE SMALLEST VALUE////////////////////////////////////////////////
   double aaa;
   double counter;
     int ccc=1;
   for(int i=0;i<num;i++)
   {
         int c=1;
         counter=c;
         aaa=diff[i][0];
         for(int j=0;j<cnum;j++)
         {
          //System.out.println(diff[i][j]);

            if(aaa>=diff[i][j] )                                                //change
                {
                    aaa=diff[i][j];
                    counter=j;


                    // Jth value are stord in the counter variable 
               //   System.out.println(counter);
               }


         }

            data_set[i][row]=counter;                                        //assign the counter to last position of data set

            //System.out.println("--");
      }                                                                  //print the first itration
            System.out.println("**FIRST ITRATION**");

      for(int i=0;i<num;i++)
              {
                  for(int j=0;j<=row;j++)
                      {
                      //System.out.println("hii");
                              System.out.print(data_set[i][j]+ " ");
                       }
                  System.out.println();
              }

    it++;
    }


    public static void calck_mean()
    { 
        for(int i=0;i<20000;i++)
        {
            for(int j=0;j<100;j++)
            {
                center_mean[i][j]=0.0;
            }
        }


  double c = 0; 
     int a=0;
     int p;
     int abbb = 0;
        if(it%2==0)
         {
             abbb=row;
         }
        else if(it%2==1)
         {
             abbb=row+1;
          }
        for(int k=0;k<cnum;k++)
            {
                    double counter = 0;    
                    for(int i=0;i<num;i++)
                     {
                        for(int j=0;j<=row;j++)
                        {               
                            if(data_set[i][abbb]==a)
                            {
                            System.out.print(data_set[i][j]);
                            center_mean[k][j] += data_set[i][j];

                            }

                          }
                        System.out.println();
                      if(data_set[i][abbb]==a)
                        {
                            counter++;
                        }
                  System.out.println();
              }

         a++;
         total_mean[k]=counter;

         }
         for(int i=0;i<cnum;i++)
            {
            System.out.println("\n");
            for(int j=0;j<row;j++)
            {
              if(total_mean[i]==0)
              {
                   center_mean[i][j]=0.0;
              }
              else
              {
                center_mean[i][j]=center_mean[i][j]/total_mean[i];
              }
              }
        }
        for(int k=0;k<cnum;k++)
        {
            for(int j=0;j<row;j++)
            {
              //System.out.print(center_mean[k][j]);
            }
            System.out.println();

        }
       /* for(int j=0;j<cnum;j++)
        {
            System.out.println(total_mean[j]);
        }*/

    }
public static void kmeans1()
    {
       double  a = 0;
       int result[]=new int[cnum];
       double re=0;

  //// CENTROID ARE STORED IN AN data_set VARIABLE intial_centroid 
         System.out.println(" new centroid");
            for(int i=0;i<cnum;i++)
            {
                for(int j=0;j<row;j++)
                {
                    intial_centroid[i][j]=center_mean[i][j];
                    System.out.print(intial_centroid[i][j]);
                }
             System.out.println();
            }

   //----------------------------------------------JUST PRINT THE data_set

           //----------------------------------
        for(int i=0;i<num;i++)
        {
            for(int j=0;j<cnum;j++)
            {
             re=0;
             for(int k=0;k<row;k++)
             {

               a=(intial_centroid[j][k]-data_set[i][k]);
                 //System.out.println(a);
                a=a*a;        
               re=re+a;

                }

             diff[i][j]= Math.sqrt(re);
             //System.out.println(diff[i][j]);
            }
        }
   double aaa;
    double counter;
     for(int i=0;i<num;i++)
     {

         int c=1;
         counter=c;
          aaa=diff[i][0];
         for(int j=0;j<cnum;j++)
         {
            // System.out.println(diff[i][j]);
            if(aaa>=diff[i][j])                                                  //change
            {
               aaa=diff[i][j];
                counter=j;
               //   System.out.println(counter);
            }


         }


         if(it%2==0)
            {
        // abbb=4;
                data_set[i][row+1]=counter;
            }
         else if(it%2==1)
            {
                data_set[i][row]=counter;
      //   abbb=4;
            }


        //System.out.println("--");
     }
     System.out.println(it+" ITRATION**");

      for(int i=0;i<num;i++)
              {
                  for(int j=0;j<=row+1;j++)
                  {
                      //System.out.println("hii");
                      System.out.print(data_set[i][j]+" ");
                  }
                  System.out.println();
              }

    it++;
    }
public static void check()
{
    checker=0;
    for(int i=0;i<num;i++)
    {
         //System.out.println("hii");
        if(Double.compare(data_set[i][row],data_set[i][row+1]) != 0)
        {
            checker=1;
            //System.out.println("hii " + i  + " " + data_set[i][4]+ " "+data_set[i][4]);
            break;
        }
        System.out.println();
    }

}
public static void dispaly()
{

      System.out.println(it+" ITRATION**");

      for(int i=0;i<num;i++)
              {
                  for(int j=0;j<=row+1;j++)
                  {
                      //System.out.println("hii");
                      System.out.print(data_set[i][j]+" ");
                  }
                  System.out.println();
              }
}


 public static void print()
    {
        System.out.println();
         System.out.println();
          System.out.println();
        System.out.println("----OUTPUT----");
        int c=0;
        int a=0;
        for(int i=0;i<cnum;i++)
        {
            System.out.println("---------CLUSTER-"+i+"-----");
         a=0;
            for(int j=0;j<num;j++)
            {
                 if(data_set[j][row]==i)
                 {a++;
                for(int k=0;k<row;k++)
                {

                    System.out.print(data_set[j][k]+"  ");
                }
                c++;
                System.out.println();
                }
                 //System.out.println(num);

            }
               System.out.println("CLUSTER INSTANCES="+a);


        }
        System.out.println("TOTAL INSTANCE"+c);
    }


    public static void main(String[] args) throws FileNotFoundException 
    {
    readFile();
    first_itration();

    while(checker!=0)
            {
            calck_mean();
            kmeans1();
            check();
            } 
  dispaly();
  print();
    }




}


    ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
查看更多
别忘想泡老子
5楼-- · 2019-04-02 12:55

You have a bunch of errors:

  1. At the start of your do loop you should reset sum1 and sum2 to 0.

  2. You should loop until k and j respectively when calculating sum1 and sum2 (or clear cluster1 and cluster2 at the start of your do loop.

  3. In the calculation of sum2 you accidentally use sum1.

When I make those fixes the code runs fine, yielding the output:

Final cluster 1 :   
2   4   10   12  3   11  0   0   0

Final cluster 2 :
20  30  25   0   0   0   0   0   0

My general advise: learn how to use a debugger. Stackoverflow is not meant for questions like this: it is expected that you can find your own bugs and only come here when everything else fails...

查看更多
登录 后发表回答