{"id":492,"date":"2024-10-28T16:17:21","date_gmt":"2024-10-28T08:17:21","guid":{"rendered":"https:\/\/eve2333.top\/?p=492"},"modified":"2024-10-28T16:17:21","modified_gmt":"2024-10-28T08:17:21","slug":"resnet18%e5%ae%9e%e6%88%98%e6%90%ad%e5%bb%ba","status":"publish","type":"post","link":"https:\/\/eve2333.top\/?p=492","title":{"rendered":"ResNet18\u5b9e\u6218\u642d\u5efa"},"content":{"rendered":"\n<p>model.py<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\nfrom torch import nn\nfrom torchsummary import summary\n\n\n\nclass Residual(nn.Module):\n    def __init__(self, input_channels, num_channels, use_1conv=False, strides=1):\n        super(Residual, self).__init__()\n        self.ReLU = nn.ReLU()\n        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=num_channels, kernel_size=3, padding=1, stride=strides)\n        self.conv2 = nn.Conv2d(in_channels=num_channels,  out_channels=num_channels, kernel_size=3, padding=1)\n        self.bn1 = nn.BatchNorm2d(num_channels)\n        self.bn2 = nn.BatchNorm2d(num_channels)\n        if use_1conv:\n            self.conv3 = nn.Conv2d(in_channels=input_channels, out_channels=num_channels, kernel_size=1, stride=strides)\n        else:\n            self.conv3 = None\n    def forward(self, x):\n        y = self.ReLU(self.bn1(self.conv1(x)))\n        y = self.bn2(self.conv2(y))\n        if self.conv3:\n            x = self.conv3(x)\n        y = self.ReLU(y+x)\n        return y\n\nclass ResNet18(nn.Module):\n    def __init__(self, Residual):\n        super(ResNet18, self).__init__()\n        self.b1 = nn.Sequential(\n            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=2, padding=3),\n            nn.ReLU(),\n            nn.BatchNorm2d(64),\n            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n\n        self.b2 = nn.Sequential(Residual(64, 64, use_1conv=False, strides=1),\n                                Residual(64, 64, use_1conv=False, strides=1))\n\n        self.b3 = nn.Sequential(Residual(64, 128, use_1conv=True, strides=2),\n                                Residual(128, 128, use_1conv=False, strides=1))\n\n        self.b4 = nn.Sequential(Residual(128, 256, use_1conv=True, strides=2),\n                                Residual(256, 256, use_1conv=False, strides=1))\n\n        self.b5 = nn.Sequential(Residual(256, 512, use_1conv=True, strides=2),\n                                Residual(512, 512, use_1conv=False, strides=1))\n\n        self.b6 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),\n                                nn.Flatten(),\n                                nn.Linear(512, 10))\n\n\n\n    def forward(self, x):\n        x = self.b1(x)\n        x = self.b2(x)\n        x = self.b3(x)\n        x = self.b4(x)\n        x = self.b5(x)\n        x = self.b6(x)\n        return x\n\nif __name__ == \"__main__\":\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    model = ResNet18(Residual).to(device)\n    print(summary(model, (1, 224, 224)))\n\n\n\n\n\n<\/code><\/pre>\n\n\n\n<p>model_test.py<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\nimport torch.utils.data as Data\nfrom torchvision import transforms\nfrom torchvision.datasets import FashionMNIST\nfrom model import Residual, ResNet18\n\n\n\ndef test_data_process():\n    test_data = FashionMNIST(root='.\/data',\n                              train=False,\n                              transform=transforms.Compose(&#91;transforms.Resize(size=224), transforms.ToTensor()]),\n                              download=True)\n\n    test_dataloader = Data.DataLoader(dataset=test_data,\n                                       batch_size=1,\n                                       shuffle=True,\n                                       num_workers=0)\n    return test_dataloader\n\n\ndef test_model_process(model, test_dataloader):\n    # \u8bbe\u5b9a\u6d4b\u8bd5\u6240\u7528\u5230\u7684\u8bbe\u5907\uff0c\u6709GPU\u7528GPU\u6ca1\u6709GPU\u7528CPU\n    device = \"cuda\" if torch.cuda.is_available() else 'cpu'\n\n    # \u8bb2\u6a21\u578b\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n    model = model.to(device)\n\n    # \u521d\u59cb\u5316\u53c2\u6570\n    test_corrects = 0.0\n    test_num = 0\n\n    # \u53ea\u8fdb\u884c\u524d\u5411\u4f20\u64ad\u8ba1\u7b97\uff0c\u4e0d\u8ba1\u7b97\u68af\u5ea6\uff0c\u4ece\u800c\u8282\u7701\u5185\u5b58\uff0c\u52a0\u5feb\u8fd0\u884c\u901f\u5ea6\n    with torch.no_grad():\n        for test_data_x, test_data_y in test_dataloader:\n            # \u5c06\u7279\u5f81\u653e\u5165\u5230\u6d4b\u8bd5\u8bbe\u5907\u4e2d\n            test_data_x = test_data_x.to(device)\n            # \u5c06\u6807\u7b7e\u653e\u5165\u5230\u6d4b\u8bd5\u8bbe\u5907\u4e2d\n            test_data_y = test_data_y.to(device)\n            # \u8bbe\u7f6e\u6a21\u578b\u4e3a\u8bc4\u4f30\u6a21\u5f0f\n            model.eval()\n            # \u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u8f93\u5165\u4e3a\u6d4b\u8bd5\u6570\u636e\u96c6\uff0c\u8f93\u51fa\u4e3a\u5bf9\u6bcf\u4e2a\u6837\u672c\u7684\u9884\u6d4b\u503c\n            output= model(test_data_x)\n            # \u67e5\u627e\u6bcf\u4e00\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u884c\u6807\n            pre_lab = torch.argmax(output, dim=1)\n            # \u5982\u679c\u9884\u6d4b\u6b63\u786e\uff0c\u5219\u51c6\u786e\u5ea6test_corrects\u52a01\n            test_corrects += torch.sum(pre_lab == test_data_y.data)\n            # \u5c06\u6240\u6709\u7684\u6d4b\u8bd5\u6837\u672c\u8fdb\u884c\u7d2f\u52a0\n            test_num += test_data_x.size(0)\n\n    # \u8ba1\u7b97\u6d4b\u8bd5\u51c6\u786e\u7387\n    test_acc = test_corrects.double().item() \/ test_num\n    print(\"\u6d4b\u8bd5\u7684\u51c6\u786e\u7387\u4e3a\uff1a\", test_acc)\n\n\nif __name__ == \"__main__\":\n    # \u52a0\u8f7d\u6a21\u578b\n    model = ResNet18(Residual)\n    model.load_state_dict(torch.load('best_model.pth'))\n    # # \u5229\u7528\u73b0\u6709\u7684\u6a21\u578b\u8fdb\u884c\u6a21\u578b\u7684\u6d4b\u8bd5\n    test_dataloader = test_data_process()\n    test_model_process(model, test_dataloader)\n\n    # \u8bbe\u5b9a\u6d4b\u8bd5\u6240\u7528\u5230\u7684\u8bbe\u5907\uff0c\u6709GPU\u7528GPU\u6ca1\u6709GPU\u7528CPU\n    device = \"cuda\" if torch.cuda.is_available() else 'cpu'\n    model = model.to(device)\n\n    classes = &#91;'T-shirt\/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']\n    with torch.no_grad():\n        for b_x, b_y in test_dataloader:\n            b_x = b_x.to(device)\n            b_y = b_y.to(device)\n\n            # \u8bbe\u7f6e\u6a21\u578b\u4e3a\u9a8c\u8bc1\u6a21\u578b\n            model.eval()\n            output = model(b_x)\n            pre_lab = torch.argmax(output, dim=1)\n            result = pre_lab.item()\n            label = b_y.item()\n            print(\"\u9884\u6d4b\u503c\uff1a\",  classes&#91;result], \"------\", \"\u771f\u5b9e\u503c\uff1a\", classes&#91;label])\n\n\n\n\n\n\n\n\n<\/code><\/pre>\n\n\n\n<p>model.train.py<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import copy\nimport time\n\nimport torch\nfrom torchvision.datasets import FashionMNIST\nfrom torchvision import transforms\nimport torch.utils.data as Data\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom model import Residual, ResNet18\nimport torch.nn as nn\nimport pandas as pd\n\n\ndef train_val_data_process():\n    train_data = FashionMNIST(root='.\/data',\n                              train=True,\n                              transform=transforms.Compose(&#91;transforms.Resize(size=224), transforms.ToTensor()]),\n                              download=True)\n\n    train_data, val_data = Data.random_split(train_data, &#91;round(0.8*len(train_data)), round(0.2*len(train_data))])\n    train_dataloader = Data.DataLoader(dataset=train_data,\n                                       batch_size=32,\n                                       shuffle=True,\n                                       num_workers=2)\n\n    val_dataloader = Data.DataLoader(dataset=val_data,\n                                       batch_size=32,\n                                       shuffle=True,\n                                       num_workers=2)\n\n    return train_dataloader, val_dataloader\n\n\ndef train_model_process(model, train_dataloader, val_dataloader, num_epochs):\n    # \u8bbe\u5b9a\u8bad\u7ec3\u6240\u7528\u5230\u7684\u8bbe\u5907\uff0c\u6709GPU\u7528GPU\u6ca1\u6709GPU\u7528CPU\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    # \u4f7f\u7528Adam\u4f18\u5316\u5668\uff0c\u5b66\u4e60\u7387\u4e3a0.001\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n    # \u635f\u5931\u51fd\u6570\u4e3a\u4ea4\u53c9\u71b5\u51fd\u6570\n    criterion = nn.CrossEntropyLoss()\n    # \u5c06\u6a21\u578b\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n    model = model.to(device)\n    # \u590d\u5236\u5f53\u524d\u6a21\u578b\u7684\u53c2\u6570\n    best_model_wts = copy.deepcopy(model.state_dict())\n\n    # \u521d\u59cb\u5316\u53c2\u6570\n    # \u6700\u9ad8\u51c6\u786e\u5ea6\n    best_acc = 0.0\n    # \u8bad\u7ec3\u96c6\u635f\u5931\u5217\u8868\n    train_loss_all = &#91;]\n    # \u9a8c\u8bc1\u96c6\u635f\u5931\u5217\u8868\n    val_loss_all = &#91;]\n    # \u8bad\u7ec3\u96c6\u51c6\u786e\u5ea6\u5217\u8868\n    train_acc_all = &#91;]\n    # \u9a8c\u8bc1\u96c6\u51c6\u786e\u5ea6\u5217\u8868\n    val_acc_all = &#91;]\n    # \u5f53\u524d\u65f6\u95f4\n    since = time.time()\n\n    for epoch in range(num_epochs):\n        print(\"Epoch {}\/{}\".format(epoch, num_epochs-1))\n        print(\"-\"*10)\n\n        # \u521d\u59cb\u5316\u53c2\u6570\n        # \u8bad\u7ec3\u96c6\u635f\u5931\u51fd\u6570\n        train_loss = 0.0\n        # \u8bad\u7ec3\u96c6\u51c6\u786e\u5ea6\n        train_corrects = 0\n        # \u9a8c\u8bc1\u96c6\u635f\u5931\u51fd\u6570\n        val_loss = 0.0\n        # \u9a8c\u8bc1\u96c6\u51c6\u786e\u5ea6\n        val_corrects = 0\n        # \u8bad\u7ec3\u96c6\u6837\u672c\u6570\u91cf\n        train_num = 0\n        # \u9a8c\u8bc1\u96c6\u6837\u672c\u6570\u91cf\n        val_num = 0\n\n        # \u5bf9\u6bcf\u4e00\u4e2amini-batch\u8bad\u7ec3\u548c\u8ba1\u7b97\n        for step, (b_x, b_y) in enumerate(train_dataloader):\n            # \u5c06\u7279\u5f81\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n            b_x = b_x.to(device)\n            # \u5c06\u6807\u7b7e\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n            b_y = b_y.to(device)\n            # \u8bbe\u7f6e\u6a21\u578b\u4e3a\u8bad\u7ec3\u6a21\u5f0f\n            model.train()\n\n            # \u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u8f93\u5165\u4e3a\u4e00\u4e2abatch\uff0c\u8f93\u51fa\u4e3a\u4e00\u4e2abatch\u4e2d\u5bf9\u5e94\u7684\u9884\u6d4b\n            output = model(b_x)\n            # \u67e5\u627e\u6bcf\u4e00\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u884c\u6807\n            pre_lab = torch.argmax(output, dim=1)\n            # \u8ba1\u7b97\u6bcf\u4e00\u4e2abatch\u7684\u635f\u5931\u51fd\u6570\n            loss = criterion(output, b_y)\n\n            # \u5c06\u68af\u5ea6\u521d\u59cb\u5316\u4e3a0\n            optimizer.zero_grad()\n            # \u53cd\u5411\u4f20\u64ad\u8ba1\u7b97\n            loss.backward()\n            # \u6839\u636e\u7f51\u7edc\u53cd\u5411\u4f20\u64ad\u7684\u68af\u5ea6\u4fe1\u606f\u6765\u66f4\u65b0\u7f51\u7edc\u7684\u53c2\u6570\uff0c\u4ee5\u8d77\u5230\u964d\u4f4eloss\u51fd\u6570\u8ba1\u7b97\u503c\u7684\u4f5c\u7528\n            optimizer.step()\n            # \u5bf9\u635f\u5931\u51fd\u6570\u8fdb\u884c\u7d2f\u52a0\n            train_loss += loss.item() * b_x.size(0)\n            # \u5982\u679c\u9884\u6d4b\u6b63\u786e\uff0c\u5219\u51c6\u786e\u5ea6train_corrects\u52a01\n            train_corrects += torch.sum(pre_lab == b_y.data)\n            # \u5f53\u524d\u7528\u4e8e\u8bad\u7ec3\u7684\u6837\u672c\u6570\u91cf\n            train_num += b_x.size(0)\n        for step, (b_x, b_y) in enumerate(val_dataloader):\n            # \u5c06\u7279\u5f81\u653e\u5165\u5230\u9a8c\u8bc1\u8bbe\u5907\u4e2d\n            b_x = b_x.to(device)\n            # \u5c06\u6807\u7b7e\u653e\u5165\u5230\u9a8c\u8bc1\u8bbe\u5907\u4e2d\n            b_y = b_y.to(device)\n            # \u8bbe\u7f6e\u6a21\u578b\u4e3a\u8bc4\u4f30\u6a21\u5f0f\n            model.eval()\n            # \u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u8f93\u5165\u4e3a\u4e00\u4e2abatch\uff0c\u8f93\u51fa\u4e3a\u4e00\u4e2abatch\u4e2d\u5bf9\u5e94\u7684\u9884\u6d4b\n            output = model(b_x)\n            # \u67e5\u627e\u6bcf\u4e00\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u884c\u6807\n            pre_lab = torch.argmax(output, dim=1)\n            # \u8ba1\u7b97\u6bcf\u4e00\u4e2abatch\u7684\u635f\u5931\u51fd\u6570\n            loss = criterion(output, b_y)\n\n\n            # \u5bf9\u635f\u5931\u51fd\u6570\u8fdb\u884c\u7d2f\u52a0\n            val_loss += loss.item() * b_x.size(0)\n            # \u5982\u679c\u9884\u6d4b\u6b63\u786e\uff0c\u5219\u51c6\u786e\u5ea6train_corrects\u52a01\n            val_corrects += torch.sum(pre_lab == b_y.data)\n            # \u5f53\u524d\u7528\u4e8e\u9a8c\u8bc1\u7684\u6837\u672c\u6570\u91cf\n            val_num += b_x.size(0)\n\n        # \u8ba1\u7b97\u5e76\u4fdd\u5b58\u6bcf\u4e00\u6b21\u8fed\u4ee3\u7684loss\u503c\u548c\u51c6\u786e\u7387\n        # \u8ba1\u7b97\u5e76\u4fdd\u5b58\u8bad\u7ec3\u96c6\u7684loss\u503c\n        train_loss_all.append(train_loss \/ train_num)\n        # \u8ba1\u7b97\u5e76\u4fdd\u5b58\u8bad\u7ec3\u96c6\u7684\u51c6\u786e\u7387\n        train_acc_all.append(train_corrects.double().item() \/ train_num)\n\n        # \u8ba1\u7b97\u5e76\u4fdd\u5b58\u9a8c\u8bc1\u96c6\u7684loss\u503c\n        val_loss_all.append(val_loss \/ val_num)\n        # \u8ba1\u7b97\u5e76\u4fdd\u5b58\u9a8c\u8bc1\u96c6\u7684\u51c6\u786e\u7387\n        val_acc_all.append(val_corrects.double().item() \/ val_num)\n\n        print(\"{} train loss:{:.4f} train acc: {:.4f}\".format(epoch, train_loss_all&#91;-1], train_acc_all&#91;-1]))\n        print(\"{} val loss:{:.4f} val acc: {:.4f}\".format(epoch, val_loss_all&#91;-1], val_acc_all&#91;-1]))\n\n        if val_acc_all&#91;-1] > best_acc:\n            # \u4fdd\u5b58\u5f53\u524d\u6700\u9ad8\u51c6\u786e\u5ea6\n            best_acc = val_acc_all&#91;-1]\n            # \u4fdd\u5b58\u5f53\u524d\u6700\u9ad8\u51c6\u786e\u5ea6\u7684\u6a21\u578b\u53c2\u6570\n            best_model_wts = copy.deepcopy(model.state_dict())\n\n        # \u8ba1\u7b97\u8bad\u7ec3\u548c\u9a8c\u8bc1\u7684\u8017\u65f6\n        time_use = time.time() - since\n        print(\"\u8bad\u7ec3\u548c\u9a8c\u8bc1\u8017\u8d39\u7684\u65f6\u95f4{:.0f}m{:.0f}s\".format(time_use\/\/60, time_use%60))\n\n    # \u9009\u62e9\u6700\u4f18\u53c2\u6570\uff0c\u4fdd\u5b58\u6700\u4f18\u53c2\u6570\u7684\u6a21\u578b\n    model.load_state_dict(best_model_wts)\n    # torch.save(model.load_state_dict(best_model_wts), \"C:\/Users\/86159\/Desktop\/LeNet\/best_model.pth\")\n    torch.save(best_model_wts, \"C:\/Users\/86159\/Desktop\/ResNet18\/best_model.pth\")\n\n\n    train_process = pd.DataFrame(data={\"epoch\":range(num_epochs),\n                                       \"train_loss_all\":train_loss_all,\n                                       \"val_loss_all\":val_loss_all,\n                                       \"train_acc_all\":train_acc_all,\n                                       \"val_acc_all\":val_acc_all,})\n\n    return train_process\n\n\ndef matplot_acc_loss(train_process):\n    # \u663e\u793a\u6bcf\u4e00\u6b21\u8fed\u4ee3\u540e\u7684\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u7684\u635f\u5931\u51fd\u6570\u548c\u51c6\u786e\u7387\n    plt.figure(figsize=(12, 4))\n    plt.subplot(1, 2, 1)\n    plt.plot(train_process&#91;'epoch'], train_process.train_loss_all, \"ro-\", label=\"Train loss\")\n    plt.plot(train_process&#91;'epoch'], train_process.val_loss_all, \"bs-\", label=\"Val loss\")\n    plt.legend()\n    plt.xlabel(\"epoch\")\n    plt.ylabel(\"Loss\")\n    plt.subplot(1, 2, 2)\n    plt.plot(train_process&#91;'epoch'], train_process.train_acc_all, \"ro-\", label=\"Train acc\")\n    plt.plot(train_process&#91;'epoch'], train_process.val_acc_all, \"bs-\", label=\"Val acc\")\n    plt.xlabel(\"epoch\")\n    plt.ylabel(\"acc\")\n    plt.legend()\n    plt.show()\n\n\nif __name__ == '__main__':\n    # \u52a0\u8f7d\u9700\u8981\u7684\u6a21\u578b\n    ResNet18 = ResNet18(Residual)\n    # \u52a0\u8f7d\u6570\u636e\u96c6\n    train_data, val_data = train_val_data_process()\n    # \u5229\u7528\u73b0\u6709\u7684\u6a21\u578b\u8fdb\u884c\u6a21\u578b\u7684\u8bad\u7ec3\n    train_process = train_model_process(ResNet18, train_data, val_data, num_epochs=20)\n    matplot_acc_loss(train_process)<\/code><\/pre>\n","protected":false},"excerpt":{"rendered":"<p>model.py model_test.py model.train.py<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"emotion":"","emotion_color":"","title_style":"","license":"","footnotes":""},"categories":[3],"tags":[6,9],"class_list":["post-492","post","type-post","status-publish","format-standard","hentry","category-3","tag-cv","tag-9"],"_links":{"self":[{"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/posts\/492","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/eve2333.top\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=492"}],"version-history":[{"count":0,"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/posts\/492\/revisions"}],"wp:attachment":[{"href":"https:\/\/eve2333.top\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=492"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/eve2333.top\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=492"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/eve2333.top\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=492"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}