Java并发编程——CyclicBarrier原理

PunkLu 2020年01月16日 63次浏览
回环屏障CyclicBarrier原理探究

回环屏障CyclicBarrier

CountDownLatch在解决多个线程同步方面相对于调用线程的join方法已经有了不少优化,但是CountDownLatch是一次性的,也就是等到计数器值变为0后,再调用CountDownLatch的await方法和countdown方法都会立刻返回,这就起不到线程同步的效果了。所以为了满足计数器可以重置的需要,JDK开发组提供了CyclicBarrier类,它可以让一组线程全部达到一个状态后再全部同时执行。

案例

在下面的例子中,使用两个线程去执行一个被分解的任务A,当两个线程把自己的任务都执行完毕后再对他们的结果进行汇总处理。

public class CycleBarrierTest1 {

    // 创建一个CyclicBarrier实例,添加一个所有子线程全部到达屏障后执行的任务
    private static CyclicBarrier cyclicBarrier = new CyclicBarrier(2, new Runnable() {
        public void run() {
            System.out.println(Thread.currentThread() + " task1 merge result");
        }
    });

    public static void main(String[] args) throws InterruptedException{
        // 创建一个线程个数固定为2的线程池
        ExecutorService executorService = Executors.newFixedThreadPool(2);
        // 将线程A添加到线程池
        executorService.submit(new Runnable() {
            public void run() {
                try{
                    System.out.println(Thread.currentThread() + " task1-1");
                    System.out.println(Thread.currentThread() + " enter in barrier");
                    cyclicBarrier.await();
                    System.out.println(Thread.currentThread() + " enter out barrier");
                }catch (Exception e){
                    e.printStackTrace();
                }
            }
        });

        // 将线程B添加到线程池
        executorService.submit(new Runnable() {
            public void run() {
                try{
                    System.out.println(Thread.currentThread() + " task1-2");
                    System.out.println(Thread.currentThread() + "enter in barrier");
                    cyclicBarrier.await();
                    System.out.println(Thread.currentThread() + " enter out barrier");
                }catch (Exception e){
                    e.printStackTrace();
                }
            }
        });

        // 关闭线程池
        executorService.shutdown();
    }
}

如上代码创建了一个CyclicBarrier对象,其第一个参数为计数器初始值,第二个参数Runable是当计数器值为0时需要执行的任务。在main函数里面首先创建了一个大小为2的线程池,然后添加两个子任务到线程池,每个子任务在执行完自己的逻辑后会调用await()方法。一开始计数器值为2,当第一个线程调用await方法时,计数器值会递减为1。由于此时计数器值不为0,所以当前线程就到了屏障点而被阻塞。然后第二个线程调用await时,会进入屏障,计数器值也会递减,现在计数器值为0,这时就会去执行CyclicBarrier构造函数中的任务,执行完毕后退出屏障点,并且唤醒被阻塞的第二个线程,这时候第一个线程也会退出屏障点继续向下运行。

对于上面的例子,使用CountDownLatch也可以得到类似的输出结果。下面的例子可说明CyclicBarrier的可复用性:

public class CyclicBarrierTest2 {
    // 创建一个CyclicBarrier实例
    private static CyclicBarrier cyclicBarrier = new CyclicBarrier(2);

    public static void main(String[] args) throws InterruptedException{
        ExecutorService executorService = Executors.newFixedThreadPool(2);

        // 将线程A添加到线程池
        executorService.submit(new Runnable() {
            public void run() {
                try{
                    System.out.println(Thread.currentThread() + " step1");
                    cyclicBarrier.await();

                    System.out.println(Thread.currentThread() + " step2");
                    cyclicBarrier.await();

                    System.out.println(Thread.currentThread() + " step3" +" "+ cyclicBarrier.getParties());
                }catch (Exception e){
                    e.printStackTrace();
                }
            }
        });

        // 将线程B添加到线程池
        executorService.submit(new Runnable() {
            public void run() {
                try{
                    System.out.println(Thread.currentThread() + " step1");
                    cyclicBarrier.await();

                    System.out.println(Thread.currentThread() + " step2");
                    cyclicBarrier.await();

                    System.out.println(Thread.currentThread() + " step3" +" "+ cyclicBarrier.getParties());
                }catch (Exception e){
                    e.printStackTrace();
                }
            }
        });

        // 关闭线程池
        executorService.shutdown();
    }
}

运行结果:

Thread[pool-1-thread-1,5,main] step1
Thread[pool-1-thread-2,5,main] step1
Thread[pool-1-thread-2,5,main] step2
Thread[pool-1-thread-1,5,main] step2
Thread[pool-1-thread-1,5,main] step3 2
Thread[pool-1-thread-2,5,main] step3 2

在如上代码中,每个子线程在执行完阶段1后都调用了await方法,等到所有线程都到达屏障点后才会一起往下执行,这就保证了所有线程都完成了阶段1后才会开始执行阶段2。然后在阶段2后面调用了await方法,这保证了所有线程都完成了阶段2后,才能开始阶段3的执行,这个功能使用单个CountDownLatch是无法完成的。

实现原理探究

