深入理解 CAS

1 Java层面探究

Java中的原子类Atomic底层的实现原理是CAS,本文就让我们一起来深入探究CAS。
下面是AtomicInteger的测试代码 ,执行完毕后会发现原子类实例ai最终是精确的10000,而普通变量bi的值是一个小于10000的不固定的值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package com.hw.review2022.concurrent;

import org.junit.Test;

import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

public class TestCAS {
/**
* 被多个线程访问的变量
*/
static AtomicInteger ai = new AtomicInteger(0);
static int bi = 0;

/**
* 对ai和bi执行1000次自增
*/
class Task implements Runnable {
//任务id
private int id;

public Task(int i) {
this.id = i;
}

@Override
public void run() {
for(int i = 0; i < 1000; i++) {
try{
if(i % 100 == 0) {
System.out.println("Thread name : " + Thread.currentThread().getName() + ", Task id : " + id + ", loop : " + i);
}
Thread.sleep(2);
ai.addAndGet(1);
bi++;
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}

@Test
public void test() throws Exception {
/**
* 创建含有3个线程的线程池
*/
ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(3);
/**
* 在线程池中执行10个任务
*/
for(int i = 0; i < 10; i++) {
executor.execute(new Task(i));
}

/**
* 等待线程池中的任务执行完毕后才去打印ai和bi的值
*/
executor.shutdown();
executor.awaitTermination(60, TimeUnit.SECONDS);
System.out.println("ai is right : " + ai);
System.out.println("bi is wrong : " + bi);
}
}

上述代码的核心是 AtomicInteger.addAndGet方法

1
2
3
4
5
6
7
8
9
10
11
12
private volatile int value;
private static final long valueOffset;

/**
* Atomically adds the given value to the current value.
*
* @param delta the value to add
* @return the updated value
*/
public final int addAndGet(int delta) {
return unsafe.getAndAddInt(this, valueOffset, delta) + delta;
}

首先我们看一下AtomicInteger类实例对象的内存布局

this指向了对象的起始地址,通过this + valueOffset(=12)我们就可以获得value字段的内存地址(即C/C++中的指向value的指针),进而读写该value值。

然后我们接着看Unsafe.getAndAddInt方法

1
2
3
4
5
6
7
8
9
//                             原子实例ai  valueOffset  delta(增量)
public final int getAndAddInt(Object var1, long var2, int var4) {
int var5;
do {
var5 = this.getIntVolatile(var1, var2); // var5: oldValue
} while(!this.compareAndSwapInt(var1, var2, var5, var5 + var4)); // var5 + var4: updateValue

return var5;
}

var5这个变量是通过ai的起始地址 + valueOffset偏移值获取到value的内存地址,进而获取到的value的值,我们把这个值称为oldValue

var5+var4就是我们期望更新的值,我们把它叫做updateValue

接下来发现Unsafe.compareAndSwapInt是一个 native 方法(就是在Java虚拟机中用 C/C++ 实现的方法)

1
public final native boolean compareAndSwapInt(Object var1, long var2, int var4, int var5);

2 C/C++层面探究

现在Java语言层面我们已经分析完了,完全看不到CAS具体是怎么实现的,接下来我们继续去探究JDK源码[1]

Unsafe.compareAndSwapInt的源码所在目录为hotspot/src/share/vm/prims/unsafe.cpp

这个方法的前两个参数不用去了解(JNI是Java调用C的方式),后四个参数和Unsafe.compareAndSwapInt方法的一一对应。

1
2
3
4
5
6
7
8
9
10
11
UNSAFE_ENTRY(jboolean, Unsafe_CompareAndSwapInt(JNIEnv *env, jobject unsafe, jobject obj, jlong offset, jint e, jint x))
UnsafeWrapper("Unsafe_CompareAndSwapInt");
oop p = JNIHandles::resolve(obj);
jint* addr = (jint *) index_oop_from_field_offset_long(p, offset);
/**
* x代表了期望更新的值updateValue
* addr代表了value这个字段的内存地址
* e代表了oldValue
*/
return (jint)(Atomic::cmpxchg(x, addr, e)) == e;
UNSAFE_END

继续往下分析来到了最核心的部分

1
2
3
4
5
6
7
8
inline jint     Atomic::cmpxchg    (jint     exchange_value, volatile jint*     dest, jint     compare_value) {
int mp = os::is_MP(); // MP means multiprocessor,多处理器系统需要给cmpxchg指令加上lock前缀
__asm__ volatile (LOCK_IF_MP(%4) "cmpxchgl %1,(%3)"
: "=a" (exchange_value)
: "r" (exchange_value), "a" (compare_value), "r" (dest), "r" (mp)
: "cc", "memory");
return exchange_value;
}

cmpxchgl是一个汇编指令(最后一个字符l代表了cmpxchg指令的参数类型为jint,对于jlong类型最后一个字符是q,不同操作系统可能会不同),所以我们需要去理解cmpxchg这条汇编指令。

3 汇编层面探究

下面给出我在C中实现的CAS来帮助大家理解CAS和cmpxchg指令。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <unistd.h>
#include "stdio.h"
#include "pthread.h"

int cmp(int oldValue, int* addr, int updateValue);
void atomic_add(int* addr, int delta);
void task();
void atomic_task();

/**
* 被多个线程访问的共享变量
*/
int atomic_count = 0, count = 0;


/**
* 测试代码
*/
int main() {
pthread_t tid, tid2, tid3, tid4, tid5, tid6;
/**
* 创建了6个子线程,tid为子线程id,task和atomic_task是子线程的任务函数
*/
pthread_create(&tid, NULL, (void *) task, NULL);
pthread_create(&tid2, NULL, (void *) task, NULL);
pthread_create(&tid3, NULL, (void *) task, NULL);
pthread_create(&tid4, NULL, (void *) atomic_task, NULL);
pthread_create(&tid5, NULL, (void *) atomic_task, NULL);
pthread_create(&tid6, NULL, (void *) atomic_task, NULL);

/**
* 主线程等待6个子线程执行完才会往下执行printf
*/
pthread_join(tid, NULL);
pthread_join(tid2, NULL);
pthread_join(tid3, NULL);
pthread_join(tid4, NULL);
pthread_join(tid5, NULL);
pthread_join(tid6, NULL);

printf("count is %d\n", count); //预期最后的结果小于900000
printf("atomic_count is %d\n", atomic_count); //预期最后的结果精确等于900000
return 0;
}

/**
*
* @param oldValue 最开始获取到的value值
* @param addr 多线程访问的变量的地址
* @param updateValue 期望更新的新值
* @return
*/
int cmp(int oldValue, int* addr, int updateValue) {
int resValue;
__asm__ __volatile__ ("lock\n\t"
"cmpxchgl %2, (%3)"
: "=a"(resValue)
: "a"(oldValue), "r"(updateValue), "r"(addr)
: "cc", "memory");
return resValue;
}

/**
*
* @param addr 多线程访问变量的地址
* @param delta 这个变量想要增加的值
* @return
*/
void atomic_add(int* addr, int delta) {
int oldValue;
do {
oldValue = *addr;
} while ( cmp(oldValue, addr, oldValue + delta) != oldValue );
}

void atomic_task() {
for(int i = 0; i < 300000; i++) {
atomic_add(&atomic_count, 1);
usleep(60);
}
}

void task() {
for(int i = 0; i < 300000; i++) {
count += 1;
usleep(60); //0.06ms
}
}

重点讲解一下cmp函数

1
2
3
4
5
6
7
8
9
int cmp(int oldValue, int* addr, int updateValue) {
int resValue;
__asm__ __volatile__ ("lock\n\t" //保证原子性
"cmpxchgl %2, (%3)" //核心指令
: "=a"(resValue) //输出列表
: "a"(oldValue), "r"(updateValue), "r"(addr) //输入列表
: "cc", "memory"); //表明这条内联汇编会更改内存值和flag寄存器
return resValue;
}

其实也就是我自己实现的JDK中的Atomic::cmpxchg函数,只不过更容易看懂一些。

oldValue就是我们先前获取到的value值,addr就是value字段的内存地址,在cmpxchg指令中会用到这个地址,updateValue就是我们期望更新的value值(oldValue+增量delta)

以下关于内联汇编的知识大家可以去参考GCC-Inline-Assembly-HOWTO[2]

输入列表中:

"a"(oldValue):a代表EAX寄存器,意思是将变量oldValue的值输入到EAX寄存器

"r"(updateValue) "r"(addr):意思是将变量updateValue和地址addr也放到寄存器中,r(register)代表一组寄存器,也就是从这一组寄存器中随便选一个存updateValue,随便选一个存addr

输出列表中:

"=a"(resValue):表示内联汇编执行完后,将EAX寄存器的值存到resValue这个变量中。

我们在输出列表和输入列表中声明了许多变量,这些变量从%0开始依次往下标号,所以%0表示resValue%1代表了oldValue%2代表了updateValue%3代表了addr

现在我们结合的cmpxchg指令的功能,来理解一下这段汇编到底在做什么。汇编语言有Intel和AT&T两种语法,一般我们用的都是AT&T这种,下面的讲解也是基于该语法的:

cmpxchg 指令有两个操作数,同时还使用了EAX 寄存器。首先,它将第二个操作数和EAX寄存器相比较,如果相同则把第一个操作数赋值给第二个操作数,否则将第一个操作数赋值给EAX 寄存器

1
cmpxchgl %2, (%3) 

%2是第一个操作数,即updateValue,也就是无冲突的时候我们期望更新的值

(%3)是第二个操作数,即(addr)()表示取值操作(相当于C语言中的*),因为addrvalue变量的地址,所以该操作数是在取此刻value的值curValue

1
int *a = 5; //代表了a是一个指针,指向一个int变量,即a是该int变量的地址,*a表示取该int变量的值5

首先将oldValue存储到EAX寄存器中,然后用第二个操作数curValueoldValue行比较,如果相等,则说明从得到oldValue到现在执行cmpxchg这条指令这段时间内,value没有被其他线程改写(抛开ABA问题不谈),没有发生冲突,所以我们就可以直接把我们希望更新的新值updateValue写入到value中,那我们再来看看cmpxchg这条指令干了啥,如果相同则把第一个操作数赋值给第二个操作数,也就是将updateValue赋值给value对象(成功更新value的值)。现在这条指令就执行完毕了,EAX寄存器中存储的还是oldValue,所以最终cmp函数返回的也就是oldValue

1
2
3
4
5
6
void atomic_add(int* addr, int delta) {
int oldValue;
do {
oldValue = *addr;
} while ( cmp(oldValue, addr, oldValue + delta) != oldValue );
}

这个函数也就可以返回了,对应的就是Java中的Unsafe.getAndAddInt方法

1
2
3
4
5
6
7
8
9

public final int getAndAddInt(Object var1, long var2, int var4) {
int var5;
do {
var5 = this.getIntVolatile(var1, var2);
} while(!this.compareAndSwapInt(var1, var2, var5, var5 + var4));

return var5;
}

如果EAX中保存的oldValue和我们执行cmpxchg指令时获取的curValue不同,说明value的值被其他线程改写了,那此刻将第一个操作数updateValue赋值给 EAX 寄存器(我们并没有更新value的值),所以cmp函数返回的值为updateValueatomic_add函数中cmp返回的updateValueoldValue不同,所以要重新执行do while循环去自旋,直到没有冲突发生。

4 动画展示与总结

该动画使用Python manim[3]制作,动画源码

最后我们再用高级语言解释一下cmpxchg这条汇编指令的功能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/**
* 被多个线程访问的value对象
*/
int value;

/**
* 注意下面的逻辑是一条汇编指令完成的
*
* @param addr value对象的内存地址
* @param oldValue 执行cmpxchg指令前当前线程获取到的value对象的值
* @param updateValue 期望更新的值
* @return
*/
int cmpxchg(int* addr, int oldValue, int updateValue)
{
int curValue = *addr; //获取最新的value的值
if (curValue == oldValue) { //无冲突则更新value值为updateValue并返回oldValue
*addr == updateValue;
return oldValue;
} else { //发生冲突,直接返回updateValue,value对象的值没有更新哦
return updateValue;
}
}

5 引用


深入理解 CAS
https://hwollin.github.io/2022/08/19/cas/
作者
Wei Han
发布于
2022年8月19日
许可协议