原创

ThreadLocal源码详解

作用

ThreadLoacl的主要作用:

  1. 用来做数据隔离,这样数据只能在线程内访问(同时线程内还有一个inheritableThreadLocals,这个是可以线程间传递的)
  2. 在进行对象跨层传递的时候,使用ThreadLocal可以避免多次传递,打破层次间的约束。这个通常用作上下文。
  3. 数据库连接,session传递等。

使用场景

  1. 比如说Spring对事务的隔离,采用ThreadLocal的方式保证同一个线程对于数据库的操作使用的是同一个连接,业务层面上对于数据库的操作不需要关心connection对象。其代码实现在TransactionSynchronizationManager中,使用了多个ThreadLocal对象保存每个连接所要处理的事务。
private static final ThreadLocal<Map<Object, Object>> resources =
            new NamedThreadLocal<>("Transactional resources");

private static final ThreadLocal<Set<TransactionSynchronization>> synchronizations =
            new NamedThreadLocal<>("Transaction synchronizations");

private static final ThreadLocal<String> currentTransactionName =
            new NamedThreadLocal<>("Current transaction name");
  1. 比如说项目中对于日期工具类在多线程条件下的使用,以下代码:

    public class DateUtils {
    
        private static final String DEFAULT_DATE_SCHEMA = "yyyy-MM-dd";
    
        private static SimpleDateFormat sdf = new SimpleDateFormat(DEFAULT_DATE_SCHEMA);
    
        public static Date parseToDate(String str) {
            try {
                return sdf.parse(str);
            } catch (ParseException e) {
                e.printStackTrace();
                return null;
            }
        }
    }
    

    这里SimpleDateFormat.parse操作是线程不安全的,其内部有一个Calendar对象,调用SimpleDataFormat.parse()时会先调用Calendar.clear(),然后调用Calendar.add(),在多线程并发执行时,线程A先执行并已经执行了clear方法和add方法,此时线程B进来执行了clear方法,那么此时线程A再去parseDate就会获得一个错误结果。

    针对线程不安全的问题,解决的办法有几个:

    • 每个线程执行new SimpleDateFormat()创建新的SimpleDateFormat实现,这种方式在线程多的情况下,会创建很多个SimpleDateFormat对象,内存消耗比较大,且实现不优雅。

    • 在方法执行上加上synchronized。使用锁保证每个线程顺序执行,这种实现方式解决了线程不安全的问题,但是同时带来了效率低下的问题。

    • 使用ThreadLocal包装SimpleDateFormat,这样每个线程在执行时,会在线程内有一个SimpleDateForamt副本:

      public class DateUtils {
      
              private static final String DEFAULT_DATE_SCHEMA = "yyyy-MM-dd";
      
          private static ThreadLocal<SimpleDateFormat> threadLocalSdf = new ThreadLocal<>();
      
          public static Date parseToDate(String str) {
              try {
                  SimpleDateFormat sdf = threadLocalSdf.get();
                  if (sdf == null) {
                      sdf = new SimpleDateFormat(DEFAULT_DATE_SCHEMA);
                      threadLocalSdf.set(sdf);
                  }
                  return sdf.parse(str);
              } catch (ParseException e) {
                  e.printStackTrace();
                  return null;
              }
          }
      }
      

      诸如此类类似上下文参数的线程内传递的情况都可以使用ThreadLocal实现,实现逻辑:

      main() {
          init(obj);
          use();
          destroy();
      }
      
      init(Object obj) {
          threadLocal.set(obj);
      }
      
      use() {
          Object obj = threadLocal.get();
          doSomething();
      }
      
      destroy() {
          threadLocal.remove();
      }
      

源码解析

如何实现线程隔离

以上述代码为例:

public class DateUtils {

        private static final String DEFAULT_DATE_SCHEMA = "yyyy-MM-dd";

    private static ThreadLocal<SimpleDateFormat> threadLocalSdf = new ThreadLocal<>();

    public static Date parseToDate(String str) {
        try {
            SimpleDateFormat sdf = threadLocalSdf.get();
            if (sdf == null) {
                sdf = new SimpleDateFormat(DEFAULT_DATE_SCHEMA);
                threadLocalSdf.set(sdf);
            }
            return sdf.parse(str);
        } catch (ParseException e) {
            e.printStackTrace();
            return null;
        }
    }
}

set方法开始看起:

