目录

ThreadLocal、InheritableThreadLocal、TransmittableThreadLocal解析

使用场景

在我们日常 Java Web 开发中难免遇到需要把一个参数层层的传递到最内层。

例如,用户进行操作需要在拦截器中从redis等缓存中间件去获取用户信息并判断是否过期,如果接下来的的业务方法需要用到用户信息时怎么获取呢?

Java的Web项目大部分都是基于Tomcat,每次访问都是一个新的线程,这样让我们联想到了ThreadLocal,每一个线程都独享一个ThreadLocal,在接收请求的时候set特定内容,在需要的时候get这个值。

先附上本文中demo源码演示地址,有兴趣的可以看下

ThreadLocalDemo地址

ThreadLocal

Demo

模拟一个普通的用户请求(新启动一个线程)

  1. 请求先经过拦截器,拦截器中必然需要获取用户信息,同时调用ThreadLocal.set(userInfo)将用户信息塞入线程上下文中
  2. 进行业务处理(业务处理时从ThreadLocal中获取用户信息,避免参数层层传递)

ThreadLocal封装类

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
public class ThreadLocalHolder {
    /**
     * 普通THREAD_LOCAL
     */
    private static final ThreadLocal<UserInfo> THREAD_LOCAL = new ThreadLocal<>();

    public static UserInfo getUser() {
        return THREAD_LOCAL.get();
    }

    public static void setUser(UserInfo userInfo) {
        THREAD_LOCAL.set(userInfo);
    }
}

测试类,启动一个线程模拟一个普通的Web同步请求

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
public static void main(String[] args) {
    BusinessService businessService = new BusinessService();
    LoginInterceptor loginInterceptor = new LoginInterceptor();
    //模拟一个普通的同步web请求
    new Thread(() -> {
        // 模拟用户身份拦截器
        loginInterceptor.userInterceptor();
        System.out.println(Thread.currentThread().getName() + ":" + ThreadLocalHolder.getUser());
        // 拦截器通过后 同步处理业务
        businessService.doBusiness();
    }).start();

}

模拟Web项目中的拦截器实现,从缓存中获取用户信息,塞入ThreadLocal中

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
public class LoginInterceptor {
    /**
     * 模拟拦截方法
     */
    public void userInterceptor() {
        UserInfo userInfo = getUserFromRedis();
        //将用户信息塞入ThreadLocal中
        ThreadLocalHolder.setUser(userInfo);
    }

    /**
     * 模拟从redis中获取信息,这里写死直接返回
     *
     * @return
     */
    public UserInfo getUserFromRedis() {
        UserInfo userInfo = new UserInfo();
        userInfo.setId(1L);
        userInfo.setUserName("chenyin");
        return userInfo;
    }
}

业务处理类,获取用户信息,再处理业务

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
public class BusinessService {
    /**
     * 模拟同步处理业务
     */
    public void doBusiness() {
        //获取用户信息,避免显示参数传递
        System.out.println(Thread.currentThread().getName() + ":" + ThreadLocalHolder.getUser());
        //业务处理。。略去
    }

    /**
     * 模拟异步处理业务
     */
    public void doBusinessAsync() {
        new Thread(() -> {
            //获取用户信息,避免显示参数传递
            System.out.println(Thread.currentThread().getName() + ":" + ThreadLocalHolder.getUser());
            //业务处理。。略去
        }).start();
    }
}

main方法执行结果如下 ./1.png 可以看到,同一个线程中,即无论调用层级多深,也不需要将UserInfo作为参数层层传递,直接调用ThreadLocal.get()方法即可获取用户信息

ThreadLocal存储结构

首先提出一个问题,ThreadLocal中set()方法设的值具体存储在哪里?

先看几个关键的变量定义

Thread类中变量定义

1
2
3
4
5
6
7
public class Thread implements Runnable {
    //略去其他变量定义
    //普通线程上下文存储所在Map
    ThreadLocal.ThreadLocalMap threadLocals = null;
    //InheritableThreadLocal可继承线程本地变量存储所在Map
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
}

