手写HashMap(不含红黑树)

eve2333 发布于 23 天前 57 次阅读


好的,这里是一个新的工程。
今天我们实现一个哈希Map
我们这里只实现三个函数:putgetremove

我们先来看一下这三个函数的定义。
根据JDK的Map的定义,put是有一个返回值的,返回值是一个value。
它的意思是:如果当前的Map里面已经有了key对应的value,那么我们就用当前传入的value去把它覆盖掉,同时把之前的value返回;如果当前的HashMap里面没有这个key对应的value,我们就返回null。
也就是说,如果我们传进来的这个key和value把之前Map里的key-value对“挤出去”了,那么就把被挤出的value返回;如果没有被挤出去(即原本不存在),那么就返回null。

同样地,remove也是有返回值的:
如果我们当前的Map里面有这个key对应的value,那我们就把这个value返回;如果没有这个key对应的value,我们就直接返回null就可以了。


我们来思考一下:Map到底是什么?
Map是一个容器,容器里装着什么?
容器里装着的是key和value。
key和value是一对,它们一定是成对出现的,不可能单独存在一个key或单独一个value。

所以我们第一步应该把key和value关联起来。
那么我们怎么样把key和value关联起来呢?

我们可以写一个类,叫做Node(当然这个名字可以任意取,我这边先叫它Node)。
这个Node应该有什么成员变量呢?
它应该有一个key,也应该有一个value
当它存了一个key和一个value之后,我们就可以直接操作这个node,而不需要独立地对key和value进行操作。
我们把key和value组合在一起——这个Node就是装着一对key-value的容器。
而整个Map,则是装着一大堆这样的Node的容器。

现在我们通过一个Node类,成功地把传进来的key和value关联了起来。

接下来,我们需要一个容器把这些Node装起来。
为什么?因为我们在后续调用get的时候需要能够找到这些节点。

说到容器,我们自然而然就想到了数组。
于是我们可以在这定义一个数组:Node[],我们把它叫做table
那么这个数组我们定义多长呢?我们先随便写一个10。

但随着不断地put操作,我们的数据会越来越多,数组就需要动态扩容。
我们知道JDK中的ArrayList本质上就是一个动态数组。
所以为了简化实现,我们暂时先用ArrayList<Node>来替代数组,也叫做table

List<Node> table = new ArrayList<>();

OK,我们现在用这个ArrayList来代替固定长度的数组。

public class MyHashMap<K, V> {
    List<Node<K, V>> table = new ArrayList<>();

    public V put(K key, V value) {
        return null;
    }

    public V get(K key) {
        return null;
    }

    public V remove(K key) {
        return null;
    }

    class Node<K, V> {
        K key;
        V value;

        public Node(K key, V value) {
            this.key = key;
            this.value = value;
        }
    }
}

现在我们开始填充这三个核心函数。

第一步:实现 put

put的时候,我们首先要知道当前这个table里面有没有我们这个key。
如果有,只需要把这个value改成新的value即可;如果没有,就需要向table中新增一个Node

好的,第一步:检查当前table中是否已有该key。

我们可以遍历一下这个table
对于每一个Node,如果它的key和我们要插入的key相同,说明已经存在,我们就将它的value更新为新值,并返回旧的value。

但是注意一点:如果原来的table中已经存在这个key,那么我们要把旧的value返回回去。

所以我们在这里需要记录一下旧的value:

V oldValue = node.value;
node.value = value;
return oldValue;

如果遍历结束都没有找到匹配的key,说明原来不存在,我们就需要创建一个新的Node对象:

Node newNode = new Node(key, value);
table.add(newNode);
return null; // 表示之前没有这个key

这样,put函数就写完了。

public V put(K key, V value) {
    for (Node<K, V> kvNode = table) {
        if (kvNode.key.equals(key)) {
            V oldValue = kvNode.value;
            kvNode.value = value;
            return oldValue;
        }
    }
    Node<K, V> newNode = new Node<>(key, value);
    table.add(newNode);
    return null;
}

第二步:实现 get

get的操作逻辑其实和上面遍历的逻辑是一样的。
我们直接复用刚才的思路:

遍历table中的每个Node,判断其key是否与传入的key相等。
如果相等,就直接返回这个Nodevalue

如果遍历完都没有找到,就返回null

代码如下:

for (Node node : table) {
    if (node.key.equals(key)) {
        return node.value;
    }
}
return null;

第三步:实现 remove

remove也是类似的道理。
我们需要从数组中找到那个key与传入key相同的Node

但这一次我们不能只拿到Node本身,还需要知道它的索引(角标),因为我们是要从ArrayList中移除它。

于是我们可以使用table.size()来作为循环条件,通过table.get(i)逐个访问元素:

