# Fork/Join 框架

本文不会对源码部分涉及太多,主要是讲几个小例子来演示如何使用。

Fork/Join 框架是 Java 并发工具包中的一种可以将一个大任务拆分为很多小任务来异步执行的工具,该技术就是分治算法的并行实现,自 JDK1.7 引入。

该框架主要包含三个模块:

  • 任务对象: ForkJoinTask
  • 执行 Fork/Join 任务的线程: ForkJoinWorkerThread
  • 线程池: ForkJoinPool

三者关系: ForkJoinPool 可以通过池中的 ForkJoinWorkerThread 来处理 ForkJoinTask 任务。

ForkJoinPool 只接收 ForkJoinTask 任务 (在实际使用中,也可以接收 Runnable/Callable 任务,但在真正运行时,也会把这些任务封装成 ForkJoinTask 类型的任务)。

# 工作窃取算法

另一个核心思想就是分治算法,本文不加以赘述。

工作窃取算法(work-stealing): 线程池内的所有工作线程都尝试找到并执行已经提交的任务,或者是被其他活动任务创建的子任务 (如果不存在就阻塞等待)

在 ForkJoinPool 中,线程池中每个工作线程 (ForkJoinWorkerThread) 都对应一个任务队列 (WorkQueue),工作线程优先处理来自自身队列的任务 (LIFO 或 FIFO 顺序,参数 mode 决定),然后以 FIFO 的顺序随机窃取其他队列中的任务。

具体思路为:

  • 每个线程都有自己的一个 WorkQueue ,该工作队列是一个双端队列。

  • 队列支持三个功能 push、pop、poll, push/pop 只能被队列的所有者线程调用,而 poll 可以被其他线程调用。

  • 划分的子任务调用 fork 时,都会被 push 到自己的队列中。

  • 默认情况下,工作线程从自己的双端队列获出任务并执行。

  • 当自己的队列为空时,线程随机从另一个线程的队列末尾调用 poll 方法窃取任务。

# ForkJoinPool

外部程序ForkJoinPool 提交任务有三种方式:

  • invoke() 会等待任务计算完毕并返回计算结果;

  • execute() 是直接向池提交一个任务来异步执行,无返回结果;

  • submit() 也是异步执行,但是会返回提交的任务,在适当的时候可通过 task.get() 获取执行结果。

子任务提交是由 fork() 方法完成的,任务被分割 (fork) 之后调用了 ForkJoinPool.WorkQueue.push() 方法直接把任务放到队列中等待被执行(该任务可能是当前线程执行,也可能被其他线程窃取)。

获取任务结果:使用 join()/invoke() 方法, join 只有任务在队列 top 位(双端队列)时才会执行,所以会使当前线程阻塞,直到对应的子任务完成允许并返回执行结果; invoke 会直接执行当前任务。

所以两个子任务同时需要提交,一般流程是:

task1().fork();
task2().fork();
task2.join(); //task2 最晚加进去,在栈顶,此时直接调用 join 就会直接执行
task1.join();

compute() 方法其实比调用 fork 效率更高,它实际上会在当前工作线程进行计算(线程重用),这比” 将子任务提交到工作队列,线程又将工作队列中拿任务快得多 “。

# Fork/Join 陷阱

避免不必要的 fork ():划分两个子任务后,不要同时调用两个子任务的 fork() ,这当然是可以的,但是其中一个子任务调用 compute() 的效率更高。当一个大人物被划分位两个以上的子任务时,尽可能使用 invokeAll() ,它可以避免不必要的 fork()invokeAll 会把传入的任务的第一个交给当前线程来执行,其他的任务都 fork 加入工作队列,这样等于利用当前线程也执行任务了。

注意 fork、compute、join 顺序

// 正常调用
R.fork(); // 右边任务加入队列,等待计算
long LAns = L.compute(); // 当前线程计算左边任务
long RAns = R.join(); // (需要 compute 计算完后才会执行) 拿到右边的结果
return LAns + RAns;
// 错误调用 1
R.fork();
long RAns = R.join(); // 阻塞,直到拿到 R 的结果
long LAns = L.compute();
return LAns = RAns;
// 错误调用 2
long RAns = R.compute();
L.fork(); // 需要 RAns 拿到结果才会再执行 fork
return RAns + L.join();

最后两个实际上都没有并行计算

# 例子 1:异步计算 1~n 的累加和

实际运用中,我们一般会继承 RecursiveTaskRecursiveActionCountedCompleter 来实现我们的业务需求,而不会直接继承 ForkJoinTask 类。

  • RecursiveTask :是 ForkJoinTask 的子类,是一个可以递归执行的 ForkJoinTask
  • RecursiveAction :是一个无返回值的 RecursiveTask
  • CountedCompleter :在任务完成执行后会触发执行一个自定义的钩子函数。
public class Test {
    // 继承 RecursiveTask
	static final class SumTask extends RecursiveTask<Integer> {
		private static final long serialVersionUID = 1L;
		
		final int start; // 开始计算的数
		final int end; // 最后计算的数
		
		SumTask(int start, int end) {
			this.start = start;
			this.end = end;
		}
		@Override
		protected Integer compute() {
			// 如果计算量小于 1000,那么分配一个线程执行 if 中的代码块,并返回执行结果
			if(end - start < 1000) {
				System.out.println(Thread.currentThread().getName() + " 开始执行: " + start + "-" + end);
				int sum = 0;
				for(int i = start; i <= end; i++)
					sum += i;
				return sum;
			}
			// 如果计算量大于 1000,那么拆分为两个任务
			SumTask task1 = new SumTask(start, (start + end) / 2);
			SumTask task2 = new SumTask((start + end) / 2 + 1, end);
			// 执行任务
			task1.fork();
			task2.fork();
            // 也可以不使用 fork, 直接使用 invokeAll ()
            // invokeAll(task1,task2);
			
            // 获取任务执行的结果
			return task2.join() + task1.join();
		}
	}
	
	public static void main(String[] args) throws InterruptedException, ExecutionException {
		ForkJoinPool pool = new ForkJoinPool();
		ForkJoinTask<Integer> task = new SumTask(1, 10000);
		pool.submit(task);
		System.out.println(task.get());
	}
}

需要重写 compute 方法,并递归下去。

# 例子 2:斐波那契数列

public static void main(String[] args) {
    ForkJoinPool forkJoinPool = new ForkJoinPool(4); // 最大并发数 4
    Fibonacci fibonacci = new Fibonacci(20);
    long startTime = System.currentTimeMillis();
    Integer result = forkJoinPool.invoke(fibonacci);
    long endTime = System.currentTimeMillis();
    System.out.println("Fork/join sum: " + result + " in " + (endTime - startTime) + " ms.");
}
// 以下为官方 API 文档示例
static  class Fibonacci extends RecursiveTask<Integer> {
    final int n;
    Fibonacci(int n) {
        this.n = n;
    }
    @Override
    protected Integer compute() {
        if (n <= 1) {
            return n;
        }
        Fibonacci f1 = new Fibonacci(n - 1);
        f1.fork(); 
        Fibonacci f2 = new Fibonacci(n - 2);
        return f2.compute() + f1.join(); 
    }
}

# 参考

https://www.liaoxuefeng.com/wiki/1252599548343744/1306581226487842

https://pdai.tech/md/java/thread/java-thread-x-juc-executor-ForkJoinPool.html