ThreadLocal.ThreadLocalMap中变量定义,Entry中的Key(ThreadLocal类型)是个WeakReference弱引用

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
static class ThreadLocalMap {
    static class Entry extends WeakReference<ThreadLocal<?>> {
        /**
         * The value associated with this ThreadLocal.
         */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

    private Entry[] table;
}

再看下set方法实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
public void set(T value) {
    //获取当前线程对象t
    Thread t = Thread.currentThread();
    //获取t中的 ThreadLocal.ThreadLocalMap变量
    ThreadLocalMap map = getMap(t);
    //往ThreadLocalMap中的tables中加入数据,key为当前ThreadLocal对象,value为用户传入的值
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

ThreadLocal之所以能做不到不同线程之间的隔离性,就是因为set方法设的值不是存在我们定义的ThreadLocal变量中,而是存储在每个线程的变量(ThreadLocal.ThreadLocalMap)中

再提出一个问题,ThreadLocalMap.Entry中的key值为什么是ThreadLocal类型?

假设有如下场景,同一个线程中同时使用了2个ThreadLocal

1
2
3
4
5
6
public static void main(String[] args) {
    ThreadLocal<Integer> threadLocalA = new ThreadLocal<>();
    ThreadLocal<Integer> threadLocalB = new ThreadLocal<>();
    threadLocalA.set(1);
    threadLocalB.set(2);
}

同一线程中可能定义了不同的ThreadLocal变量,这些ThreadLocal实例共享一个table数组,然后每个ThreadLocal实例在table中的索引i是不同的,因此Key为ThreadLocal能够根据ThreadLocal中的hashCode唯一确定其value在table中的下标

关键API

1
2
3
4
5
6
//从线程上下文中获取值
public T get() ;
//将值设入线程上下文中,供同一线程后续使用
public void set(T value) ;
//清除线程上下文
public void remove() ;

set方法实现

源码如下

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
public void set(T value) {
    //获取当前线程对象t
    Thread t = Thread.currentThread();
    //获取线程t实例对象中的 ThreadLocal.ThreadLocalMap变量
    ThreadLocalMap map = getMap(t);
    //往ThreadLocalMap中的tables中加入数据,key为当前ThreadLocal对象,value为用户传入的值
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

实现核心为java.lang.ThreadLocal.ThreadLocalMap#set方法,看下实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    //根据当前ThreadLocal变量的hashCode与数组长度做位运算得到在Entry[] tab数组中的存储下标
    int i = key.threadLocalHashCode & (len - 1);
    //e != null说明hash冲突,下标往后+1
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;
            return;
        }
        // 可能threadLocal对象已经被gc回收,此时key为null,清除无效的entry
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    //找不到对应entry,新建一个Entry,塞入下标为i的槽位处
    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

get方法实现

基本思路与get类似,先获取当前调用线程对象t,再获取其ThreadLocalMap对象,再调用 ThreadLocal.ThreadLocalMap#getEntry方法获取值

java.lang.ThreadLocal#get

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T) e.value;
            return result;
        }
    }
    return setInitialValue();
}

ThreadLocal.ThreadLocalMap#getEntry

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    //找到key,直接返回
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;
    //hash冲突时,遍历tab,直到key值相等
    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            //清楚key为null的无效entry
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

缺点

内存泄露

为什么ThreadLocal会出现内存泄露?

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
static class Entry extends WeakReference<ThreadLocal<?>> {
    /**
     * The value associated with this ThreadLocal.
     */
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

java.lang.ThreadLocal.ThreadLocalMap.Entry的定义中,ThreadLocal是个WeakReference弱引用类型,什么是WeakReference(弱引用)?

  1. 强引用:我们平时使用的最多的引用,是最普遍的引用。JVM不会回收这些引用,即使当内存空间不足时,JVM宁可抛出OutOfMemoryError异常也不会回收这些对象
1
2
3
String str = "abc";
List<String> list = new Arraylist<String>();
list.add(str);
  1. 弱引用:当垃圾回收器进行线程扫描时,无论此时内存空间是否充足,都会将其回收掉,即弱引用生命周期只在一次GC周期中

再回到问题中来,假设定义如下变量

1
2
ThreadLocal<Integer> threadLocalA = new ThreadLocal<>();
threadLocalA.set(1);

此时,ThreadLocalMap中的Entry[] table的数据存储情况如下,外部的引用threadLocalA与Entry[] table中key的引用都指向一个threadLocal实例。

./2.png 假设此时执行

1
threadLocalA = null;

堆区threadLocal的实例对象有2个引用链