for (int i = 0; i < table.size(); i++) {
    Node node = table.get(i);
    if (node.key.equals(key)) {
        // 找到了,执行 remove 操作
        Node removedNode = table.remove(i); // remove 返回被删除的元素
        return removedNode.value;
    }
}
return null;

这里补充一点关于ArrayListremove机制的知识。

假设我们有一个长度为5的数组:[1, 2, 3, 4, 5],初始时后面都是null
但实际上ArrayList内部维护了一个size字段,表示当前有效元素的数量。

比如我们只添加了三个元素 [1, 2, 3],那么虽然底层数组长度可能是5,但size=3,只有前三个位置是可访问的。这就是ArrayList的本质。

当我们调用remove(1)(也就是删除值为2的那个元素)时,不仅仅是把那个位置设为null那么简单,而是要把后面的每一个元素都往前移动一位:

原数组:[1, 2, 3, 4, null] → 删除后变成 [1, 3, 4, 4, null],然后size--变为2。

不过还有一个小技巧:只要size改小了,size之后的元素即使不置为null也没关系,因为我们只能访问前size个元素。

举个例子:数组是 [1, 2, 3, 4, null]size=4
如果我们删除index=1的元素(即2),则:

  • 把index=2的元素(3)移到index=1;
  • 把index=3的元素(4)移到index=2;
  • 最后size--变成3。

此时index=3的位置是否为null无所谓,因为超出size的部分不会被访问。

这部分关于动态数组的移动细节,我们之后可以在手写ArrayList的视频中详细展开。
这里只是稍微提一下。


现在回到我们的remove实现:

我们通过索引i找到目标节点,调用table.remove(i),这个方法会返回被移除的Node,然后我们返回它的value即可。

OK,remove完成。

public V remove(K key) {
    for (int i = 0; i < table.size(); i++) {
        Node<K, V> kvNode = table.get(i);
        if (kvNode.key.equals(key)) {
            Node<K, V> removedNode = this.table.remove(i);
            return removedNode.value;
        }
    }
    return null;
}

写完这三个函数之后,我们应该再给我们的HashMap提供一个查询当前键值对数量的功能。
也就是说,增加一个返回int类型的函数,叫做size()

这个函数很简单,直接返回当前tablesize()即可:

public int size() {
    return table.size();
}

这样我们就对外提供了一个API,可以让外部知道当前HashMap中有多少个键值对。


为了验证正确性,我写了一个测试用例。
我们先看一下这段代码:

class MyHashMapTest {
    @Test
    void testApi() {
        MyHashMap<String, String> myHashMap = new MyHashMap<>();
        int count = 10;
        for (int i = 0; i < count; i++) {
            myHashMap.put(String.valueOf(i), String.valueOf(i));
        }

        assertEquals(expected: 10, myHashMap.size());

        for (int i = 0; i < count; i++) {
            assertEquals(String.valueOf(i), myHashMap.get(String.valueOf(i)));
        }

        myHashMap.remove("8");
        assertNull(myHashMap.get("8"));

        assertEquals(expected: count - 1, myHashMap.size());
    }
}

如果这个测试通过了,说明我们的代码基本是正确的。

我们运行一下测试……发现出了问题!

报错提示是remove失败了。
我们点进去看remove方法:

发现问题出在这里:我们写了 node.equals(key),但实际上应该是 node.key.equals(key)!

我们比较的是Node的key和传入的key是否相等,而不是拿整个Node去和key比较。

事实上早已经修改了

修改为:

if (node.key.equals(key))

然后再重新跑测试——OK,这次通过了!

说明我们的基础逻辑是没有问题的。


现在我们把count改成1万:

int count = 10000;

再跑一次测试——仍然能通过,但明显变慢了。

我们再加一个数量级,变成10万:

int count = 100000;

这次执行效率非常差,甚至长时间无法结束。

我们手动终止程序。

这说明我们的性能存在严重问题。

我们重新审视一下逻辑:

每次put时都需要遍历整个table,以查找是否存在相同的key。
这意味着put的时间复杂度是O(n),其中n是当前table中元素的个数。

同理,getremove也都需要遍历,时间复杂度也是O(n)。

而且,这个类明明叫“HashMap”,却和“哈希”没有任何关系!
我们应该叫它ArrayMap更合适一些。

真正的HashMap应该利用数组支持随机访问(Random Access)的特性:
如果我们可以通过key快速计算出它在数组中的位置(角标),就不需要遍历整个数组了。


如何做到这一点?

我们可以使用key的hashCode()来决定它在table中的存放位置。

例如,假设当前数组长度是10,某个key的hashCode()是11,
我们用 11 % 10 = 1 得到余数1,于是就把这个key-value放到数组下标为1的位置。

