ForkJoinPool

ForkJoinPool的作用

ThreadPoolExecutor每个任务都是单独线程处理的。如果某个任务耗时很大,就可能出现其他线程都在等着这个线程结束的情况。

为了处理这种问题,ForkJoinPool将一个大任务拆分为多个小任务,使用fork可以将小任务分发给其他线程同时处理,使用join可以将多个线程处理的结果进行汇总。

bda5cd74gy1ft9simzhklj20cc0etgm1

ForkJoinPool的原理

SoWkIImgAStDuVBCAqajIajCJbLmhKXDBYt9By8Y59nQL9QPd8ea4voSaPfIcfEQWguaCQcuf2WnkKGXEIUnk5Z14Sf5LmRZAzZewc9OWyO7gK5OI51nZQukJcjXubWrXMGKfIQc1EJdvy2ayQjtoo_AzihFp0FfmAiIuNWiLi3bIDRfa9gN0en20000

ForkJoinPool中每个线程都有自己的双端列表用于存储任务。这个双端列表对于工作窃取算法非常重要。

public class ForkJoinWorkerThread extends Thread {
    final ForkJoinPool pool;                // 工作线程所在的线程池
    final ForkJoinPool.WorkQueue workQueue; // 线程的工作队列(这个双端队列是work-stealing机制的核心)
    ...
}

工作窃取算法

  • 每个线程都有自己的WorkQueue,该工作队列是一个双端列表;
  • 队列支持push、pop、poll;
  • push/pop只能被队列所有者线程调用,poll可以被其他线程调用;
  • 划分的子任务调用fork时,都会被push到自己的队列中;
  • 默认情况下,工作线程从自己的双端列表获取任务并执行;
  • 当自己的队列为空时,线程随机从另一个线程的队列末尾调用poll方法窃取任务。

(PS:poll是队列数据结构实现类的方法,从队首获取元素,同时获取的这个元素将从原队列删除;
pop是栈结构的实现类的方法,表示返回栈顶的元素,同时该元素从栈中删除,当栈中没有元素时,调用该方法会发生异常)

bda5cd74gy1fvadx7bjxzj20di08p0t6

创建ForkJoinPool对象

  1. 使用Executors工具类
// parallelism定义并行级别
public static ExecutorService newWorkStealingPool(int parallelism);
// 默认并行级别为JVM可用的处理器个数
// Runtime.getRuntime().availableProcessors()
public static ExecutorService newWorkStealingPool();
  1. 使用ForkJoinPool内部已经初始化好的commonPool
public static ForkJoinPool commonPool();
// 类静态代码块中会调用makeCommonPool方法初始化一个commonPool
  1. 使用构造器创建
public ForkJoinPool() {
    this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
         defaultForkJoinWorkerThreadFactory, null, false);
}
public ForkJoinPool(int parallelism) {
    this(parallelism, defaultForkJoinWorkerThreadFactory, null, false);
}
public ForkJoinPool(int parallelism,
                    ForkJoinWorkerThreadFactory factory,
                    UncaughtExceptionHandler handler,
                    boolean asyncMode) {
    this(checkParallelism(parallelism),
         checkFactory(factory),
         handler,
         asyncMode ? FIFO_QUEUE : LIFO_QUEUE, // 队列工作模式
         "ForkJoinPool-" + nextPoolId() + "-worker-");
    checkPermission();
}
  • paralleism:并行级别,通常默认为JVM可用处理器个数。
  • factory:用于创建ForkJoinPool中使用的线程
public static interface ForkJoinWorkerThreadFactory {
    public ForkJoinWorkerThread newThread(ForkJoinPool pool);
}

ForkJoinPool管理的线程均是扩展自Thread类的ForkJoinWorkerThread类型(里面包含了一个双端列表)

  • handler:用于处理工作线程未处理的异常,默认为null;
  • asyncMode:用于控制WorkQueue的工作模式
// asyncMode用于控制WorkQueue取任务模式
final ForkJoinTask<?> peek() {
    ForkJoinTask<?>[] a = array; int m;
    if (a == null || (m = a.length - 1) < 0)
        return null;
    // 如果是FIFO_QUEUE从base取任务,LIFO_QUEUE从top取任务
    int i = (config & FIFO_QUEUE) == 0 ? top - 1 : base;
    int j = ((i & m) << ASHIFT) + ABASE;
    return (ForkJoinTask<?>)U.getObjectVolatile(a, j);
}
final void execLocalTasks() {
    int b = base, m, s;
    ForkJoinTask<?>[] a = array;
    if (b - (s = top - 1) <= 0 && a != null &&
        (m = a.length - 1) >= 0) {
        if ((config & FIFO_QUEUE) == 0) {
            // 从队列top端取任务执行
        }
        else // 从队列base端取任务执行
            pollAndExecAll();
    }
}
final void pollAndExecAll() { // 从队列base端取任务执行
    for (ForkJoinTask<?> t; (t = poll()) != null;)
        t.doExec();
}

