21 - CountDownLatch 让线程等待其他线程完成

    技术2025-05-06  14

    CountDownLatch 让线程等待其他线程完成

    1. CountDownLatch 使用1.1 使用背景1.2 利用并行优化1.3 用 CountDownLatch 实现线程等待 2. 源码分析2.1 类结构2.2 await2.3 countDown 3. 总结

      并发编程中常遇到这种情况,一个线程需要等待另外多个线程执行后再执行。遇到这种情况你一般怎么做呢?今天就介绍一种 JDk 提供的解决方案来优雅的解决这一问题,那就是倒计时器 CountDownLatch。

      

    1. CountDownLatch 使用

    1.1 使用背景

      公司里对账系统最近越来越慢了,老板要求能不能快速优化一下。我了解了对账系统的业务后,发现还是挺简单的,用户通过在线商城下单,会生成电子订单,保存在订单库;之后物流会生成派送单给用户发货,派送单保存在派送单库。为了防止漏派送或者重复派送,对账系统每天还会校验是否存在异常订单。

      对账系统的处理逻辑很简单,你可以参考下面的对账系统流程图。目前对账系统的处理逻辑是首先查询订单,然后查询派送单,之后对比订单和派送单,将差异写入差异库。

      对账系统的代码抽象之后,也很简单,核心代码如下,就是在一个单线程里面循环查询订单、派送单,然后执行对账,最后将写入差异库。

    while(存在未对账订单){ // 查询未对账订单 pos = getPOrders(); // 查询派送单 dos = getDOrders(); // 执行对账操作 diff = check(pos, dos); // 差异写入差异库 save(diff); }

      

    1.2 利用并行优化

      老板要我优化性能,那我就首先要找到这个对账系统的瓶颈所在。

      目前的对账系统,由于订单量和派送单量巨大,所以查询未对账订单 getPOrders() 和查询派送单 getDOrders() 相对较慢,那有没有办法快速优化一下呢?目前对账系统是单线程执行的。对于串行化的系统,优化性能首先想到的是能否利用多线程并行处理。

      所以,这里你应该能够看出来这个对账系统里的瓶颈:查询未对账订单 getPOrders() 和查询派送单 getDOrders() 是否可以并行处理呢?显然是可以的,因为这两个操作并没有先后顺序的依赖。这两个最耗时的操作并行之后,执行过程如下图所示。对比一下单线程的执行示意图,你会发现同等时间里,并行执行的吞吐量近乎单线程的 2 倍,优化效果还是相对明显的。

         思路有了,下面我们再来看看如何用代码实现。在下面的代码中,我们创建了两个线程 T1 和 T2,并行执行查询未对账订单 getPOrders() 和查询派送单 getDOrders() 这两个操作。在主线程中执行对账操作 check() 和差异写入 save() 两个操作。不过需要注意的是:主线程需要等待线程 T1 和 T2 执行完才能执行 check() 和 save() 这两个操作,为此我们通过调用 T1.join() 和 T2.join() 来实现等待,当 T1 和 T2 线程退出时,调用 T1.join() 和 T2.join() 的主线程就会从阻塞态被唤醒,从而执行之后的 check() 和 save()。   

    while (存在未对账订单) { // 查询未对账订单 Thread T1 = new Thread(() -> { pos = getPOrders(); }); T1.start(); // 查询派送单 Thread T2 = new Thread(() -> { dos = getDOrders(); }); T2.start(); // 等待T1、T2结束 T1.join(); T2.join(); // 执行对账操作 diff = check(pos, dos); // 差异写入差异库 save(diff); }

      

    1.3 用 CountDownLatch 实现线程等待

      经过上面的优化之后,基本上可以跟老板汇报收工了,但还是有点美中不足,相信你也发现了,while 循环里面每次都会创建新的线程,而创建线程可是个耗时的操作。所以最好是创建出来的线程能够循环利用,估计这时你已经想到线程池了,是的,线程池就能解决这个问题。   

      而下面的代码就是用线程池优化后的:我们首先创建了一个固定大小为 2 的线程池,之后在 while 循环里重复利用。一切看上去都很顺利,但是有个问题好像无解了,那就是主线程如何知道 getPOrders() 和 getDOrders() 这两个操作什么时候执行完。前面主线程通过调用线程 T1 和 T2 的 join() 方法来等待线程 T1 和 T2 退出,但是在线程池的方案里,线程根本就不会退出,所以 join() 方法已经失效了。   

      那如何解决这个问题呢?你可以开动脑筋想出很多办法,最直接的办法是弄一个计数器,初始值设置成 2,当执行完pos = getPOrders();这个操作之后将计数器减 1,执行完dos = getDOrders();之后也将计数器减 1,在主线程里,等待计数器等于 0;当计数器等于 0 时,说明这两个查询操作执行完了。等待计数器等于 0 其实就是一个条件变量,用管程实现起来也很简单。   

      不过我并不建议你在实际项目中去实现上面的方案,因为 Java 并发包里已经提供了实现类似功能的工具类:CountDownLatch,我们直接使用就可以了。下面的代码示例中,在 while 循环里面,我们首先创建了一个 CountDownLatch,计数器的初始值等于 2,之后在pos = getPOrders();和dos = getDOrders();两条语句的后面对计数器执行减 1 操作,这个对计数器减 1 的操作是通过调用 latch.countDown(); 来实现的。在主线程中,我们通过调用 latch.await() 来实现对计数器等于 0 的等待。

    // 创建2个线程的线程池 Executor executor = Executors.newFixedThreadPool(2); while (存在未对账订单) { // 计数器初始化为2 CountDownLatch latch = new CountDownLatch(2); // 查询未对账订单 executor.execute(() -> { pos = getPOrders(); latch.countDown(); }); // 查询派送单 executor.execute(() -> { dos = getDOrders(); latch.countDown(); }); // 等待两个查询操作结束 latch.await(); // 执行对账操作 diff = check(pos, dos); // 差异写入差异库 save(diff); }

      经过上面的重重优化之后,长出一口气,终于可以交付了。

      CountDownLatch 的作用是让线程等待其它线程完成一组操作后才能执行,否则就一直等待。总结 CountDownLatch 的使用步骤(比如线程A需要等待线程B和线程C执行后再执行):

    创建CountDownLatch对象,设置要等待的线程数N(这里是2);等待线程A 调用 await() 挂起;线程B执行后调用 countDown(),使 N-1;线程C 执行后调用 countDown(),使N-1;调用 countDown() 后检查 N=0 了,唤醒线程A,在 await() 挂起的位置继续执行。

      

    2. 源码分析

      CountDownLatch 是通过一个计数器来实现的,当我们在 new 一个CountDownLatch 对象的时候需要带入该计数器值,该值就表示了线程的数量。每当一个线程完成自己的任务后,计数器的值就会减1。当计数器的值变为0 时,就表示所有的线程均已经完成了任务,然后就可以恢复等待的线程继续执行了。   

    2.1 类结构

      CountDownLatch 只有一个属性 Sync,Sync 是继承了 AQS 的内部类。创建 CountDownLatch 时传入一个 count 值,count 值被赋值给AQS.state。

      CountDownLatch 是通过 AQS 共享锁实现的,AQS 这篇文章中详细讲解了 AQS 独占锁的原理,AQS 共享锁和独占锁原理只有很细微的区别,这里大致介绍下:

    线程调用 acquireSharedInterruptibly() 方法获取不到锁时,线程被构造成结点进入AQS阻塞队列;当有线程调用 releaseShared() 方法将当前线程持有的锁彻底释放后,会唤醒AQS 阻塞队列中等锁的线程,如果 AQS 阻塞队列中有连续N 个等待共享锁的线程,就将这N 个线程依次唤醒。 public class CountDownLatch { private static final class Sync extends AbstractQueuedSynchronizer { Sync(int count) { setState(count); } } private final Sync sync; public CountDownLatch(int count) { if (count < 0) throw new IllegalArgumentException("count < 0"); this.sync = new Sync(count); } }

      

    2.2 await

      await() 是将当前线程阻塞,理解 await() 的原理就是要弄清楚 await()是如何将线程阻塞的。

      await() 调用的就是 AQS 获取共享锁的方法。当 AQS.state=0 时才能获取到锁,由于创建 CountDownLatch 时设置了state=count,此时是获取不到锁的,所以调用 await() 的线程挂起并构造成结点进入 AQS 阻塞队列。

    创建 CountDownLatch 时设置 AQS.state=count,可以理解成锁被重入了count 次。await() 方法获取锁时锁被占用了,只能阻塞。

    /** * CountDownLatch.await()调用的就是AQS获取共享锁的方法acquireSharedInterruptibly() */ public void await() throws InterruptedException { sync.acquireSharedInterruptibly(1); } /** * 获取共享锁 * 如果获取锁失败,就将当前线程挂起,并将当前线程构造成结点加入阻塞队列 * 判断是否获取锁成功的方法由CountDownLatch的内部类Sync实现 */ public final void acquireSharedInterruptibly(int arg) throws InterruptedException { if (Thread.interrupted()) throw new InterruptedException(); // 尝试获取锁的方法由CountDownLatch的内部类Sync实现 if (tryAcquireShared(arg) < 0) // 获取锁失败,就将当前线程挂起,并将当前线程构造成结点加入阻塞队列 doAcquireSharedInterruptibly(arg); } /** * CountDownLatch.Sync实现AQS获取锁的方法 * 只有AQS.state=0时获取锁成功。 * 创建CountDownLatch时设置了state=count,调用await()时state不为0, * 返回-1,表示获取锁失败。 */ protected int tryAcquireShared(int acquires) { return (getState() == 0) ? 1 : -1; }

      

    2.3 countDown

      countDown() 方法是将 count-1,如果发现 count=0 了,就唤醒阻塞的线程。countDown() 调用 AQS 释放锁的方法,每次将 state 减1。当 state 减到0时是无锁状态了,就依次唤醒 AQS 队列中阻塞的线程来获取锁,继续执行逻辑代码。

    /** * CountDownLatch.await()调用的就是AQS释放共享锁的方法releaseShared() */ public void countDown() { sync.releaseShared(1); } /** * 释放锁 * 如果锁被全部释放了,依次唤醒AQS队列中等待共享锁的线程 * 锁全部释放指的是同一个线程重入了N次需要N次解锁,最终将state变回0 * 具体释放锁的方法由CountDownLatch的内部类Sync实现 */ public final boolean releaseShared(int arg) { if (tryReleaseShared(arg)) { // 释放锁,由CountDownLatch的内部类Sync实现 doReleaseShared(); // 锁全部释放之后,依次唤醒等待共享锁的线程 return true; } return false; } /** * CountDownLatch.Sync实现AQS释放锁的方法 * 释放一次,将state减1 * 如果释放之后state=0,表示当前是无锁状态了,返回true */ protected boolean tryReleaseShared(int releases) { for (;;) { int c = getState(); if (c == 0) return false; // state每次减1 int nextc = c-1; if (compareAndSetState(c, nextc)) return nextc == 0;// state=0时,无锁状态,返回true } }

      

    3. 总结

      CountDownLatch 用于一个线程A 需要等待另外多个线程(B、C)执行后再执行的情况。

      创建 CountDownLatch 时设置一个计数器 count,表示要等待的线程数量。线程A 调用 await() 方法后将被阻塞,线程B 和线程C 调用 countDown() 之后计数器 count 减1。当计数器的值变为0 时,就表示所有的线程均已经完成了任务,然后就可以恢复等待的线程A 继续执行了。

      CountDownLatch 是由 AQS 实现的,创建 CountDownLatch 时设置计数器 count 其实就是设置 AQS.state=count,也就是重入次数。await() 方法调用获取锁的方法,由于 AQS.state=count 表示锁被占用且重入次数为 count,所以获取不到锁线程被阻塞并进入 AQS 队列。countDown() 方法调用释放锁的方法,每释放一次 AQS.state 减1,当 AQS.state 变为0 时表示处于无锁状态了,就依次唤醒AQS 队列中阻塞的线程来获取锁,继续执行逻辑代码。

    Processed: 0.014, SQL: 9