Parallel Matrix Multiplication in Java 6

2019-01-25 19:40发布

Yesterday I asked a question about parallel matrix multiplication in Java 7 using the fork/join framework here. With the help of axtavt I got my example program to work. Now I’m implementing an equivalent program using Java 6 functionality only. I get the same problem as yesterday, dispite applying the the feedback axtavt gave me (I think). Am I overlooking something? Code:

package algorithms;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

public class Java6MatrixMultiply implements Algorithm {

    private static final int SIZE = 1024;
    private static final int THRESHOLD = 64;
    private static final int MAX_THREADS = Runtime.getRuntime().availableProcessors();

    private final ExecutorService executor = Executors.newFixedThreadPool(MAX_THREADS);

    private float[][] a = new float[SIZE][SIZE];
    private float[][] b = new float[SIZE][SIZE];
    private float[][] c = new float[SIZE][SIZE];

    @Override
    public void initialize() {
        init(a, b, SIZE);
    }

    @Override
    public void execute() {
        MatrixMultiplyTask task =  new MatrixMultiplyTask(a, 0, 0, b, 0, 0, c, 0, 0, SIZE);
        task.split();

        executor.shutdown();    
        try {
            executor.awaitTermination(Integer.MAX_VALUE, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            System.out.println("Error: " + e.getMessage());
        }
    }

    @Override
    public void printResult() {
        check(c, SIZE);

        for (int i = 0; i < SIZE && i <= 10; i++) {
            for (int j = 0; j < SIZE && j <= 10; j++) {         
                if(j == 10) {
                    System.out.print("...");
                }
                else {
                    System.out.print(c[i][j] + " ");
                }
            }

            if(i == 10) {
                System.out.println();
                for(int k = 0; k < 10; k++) System.out.print(" ... ");
            }   

            System.out.println();
        }       

        System.out.println();
    }

    // To simplify checking, fill with all 1's. Answer should be all n's.
    static void init(float[][] a, float[][] b, int n) {
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                a[i][j] = 1.0F;
                b[i][j] = 1.0F;
            }
        }
    }

    static void check(float[][] c, int n) {
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if (c[i][j] != n) {
                    throw new Error("Check Failed at [" + i + "][" + j + "]: " + c[i][j]);
                    //System.out.println("Check Failed at [" + i + "][" + j + "]: " + c[i][j]); 
                }
            }
        }       
    }   

    public class Seq implements Runnable {

        private final MatrixMultiplyTask a;
        private final MatrixMultiplyTask b;

        public Seq(MatrixMultiplyTask a, MatrixMultiplyTask b, int size) {
            this.a = a;
            this.b = b;

            if (size <= THRESHOLD) {
                executor.submit(this);
            } else {            
                a.split();
                b.split();
            }
        }

        public void run() {
            a.multiplyStride2();
            b.multiplyStride2();
        }   
    }

    private class MatrixMultiplyTask {
        private final float[][] A; // Matrix A
        private final int aRow; // first row of current quadrant of A
        private final int aCol; // first column of current quadrant of A

        private final float[][] B; // Similarly for B
        private final int bRow;
        private final int bCol;

        private final float[][] C; // Similarly for result matrix C
        private final int cRow;
        private final int cCol;

        private final int size;

        MatrixMultiplyTask(float[][] A, int aRow, int aCol, float[][] B,
                int bRow, int bCol, float[][] C, int cRow, int cCol, int size) {

            this.A = A;
            this.aRow = aRow;
            this.aCol = aCol;
            this.B = B;
            this.bRow = bRow;
            this.bCol = bCol;
            this.C = C;
            this.cRow = cRow;
            this.cCol = cCol;
            this.size = size;
        }   

        public void split() {
            int h = size / 2;

            new Seq(new MatrixMultiplyTask(A,
                    aRow, aCol, // A11
                    B, bRow, bCol, // B11
                    C, cRow, cCol, // C11
                    h),

            new MatrixMultiplyTask(A, aRow, aCol + h, // A12
                    B, bRow + h, bCol, // B21
                    C, cRow, cCol, // C11
                    h), h);

            new Seq(new MatrixMultiplyTask(A,
                    aRow, aCol, // A11
                    B, bRow, bCol + h, // B12
                    C, cRow, cCol + h, // C12
                    h),

            new MatrixMultiplyTask(A, aRow, aCol + h, // A12
                    B, bRow + h, bCol + h, // B22
                    C, cRow, cCol + h, // C12
                    h), h);

            new Seq(new MatrixMultiplyTask(A, aRow
                    + h, aCol, // A21
                    B, bRow, bCol, // B11
                    C, cRow + h, cCol, // C21
                    h),

            new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22
                    B, bRow + h, bCol, // B21
                    C, cRow + h, cCol, // C21
                    h), h);

            new Seq(new MatrixMultiplyTask(A, aRow
                    + h, aCol, // A21
                    B, bRow, bCol + h, // B12
                    C, cRow + h, cCol + h, // C22
                    h),

            new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22
                    B, bRow + h, bCol + h, // B22
                    C, cRow + h, cCol + h, // C22
                    h), h);
        }

        public void multiplyStride2() {
            for (int j = 0; j < size; j += 2) {
                for (int i = 0; i < size; i += 2) {

                    float[] a0 = A[aRow + i];
                    float[] a1 = A[aRow + i + 1];

                    float s00 = 0.0F;
                    float s01 = 0.0F;
                    float s10 = 0.0F;
                    float s11 = 0.0F;

                    for (int k = 0; k < size; k += 2) {

                        float[] b0 = B[bRow + k];

                        s00 += a0[aCol + k] * b0[bCol + j];
                        s10 += a1[aCol + k] * b0[bCol + j];
                        s01 += a0[aCol + k] * b0[bCol + j + 1];
                        s11 += a1[aCol + k] * b0[bCol + j + 1];

                        float[] b1 = B[bRow + k + 1];

                        s00 += a0[aCol + k + 1] * b1[bCol + j];
                        s10 += a1[aCol + k + 1] * b1[bCol + j];
                        s01 += a0[aCol + k + 1] * b1[bCol + j + 1];
                        s11 += a1[aCol + k + 1] * b1[bCol + j + 1];
                    }

                    C[cRow + i][cCol + j] += s00;
                    C[cRow + i][cCol + j + 1] += s01;
                    C[cRow + i + 1][cCol + j] += s10;
                    C[cRow + i + 1][cCol + j + 1] += s11;
                }
            }           
        }
    }
}