bda5cd74gy1fvbdlabm6qj20zk0k0wfc

提交任务到ForkJoinPool

// 提交没有返回值的任务
public void execute(ForkJoinTask<?> task) {
    if (task == null)
        throw new NullPointerException();
    externalPush(task);
}
public void execute(Runnable task) {
    if (task == null)
        throw new NullPointerException();
    ForkJoinTask<?> job;
    if (task instanceof ForkJoinTask<?>) // 避免二次包装
        job = (ForkJoinTask<?>) task;
    else
        job = new ForkJoinTask.RunnableExecuteAction(task); // 包装成ForkJoinTask
    externalPush(job);
}
// 提交有返回值的任务
public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
    if (task == null)
        throw new NullPointerException();
    externalPush(task);
    return task;
}
public <T> ForkJoinTask<T> submit(Callable<T> task) {
    // 包装成ForkJoinTask
    ForkJoinTask<T> job = new ForkJoinTask.AdaptedCallable<T>(task);
    externalPush(job);
    return job;
}
public <T> ForkJoinTask<T> submit(Runnable task, T result) {
    // 包装成ForkJoinTask
    ForkJoinTask<T> job = new ForkJoinTask.AdaptedRunnable<T>(task, result);
    externalPush(job);
    return job;
}
public ForkJoinTask<?> submit(Runnable task) {
    if (task == null)
        throw new NullPointerException();
    ForkJoinTask<?> job;
    if (task instanceof ForkJoinTask<?>) // 避免二次包装
        job = (ForkJoinTask<?>) task;
    else
        job = new ForkJoinTask.AdaptedRunnableAction(task); // 包装成ForkJoinTask
    externalPush(job);
    return job;
}
// 同步提交,阻塞等结果
public <T> T invoke(ForkJoinTask<T> task) {
    if (task == null)
        throw new NullPointerException();
    externalPush(task);
    return task.join(); // 等待任务完成
}

ForkJoinTask

大多数情况下,都是直接提交ForkJoinTask对象到ForkJoinPool中。

因为ForkJoinTask有以下三个核心方法:

  1. fork():在执行过程中将大任务划分为多个小的子任务,调用子任务的fork()可以将任务放到线程池中异步调度;
  2. join():调用子任务的join()等待任务返回的结果,这个方法类似于Thread.join(),区别在于前者不受线程中断机制的影响;
  3. invoke():在当前线程同步执行该任务,该方法不受中断机制影响。

ForkJoinTask中join()、invoke()方法都不受中断机制影响,内部调用externalAwaitDone()实现。
如果是在ForkJoinTask内部调用get(),本质和join()方法一样是调用externalAwaitDone()。在外部的话,会受中断机制影响,因为外部是调用的是externalInterruptibleAwaitDone()方法实现的。

public final V get() throws InterruptedException, ExecutionException {
    int s = (Thread.currentThread() instanceof ForkJoinWorkerThread) ?
        doJoin() : externalInterruptibleAwaitDone();
    ...
}

ForkJoinTask由上面三个方法衍生出了几个静态方法:

public static void invokeAll(ForkJoinTask<?> t1, ForkJoinTask<?> t2);
public static void invokeAll(ForkJoinTask<?>... tasks);
public static <T extends ForkJoinTask<?>> Collection<T> invokeAll(Collection<T> tasks);

上面几个方法都是让第一个任务同步执行,其他任务异步执行(注意:其他任务先fork,第一个任务再invoke)。

任务状态

ForkJoinTask内部维护的四个状态:

/** The run status of this task */
volatile int status; // 默认等于0
static final int DONE_MASK   = 0xf0000000;  // NORMAL|CANCELLED|EXCEPTIONAL掩码
// NORMAL,CANCELLED,EXCEPTIONAL均小于0
static final int NORMAL      = 0xf0000000;  // must be negative
static final int CANCELLED   = 0xc0000000;  // must be < NORMAL
static final int EXCEPTIONAL = 0x80000000;  // must be < CANCELLED
static final int SIGNAL      = 0x00010000;  // must be >= 1 << 16

static final int SMASK       = 0x0000ffff;  // short bits for tags

SoWkIImgAStDuUAArefLqDMrK_3p3_9rzB5Apiyjo4ajITKequp9KoWipKmjoQbqXWbgmfKxE_evk_hukBfO9IVc9QVcQ6JcbQGM5PKMb23Kk4OukmQ4UFhx8PdhMf6SMb2IcP-Nc9DJgP6AK1piwGvIfv2Ldvcd0R8xFRL4GrC1oCZCAylF1jaDSI9ODL0N5zm9S164XGXL87EnQh9IUB9xzzEkXSywbxzOsFDYV_lp5TrF

RecursiveAction与RecursiveTask