这就是哈希函数的基本思想:通过哈希值定位存储位置

但这里有个问题:不同的key可能产生相同的余数,比如:

  • hashCode=11 → 11 % 10 = 1
  • hashCode=21 → 21 % 10 = 1
  • hashCode=31 → 31 % 10 = 1

这就产生了哈希冲突

你可能会想:那我们把数组弄得很大不就行了?比如长度50、100?

但这不可靠,原因有两个:

  1. hashCode()本身是int类型(32位),最大约42亿种不同值,但对象是无限的,根据鸽巢原理,必然会出现冲突。
  2. 即使容量足够大,也无法避免不同对象hashCode()相同的情况(如字符串"FB"和"EA"的hashCode()可能相同)。

因此,靠扩大数组容量来解决冲突是不可行的。

我们需要一种处理冲突的方法——拉链法(Separate Chaining)


下面我用一个动画来解释什么是拉链法:

想象一个长度为10的数组,角标从0到9。

来了一个key="Tom",它的hashCode()经过取模运算后得到1,于是放在table[1]

接着key="Jerry",计算得角标为2,放在table[2]

再来一个key="Leo",它的哈希取模结果也是1,发生冲突。

这时我们不再覆盖,而是让table[1]成为一个链表头,把"Tom"和"Leo"串成一个链表。

查找时,先算出角标1,再沿着链表依次比较key,直到找到匹配项。

这种方式就是拉链法:每个数组位置挂一个链表,用来存储所有哈希到该位置的键值对。


我们回到代码,实现这个思路。

首先,table不再需要用ArrayList<Node>了,而是改为真正的数组:

Node[] table;

初始化时创建一个Node数组:

table = new Node[10]; // 初始长度设为10

然后我们重写put方法。


put之前,我们需要一个辅助函数,用来计算key对应在数组中的角标:

private int indexOf(Object key) {
    return Math.abs(key.hashCode()) % table.length;
}

注意:hashCode()可能是负数,所以取绝对值。

拿到角标后,我们就可以定位到table[index],也就是链表的头节点。

如果头节点为空(null),说明当前位置还没有任何元素,我们直接创建新节点放上去:

int index = indexOf(key);
if (table[index] == null) {
    table[index] = new Node(key, value);
    return null;
}

但如果头节点不为空,说明已经有链表存在,我们需要遍历链表:

  • 如果发现某个节点的key与当前key相等,则更新其value,并返回旧value;
  • 如果遍历到链表末尾仍未找到,则将新节点插入链表尾部。

具体实现如下:

public V put(K key, V value) {
    int keyIndex = indexOf(key);
    Node<K, V> kvNode = table[keyIndex];
    if (kvNode == null) {
        table[keyIndex] = new Node<>(key, value);
        return null;
    }

    while (true) {
        if (kvNode.key.equals(key)) {
            V oldValue = kvNode.value;
            kvNode.value = value;
            return oldValue;
        }
        if (kvNode.next == null) {
            kvNode.next = new Node<>(key, value);
            return null;
        }
        kvNode = kvNode.next;
    }
}

这样,put就完成了。


接下来是get函数。

逻辑类似:先计算角标,获取链表头,然后遍历链表:

public V get(K key) {
    int keyIndex = indexOf(key);
    Node<K, V> head = table[keyIndex];
    while (head != null) {
        if (head.key.equals(key)) {
            return head.value;
        }
        head = head.next;
    }
    return null;
}

remove也差不多。

我们仍然先计算角标,拿到链表头。

先判断头节点是否为null,如果是,直接返回null。

如果不是,先判断头节点是不是要删除的key:

if (head.key.equals(key)) {
    table[index] = head.next; // 头节点移除,指向下一个
    return head.value;
}

如果不是头节点,就需要遍历链表,找到目标节点及其前驱:

Node prev = head;
Node curr = head.next;
while (curr != null) {
    if (curr.key.equals(key)) {
        prev.next = curr.next; // 跳过当前节点
        return curr.value;
    }
    prev = curr;
    curr = curr.next;
}
return null;
public V remove(K key) {
    int keyIndex = indexOf(key);
    if (head == null) {
        return null;
    }
    if (head.key.equals(key)) {
        table[keyIndex] = head.next;
        return head.value;
    }
    Node<K, V> pre = head;
    Node<K, V> current = head.next;
    while (current != null) {
        if (current.key.equals(key)) {
            pre.next = current.next;
            return current.value;
        }
        pre = current;
        current = current.next;
    }
    return null;
}

最后实现size()函数。

但现在不能再用table.length了,因为那是数组长度,不是实际元素个数。
我们需要自己维护一个size变量:

private int size = 0;

put成功插入新节点时(即未覆盖已有key),size++
remove成功删除节点时,size--