2条回答
Luminary・发光体
2楼-- · 2019-01-25 20:03

After reading this this question I decided to adapt my program. My new program works very well without synchronization. Thanks for your thoughts Peter.

New code:

package algorithms;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;

public class Java6MatrixMultiply implements Algorithm {

    private static final int SIZE = 2048;
    private static final int THRESHOLD = 64;
    private static final int MAX_THREADS = Runtime.getRuntime().availableProcessors();

    private final ExecutorService executor = Executors.newFixedThreadPool(MAX_THREADS);

    private float[][] a = new float[SIZE][SIZE];
    private float[][] b = new float[SIZE][SIZE];
    private float[][] c = new float[SIZE][SIZE];

    @Override
    public void initialize() {
        init(a, b, SIZE);
    }

    @Override
    public void execute() {
        MatrixMultiplyTask mainTask =  new MatrixMultiplyTask(a, 0, 0, b, 0, 0, c, 0, 0, SIZE);
        Future future = executor.submit(mainTask);  

        try {
            future.get();
        } catch (Exception e) {
            System.out.println("Error: " + e.getMessage());
        }
    }

    @Override
    public void printResult() {
        check(c, SIZE);

        for (int i = 0; i < SIZE && i <= 10; i++) {
            for (int j = 0; j < SIZE && j <= 10; j++) {         
                if(j == 10) {
                    System.out.print("...");
                }
                else {
                    System.out.print(c[i][j] + " ");
                }
            }

            if(i == 10) {
                System.out.println();
                for(int k = 0; k < 10; k++) System.out.print(" ... ");
            }   

            System.out.println();
        }       

        System.out.println();
    }

