80行代码手写mini线程池

eve2333 发布于 4 小时前 2 次阅读


这是我们的自定义线程池,里面只有一个方法叫做 `execute`,它接收一个参数,类型是 `Runnable` 接口。这个参数表示我们希望线程池帮我们执行的任务命令。

我们写一个 `main` 函数,先 `new` 一个我们自定义的 `MyThreadPool`,然后调用 `execute` 方法,打印一下当前执行任务的线程名。这就是我们向线程池提交的一个任务。

当然,这个任务是可以有延迟的。我们可以在任务内部模拟延迟,比如在 `execute` 提交的 `Runnable` 中,`try-catch` 包裹 `Thread.sleep(1000)`,表示这个任务会在一秒之后打印“执行这个任务的线程”。当然,这个任务不能阻塞主线程,所以我们可以在主线程中紧接着打印一句“主线程没有被阻塞”,来验证主线程是继续执行的。

测试代码写完后,我们可以加一个循环,比如循环五次,每次提交一个任务到线程池中。这样我们就有了一个完整的测试类。

接下来我们实现 `execute` 方法。首先我们知道,我们需要用一些线程来执行用户提交的这些 `Runnable` 任务。这时我们要思考两个问题:

第一,线程什么时候创建?  
第二,线程的 `Runnable` 是什么?是我们提交的 `command` 吗?

如果直接把 `command` 作为线程的 `Runnable`,我们可以这样写:每次调用 `execute` 时都 `new Thread(command).start()`。这种方式其实也能实现功能,但有两个严重问题。


public class MyThreadPool {

    // 1. 线程什么时候创建
    // 2 线程的runnable是什么?是我们提交command吗?
    List<Thread> threadList = new ArrayList<>();

    void execute(Runnable command) {
        Thread thread = new Thread(command);
        threadList.add(thread);
        thread.start();
    }
}

第一,既然我们叫“线程池”,目的就是避免频繁创建和销毁线程,因为线程的创建是非常消耗资源的。如果我们每次提交任务都新建线程,那就失去了线程池的意义。

第二,我们无法管理这些线程。虽然我们可以把每次创建的线程对象提出来,存入一个集合中,比如 `List<Thread> threadList`,然后每次创建完线程就加入这个集合进行管理,但这仍然没有解决线程复用的问题。

我们进一步思考:这个线程可以复用吗?比如当线程执行完一个任务后,能不能再执行下一个任务?

答案是:不能。因为一个线程的生命周期是从 `start()` 开始,到它执行完自己的 `run()` 方法结束。一旦任务执行完毕,线程就会终止,无法再次用来执行另一个任务。

所以我们需要换一种思路。我们简化一下场景:假设我们的线程池只有一个线程。那么这个线程应该做什么?

我们可以设计一个任务队列,比如 `List<Runnable> commandList`,用来存放所有用户提交的任务。每次调用 `execute` 方法时,就把 `command` 添加到这个列表中。然后那个唯一的线程写一个死循环,不断检查这个列表是否为空,如果不为空,就取出第一个任务执行。

比如:

public class MyThreadPool {

    List<Runnable> commandList = new ArrayList<>();

    Thread thread = new Thread(() ->{
        while (true){
            if (!commandList.isEmpty()) {
                Runnable command = commandList.remove(0);
                command.run();
            }
        }
    });

    void execute(Runnable command) {
        commandList.add(command);
    }
}

这样,这个线程就可以不断从队列中取任务执行,实现了线程的复用。

但你现在应该已经发现问题了:如果 `commandList` 是空的,这个 `while(true)` 就会一直空转,白白消耗 CPU 资源。

那有没有一种容器,可以在为空时自动阻塞线程,有元素时自动唤醒线程取出元素?答案是:**阻塞队列**(BlockingQueue)。

我们现在把 `List` 换成 `BlockingQueue<Runnable> blockingQueue`。因为 `BlockingQueue` 是接口,我们需要指定一个实现类,比如 `ArrayBlockingQueue`,并传入一个容量。

这样我们就不再需要手动判断队列是否为空。我们使用 `blockingQueue.take()` 方法,这个方法在队列为空时会自动阻塞当前线程,直到有任务被放入队列。

当然,`take()` 方法会抛出 `InterruptedException` 异常。这个异常在多线程中非常常见:当线程在阻塞或等待状态时,如果被其他线程调用了 `interrupt()`,就会抛出这个异常。我们必须手动处理它。

比如:
try {
    Runnable command = blockingQueue.take();
    command.run();
} catch (InterruptedException e) {
    // 处理中断
    Thread.currentThread().interrupt(); // 重新设置中断标志
}