CyclicBarrier基于独占锁实现,本质底层还是基于AQS的。parties用来记录线程个数,这里表示多少线程调用await后,所有线程才会冲破屏障继续往下运行。而count一开始等于parties,每当有线程调用await方法就递减1,当count为0时就表示所有线程都到达了屏障点。使用两个变量的原因是为了确保复用。这两个变量是在构造CyclicBarrier对象时传递的,如下:

public CyclicBarrier(int parties, Runnable barrierAction) {
    if (parties <= 0) throw new IllegalArgumentException();
    this.parties = parties;
    this.count = parties;
    this.barrierCommand = barrierAction;
}

还有一个变量barrierCommand也通过构造函数传递,这是一个任务,这个任务的执行时机是当所有线程都到达屏障点后。使用lock首先保证了更新计数器count的原子性。另外使用lock的条件变量trip支持线程间使用await和signal操作进行同步。

在变量generation内部有一个变量broken,其用来记录当前屏障是否被打破。

private static class Generation{
	boolean broken = false;
}

1、int await()方法

当前线程调用CyclicBarrier的该方法时会被阻塞,直到满足下面条件之一才会返回:

  1. parties个线程都调用了await()方法,也就是线程都到达了屏障点;
  2. 其他线程调用了当前线程的interrupt()方法中断了当前线程,则当前线程会抛出InterruptedException异常而返回;
  3. 与当前屏障点关联的Generation对象的broken标志被设置为true时,会抛出BrokenBarrierException异常,然后返回。

由如下代码可知,在内部调用了dowait方法。第一个参数为false则说明不设置超时时间,这时候第二个参数没有意义。

public int await() throws InterruptedException,BrokenBarrierException{
    try{
        return dowait(false,0L);
    }catch(TimeoutException toe){
        throw new Error(toe);
    }
}

2、boolean await(long timeout,TimeUnit unit)方法

当前线程调用CyclicBarrier的该方法时会被阻塞,直到满足下面条件之一才会返回:

  1. parties个线程都调用了await()方法,也就是线程都到了屏障点,这时候返回true;
  2. 设置的超时时间到了后返回false;
  3. 其他线程调用当前线程的interrupt()方法中断了当前线程,则当前线程会抛出InterruptException异常然后返回;
  4. 与当前屏障点关联的Generation对象的broken标志被设置为true时,会抛出BrokenBarrierException异常,然后返回。

由以下代码可知,在内部调用了dowait方法。第一个参数为true则说明设置了超时时间,这时候第二个参数是超时时间。

public int await(long timeout,TimeUnit unit) throws InterruptedException,BrokenBarrierException,TimeoutException{
    return dowait(true,unit.toNanos(timeout));
}

3、int dowait(boolean timed,long nanos)方法

该方法实现了CyclicBarrier的核心功能,其代码如下:

private int dowait(boolean timed,long nanos) throws InterruptedException,BrokenBarrierException,TimeoutException{
    final ReentrantLock lock = this.lock;
    lock.lock();
    try{
        ...
        // 1、如果index==0则说明所有线程都到了屏障点,此时执行初始化时传递的任务
        int index = --count;
        if(index == 0){
            boolean ranAction = false;
            try{
                final Runnable command = barrierCommand;
                // 2、执行任务
                if(command != null)
                    command.run();
                ranActive = true;
                // 3、激活其他因调用await方法而被阻塞的线程,并重置CyclicBararrier
                nextGeneration();
                //  返回
                return 0;
            }finally{
                if(!ranAction)
                    breakBarrier();
            }
        }
        
        // 4、如果index!=0
        for(;;){
            try{
                // 5、没有设置超时时间
                if(!timed)
                    trip.await();
                // 6、设置了超时时间
                else if(nanos > 0L)
                    nanos = trip.awaitNanos(nanos);
            }catch(InterruptedException ie){
                ...
            }
            ...
        }
    }finally{
        lock.unlock();
    }
}

private void nextGeneration(){
    // 7、唤醒条件队列里面的阻塞线程
    trip.signalAll();
    count = parties;
    generation = new Generation();
}

以上是dowait方法的主干代码。当一个线程调用了dowait方法后,首先会获取独占锁lock,如果创建CyclicBarrier时传递的参数为10,那么后面9个调用线程会被阻塞。然后当前获取到锁的线程会对计数器count进行递减操作,递减后count = index=9,因为index!=0所以当前代码会执行代码4。如果当前线程调用的是无参数的await()方法,则这里timed=false,所以当前线程会被放入条件变量trip的条件阻塞队列,当前线程会被挂起并释放获取的lock锁。如果调用的是有参数的await方法则timed=true,然后当前线程也会被放入条件变量的条件队列并释放锁资源,不同的是当前线程会在指定时间超时后自动被激活。

当第一个获取锁的线程由于被阻塞释放锁后,被阻塞的9个线程中有一个会竞争到lock锁,然后执行与第一个线程一样的操作,直到最后一个线程获取到lock锁,此时已经有9个线程被放入了条件变量trip的条件队列里边。最后count=index等于0,所以执行代码2,如果创建CyclicBarrier时传递了任务,则在其他线程被唤醒前先执行这个任务,任务执行完毕后再执行代码3,唤醒其他9个线程,并重置CyclicBarrier,然后这10个线程就可以继续向下执行了。