多线程-工具类-CountDownLatch
文章目录
多线程-工具类-CountDownLatch简介使用示例多线程同步功能
原理详解`new`初始化`countDown()`方法`countDwon()`流程图`countDwon()`源码解析
`await()`方法`await()`流程图`await()`源码解析
简介
CountDownLatch是jdk自带并发工具类,实现了类似倒计数器的功能。通过countDown()方法和await()方法实现多线程任务同步。 使用await()方法阻塞的线程,需要等待其他线程调用足够次数的countDown()方法,才能解除阻塞。
使用示例
多线程同步功能
import java
.util
.Random
;
import java
.util
.concurrent
.CountDownLatch
;
import java
.util
.concurrent
.TimeUnit
;
public class CountDownLatchDemo {
public static void main(String
[] args
) {
final CountDownLatch latch
= new CountDownLatch(5);
final Random random
= new Random();
new Thread(new CountDownLatchRunnable(latch
, random
)).start();
new Thread(new CountDownLatchRunnable(latch
, random
)).start();
new Thread(new CountDownLatchRunnable(latch
, random
)).start();
new Thread(new CountDownLatchRunnable(latch
, random
)).start();
new Thread(new CountDownLatchRunnable(latch
, random
)).start();
new Thread(new CountDownLatchRunnable(latch
, random
)).start();
System
.out
.println("main thread waiting");
try {
latch
.await();
} catch (InterruptedException e
) {
e
.printStackTrace();
}
System
.out
.println("end");
}
static class CountDownLatchRunnable implements Runnable {
private CountDownLatch latch
;
private Random random
;
CountDownLatchRunnable(CountDownLatch latch
, Random random
) {
this.latch
= latch
;
this.random
= random
;
}
@Override
public void run() {
int time
= random
.nextInt(10);
try {
TimeUnit
.SECONDS
.sleep(time
);
} catch (InterruptedException e
) {
e
.printStackTrace();
}
latch
.countDown();
System
.out
.println(Thread
.currentThread().getName() + " finished, cost time: " + time
+ ", current count:" + latch
.getCount());
}
}
}
原理详解
new初始化
public CountDownLatch(int count
) {
if (count
< 0) throw new IllegalArgumentException("count < 0");
this.sync
= new Sync(count
);
}
Sync(int count
) {
setState(count
);
}
countDown()方法
该方法是线程安全的方法。通过私有的静态内部类Sync对象倒计数器减一功能。
countDwon()流程图
countDwon()源码解析
public void countDown() {
sync
.releaseShared(1);
}
public final boolean releaseShared(int arg
) {
if (tryReleaseShared(arg
)) {
doReleaseShared();
return true;
}
return false;
}
protected boolean tryReleaseShared(int releases
) {
for (;;) {
int c
= getState();
if (c
== 0)
return false;
int nextc
= c
-1;
if (compareAndSetState(c
, nextc
))
return nextc
== 0;
}
}
private void doReleaseShared() {
for (;;) {
Node h
= head
;
if (h
!= null
&& h
!= tail
) {
int ws
= h
.waitStatus
;
if (ws
== Node
.SIGNAL
) {
if (!compareAndSetWaitStatus(h
, Node
.SIGNAL
, 0))
continue;
unparkSuccessor(h
);
}
else if (ws
== 0 &&
!compareAndSetWaitStatus(h
, 0, Node
.PROPAGATE
))
continue;
}
if (h
== head
)
break;
}
}
private void unparkSuccessor(Node node
) {
int ws
= node
.waitStatus
;
if (ws
< 0)
compareAndSetWaitStatus(node
, ws
, 0);
Node s
= node
.next
;
if (s
== null
|| s
.waitStatus
> 0) {
s
= null
;
for (Node t
= tail
; t
!= null
&& t
!= node
; t
= t
.prev
)
if (t
.waitStatus
<= 0)
s
= t
;
}
if (s
!= null
)
LockSupport
.unpark(s
.thread
);
}
await()方法
该方法也是通过sync对象的同步队列和LockSupport的park()、unpark()实现阻塞。
await()流程图
await()源码解析
public void await() throws InterruptedException
{
sync
.acquireSharedInterruptibly(1);
}
public final void acquireSharedInterruptibly(int arg
)
throws InterruptedException
{
if (Thread
.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg
) < 0)
doAcquireSharedInterruptibly(arg
);
}
protected int tryAcquireShared(int acquires
) {
return (getState() == 0) ? 1 : -1;
}
private void doAcquireSharedInterruptibly(int arg
)
throws InterruptedException
{
final Node node
= addWaiter(Node
.SHARED
);
boolean failed
= true;
try {
for (;;) {
final Node p
= node
.predecessor();
if (p
== head
) {
int r
= tryAcquireShared(arg
);
if (r
>= 0) {
setHeadAndPropagate(node
, r
);
p
.next
= null
;
failed
= false;
return;
}
}
if (shouldParkAfterFailedAcquire(p
, node
) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed
)
cancelAcquire(node
);
}
}
private Node
addWaiter(Node mode
) {
Node node
= new Node(Thread
.currentThread(), mode
);
Node pred
= tail
;
if (pred
!= null
) {
node
.prev
= pred
;
if (compareAndSetTail(pred
, node
)) {
pred
.next
= node
;
return node
;
}
}
enq(node
);
return node
;
}
private Node
enq(final Node node
) {
for (;;) {
Node t
= tail
;
if (t
== null
) {
if (compareAndSetHead(new Node()))
tail
= head
;
} else {
node
.prev
= t
;
if (compareAndSetTail(t
, node
)) {
t
.next
= node
;
return t
;
}
}
}
}
private static boolean shouldParkAfterFailedAcquire(Node pred
, Node node
) {
int ws
= pred
.waitStatus
;
if (ws
== Node
.SIGNAL
)
return true;
if (ws
> 0) {
do {
node
.prev
= pred
= pred
.prev
;
} while (pred
.waitStatus
> 0);
pred
.next
= node
;
} else {
compareAndSetWaitStatus(pred
, ws
, Node
.SIGNAL
);
}
return false;
}
private final boolean parkAndCheckInterrupt() {
LockSupport
.park(this);
return Thread
.interrupted();
}
private void setHeadAndPropagate(Node node
, int propagate
) {
Node h
= head
;
setHead(node
);
if (propagate
> 0 || h
== null
|| h
.waitStatus
< 0 ||
(h
= head
) == null
|| h
.waitStatus
< 0) {
Node s
= node
.next
;
if (s
== null
|| s
.isShared())
doReleaseShared();
}
}
private void setHead(Node node
) {
head
= node
;
node
.thread
= null
;
node
.prev
= null
;
}
private void cancelAcquire(Node node
) {
if (node
== null
)
return;
node
.thread
= null
;
Node pred
= node
.prev
;
while (pred
.waitStatus
> 0)
node
.prev
= pred
= pred
.prev
;
Node predNext
= pred
.next
;
node
.waitStatus
= Node
.CANCELLED
;
if (node
== tail
&& compareAndSetTail(node
, pred
)) {
compareAndSetNext(pred
, predNext
, null
);
} else {
int ws
;
if (pred
!= head
&&
((ws
= pred
.waitStatus
) == Node
.SIGNAL
||
(ws
<= 0 && compareAndSetWaitStatus(pred
, ws
, Node
.SIGNAL
))) &&
pred
.thread
!= null
) {
Node next
= node
.next
;
if (next
!= null
&& next
.waitStatus
<= 0)
compareAndSetNext(pred
, predNext
, next
);
} else {
unparkSuccessor(node
);
}
node
.next
= node
;
}
}