不只是 `take()`,`Thread.sleep()`、`CountDownLatch.await()` 等几乎所有需要等待的 API 都会抛出 `InterruptedException`。唯一的例外是 `LockSupport.park()`,它不会抛出异常,而是通过中断状态来标记,这也是为什么在写底层多线程代码时,很多人更喜欢用 `LockSupport`。

我们继续完善代码。从 `blockingQueue` 中拿到 `command` 后,就执行 `command.run()`。我们把这个阻塞队列命名为 `blockingQueue`。

public class MyThreadPool {

    BlockingQueue<Runnable> blockingQueue = new ArrayBlockingQueue<>(1024);

    Thread thread = new Thread(() ->{
        while (true){
            try {
                Runnable command = blockingQueue.take();
                command.run();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
    });

    void execute(Runnable command) {
        blockingQueue.add(command);
    }
}

在添加任务时,我们尽量不要用 `add()` 方法,而要用 `offer()`。虽然 `add()` 和 `offer()` 都是往队列中添加元素,但语义不同:`add()` 在队列满时会抛出 `IllegalStateException`,而 `offer()` 会返回一个布尔值,表示是否添加成功。我们未来会用到这个返回值来做判断,所以统一使用 `offer()`。

其实,一个单线程的线程池到这里已经基本完成了。但我们只创建了线程对象,还没有启动它。我们需要手动调用 `thread.start()`,并给这个线程起个名字,比如“唯一线程”。

回到主函数执行程序:我们循环五次提交任务,主线程打印“主线程没有被阻塞”,说明主线程没有卡住。五个任务由同一个线程依次执行,每个任务睡眠一秒后打印线程名,所以输出是五次相同的线程名,间隔一秒。

但程序没有终止,我们只能手动结束。为什么?因为那个唯一线程一直在 `take()` 上阻塞,等待新任务。只要没有任务,它就不会退出。

这就是我们当前线程池的逻辑:单线程处理所有任务,无限等待。

但一个线程显然不能满足实际需求,我们需要更多线程。于是我们把线程执行的这段逻辑抽出来,封装成一个 `Runnable`,叫做 `task`。这个 `task` 就是所有线程都要执行的公共逻辑:不断从阻塞队列取任务并执行。

然后我们把线程也变成一个集合,`List<Thread> threadList`,表示线程池中的所有工作线程。

public class MyThreadPool {

    BlockingQueue<Runnable> blockingQueue = new ArrayBlockingQueue<>(1024);

    private final Runnable task = () -> {
        while (true) {
            try {
                Runnable command = blockingQueue.take();
                command.run();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
    };

    List<Thread> threadList = new ArrayList<>();

    void execute(Runnable command) {
        boolean offer = blockingQueue.offer(command);
    }
}

现在问题来了:`threadList` 里该放多少个线程合适?有人说6个,有人说10个,有人说12个。我们不硬编码,而是用一个变量控制,叫做 `corePoolSize`,比如设为10,表示我们希望核心线程数是10个。

我们定义 `execute` 的流程:

1. 如果当前 `threadList.size() < corePoolSize`,就创建新线程;
2. 创建线程时,把上面那个 `task` 作为 `Runnable` 传入,这样每个线程都会不断从队列取任务;
3. 把新线程加入 `threadList` 并启动;
4. 如果线程数已达到 `corePoolSize`,就不再创建线程,直接把任务 `offer` 到阻塞队列。

public class MyThreadPool {

    BlockingQueue<Runnable> blockingQueue = new ArrayBlockingQueue<>(1024);

    private final Runnable task = () -> {
        while (true) {
            try {
                Runnable command = blockingQueue.take();
                command.run();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
    };

    private int corePoolSize = 10;
    // 我们的线程池 应该有多少个线程?
    List<Thread> threadList = new ArrayList<>();

    // 我们判断thread list中 有多少个元素 如果没到core pool size 那么我们就创建线程
    void execute(Runnable command) {
        if (threadList.size() < corePoolSize) {
            Thread thread = new Thread(task);
            threadList.add(thread);
            thread.start();
        }
        boolean offer = blockingQueue.offer(command);
    }
}

但这里有个问题:如果 `offer` 返回 `false`,说明队列满了,任务提交失败。此时我们是否还能创建更多线程来处理任务?

当然可以。这就是“扩展线程”的概念。我们把原来的 `threadList` 改名为 `coreThreadList`,表示这是核心线程集合。当队列满时,我们可以创建“辅助线程”来帮忙,这些线程不属于核心线程,我们把它们放在另一个集合 `supportThreadList` 中。

流程调整为:

- 如果核心线程不足,创建核心线程;
- 如果核心线程已满,尝试将任务加入队列;
- 如果队列满,且辅助线程未达到上限,则创建辅助线程;
- 如果辅助线程也满了,执行拒绝策略。

我们引入 `maxPoolSize` 表示最大线程数。判断是否还能创建辅助线程的条件是:`coreThreadList.size() + supportThreadList.size() < maxPoolSize`。还有就是辅助线程满了怎么办呢?

为了代码清晰,我们用 `if-else` 结构优化逻辑:

```java
if (核心线程不足) {
    创建核心线程并启动
} else if (成功将任务放入队列) {
    返回
} else if (辅助线程未满) {
    创建辅助线程并启动
} else {
    执行拒绝策略
}
```

private int corePoolSize = 10;
private int maxSize = 16;
List<Thread> coreList = new ArrayList<>();
List<Thread> supportList = new ArrayList<>();

void execute(Runnable command) {
    if (coreList.size() < corePoolSize) {
        Thread thread = new Thread(task);
        coreList.add(thread);
        thread.start();
    }
    if (blockingQueue.offer(command)) {
        return;
    }
    if (coreList.size() + supportList.size() < maxSize) {
        Thread thread = new Thread(task);
        supportList.add(thread);
        thread.start();
    }
}

但这里存在线程安全问题:判断集合大小和添加线程不是原子操作,在多线程环境下可能出错。真实线程池会用 `AtomicInteger` 或加锁解决,这里我们只是演示,暂不处理。

public class MyThreadPool {
    List<Thread> coreList = new ArrayList<>();
    List<Thread> supportList = new ArrayList<>();

    void execute(Runnable command) {
        if (coreList.size() < corePoolSize) {
            Thread thread = new Thread(task);
            coreList.add(thread);
            thread.start();
        }
        if (blockingQueue.offer(command)) {
            return;
        }
        if (coreList.size() + supportList.size() < maxSize) {
            Thread thread = new Thread(task);
            supportList.add(thread);
            thread.start();
        }
        if (!blockingQueue.offer(command)) {
            throw new RuntimeException("阻塞队列满了!");
        }
    }
}

还有一个问题:辅助线程如何结束?我们不希望它们一直存在。可以让辅助线程在空闲一段时间后自动退出。我们设置一个超时时间还没有阻塞队列拿到任务

我们修改辅助线程的逻辑:它不再使用 `take()`,而是使用 `poll(timeout, unit)`,比如 `poll(1, TimeUnit.SECONDS)`。如果在1秒内没有拿到任务,`poll` 返回 `null`,线程就可以退出循环,结束生命周期。

我们把这个超时时间 `timeout` 和 `TimeUnit` 也作为参数暴露出来。

最终,我们将所有参数放在构造函数中:`corePoolSize`、`maxPoolSize`、`keepAliveTime`、`timeUnit`、`blockingQueue`、`rejectedExecutionHandler`。

我们把这四个参数分别写在构造函数上,好的,我们现在把参数放在了构造函数上,我们可以把这些任务重新封装成一个类,核心线程它继承了一个线程,然后让它实现一个run函数,这个run函数就是我们之前的那个核心线程,执行的任务就是这个死循环,OK然后我们在创建这个线程的时候,我们就只需要创建一个核心线程就可以了,然后辅助线程也是同理,我们把这一段代码粘到我们的辅助线程里面,OK我们在创建的时候直接创建辅助线程,好这样我们的代码就整齐一些,然后在等待的过程中,我们把参数换成time out,参数换成time unit,Ok,好我们在慢方法的时候,我们可以传一些,假如说我们有两个核心线程,最大线程是四,要等待一秒之后,等待一秒之后,我们的辅助线程就会自己结束,那么这个阻塞队列的容量,我们自己写的是1024,那么这个容量填多少合适呢,于是我们也可以把这个阻塞队列放在我们的构造函数上,这里还要传一个阻塞队列,比如说我们创建一个阻塞队列,让它的最大容量是二,让它来模拟阻塞队列已经满了的场景,于是我们来分析一下这段代码,我们有两个核心线程,有四个最大线程,也就是说我们有两个核心线程,有两个辅助线程,核心线程分别是零和一,辅助线程分别是二和三,当最开始零和一创建的时候,他发现阻塞队列满了,于是我们创建了二和三辅助线程,来完成我们比较忙碌的任务,当辅助线程一秒之后没有拿到任务的时候,二和三就已经结束了,零和一它是核心线程,仍然在线程池中保持状态,等待着阻塞队列中被添加任务。

import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.TimeUnit;

/**
 * @Author gongxuanzhangmeit@gmail.com
 */
public class Main {
    public static void main(String[] args) {
        MyThreadPool myThreadPool = new MyThreadPool(2, 2, 4, 1, TimeUnit.SECONDS, new ArrayBlockingQueue<>(2));
        for (int i = 0; i < 10; i++) {
            myThreadPool.execute(() -> {
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
                System.out.println(Thread.currentThread().getName());
            });
        }
        System.out.println("主线程没有被阻塞");
    }
}

我们可以把这些线程封装成类,比如 `Worker` 类继承 `Thread`,在 `run` 方法中写取任务执行的逻辑。核心线程和辅助线程可以共用同一个 `Worker` 类,只是创建时机不同。

下面是极端场景:在测试时,我们可以设置 `corePoolSize=2`,`maxPoolSize=4`,`queueCapacity=2`。当提交6个任务时,前2个由核心线程处理,队列放2个,剩下2个触发创建2个辅助线程。辅助线程执行完任务后,空闲1秒就自动结束。

public class MyThreadPool {
    private final int corePoolSize;
    private final int maxSize;
    private final int timeout;
    private final TimeUnit timeUnit;
    private final BlockingQueue<Runnable> blockingQueue;
    private final RejectHandle rejectHandle;

    public MyThreadPool(int corePoolSize, int maxSize, int timeout, TimeUnit timeUnit, BlockingQueue<Runnable> blockingQueue, RejectHandle rejectHandle) {
        this.corePoolSize = corePoolSize;
        this.maxSize = maxSize;
        this.timeout = timeout;
        this.timeUnit = timeUnit;
        this.blockingQueue = blockingQueue;
        this.rejectHandle = rejectHandle;
    }

    List<Thread> coreList = new ArrayList<>();
    List<Thread> supportList = new ArrayList<>();

    void execute(Runnable command) {
        if (coreList.size() < corePoolSize) {
            Thread thread = new CoreThread();
            coreList.add(thread);
            thread.start();
        }
        if (blockingQueue.offer(command)) {
            return;
        }
        if (coreList.size() + supportList.size() < maxSize) {
            Thread thread = new SupportThread();
            supportList.add(thread);
            thread.start();
        }
        if (!blockingQueue.offer(command)) {
            rejectHandle.reject(command, this);
        }
    }
}

如果队列满且线程数已达上限,我们不能再简单抛异常,而应该提供更灵活的处理方式。于是我们定义一个拒绝策略接口:RejectHandle和ThrowRejectHandle

public class ThrowRejectHandle implements RejectHandle {
    @Override
    public void reject(Runnable rejectCommand, MyThreadPool threadPool) {
        throw new RuntimeException("阻塞队列满了!");
    }
}
public interface RejectHandle{
    void reject(Runnable rejectCommand,MyThreadPool threadPool);
}

我们吧这个拒绝策略导到main中:
MyThreadPool myThreadPool = new MyThreadPool(2, 2, 4, 1, TimeUnit.SECONDS, new ArrayBlockingQueue<>(2),new ThrowRejectHandle());

参数包括被拒绝的任务 `command` 和线程池实例,方便策略中获取上下文。

我们可以实现多种策略:

- `AbortPolicy`:直接抛异常;
- `DiscardPolicy`:直接丢弃任务;
- `DiscardOldestPolicy`:丢弃队列中最老的任务,然后重试提交;
- `CallerRunsPolicy`:让提交任务的线程自己执行任务。

public化final BlockingQueue<Runnable> blockingQueue;

比如 `DiscardOldestRejectedHandler` 实现如下:

```java
public void reject(Runnable command, MyThreadPool threadPool) {
    threadPool.blockingQueue.poll(); // 丢弃队首任务
    threadPool.execute(command);     // 重新提交当前任务
}
```

这样就不会抛异常,但可能丢失旧任务。

最后,我们思考几个问题:

1. 如何给线程池增加 `shutdown()` 功能?  
   可以遍历所有线程并调用 `interrupt()`,同时设置一个标志位,让工作线程在捕获到中断或异常后退出循环。

2. 面试官问“你怎么理解拒绝策略”?  
   可以回答:拒绝策略是线程池在资源耗尽时的兜底机制,体现了系统的容错和降级设计。不同的业务场景应选择不同的策略,比如关键任务用 `Abort`,非关键任务可用 `Discard` 或 `CallerRuns`。

3. JDK 线程池还有一个参数叫 `ThreadFactory`,它是干嘛的?  
   `ThreadFactory` 用于创建线程,可以统一设置线程名称、优先级、是否为守护线程等。通过自定义 `ThreadFactory`,可以更好地监控和调试线程池中的线程。 ​