    // To simplify checking, fill with all 1's. Answer should be all n's.
    static void init(float[][] a, float[][] b, int n) {
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                a[i][j] = 1.0F;
                b[i][j] = 1.0F;
            }
        }
    }

    static void check(float[][] c, int n) {
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if (c[i][j] != n) {
                    throw new Error("Check Failed at [" + i + "][" + j + "]: " + c[i][j]);
                    //System.out.println("Check Failed at [" + i + "][" + j + "]: " + c[i][j]); 
                }
            }
        }       
    }   

    public class Seq implements Runnable {

        private final MatrixMultiplyTask a;
        private final MatrixMultiplyTask b;

        public Seq(MatrixMultiplyTask a, MatrixMultiplyTask b) {
            this.a = a;
            this.b = b;     
        }

        public void run() {
            a.run();
            b.run();
        }   
    }

    private class MatrixMultiplyTask implements Runnable {
        private final float[][] A; // Matrix A
        private final int aRow; // first row of current quadrant of A
        private final int aCol; // first column of current quadrant of A

        private final float[][] B; // Similarly for B
        private final int bRow;
        private final int bCol;

        private final float[][] C; // Similarly for result matrix C
        private final int cRow;
        private final int cCol;

        private final int size;

        public MatrixMultiplyTask(float[][] A, int aRow, int aCol, float[][] B,
                int bRow, int bCol, float[][] C, int cRow, int cCol, int size) {

            this.A = A;
            this.aRow = aRow;
            this.aCol = aCol;
            this.B = B;
            this.bRow = bRow;
            this.bCol = bCol;
            this.C = C;
            this.cRow = cRow;
            this.cCol = cCol;
            this.size = size;
        }   

        public void run() {

            //System.out.println("Thread: " + Thread.currentThread().getName());

            if (size <= THRESHOLD) {
                multiplyStride2();
            } else {

                int h = size / 2;

                        Seq seq1 = new Seq(new MatrixMultiplyTask(A,
                                aRow, aCol, // A11
                                B, bRow, bCol, // B11
                                C, cRow, cCol, // C11
                                h),

                        new MatrixMultiplyTask(A, aRow, aCol + h, // A12
                                B, bRow + h, bCol, // B21
                                C, cRow, cCol, // C11
                                h));

                        Seq seq2 = new Seq(new MatrixMultiplyTask(A,
                                aRow, aCol, // A11
                                B, bRow, bCol + h, // B12
                                C, cRow, cCol + h, // C12
                                h),

                        new MatrixMultiplyTask(A, aRow, aCol + h, // A12
                                B, bRow + h, bCol + h, // B22
                                C, cRow, cCol + h, // C12
                                h));

                        Seq seq3 = new Seq(new MatrixMultiplyTask(A, aRow
                                + h, aCol, // A21
                                B, bRow, bCol, // B11
                                C, cRow + h, cCol, // C21
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22
                                B, bRow + h, bCol, // B21
                                C, cRow + h, cCol, // C21
                                h));

                        Seq seq4 = new Seq(new MatrixMultiplyTask(A, aRow
                                + h, aCol, // A21
                                B, bRow, bCol + h, // B12
                                C, cRow + h, cCol + h, // C22
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22
                                B, bRow + h, bCol + h, // B22
                                C, cRow + h, cCol + h, // C22
                                h));            



                final FutureTask s1Task = new FutureTask(seq2, null);
                final FutureTask s2Task = new FutureTask(seq3, null);
                final FutureTask s3Task = new FutureTask(seq4, null);

                executor.execute(s1Task);
                executor.execute(s2Task);
                executor.execute(s3Task);

                seq1.run();
                s1Task.run();
                s2Task.run();
                s3Task.run();

                try {
                    s1Task.get();
                    s2Task.get();
                    s3Task.get();
                } catch (Exception e) {
                    System.out.println("Error: " + e.getMessage());
                    executor.shutdownNow();
                }       
            }       
        }       

        public void multiplyStride2() {
            for (int j = 0; j < size; j += 2) {
                for (int i = 0; i < size; i += 2) {

                    float[] a0 = A[aRow + i];
                    float[] a1 = A[aRow + i + 1];

                    float s00 = 0.0F;
                    float s01 = 0.0F;
                    float s10 = 0.0F;
                    float s11 = 0.0F;

                    for (int k = 0; k < size; k += 2) {

                        float[] b0 = B[bRow + k];

                        s00 += a0[aCol + k] * b0[bCol + j];
                        s10 += a1[aCol + k] * b0[bCol + j];
                        s01 += a0[aCol + k] * b0[bCol + j + 1];
                        s11 += a1[aCol + k] * b0[bCol + j + 1];

                        float[] b1 = B[bRow + k + 1];

                        s00 += a0[aCol + k + 1] * b1[bCol + j];
                        s10 += a1[aCol + k + 1] * b1[bCol + j];
                        s01 += a0[aCol + k + 1] * b1[bCol + j + 1];
                        s11 += a1[aCol + k + 1] * b1[bCol + j + 1];
                    }

                    C[cRow + i][cCol + j] += s00;
                    C[cRow + i][cCol + j + 1] += s01;
                    C[cRow + i + 1][cCol + j] += s10;
                    C[cRow + i + 1][cCol + j + 1] += s11;
                }
            }           
        }
    }
}
查看更多
做个烂人
3楼-- · 2019-01-25 20:15

I tried adding synchronized as I suggested and this fixed the problem. ;)

I tried

  • synchronizing each row 299 ms.
  • swapping the loops in mutliplyStride so that it goes by column instead of by row. 253 ms
  • assumed one lock for each pair of rows (i.e. I locked one row for both updates. 216 ms
  • Disable biased locking -XX:-UseBiasedLocking 207 ms
  • use 2x the number of processors for threads. 199 ms.
  • same except using double instead of float 237 ms.
  • no synchronization at all. 174 ms.

As you can see the fifth option is less than 10% slower than no synchronization. If you want further gains I suggest you alter the way the data is accessed to make them more cache friendly.

In summary I suggest

private final ExecutorService executor = Executors.newFixedThreadPool(MAX_THREADS*2);

public void multiplyStride2() {
    for (int i = 0; i < size; i += 2) {
        for (int j = 0; j < size; j += 2) {

        // code as is......

            synchronized (C[cRow + i]) {
                C[cRow + i][cCol + j] += s00;
                C[cRow + i][cCol + j + 1] += s01;

                C[cRow + i + 1][cCol + j] += s10;
                C[cRow + i + 1][cCol + j + 1] += s11;
            }

Interestingly, if I calculate a block of 2x4 instaed of 2x2 the average times drops to 172 ms. (faster than the previous result with no synchronization) ;)

查看更多
登录 后发表回答