public void set(T value) {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 从当前线程中获得ThreadLocalMap,默认为null
    ThreadLocalMap map = getMap(t);
    if (map != null)
        // 如果不为空
        map.set(this, value);
    else
        // 如果为空,创建一个ThreadLocalMap并赋值给当前线程的threadLocals属性
        createMap(t, value);
}

这个ThreadLocalMap是从当前线程中获得的,是线程对象的一个属性:

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

Thread类:

// 从这里可以看到ThreadLocalMap实际上是ThreadLocal的一个内部类
ThreadLocal.ThreadLocalMap threadLocals = null;
// 可继承的ThreadLocalMap
ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

注意

这里我们基本上可以找到ThreadLocal数据隔离的真相了,每个线程Thread都维护了自己的threadLocals变量,所以在每个线程创建ThreadLocal的时候,实际上数据是存在自己线程的threadLocals变量里面的,别人没办法拿到,从而实现了隔离。

ThreadLocalMap底层结构

上面我们了解了ThreadLocal是通过每个线程维护一个ThreadLocalMap,而数据久保存在线程内部的ThreadLocalMap中,所以这样其他线程就没有办法获得当前线程ThreadLocal中的对象了,实现了线程间的隔离。

那么同一个线程内多个ThreadLocal是怎么保存的呢,ThreadLocalMap的结构是怎样的呢?

当前线程threadLocals为空时

顺着前面的思路我们继续向下看,上面我们说到了在第一次执行ThreadLocal.set方法时,得到的map对象为空,然后会执行createMap(t, value);

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

在这里实例化了一个ThreadLocalMap并赋值给t.threadLocals属性,顺便我们来看看ThreadLocalMap的结构:

static class ThreadLocalMap {

    // 注意看这个Entry类是弱引用类型的,弱引用类型是在GC时不考虑是否还有内存空闲,都会被GC回收掉的
    static class Entry extends WeakReference<ThreadLocal<?>> {
        Object value;

        // 因为这里弱引用的是key,所以当key被回收掉的时候,value就有内存泄漏的危险,在使用ThreadLocal过程中应该要注意这个使用
        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

    // 默认的存储空间是16个槽
    private static final int INITIAL_CAPACITY = 16;

    // 用来作为实际保存数据的entry数组
    private Entry[] table;

    // 记录entry数组的size
    private int size = 0;

    // resize阈值
    private int threshold;

    // 初始化方法
    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        table = new Entry[INITIAL_CAPACITY];
        // 计算初始值所在的index,这里注意的是即使value是null,也会分配一个槽位,原因是hashCode与value没有关系
          // 后面再解决hash冲突时会详细解释
        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
        table[i] = new Entry(firstKey, firstValue);
        size = 1;
        // 计算resize阈值  threshold = len * 2 / 3;
        setThreshold(INITIAL_CAPACITY);
    }

    // ...
}

这里我们可以总结以下几点:

  1. 初始化entry数组的大小是16。
  2. 计算resize的阈值是当前数组最大值的2/3。

ThreadLocalMap

对于ThreadLocalMap的结构,我们可能有下面两种疑问:

  1. 计算数组中索引的算法是firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1),而且这里我们看到了此处的ThreadLocalMap并不像HashMap一样是在Entry数组之上再加入链表来解决hash冲突的,那么对于ThreadLocalMap是如何解决hash冲突的?
  2. ThreadLocalMapEntry的key为ThreadLocal对象,且对于ThreadLocal的存储是弱引用,而value是外界传入的,如果外界传入的value是强引用(比如说new Object()创建的对象),会存在内存泄漏的危险?

首先我们先来看下第点,ThreadLocalMap是如何解决hash冲突的,这里我们首先要理解为什么要使用数组去保存Entry(key为ThreadLocal,value为实际保存的值)。

这里是因为在程序运行期间,比如说同一个上下文中,我们需要保存多个类型的数据,那么就需要在一个线程中同时保存多个ThreadLocal,而我们知道,ThreadLocal是保存在线程的threadLocals(ThreadLocalMap的实例)中的,所以ThreadLocalMap也是需要像数组这种结构来存储的,那么在什么时候会遇到hash冲突呢?这就需要来看看ThreadLocalMap不为空时是如何操作的。

当前线程threadLocals不为空时

ThreadLocal.ThreadLocalMap.set