  1. threadLocal引用->堆区threadLocal实例(由于执行threadLocalA=null被切断)
  2. thread引用->堆区thread实例->threadLocalMap->entryTable->entry->找到threadLocal实例的弱引用key值(由于key是弱引用,下次GC后会被回收)

GC后,threadLocal实例有可能被JVM回收,Entry[] table中的key就会存在为null的情况,因此该entry永远不能被访问到。

但此时key对应的value存在如下引用链:栈区线程对象引用(threadRef)->thread实例->ThreadLocalMap对象->Entry[] table数组->entry对象->value对象,因此Value可达,GC时不会回收

entry中key为null导致value不能被访问+value不会被回收是造成内存泄露的主要原因

目前源码中针对key为null的情况已有优化方案,set(),get(),remove()中的replaceStaleEntry、cleanSomeSlots、expungeStaleEntry即为清除key为null的方法

那key为什么不使用强引用?

和上面分析value对象可达的引用链路类似,如果key使用强引用,即使调用threadLocalA = null,此时线程中threadLocalMap中仍然持有threadLocal实例的引用,threadLocalA实例仍然不会被GC回收,造成异常情况

value为什么不使用弱引用? value只存在thread引用->堆区thread实例->threadLocalMap->entryTable->entry->value这一条引用链,假设value为弱引用,则GC后会被回收,再也无法通过ThreadLocal.get()方法获取value值

父子线程传值问题

修改测试代码如下,doBusinessAsync方法又启动了一个子线程来执行业务(模拟异步处理)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
public static void main(String[] args) {
    BusinessService businessService = new BusinessService();
    LoginInterceptor loginInterceptor = new LoginInterceptor();
    //模拟一个普通的异步web请求
    new Thread(() -> {
        // 模拟用户身份拦截器
        loginInterceptor.userInterceptor();
        System.out.println(Thread.currentThread().getName() + ":" + ThreadLocalHolder.getUser());
        // 拦截器通过后 异步处理业务
        businessService.doBusinessAsync();
    }).start();
}

输出结果如下 ./3.png 子线程中无法获取到ThreadLocal中的value,从上面的存储原理分析中,已经很明白了,子线程拥有自己的ThreadLocalMap,自然无法获取父线程ThreadLocalMap中的值。

但往往很多操作是需要异步操作的,因此父子线程直接共享ThreadLocal中的值是有必要的,下面介绍以下InheritableThreadLocal,看下它是如何实现父子线程之间共享线程上下文的。

InheritableThreadLocal

其实现原理就是在创建子线程将父线程当前存在的本地线程变量拷贝到子线程的本地线程变量中

线程上下文复制

重新回顾下Thread中定义的两个变量 ./4.png

其中inheritableThreadLocals即为存储InheritableThreadLocal的Map变量,下面称为InheritableThreadLocalMap

先看下线程的创建过程

  1. step1:java.lang.Thread#Thread(java.lang.Runnable)
  2. step2:java.lang.Thread#init(java.lang.ThreadGroup, java.lang.Runnable, java.lang.String, long)
  3. step3:java.lang.Thread#init(java.lang.ThreadGroup, java.lang.Runnable, java.lang.String, long, java.security.AccessControlContext, boolean)

其中有如下实现 ./5.png 判断:如果当前线程(父线程)中有inheritableThreadLocals变量,则子线程的InheritableThreadLocalMap对象由ThreadLocal.createInheritedMap方法产生

再看下ThreadLocal.createInheritedMap方法,最终调用ThreadLocalMap并传入父线程中的inheritableThreadLocals完成拷贝复制

1
2
3
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
    return new ThreadLocalMap(parentMap);
}

既然子线程创建时被赋值的也是InheritableThreadLocalMap变量,那么通过InheritableThreadLocal获取线程上下文时也应该操作的是线程中的InheritableThreadLocalMap对象,因此InheritableThreadLocal重写了几个有关ThreadLocalMap获取和赋值的方法

