Java并发编程——CountDownLatch原理剖析

PunkLu 2020年01月15日 116次浏览
CountDownLatch原理剖析

CountDownLatch原理剖析

案例

在开发中经常会遇到需要在主线程中开启多个线程去并行执行任务,并且主线程需要等待所有子线程执行完毕后再进行汇总的场景。在CountDownLatch出现之前一般都使用线程的join()方法来实现,但是join方法不够灵活,不能满足不同场景的需要。

public class CountDownLatchDemo {

    // 创建一个CountDownLatch实例
    private static volatile CountDownLatch countDownLatch = new CountDownLatch(2);

    private static ReentrantLock reentrantLock = new ReentrantLock();

    public static void main(String[] args) throws InterruptedException{

        Thread threadOne = new Thread(new Runnable() {
            public void run() {
                try{
                    // reentrantLock.lock();
                    Thread.sleep(1000);
                }catch (InterruptedException e){
                    e.printStackTrace();
                }finally {
                    //System.out.println("One1:" + countDownLatch.getCount());
                    countDownLatch.countDown();
                    //System.out.println("One2:" + countDownLatch.getCount());
                    //reentrantLock.unlock();
                }

                System.out.println("child threadOne over!");
            }
        });

        Thread threadTwo = new Thread(new Runnable() {
            public void run() {
                try {
                    // reentrantLock.lock();
                    Thread.sleep(1000);
                }catch (InterruptedException e){
                    e.printStackTrace();
                }finally {
                    //System.out.println("Two1:" + countDownLatch.getCount());
                    countDownLatch.countDown();
                    //System.out.println("Two2:" + countDownLatch.getCount());
                    // reentrantLock.unlock();
                }
                System.out.println("child threadTwo over!");
            }
        });

        // 启动子线程
        threadOne.start();
        threadTwo.start();

        System.out.println("wait all child thread over!");

        // 等待子线程执行完毕,返回
        countDownLatch.await();

        System.out.println("all child thread over!");
    }
}

运行结果:

wait all child thread over!
child threadOne over!
all child thread over!
child threadTwo over!

在如上代码中,创建了一个CountDownLatch实例,因为有两个子线程所以构造函数的传参为2。主线程调用countDownLatch.await()方法后会被阻塞。子线程执行完毕后调用countDownLatch.countDown()方法让countDownLatch内部的计数器减1,所有线程执行完毕并调用countDown()方法后计数器会变成0,这时候主线程的await()方法才会返回。

上面的代码还不够优雅,在项目实践中一般都避免直接操作线程,而是使用ExecutorService线程池来管理。使用ExecutorService时传递的参数是Runable或者Callable对象,这时候没法直接调用这些线程的join()方法,这就需要使用CountDownLatch了,将上面的代码修改如下:

public class CountDownLatchDemo2 {

    // 创建一个CountDownLatch实例
    private static CountDownLatch countDownLatch = new CountDownLatch(2);

    public static void main(String[] args) throws InterruptedException{
        ExecutorService executorService = Executors.newFixedThreadPool(2);
        // 将线程A添加到线程池
        executorService.submit(new Runnable() {
            public void run() {
                try {
                    Thread.sleep(1000);
                }catch (InterruptedException e){
                    e.printStackTrace();
                }finally {
                    countDownLatch.countDown();
                }
                System.out.println("child threadOne over!");
            }
        });

        // 将线程B添加到线程池
        executorService.submit(new Runnable() {
            public void run() {
                try {
                    Thread.sleep(1000);
                }catch (InterruptedException e){
                    e.printStackTrace();
                }finally {
                    countDownLatch.countDown();
                }
                System.out.println("child threadTwo over!");
            }
        });

        System.out.println("wait all child thread over!");
        // 等待子线程执行完毕,返回
        countDownLatch.await();
        System.out.println("all child thread over!");
        executorService.shutdown();
    }
}

实现原理

CountDownLatch是使用AQS实现的。通过下面的构造函数把计数器的值赋给了AQS的状态变量state,也就是这里使用AQS的状态值来表示计数器值。

public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

Sync(int count) {
    setState(count);
}

1、void await()方法