注意:只有真正新增节点才增加size,如果只是更新已有key的value,则不增加。


现在我们回到测试,尝试插入10万个元素。

虽然逻辑正确,但性能依然很慢。

我们分析原因:当前table长度是10,即使平均分配,每个链表也有约1万条数据。
查找、插入、删除都要遍历长约1万的链表,时间复杂度接近O(n),性能极差。

解决方案:动态扩容


我们设定扩容条件:当size >= table.length * 0.75时触发扩容。
这个0.75称为负载因子(load factor)。

我们写一个函数:

private void resizeIfNecessary() {
    if (this.size < table.length * 0.75) {
        return;
    }
    Node<K, V>[] newTable = new Node[this.table.length * 2];
    for (Node<K, V> head : this.table) {
        if (head == null) {
            continue;
        }
        Node<K, V> current = head;
        while (current != null) {
            int newIndex = current.key.hashCode() % newTable.length;
            if (newTable[newIndex] == null) {
                newTable[newIndex] = current;
                current.next = null;
                current = current.next;
            } else {
                Node<K, V> next = current.next;
                newTable[newIndex].next = current;
                newTable[newIndex] = current;
                current = next;
            }
        }
    }
}

这就是经典的头插法(Head Insertion):
将原链表中的节点逐个取出,插入到新数组对应位置的链表头部。

注意:头插法会导致链表顺序反转。

我们在每次put成功size++新增节点后调用resizeIfNecessary()


再次运行10万条数据的测试——这次仅耗时约30ms,远快于之前。

说明通过扩容打散链表,显著提升了性能。


进一步优化:使用位运算替代取模。

我们知道,在二进制中,只有1 & 1 = 1,其余都为0。

如果我们能让table.length是2的幂次(如16、32、64),那么table.length - 1的二进制低位全为1。

例如:

  • 16 → 二进制 10000
  • 15 → 二进制 01111

此时,hashCode & (table.length - 1) 等价于 hashCode % table.length,但速度更快。

所以我们修改初始化:

table = new Node[16]; // 改为2的幂

并修改indexOf方法:

private int indexOf(Object key) {
    return key.hashCode() & (newtable.length - 1);
}

注意:这里不再需要Math.abs(),因为负数的补码与运算也能得到合法索引。

同时在扩容时也要保持新长度为2的幂(oldLength * 2自然满足)。

修改完成后运行测试——依然通过,说明逻辑正确。


我们还可以加日志观察扩容过程:

System.out.println("当前扩容了,扩容到:" + newTable.length);

运行后可以看到:table从16不断翻倍,最终达到262144(约26万),说明经历了多次扩容。


但还有一个问题:即使扩容了,也不能保证每个链表都很短。

如果哈希分布极不均匀(例如大量key的hashCode集中在某几个值),就会导致某些链表特别长,其他为空。

这时搜索性能又退化为O(n)。

为了解决这个问题,JDK 8引入了红黑树优化
当某个链表长度超过阈值(默认8)且数组长度大于64时,链表转为红黑树,使查找变为O(log n)。

不过在这里我不打算带着大家手写红黑树,因为实现复杂且偏离主题。

但你要明白:仅靠扩容无法解决极端哈希冲突带来的性能问题


最后我们讨论一个经典问题:为什么HashMap在多线程环境下可能造成死循环?

关键就在于我们使用的头插法

假设有两个线程同时触发扩容:

线程A正在迁移某个链表(比如table[i]上的链表),它取出第一个节点node1,计算其在新表中的位置,准备将其插入新表头。

但在插入前,CPU切换到线程B,线程B完成了整个链表的迁移,使用头插法改变了链表顺序。

当线程A恢复执行时,它继续迁移剩下的节点,但由于头插法的指针操作,可能导致链表形成环形结构——某个节点的next指向了自己或前面的节点。

一旦形成环,任何对该链表的操作(如getput)都会陷入无限循环。

这就是JDK 7 HashMap在并发环境下可能出现死循环的原因。


为此,JDK 8将头插法改为尾插法

迁移时总是把节点插入链表尾部,保证顺序不变,避免成环。

但尾插法也有代价:必须遍历到链表末尾才能插入,性能略低。


我个人的看法是:

  1. 头插法实现简单、性能高:只需修改头指针,无需遍历。
  2. 顺序变化无关紧要:HashMap本就不保证遍历顺序。
  3. HashMap本身就是线程不安全的类:在多线程环境使用本身就是错误做法,出现问题是使用者的责任。

正因为少数人错误使用,JDK被迫牺牲所有用户的性能(改为尾插法),让所有人承担扩容成本。
这在我看来是一件非常讽刺的事。


这是今天的思考题:
你可以给HashMap增加一个遍历的函数吗?