实现如下

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
public class InheritableThreadLocal<T> extends ThreadLocal<T> {

    protected T childValue(T parentValue) {
        return parentValue;
    }

    //返回子线程的InheritableThreadLocalMap
    ThreadLocalMap getMap(Thread t) {
        return t.inheritableThreadLocals;
    }

    //为InheritableThreadLocalMap执行初始化
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

Demo

使用InheritableThreadLocal,在线程中再new一个线程,模拟异步方法执行

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
public class InheritableThreadLocalDemo {

    private static final InheritableThreadLocal<UserInfo> INHERITABLE_THREAD_LOCAL = new InheritableThreadLocal<>();

    public static void main(String[] args) {
        LoginInterceptor loginInterceptor = new LoginInterceptor();

        //模拟一个普通的异步web请求
        new Thread(() -> {
            UserInfo userInfo = loginInterceptor.getUserFromRedis();
            // 模拟用户身份拦截器
            INHERITABLE_THREAD_LOCAL.set(userInfo);
            System.out.println(Thread.currentThread().getName() + ":" + INHERITABLE_THREAD_LOCAL.get());
            new Thread(() -> {
                //获取用户信息,避免显示参数传递
                System.out.println(Thread.currentThread().getName() + ":" + INHERITABLE_THREAD_LOCAL.get());
                //业务处理。。略去
            }).start();
        }).start();
    }
}

结果输出如下 ./6.png 由此可见,子线程中也能正常获取父线程中线程上下文的数据

缺点

InheritableThreadLocal的核心思想即:创建线程的时候将父线程中的线程上下文变量值复制到子线程 ,在平时开发中,不可能每一个异步请求都new一个单独的子线程来处理(内存会被撑爆),因此需要使用到线程池,线程池中即存在线程复用的情况,假设线程池中后面创建的线程中的上下文数据否都来自线程池中被复用的线程,这就出现父子线程的上下文变量复制混乱的情况。

举个例子

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
/**
 * 演示InheritableThreadLocal的缺陷
 *
 * @author: chenyin
 * @date: 2019-10-22 13:13
 */
public class InheritableThreadLocalWeaknessDemo {

    private static final InheritableThreadLocal<Integer> INHERITABLE_THREAD_LOCAL = new InheritableThreadLocal<>();
    //模拟业务线程池
    private static final ExecutorService threadPool = Executors.newFixedThreadPool(5);

    public static void main(String[] args) throws InterruptedException {
        //模拟同时10个web请求,一个请求一个线程
        for (int i = 0; i < 10; i++) {
            new TomcatThread(i).start();
        }

        Thread.sleep(3000);
        threadPool.shutdown();
    }

    static class TomcatThread extends Thread {
        //线程下标
        int index;

        public TomcatThread(int index) {
            this.index = index;
        }

        @Override
        public void run() {
            String parentThreadName = Thread.currentThread().getName();
            //父线程中将index值塞入线程上下文变量
            System.out.println(parentThreadName + ":" + index);
            INHERITABLE_THREAD_LOCAL.set(index);

            threadPool.submit(new BusinessThread(parentThreadName));
        }
    }

    static class BusinessThread implements Runnable {
        //父进程名称
        private String parentThreadName;

        public BusinessThread(String parentThreadName) {
            this.parentThreadName = parentThreadName;
        }

        @Override
        public void run() {
            System.out.println("parent:" + parentThreadName + ":" + INHERITABLE_THREAD_LOCAL.get());
        }
    }
}

代码模拟了同时有10个web请求(启动10个线程),每个线程内部又向业务线程池中提交一个异步任务。执行结果如下图所示

./7.png

子线程中输出的父线程名称与下标index无法一一对应,即ThreadLocal线程上下文变量出现混乱的情况,应用需要的实际上是把 任务提交给线程池时的ThreadLocal值传递到 任务执行时

这种情况就需要使用阿里开源的TransmittableThreadLocal来解决了

TransmittableThreadLocal

TransmittableThreadLocal能将任务提交给线程池时的ThreadLocal值传递到任务执行时。

Demo

使用TransmittableThreadLocal代替InheritableThreadLocal,同时提交线程时结合TtlRunnable使用,使用TtlRunnable.get()来提交一个TtlRunnable到线程池中执行。

Demo中的TransmittableThreadLocal版本如下

1
2
3
4
5
6
<dependency>
	<groupId>com.alibaba</groupId>
	<artifactId>transmittable-thread-local</artifactId>
	<version>2.11.0</version>
	<scope>compile</scope>
</dependency>
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
public class TransmittableThreadLocalDemo {