private void set(ThreadLocal<?> key, Object value) {

    Entry[] tab = table;
    int len = tab.length;

    /**
     * ThreadLocal.threadLocalHashCode属性对应的是一个均匀增大的数字
     * 随着线程内ThreadLocal对象的增多,这个数字会均匀的增大(0, 1640531527, 3281063054...)
     * 关于每次增大的跨度为什么是1640531527下面会详细解析
     */
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         e != null;
         // 这里是解决hash冲突问题的关键
         // 如果遇到hash冲突,就修改索引,一直到不发生冲突位置,形象的理解就是将纵向的链表旋转90度到数组中来解决
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        // key存在且是同一个ThreadLocal,替换最新值
        if (k == key) {
            e.value = value;
            return;
        }
        // entry不为空,key为空。这里为什么会发生?
        if (k == null) {
            // 出现这种情况说明已经发生了内存泄漏,这里会构造一个新的Entry(ThreadLocal, value)替换掉这个值
            // 同时ThreadLocal还会自发的向下继续轮询(到下一个空entry为止),发现和解决(删除)发生内存泄漏的entry
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    // 找到不为空的槽位,构造一个Entry放入
    tab[i] = new Entry(key, value);
    int sz = ++size;
    // 在插入之后需要判断是否rehash
    // 这里的判断条件有两个,一个是传统的size和threashold的判断
    // 还有一个就是判断是否有slots被清除,如果有,那么就不会进行resize
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

从这里我们看到,在进行set(threadLocal, value)的过程中,在获得了一个index之后,如果发生了hash冲突,是通过对index累加(环状)的形式解决的,详细的解决方式如下。

解决hash冲突

首先我们先看看计算槽位的逻辑:

第n次:int i = key.threadLocalHashCode & (len-1);

public class ThreadLocal<T> {

    // 以当前的ThreadLocal为key的hashCode计算,具体实现看下面方法
    private final int threadLocalHashCode = nextHashCode();

    // 这里的hashCode计算不是像HashMap一样计算key的hash值
      // 而是类似轮询的策略,每一个新的key(ThreadLocal)插入时,会先获取当前nextHashCode的值并返回
      // 然后将nextHashCode+=0x61c88647
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

    // 初始化为0
    private static AtomicInteger nextHashCode = new AtomicInteger();

    // 这种计算方式是的结果均匀分布,原因下面详解
    private static final int HASH_INCREMENT = 0x61c88647;

    // ...
}

这个魔数的选取与斐波那契散列有关,0x61c88647对应的十进制为1640531527

斐波那契散列的乘数可以用(long) ((1L << 31) * (Math.sqrt(5) - 1))可以得到2654435769,如果把这个值给转为带符号的int,则会得到-1640531527。换句话说(1L << 32) - (long) ((1L << 31) * (Math.sqrt(5) - 1))得到的结果就是1640531527也就是0x61c88647

通过理论与实践,当我们用0x61c88647作为魔数累加为每个ThreadLocal分配各自的ID也就是threadLocalHashCode再与2的幂取模,得到的结果分布很均匀。

ThreadLocalMap使用的是线性探测法,均匀分布的好处在于很快就能探测到下一个临近的可用slot,从而保证效率,所以该魔法数字的选择就是为了优化效率。

从这里我们可以知道,ThreadLocalMap中通过hash计算槽位的过程与value没有关系,纯粹就是threadLocal维护一个AtomicInteger类型的nextHashCode,对这个值进行累加后与当前ThreadLocalMap的size做位与。

上面说了key.threadLocalHashCode会随着ThreadLocal对象的变多而增大,他是不会冲突的,但是数组的大小如果没有达到resize阈值的前提下是固定的,所以就会造成hash冲突问题。从这里可以看到ThreadLocalMap解决hash冲突的手段就是不断地递增索引,形象的理解就是将纵向的链表旋转90度放到数组中解决。

for (Entry e = tab[i];
     e != null;
     // 这里是解决hash冲突问题的关键
     // 如果遇到hash冲突,就修改索引,一直到不发生冲突位置,形象的理解就是将纵向的链表旋转90度到数组中来解决
     e = tab[i = nextIndex(i, len)]) {
    // ...
}
// 这里就是索引递增,如果达到ThreadLocalMap的最大值就在从头开始,环状
private static int nextIndex(int i, int len) {
    return ((i + 1 < len) ? i + 1 : 0);
}

ThreadLocalMap解决hash冲突问题

这里解决了hash冲突,那么从`ThreadLocal如何查询数据呢?

public T get() {
    // 依然是获取当前线程,并从当前线程中获得ThreadLocalMap
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);

    if (map != null) {
        // 从ThreadLocalMap中获得value,这里正好对应set方法解决hash冲突问题
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            T result = (T)e.value;
            return result;
        }
    }
    // set(this, null);
    return setInitialValue();
}
private Entry getEntry(ThreadLocal<?> key) {
    // 这里获得的hashCode依然是递增的
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];

    if (e != null && e.get() == key)
        // 如果直接命中了就返回
        return e;
    else
        // 如果没有命中就按照nextIndex方法向后查找
        return getEntryAfterMiss(key, i, e);
}
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    // 这里的 i 就是通过hash(假的hash,递增的数字)得到的结果
    // 所以看到实际上这里和set方法相似,只不过set方法中使用了for循环,而这里使用了while循环
    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            // 如果命中就返回
            return e;
        if (k == null)
            // 这里当threadlocal自身发现内存泄漏时,所做的优化处理,get和set方法都有
            expungeStaleEntry(i);
        else
            // 如果没有命中和set方法相同都是向后继续遍历
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

从这里可以看到虽然ThreadLocalMap解决了hash冲突问题,但他的实现很简单,就是对数组下标进行递增,这样当线程内使用的ThreadLocal对象多的情况下,实际上查询效率会比较低。

内存泄漏问题

这里我们首先要明白一点,就是ThreadLocal是存储在哪里的---比如说以下代码

public class StringThreadLocal {

    private ThreadLocal<String> threadLocalSdf = new ThreadLocal<>();

    public void set(String str) {
        threadLocalSdf.set(str);
    }

      public String get() {
        return threadLocalSdf.get();
    }
}

当前线程执行StringThreadLocal执行完成后,因为栈中已经没有StringThreadLocal对象,那么如果此事发生GC,threadLocalSdf对象是否回收要看线程的ThreadLocalMap对这个对象的引用关系了,前面我们说了ThreadLocalMap对于ThreadLocal的引用关系是弱引用的,那么发生GC,threadLocalSdf就会被回收,ThreadLocalMap对于value的引用是强引用的,所以此处value是不会被回收到,而此处value被回收的前提是当前线程也被终结。

在我们的业务代码中对于多线程的应用往往是使用线程池的,而线程池对于线程的应用是复用的形式,那么线程往往会一直活动的,那么此处的value就会得不到释放(因为会一直强引用),所以,如果每次处理上面这段代码时,就不不停的产生threadLocalSdf-value的键值对,而threadLocalSdf被GC之后,value得不到释放会越来越多,造成内存的泄漏。

ThreadLocal内存泄漏问题

解决内存泄漏方式一

实际上解决内存泄漏很简单,只要在每次执行ThreadLocal的动作完成之后,调用remove方法清除即可:

ThreadLocal<String> localName = new ThreadLocal();
try {
    localName.set("张三");
    ……
} finally {
    localName.remove();
}
public void remove() {
    // thread.threadLocals
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        // 删除ThreadLocalMap中key为this的Entry
        m.remove(this);
}

这样可以确保这个entry不会发生内存的泄漏。

弱引用ThreadLocal是否是造成内存泄漏问题的原因

看到这里我们会对ThreadLocalMapThreadLocal的引用关系写的是弱引用提出了质疑,是否换成强引用就不会有内存泄漏的问题了?

实则不然,如果ThreadLocalMap中对ThreadLocal的引用更换为强引用,那么在GC时,因为强引用的关系,ThreadLocal对象和value都无法被GC,反而会造成更大量的内存泄漏。

解决内存泄漏方式二

方式一是要求程序员在代码开发环节保证每次对于ThreadLocal的使用都有remove方法,以确保不会出现内存泄漏,这里我们还可以使用更高级的ThreadLocal用法,可以确保内存不被泄漏:

public class ThreadLocalUtils {

    public static final ThreadLocal<Map<String, Object>> threadLocal = new ThreadLocal();

    public static Map<String, Object> getAll() {
        return threadLocal.get();
    }

    public static Object getObject(String key) {
        Map<String, Object> all = getAll();
        return all != null ? all.get(key) : null;
    }

    public static void setObject(String key, Object value) {
        Map<String, Object> all = getAll();
        if (all == null) {
            all = new HashMap<String, Object>();
            threadLocal.set(all);
        }
        all.put(key, value);
    }
}

分析:

根据前面的分析,我们知道造成内存泄漏的原因是ThreadLocal对象被GC了而value得不到GC,所以在线程一直存活的情况下,会造成没有释放的value越来越多,从而泄漏。

那么是不是我们只要让ThreadLocal不被释放,也就是所有对于ThreadLocal的处理使用同一个对象就好了,如上代码,因为threadLocal对象是static final修饰的,它产生的对象是要放到方法区中去的,那么他就不会被GC,而每个线程在使用threadLocal对象的时候是相对隔离的,也不会产生任何的问题(因为ThreadLocalMap不同,所以即便ThreadLocal对象相同,也是在不同的Map中的,而ThreadLocal对象本身没有什么实际意义,所以互相不产生影响)。

所以在这种情况下,即便没有清楚value,在线程复用时,因为使用了同一个ThreadLocal对象,新的value会覆盖旧的value,如此反复,虽然永远会有一个value的残留,但是不会越来越多,也解决了内存泄漏的问题。

源码中自己解决内存泄漏

在JDK8之后实际上我们可以不用考虑内存泄漏问题,因为在源码中已经自我解决了该问题。先来看前面ThreadLocalMap.set方法中:

for (Entry e = tab[i];
     e != null;
     // 这里是解决hash冲突问题的关键
     // 如果遇到hash冲突,就修改索引,一直到不发生冲突位置,形象的理解就是将纵向的链表旋转90度到数组中来解决
     e = tab[i = nextIndex(i, len)]) {
    ThreadLocal<?> k = e.get();

    // key存在且是同一个ThreadLocal,替换最新值
    if (k == key) {
        e.value = value;
        return;
    }
    // entry不为空,key为空。这里为什么会发生?
    if (k == null) {
        // 出现这种情况说明已经发生了内存泄漏,这里会构造一个新的Entry(ThreadLocal, value)替换掉这个值
        // 同时ThreadLocal还会自发的向下继续轮询(到下一个空entry为止),发现和解决(删除)发生内存泄漏的entry
        replaceStaleEntry(key, value, i);
        return;
    }
}
// 找到不为空的槽位,构造一个Entry放入
tab[i] = new Entry(key, value);
int sz = ++size;
// 在插入之后需要判断是否rehash
// 这里的判断条件有两个,一个是传统的size和threashold的判断
// 还有一个就是判断是否有slots被清除,如果有,那么就不会进行resize
if (!cleanSomeSlots(i, sz) && sz >= threshold)
    rehash();

在遍历查找空槽期间,如果发生了entry.key为空且entry不为空的情况下,说明此处发生了内存泄漏,那么他是如何操作的呢?

private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    int slotToExpunge = staleSlot;

    // 从当前位置向前遍历entry数组,如果发现泄漏,标记slotToExpunge=i,一直到遇到第一个空的槽位
    // prevIndex和nextIndex正好相反,对index进行自减,到第一个元素后再回到最后一个元素开始向前遍历
    for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    /*
     * 从当前位置向后遍历,如果发现threadLocal已经存在,替换当前位置和已经存在threadLocal的位置,并将新的value更新
     * 那么之前threadLocal的位置上就是个已经发生泄漏的entry,标记slotToExpunge=i并清除
     */
    for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            // 清理泄漏的entry
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // 先释放value,然后根据key,value生成新的entry赋值
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    if (slotToExpunge != staleSlot)
        // 如果slotToExpunge != staleSlot,说明数组中必然还有泄漏的entry存在,就再次清理泄露的entry
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
/**
 * 启发式地清理slot,
 * i对应entry是非无效(指向的ThreadLocal没被回收,或者entry本身为空)
 * n是用于控制控制扫描次数的
 * 正常情况下如果log n次扫描没有发现无效slot,函数就结束了
 * 但是如果发现了无效的slot,将n置为table的长度len,做一次连续段的清理
 * 再从下一个空的slot开始继续扫描
 * 
 * 这个函数有两处地方会被调用,一处是插入的时候可能会被调用,另外一处是在替换无效slot的时候可能会被调用,
 * 区别是前者传入的n为元素个数,后者为table的容量
 */
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    // 从当前位置向后移动,每发现一个内存泄漏就将其剔除
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            // 这里也会操作i的值,依然是向后遍历,返回的是第一个为空的槽的位置
            // 并将该过程中遇到的所有泄漏的entry清理掉,这个过程中还会伴有entry位置的变化
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}
/**
 * 这个函数是ThreadLocal中核心清理函数,它做的事情很简单:
 * 就是从staleSlot开始遍历,将无效(弱引用指向对象被回收)清理,即对应entry中的value置为null,将指向这个entry的table[i]置为null,直到扫到空entry。
 * 另外,在过程中还会对非空的entry作rehash(数组位置的交换)
 * 可以说这个函数的作用就是从staleSlot开始清理连续段中的slot(断开强引用,rehash slot等)
 */
private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // 首先确认当前位置已经发生了内存泄漏
    // 删除该位置上的引用关系,发生GC会清理掉
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    // 清理之后更新size
    size--;

    Entry e;
    int i;
    // 这里继续向后遍历,如果发现了内存泄漏,也会删除引用,一直到空的槽位为止
    for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            // 对于还没有被回收的情况,需要做一次rehash。
            // 如果对应的ThreadLocal的ID对len取模出来的索引h不为当前位置i,
            // 则从h向后线性探测到第一个空的slot,把当前的entry给挪过去。
            // 这样rehash的原因是尽可能的让不为空的槽连续在一起
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;

                /*
                 * 在原代码的这里有句注释值得一提,原注释如下:
                 *
                 * Unlike Knuth 6.4 Algorithm R, we must scan until
                 * null because multiple entries could have been stale.
                 *
                 * 这段话提及了Knuth高德纳的著作TAOCP(《计算机程序设计艺术》)的6.4章节(散列)
                 * 中的R算法。R算法描述了如何从使用线性探测的散列表中删除一个元素。
                 * R算法维护了一个上次删除元素的index,当在非空连续段中扫到某个entry的哈希值取模后的索引
                 * 还没有遍历到时,会将该entry挪到index那个位置,并更新当前位置为新的index,
                 * 继续向后扫描直到遇到空的entry。
                 *
                 * ThreadLocalMap因为使用了弱引用,所以其实每个slot的状态有三种也即
                 * 有效(value未回收),无效(value已回收),空(entry==null)。
                 * 正是因为ThreadLocalMap的entry有三种状态,所以不能完全套高德纳原书的R算法。
                 *
                 * 因为expungeStaleEntry函数在扫描过程中还会对无效slot清理将之转为空slot,
                 * 如果直接套用R算法,可能会出现具有相同哈希值的entry之间断开(中间有空entry)。
                 */
                while (tab[h] != null)
                    h = nextIndex(h, len);

                tab[h] = e;
            }
        }
    }
    // 返回staleSlot之后第一个空的slot索引
    return i;
}

在调用ThreadLocal.get方法时,同样会有清理泄漏entry的做法:

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;


    // 基于线性探测法不断向后探测直到遇到空entry。
    while (e != null) {
        ThreadLocal<?> k = e.get();
        // 找到目标
        if (k == key) {
            return e;
        }
        if (k == null) {
            // 该entry对应的ThreadLocal已经被回收,调用expungeStaleEntry来清理无效的entry
            expungeStaleEntry(i);
        } else {
            // 环形意义下往后面走
            i = nextIndex(i, len);
        }
        e = tab[i];
    }
    return null;
}

通过对源码的研究,我们发现实际上在使用ThreadLocal的过程中,即使是线程池的情况下,我们依然不需要太多的考虑内存泄漏的问题,因为JDK源码中已经对这种情况做了很多的保障(再一次get或者set过程中多次调用了清理泄漏的entry的方法,这也是为了应对数组中非空且有效的entry可能不连续的情况)。当然在代码中确保最后对ThreadLocal执行remove方法是个很好的习惯,这样也会提高getset方法的执行效率,只需要存值,而不需要考虑清理数组。

ThreadLocalMap的rehash

在上面ThreadLocalMap.set方法的最后会调用rehash方法:

if (!cleanSomeSlots(i, sz) && sz >= threshold)
    rehash();

在代用之前先判断在清理了无效(泄漏)的entry之后,所真正使用了的数组slot数是否大于阈值,如果大于阈值就会调用rehash方法:

private void rehash() {
    // 做一次全量清理
    expungeStaleEntries();

    /*
     * 因为做了一次清理,所以size很可能会变小。
     * ThreadLocalMap这里的实现是调低阈值来判断是否需要扩容,
     * threshold默认为len*2/3,所以这里的threshold - threshold/4相当于len/2
     */
    if (size >= threshold - threshold / 4) {
        resize();
    }
}
// 全量清理一次
private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null) {
            // 个人感觉这里还可以进行优化,可以修改为
              // j = expungeStaleEntry(j); 因为expungeStaleEntry方法会将一个连续段内的泄漏entry都清理掉
            expungeStaleEntry(j);
        }
    }
}
private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                // 这里应该不会再出现了,如果确实出现,就将value的引用切断
                e.value = null; 
            } else {
                // 线性探测来存放entry
                int h = k.threadLocalHashCode & (newLen - 1);
                // 环形向后找到空槽
                while (newTab[h] != null) {
                    h = nextIndex(h, newLen);
                }
                newTab[h] = e;
                count++;
            }
        }
    }
    // 重新计算rehash阈值
    setThreshold(newLen);
    size = count;
    table = newTab;
}

InheritableThreadLocal实现

前面我们介绍了ThreadLocal本身是线程隔离的,而InheritableThreadLocal提供了一种父子线程之间的数据共享机制。

如果父线程中创建了InheritableThreadLocal的对象,那么在创建、初始化新线程时,会将父线程的inheritableThreadLocals属性复制到子线程中去:

if (inheritThreadLocals && parent.inheritableThreadLocals != null)
    this.inheritableThreadLocals =
        ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
    return new ThreadLocalMap(parentMap);
}
private ThreadLocalMap(ThreadLocalMap parentMap) {
    Entry[] parentTable = parentMap.table;
    int len = parentTable.length;
    setThreshold(len);
    table = new Entry[len];

    for (int j = 0; j < len; j++) {
        Entry e = parentTable[j];
        if (e != null) {
            @SuppressWarnings("unchecked")
              // 这里得到的ThreadLocal是InheritableThreadLocal
            ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
              // 不为空的key才复制给子线程
            if (key != null) {
                  // InheritableThreadLocal.childValue返回的就是e.value
                Object value = key.childValue(e.value);
                Entry c = new Entry(key, value);
                int h = key.threadLocalHashCode & (len - 1);
                while (table[h] != null)
                    h = nextIndex(h, len);
                table[h] = c;
                size++;
            }
        }
    }
}

这里需要注意的是只有不为空的key(排除掉泄漏的entry)才会复制给子线程,且只是在初始化子线程的时候执行一次(有且只有一次),所以在子线程创建好之后,父线程再往inHeritableThreadLocals中添加新的InheritableThreadLocal对象,子线程也接收不到了。

总结

  1. 每个线程维护了一个ThreadLocal.ThreadLocalMap的实例变量threadLocals,那么也就是说每个线程操作ThreadLocal是隔离的。
  2. ThreadLocalMap的存储结构也是通过键值对的形式保存的,key是到业务代码所创建的每一个ThreadLocal对象的弱引用,value就是所要保存的value,通常是new Object创建出的对象,即强引用。那么就会有内存泄漏的危险。在ThreadLocalMap的源码中调用setget方法期间,会尽可能的清理无效(泄漏)的数据,同时也可以依靠良好的使用习惯(remove的使用)和更高级的使用(static final ThreadLocal)来减少内存泄露的可能性。
  3. ThreadLocalMap因为没有使用链表,其解决hash冲突的方式使用线性探测的方式,将entry数组看作一个环形,向后遍历一直找到空的slot做存储,所以在发生大量hash冲突的条件下,效率会比较低。

本文我们对ThreadLocal进行了深入的探索,从其使用到源码的解析,再到内存泄漏问题发现和解决以及InheritableThreadLocal的实现,进一步的了解了Josh Bloch和Doug Lea两位大师对于ThreadLocal线程隔离和针对内存泄漏所做出的优化是十分优雅的。在业务开发过程中,如果对ThreadLocal使用得当,会提高效率,减少重复工作,尤其是在作为上下文属性传递的工作中表现优异,但同样也会带来类与类之间的耦合度(比如说需要知道在前序步骤中向Context中传入了什么属性,如果继续增加的话需要避免重复等)。并且尤其注意在使用ThreadLocal作为上下文属性传递的过程中如果再使用多线程的话,会造成数据的丢失。

注:
该文中使用的图片均来自于敖丙

正文到此结束