通常我们不会直接使用ForkJoinTask,而是使用它的两个抽象子类:

  • RecursiveAction:没有返回值的任务
  • RecursiveTask:有返回值的任务
    SoWkIImgAStDuVBCAqajIajCJbLmAoqfBKhbIamgBYbAJ2vHSCv9B2vMSCilolRApymBIIpEHfSBIaqkBIhEB4jrJ2x9pC_3AGtM2p5UmQP6LnVLK6GEM1hTN4mLQ4OxfEQb0Fq20000

使用RecursiveAction

public class RecursiveActionTest {
    static class Sorter extends RecursiveAction {
        public static void sort(long[] array) {
            ForkJoinPool.commonPool().invoke(new Sorter(array, 0, array.length));
        }

        private final long[] array;
        private final int lo, hi;

        private Sorter(long[] array, int lo, int hi) {
            this.array = array;
            this.lo = lo;
            this.hi = hi;
        }

        private static final int THRESHOLD = 1000;

        protected void compute() {
            // 数组长度小于1000直接排序
            if (hi - lo < THRESHOLD)
                Arrays.sort(array, lo, hi);
            else {
                int mid = (lo + hi) >>> 1;
                // 数组长度大于1000,将数组平分为两份
                // 由两个子任务进行排序
                Sorter left = new Sorter(array, lo, mid);
                Sorter right = new Sorter(array, mid, hi);
                invokeAll(left, right);
                // 排序完成后合并排序结果
                merge(lo, mid, hi);
            }
        }

        private void merge(int lo, int mid, int hi) {
            long[] buf = Arrays.copyOfRange(array, lo, mid);
            for (int i = 0, j = lo, k = mid; i < buf.length; j++) {
                if (k == hi || buf[i] < array[k]) {
                    array[j] = buf[i++];
                } else {
                    array[j] = array[k++];
                }
            }
        }
    }

    public static void main(String[] args) {
        long[] array = new Random().longs(100_0000).toArray();
        Sorter.sort(array);
        System.out.println(Arrays.toString(array));
    }
}

使用RecursiveTask

public class RecursiveTaskTest {
    static class Sum extends RecursiveTask<Long> {
        public static long sum(int[] array) {
            return ForkJoinPool.commonPool().invoke(new Sum(array, 0, array.length));
        }

        private final int[] array;
        private final int lo, hi;

        private Sum(int[] array, int lo, int hi) {
            this.array = array;
            this.lo = lo;
            this.hi = hi;
        }

        private static final int THRESHOLD = 600;

        @Override
        protected Long compute() {
            if (hi - lo < THRESHOLD) {
                return sumSequentially();
            } else {
                int middle = (lo + hi) >>> 1;
                Sum left = new Sum(array, lo, middle);
                Sum right = new Sum(array, middle, hi);
                right.fork();
                long leftAns = left.compute();
                long rightAns = right.join();
                // 注意subTask2.fork要在subTask1.compute之前
                // 因为这里的subTask1.compute实际上是同步计算的
                return leftAns + rightAns;
            }
        }

        private long sumSequentially() {
            long sum = 0;
            for (int i = lo; i < hi; i++) {
                sum += array[i];
            }
            return sum;
        }
    }

    public static void main(String[] args) {
        int[] array = IntStream.rangeClosed(1, 100_0000).toArray();
        Long sum = Sum.sum(array);
        System.out.println(sum);
    }
}

动态的划分子任务:

public class DirectoryTask extends RecursiveTask {
    protected Long compute() {
        File[] files = dir.listFiles();
        List<RecursiveTask> tasks = new ArrayList<>(files.length);
        for (File f : files) {
            if (f.isDirectory()) {
                tasks.add(new DirectoryTask(f));
            } else {
                tasks.add(new FileTask(f));
            }
        }
        long sum = 0;
        for (RecursiveTask task : invokeAll(tasks)) {
            // exception handling omitted
            sum += task.get();
        }
        return sum;
    }
}

Fork/Join的陷阱与注意事项

避免不必要的fork()

划分为两个子任务后,不要同时调用两个子任务的fork()方法。

划分为两个子任务后,直接调用compute()效率更高,因为直接调用子任务的compute()方法实际上就是在当前的工作线程进行了计算(线程重用),这比“将子任务提交到工作任务,线程又从工作任务中拿任务”更快。
bda5cd74gy1fuixyr4pcfj20bw0fa74h
直接用三个衍生的invokeAll()方法,可以避免不必要的fork()。

注意fork()、compute()、join()的顺序

right.fork(); // 计算右边的任务
long leftAns = left.compute(); // 计算左边的任务(同时右边任务也在计算)
long rightAns = right.join(); // 等待右边的结果
return leftAns + rightAns;

选择合适的子任务粒度

果任务太大,则无法提高并行的吞吐量;如果任务太小,子任务的调度开销可能会大于并行计算的性能提升

官方给出的粗略经验是:任务应该执行100~10000个基本的计算步骤

comments powered by Disqus