    private static final TransmittableThreadLocal<Integer> INHERITABLE_THREAD_LOCAL = new TransmittableThreadLocal<>();
    //模拟业务线程池
    private static final ExecutorService threadPool = Executors.newFixedThreadPool(5);

    public static void main(String[] args) throws InterruptedException {
        //模拟同时10个web请求,一个请求一个线程
        for (int i = 0; i < 10; i++) {
            new TomcatThread(i).start();
        }

        Thread.sleep(3000);
        threadPool.shutdown();
    }

    static class TomcatThread extends Thread {
        //线程下标
        int index;

        public TomcatThread(int index) {
            this.index = index;
        }

        @Override
        public void run() {
            String parentThreadName = Thread.currentThread().getName();
            //父线程中将index值塞入线程上下文变量
            System.out.println(parentThreadName + ":" + index);
            INHERITABLE_THREAD_LOCAL.set(index);

            threadPool.submit(TtlRunnable.get(new BusinessThread(parentThreadName)));
        }
    }

    static class BusinessThread implements Runnable {
        //父进程名称
        private String parentThreadName;

        public BusinessThread(String parentThreadName) {
            this.parentThreadName = parentThreadName;
        }

        @Override
        public void run() {
            System.out.println("parent:" + parentThreadName + ":" + INHERITABLE_THREAD_LOCAL.get());
        }
    }
}

执行结果如下,子线程中输出内容与父线程一致,没有出现线程上下文变量复制混乱的情况 ./8.png

原理

TransmittableThreadLocal实现的核心思想有两点

1、线程(TtlRunnable)提交时从父线程中捕获(复制一份)TransmittableThreadLocal上下文对象 2、TtlRunnable重写Run方法,在run方法执行时,根据捕获的线程上下文重新执行TransmittableThreadLocal#set方法达到父子线程

TransmittableThreadLocal#holder

先看下TransmittableThreadLocal中的holder实现,有几个关键点需要注意

  1. holder是个InheritableThreadLocal,本身是个线程上下文
  2. holder中value是WeakHashMap类型(防止内存泄露)
  3. WeakHashMap中的key是TransmittableThreadLocal对象(之所以做key是因为可能存在多个TransmittableThreadLocal实例),其value是null值(WeakHashMap允许value为null)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
    private static InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
            new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
                }

                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
                }
            };

TransmittableThreadLocal#set

1
2
3
4
5
6
7
8
@Override
public final void set(T value) {
    //向InheritableThreadLocal中写入value
    super.set(value);
    // may set null to remove value
    if (null == value) removeValue();
    else addValue();
}

holder其实更多的是个set的作用,存储了当前线程中设有ttlValue的TransmittableThreadLocal的引用

1
2
3
4
5
6
private void addValue() {
    //以当前TransmittableThreadLocal为key,塞入holder中
    if (!holder.get().containsKey(this)) {
        holder.get().put((TransmittableThreadLocal<Object>) this, null);
    }
}

TransmittableThreadLocal#set首先往super.set(value)中写入value,后调用addValue方法将当前TransmittableThreadLocal塞入了holder中

至于为什么要塞到holder中,用处是:

当用户向线程池中提交包装后的Runnable对象(TtlRunnble)时,TtlRunnble能从holder中捕获(获取)父线程中TransmittableThreadLocal上下文存储的值

捕获与重放

先看下官方介绍中的流程时序图

https://github.com/alibaba/transmittable-thread-local

./9.png