当线程调用CountDownLatch对象的await方法后,当前线程会被阻塞,直到下面的情况之一发生才会返回:

  1. 当所有线程都调用了CountDownLatch对象的countDown方法后,也就是计数器的值为0时
  2. 其他线程调用了当前线程的interrupt()方法中断了当前线程,当前线程抛出异常,然后返回。
public void await() throws InterruptedException {
     sync.acquireSharedInterruptibly(1);
}

从以上代码可以看到,await()方法委托sync调用了AQS的acquireSharedInterruptibly方法,后者的代码如下:

// AQS获取共享资源时可被中断的方法
public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    // 如果线程被中断则抛出异常
    if (Thread.interrupted())
        throw new InterruptedException();
    // 查看当前计数器值是否为0,为0则直接返回,否则进入AQS的队列等待
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}

// sync类实现的AQS的接口
protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

由如上代码可知,该方法的特点是线程获取资源时可以被中断,并且获取的资源是共享资源。acquireSharedInterruptibly首先判断当前线程是否已被中断,若是则抛出异常,否则调用sync实现的tryAcquireShared方法查看当前状态值(计数器值)是否为0,是则当前线程的await()方法直接返回,否则调用AQS的doAcquireSharedInterruptibly方法让当前线程阻塞。

2、boolean await(long timeout,TimeUnit unit)方法

当线程调用了CountDownLatch对象的该方法后,当前线程会被阻塞,直到下面的情况之一发生才会返回:当所有线程都调用了CountDownLatch对象的countDown方法后,也就是计数器值为0时,这时候会返回true;设置的timeout时间到了,因为超时而返回false;其他线程调用了当前线程的interrupt()方法中断了当前线程,当前线程会抛出InterruptedException异常,然后返回。

public boolean await(long timeout,TimeUnit unit) throws InterruptedException{
    return sync.tryAcquireSharedNanos(1,unit.toNanos(timeout));
}

3、void countDown() 方法

线程调用该方法后,计数器的值递减,递减后如果计数器值为0则唤醒所有因调用await方法而被阻塞的线程,否则什么都不做。下面看下countDown方法是如何让调用AQS的方法的。

// countDownLatch的countDown()方法
public void countDown(){
    // 委托sync调用AQS的方法
    sync.releaseShared(1);
}

由以上代码可知,CountDownLatch的countDown方法委托了sync调用了AQS的releaseShared方法,后者的代码如下:

// AQS的代码
public final boolean releaseShared(int arg){
    // 调用sync实现的tryReleaseShared
    if(tryReleaseShared(arg)){
        // AQS的释放资源方法
        doReleaseShared();
        return true;
    }
    return false;
}

在如上代码中,releaseShared首先调用了sync实现的AQS的tryReleaseShared方法,其代码如下:

// sync的方法
protected boolean tryReleaseShared(int releases){
    // 循环进行CAS,直到当前线程成功完成CAS使计数器值(状态值state)减1并更新到state
    for(;;){
        int c = getState();
        // 1、如果当前状态值为0则直接返回
        if(c == 0)
            return false;
        // 2、使用CAS让计数器值减1
        int nextc = c-1;
        if(compareAndSetState(c,nextc))
            return nextc == 0;
    }
}

如上代码首先获取当前状态值(计数器值)。代码1判断如果当前状态值为0则直接返回false,从而countDown()方法直接返回;否则执行代码2使用CAS将计数器值减1,CAS失败则循环重试,否则如果当前计数器值为0则返回true,返回true说明是最后一个线程调用的countDown方法,那么该线程除了让计数器值减1外,还需要唤醒因调用CountDownLatch的await方法而被阻塞的线程。这里代码1貌似是多余的,其实不然,之所以添加代码1是为了防止当计数器值为0后,其他线程又调用了countDown方法,如果没有代码1,状态值就可能会变为负数。

4、long getCount()方法

获取当前计数器的值,也就是AQS的state的值,一般在测试时使用该方法。下面看代码:

public long getCount(){
    return sync.getCount();
}

int getCount(){
    return getState();
}

由如上代码可知,在其内部还是调用了AQS的getState方法来获取state的值(计数器当前值)。