根据时序图的步骤来说明

  1. createTtl()、setTtlValue()其实就是调用TransmittableThreadLocal的线程上下文值,ttlValue就是上下文中的值
  2. createBizTaskRunnable就是执行业务的线程,createTtlRunnableWrapper(Runnable)就是使用TtlRunnable.get()来封装了Runnable,捕获操作captureAllTtlValues就是发生这里 看下TtlRunnable的关键实现 变量如下
1
2
3
4
//捕获的父线程的存储了ttlvalue的上下文对象,并存储在capturedRef引用指向的对象中
private final AtomicReference<Object> capturedRef;
//业务中执行的线程
private final Runnable runnable;

初始化方法

1
2
3
4
5
6
private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
    //调用TransmittableThreadLocal.Transmitter父线程中捕获上下文对象
    this.capturedRef = new AtomicReference<Object>(capture());
    this.runnable = runnable;
    this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}

在调用TtlRunnable.get()时会触发上述初始化方法,capture()方法最终调用的是TransmittableThreadLocal.Transmitter#capture方法

  1. 下面就进入到了时序图中captureAllTtlValues、get()、copy(T value)的实现。看下TransmittableThreadLocal.Transmitter#capture方法的具体实现
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
public static Object capture() {
    return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}

private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
    WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
    //调用holder.get().keySet()获取当前线程中存在ttlValue的TransmittableThreadLocal引用列表
    for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
        //将TransmittableThreadLocal实例为key,TransmittableThreadLocal的value为值塞入WeakHashMap<TransmittableThreadLocal<Object>, Object>中,最终作为Snapshot的一个属性返回给子线程 TtlRunnable
        ttl2Value.put(threadLocal, threadLocal.copyValue());
    }
    return ttl2Value;
}

get()对应holder.get() copy(T value)对应threadLocal.copyValue()

  1. submitTtlRunnableToThreadPool、run()对应线程池开始执行任务

看下TtlRunnable中重写的run方法

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
public void run() {
    //捕获的父线程的线程上下文
    Object captured = capturedRef.get();
    if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
        throw new IllegalStateException("TTL value reference is released after run!");
    }
    //在子线程中重放(调用TransmittableThreadLocal#set方法)重新设置线程上下文
    Object backup = replay(captured);
    try {
        runnable.run();
    } finally {
        //恢复线程中子线程原先的本地线程变量,避免被父线程的线程上下文污染
        restore(backup);
    }
}

其实现的关键在于replay方法

  1. 接下来就进入到了时序图中的beforeExecute、replayCapturedTtlValues()方法,对应代码中的TransmittableThreadLocal.Transmitter#replay
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
public static Object replay(@NonNull Object captured) {
    final Snapshot capturedSnapshot = (Snapshot) captured;
    return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}

@NonNull
private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) {
    WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
    //注意iterator中的TransmittableThreadLocal<Object>存储的value此时是父线程中的线程上下文值
    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocal<Object> threadLocal = iterator.next();

        // backup 备份子线程中的TransmittableThreadLocal线程上下文变量,供后续恢复restore子线程上下文使用
        backup.put(threadLocal, threadLocal.get());

        // clear the TTL values that is not in captured
        // avoid the extra TTL values after replay when run task
        if (!captured.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    // 调用threadLocal.set方法重新塞入子线程的上下文中(父子线程之间不共享ThreadLocalMap)
    setTtlValuesTo(captured);

    // call beforeExecute callback
    doExecuteCallback(true);

    return backup;
}

private static void setTtlValuesTo(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
    for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
        TransmittableThreadLocal<Object> threadLocal = entry.getKey();
        //将捕获的父线程上下文的value设置到子线程的上下文中
        threadLocal.set(entry.getValue());
    }
}
  1. 时序图中run、useValueInTtl即对应到业务Runnable中的实现,因为此时已经完成重放操作,子线程中可以使用父线程的ttlValue
  2. 后面就是使用备份的子线程上下变量backup来恢复子线程的上下文环境,避免因为重放导致子线程的上下文环境被污染。对应到时序图中的restoreTtlValueBeforeReplay,afterExecute

源码如下,有兴趣的可以自己去看了

1
2
3
4
5
public static void restore(@NonNull Object backup) {
    final Snapshot backupSnapshot = (Snapshot) backup;
    restoreTtlValues(backupSnapshot.ttl2Value);
    restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}