{"id":494,"date":"2024-10-28T16:22:08","date_gmt":"2024-10-28T08:22:08","guid":{"rendered":"https:\/\/eve2333.top\/?p=494"},"modified":"2024-10-28T16:22:08","modified_gmt":"2024-10-28T08:22:08","slug":"resnet18%e5%ae%9e%e6%88%98%e4%bd%a9%e6%88%b4%e5%8f%a3%e7%bd%a9","status":"publish","type":"post","link":"https:\/\/eve2333.top\/?p=494","title":{"rendered":"ResNet18\u5b9e\u6218\u4f69\u6234\u53e3\u7f69"},"content":{"rendered":"\n<p>data_partitioning.py<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import os\nfrom shutil import copy\nimport random\n\n\ndef mkfile(file):\n    if not os.path.exists(file):\n        os.makedirs(file)\n\n\n# \u83b7\u53d6data\u6587\u4ef6\u5939\u4e0b\u6240\u6709\u6587\u4ef6\u5939\u540d\uff08\u5373\u9700\u8981\u5206\u7c7b\u7684\u7c7b\u540d\uff09\nfile_path = 'dataset'\nflower_class = &#91;cla for cla in os.listdir(file_path)]\n\n# \u521b\u5efa \u8bad\u7ec3\u96c6train \u6587\u4ef6\u5939\uff0c\u5e76\u7531\u7c7b\u540d\u5728\u5176\u76ee\u5f55\u4e0b\u521b\u5efa5\u4e2a\u5b50\u76ee\u5f55\nmkfile('data\/train')\nfor cla in flower_class:\n    mkfile('data\/train\/' + cla)\n\n# \u521b\u5efa \u9a8c\u8bc1\u96c6val \u6587\u4ef6\u5939\uff0c\u5e76\u7531\u7c7b\u540d\u5728\u5176\u76ee\u5f55\u4e0b\u521b\u5efa\u5b50\u76ee\u5f55\nmkfile('data\/test')\nfor cla in flower_class:\n    mkfile('data\/test\/' + cla)\n\n# \u5212\u5206\u6bd4\u4f8b\uff0c\u8bad\u7ec3\u96c6 : \u6d4b\u8bd5\u96c6 = 9 : 1\nsplit_rate = 0.1\n\n# \u904d\u5386\u6240\u6709\u7c7b\u522b\u7684\u5168\u90e8\u56fe\u50cf\u5e76\u6309\u6bd4\u4f8b\u5206\u6210\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\nfor cla in flower_class:\n    cla_path = file_path + '\/' + cla + '\/'  # \u67d0\u4e00\u7c7b\u522b\u7684\u5b50\u76ee\u5f55\n    images = os.listdir(cla_path)  # iamges \u5217\u8868\u5b58\u50a8\u4e86\u8be5\u76ee\u5f55\u4e0b\u6240\u6709\u56fe\u50cf\u7684\u540d\u79f0\n    num = len(images)\n    eval_index = random.sample(images, k=int(num * split_rate))  # \u4eceimages\u5217\u8868\u4e2d\u968f\u673a\u62bd\u53d6 k \u4e2a\u56fe\u50cf\u540d\u79f0\n    for index, image in enumerate(images):\n        # eval_index \u4e2d\u4fdd\u5b58\u9a8c\u8bc1\u96c6val\u7684\u56fe\u50cf\u540d\u79f0\n        if image in eval_index:\n            image_path = cla_path + image\n            new_path = 'data\/test\/' + cla\n            copy(image_path, new_path)  # \u5c06\u9009\u4e2d\u7684\u56fe\u50cf\u590d\u5236\u5230\u65b0\u8def\u5f84\n\n        # \u5176\u4f59\u7684\u56fe\u50cf\u4fdd\u5b58\u5728\u8bad\u7ec3\u96c6train\u4e2d\n        else:\n            image_path = cla_path + image\n            new_path = 'data\/train\/' + cla\n            copy(image_path, new_path)\n        print(\"\\r&#91;{}] processing &#91;{}\/{}]\".format(cla, index + 1, num), end=\"\")  # processing bar\n    print()\n\nprint(\"processing done!\")<\/code><\/pre>\n\n\n\n<p>mean_std.py<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>from PIL import Image\nimport os\nimport numpy as np\n\n# \u6587\u4ef6\u5939\u8def\u5f84\uff0c\u5305\u542b\u6240\u6709\u56fe\u7247\u6587\u4ef6\nfolder_path = 'dataset'\n\n# \u521d\u59cb\u5316\u7d2f\u79ef\u53d8\u91cf\ntotal_pixels = 0\nsum_normalized_pixel_values = np.zeros(3)  # \u5982\u679c\u662fRGB\u56fe\u50cf\uff0c\u9700\u8981\u4e09\u4e2a\u901a\u9053\u7684\u5747\u503c\u548c\u65b9\u5dee\n\n# \u904d\u5386\u6587\u4ef6\u5939\u4e2d\u7684\u56fe\u7247\u6587\u4ef6\nfor root, dirs, files in os.walk(folder_path):\n    for filename in files:\n        if filename.endswith(('.jpg', '.jpeg', '.png', '.bmp')):  # \u53ef\u6839\u636e\u5b9e\u9645\u60c5\u51b5\u6dfb\u52a0\u5176\u4ed6\u683c\u5f0f\n            image_path = os.path.join(root, filename)\n            image = Image.open(image_path)\n            image_array = np.array(image)\n\n            # \u5f52\u4e00\u5316\u50cf\u7d20\u503c\u52300-1\u4e4b\u95f4\n            normalized_image_array = image_array \/ 255.0\n\n            # print(image_path)\n            # print(normalized_image_array.shape)\n            # \u7d2f\u79ef\u5f52\u4e00\u5316\u540e\u7684\u50cf\u7d20\u503c\u548c\u50cf\u7d20\u6570\u91cf\n            total_pixels += normalized_image_array.size\n            sum_normalized_pixel_values += np.sum(normalized_image_array, axis=(0, 1))\n\n# \u8ba1\u7b97\u5747\u503c\u548c\u65b9\u5dee\nmean = sum_normalized_pixel_values \/ total_pixels\n\n\nsum_squared_diff = np.zeros(3)\nfor root, dirs, files in os.walk(folder_path):\n    for filename in files:\n        if filename.endswith(('.jpg', '.jpeg', '.png', '.bmp')):\n            image_path = os.path.join(root, filename)\n            image = Image.open(image_path)\n            image_array = np.array(image)\n            # \u5f52\u4e00\u5316\u50cf\u7d20\u503c\u52300-1\u4e4b\u95f4\n            normalized_image_array = image_array \/ 255.0\n            # print(normalized_image_array.shape)\n            # print(mean.shape)\n            # print(image_path)\n\n            try:\n                diff = (normalized_image_array - mean) ** 2\n                sum_squared_diff += np.sum(diff, axis=(0, 1))\n            except:\n                print(f\"\u6355\u83b7\u5230\u81ea\u5b9a\u4e49\u5f02\u5e38\")\n            # diff = (normalized_image_array - mean) ** 2\n            # sum_squared_diff += np.sum(diff, axis=(0, 1))\n\nvariance = sum_squared_diff \/ total_pixels\n\nprint(\"Mean:\", mean)\nprint(\"Variance:\", variance)\n<\/code><\/pre>\n\n\n\n<p>model.py<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\nfrom torch import nn\nfrom torchsummary import summary\n\n\n\nclass Residual(nn.Module):\n    def __init__(self, input_channels, num_channels, use_1conv=False, strides=1):\n        super(Residual, self).__init__()\n        self.ReLU = nn.ReLU()\n        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=num_channels, kernel_size=3, padding=1, stride=strides)\n        self.conv2 = nn.Conv2d(in_channels=num_channels,  out_channels=num_channels, kernel_size=3, padding=1)\n        self.bn1 = nn.BatchNorm2d(num_channels)\n        self.bn2 = nn.BatchNorm2d(num_channels)\n        if use_1conv:\n            self.conv3 = nn.Conv2d(in_channels=input_channels, out_channels=num_channels, kernel_size=1, stride=strides)\n        else:\n            self.conv3 = None\n    def forward(self, x):\n        y = self.ReLU(self.bn1(self.conv1(x)))\n        y = self.bn2(self.conv2(y))\n        if self.conv3:\n            x = self.conv3(x)\n        y = self.ReLU(y+x)\n        return y\n\nclass ResNet18(nn.Module):\n    def __init__(self, Residual):\n        super(ResNet18, self).__init__()\n        self.b1 = nn.Sequential(\n            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),\n            nn.ReLU(),\n            nn.BatchNorm2d(64),\n            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n\n        self.b2 = nn.Sequential(Residual(64, 64, use_1conv=False, strides=1),\n                                Residual(64, 64, use_1conv=False, strides=1))\n\n        self.b3 = nn.Sequential(Residual(64, 128, use_1conv=True, strides=2),\n                                Residual(128, 128, use_1conv=False, strides=1))\n\n        self.b4 = nn.Sequential(Residual(128, 256, use_1conv=True, strides=2),\n                                Residual(256, 256, use_1conv=False, strides=1))\n\n        self.b5 = nn.Sequential(Residual(256, 512, use_1conv=True, strides=2),\n                                Residual(512, 512, use_1conv=False, strides=1))\n\n        self.b6 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),\n                                nn.Flatten(),\n                                nn.Linear(512, 2))\n\n\n\n    def forward(self, x):\n        x = self.b1(x)\n        x = self.b2(x)\n        x = self.b3(x)\n        x = self.b4(x)\n        x = self.b5(x)\n        x = self.b6(x)\n        return x\n\nif __name__ == \"__main__\":\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    model = ResNet18(Residual).to(device)\n    print(summary(model, (1, 224, 224)))\n\n\n\n\n\n<\/code><\/pre>\n\n\n\n<p>model_test.py<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\nimport torch.utils.data as Data\nfrom torchvision import transforms\nfrom torchvision.datasets import FashionMNIST\nfrom model import ResNet18, Residual\nfrom torchvision.datasets import ImageFolder\nfrom PIL import Image\n\ndef test_data_process():\n    # \u5b9a\u4e49\u6570\u636e\u96c6\u7684\u8def\u5f84\n    ROOT_TRAIN = r'data\\test'\n\n    normalize = transforms.Normalize(&#91;0.17263485, 0.15147247, 0.14267451], &#91;0.0736155,  0.06216329, 0.05930814])\n    # \u5b9a\u4e49\u6570\u636e\u96c6\u5904\u7406\u65b9\u6cd5\u53d8\u91cf\n    test_transform = transforms.Compose(&#91;transforms.Resize((224, 224)), transforms.ToTensor(), normalize])\n    # \u52a0\u8f7d\u6570\u636e\u96c6\n    test_data = ImageFolder(ROOT_TRAIN, transform=test_transform)\n\n    test_dataloader = Data.DataLoader(dataset=test_data,\n                                       batch_size=1,\n                                       shuffle=True,\n                                       num_workers=0)\n    return test_dataloader\n\n\ndef test_model_process(model, test_dataloader):\n    # \u8bbe\u5b9a\u6d4b\u8bd5\u6240\u7528\u5230\u7684\u8bbe\u5907\uff0c\u6709GPU\u7528GPU\u6ca1\u6709GPU\u7528CPU\n    device = \"cuda\" if torch.cuda.is_available() else 'cpu'\n\n    # \u8bb2\u6a21\u578b\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n    model = model.to(device)\n\n    # \u521d\u59cb\u5316\u53c2\u6570\n    test_corrects = 0.0\n    test_num = 0\n\n    # \u53ea\u8fdb\u884c\u524d\u5411\u4f20\u64ad\u8ba1\u7b97\uff0c\u4e0d\u8ba1\u7b97\u68af\u5ea6\uff0c\u4ece\u800c\u8282\u7701\u5185\u5b58\uff0c\u52a0\u5feb\u8fd0\u884c\u901f\u5ea6\n    with torch.no_grad():\n        for test_data_x, test_data_y in test_dataloader:\n            # \u5c06\u7279\u5f81\u653e\u5165\u5230\u6d4b\u8bd5\u8bbe\u5907\u4e2d\n            test_data_x = test_data_x.to(device)\n            # \u5c06\u6807\u7b7e\u653e\u5165\u5230\u6d4b\u8bd5\u8bbe\u5907\u4e2d\n            test_data_y = test_data_y.to(device)\n            # \u8bbe\u7f6e\u6a21\u578b\u4e3a\u8bc4\u4f30\u6a21\u5f0f\n            model.eval()\n            # \u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u8f93\u5165\u4e3a\u6d4b\u8bd5\u6570\u636e\u96c6\uff0c\u8f93\u51fa\u4e3a\u5bf9\u6bcf\u4e2a\u6837\u672c\u7684\u9884\u6d4b\u503c\n            output= model(test_data_x)\n            # \u67e5\u627e\u6bcf\u4e00\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u884c\u6807\n            pre_lab = torch.argmax(output, dim=1)\n            # \u5982\u679c\u9884\u6d4b\u6b63\u786e\uff0c\u5219\u51c6\u786e\u5ea6test_corrects\u52a01\n            test_corrects += torch.sum(pre_lab == test_data_y.data)\n            # \u5c06\u6240\u6709\u7684\u6d4b\u8bd5\u6837\u672c\u8fdb\u884c\u7d2f\u52a0\n            test_num += test_data_x.size(0)\n\n    # \u8ba1\u7b97\u6d4b\u8bd5\u51c6\u786e\u7387\n    test_acc = test_corrects.double().item() \/ test_num\n    print(\"\u6d4b\u8bd5\u7684\u51c6\u786e\u7387\u4e3a\uff1a\", test_acc)\n\n\nif __name__ == \"__main__\":\n    # \u52a0\u8f7d\u6a21\u578b\n    model = ResNet18(Residual)\n    model.load_state_dict(torch.load('best_model.pth'))\n    # # \u5229\u7528\u73b0\u6709\u7684\u6a21\u578b\u8fdb\u884c\u6a21\u578b\u7684\u6d4b\u8bd5\n    test_dataloader = test_data_process()\n    test_model_process(model, test_dataloader)\n\n    # \u8bbe\u5b9a\u6d4b\u8bd5\u6240\u7528\u5230\u7684\u8bbe\u5907\uff0c\u6709GPU\u7528GPU\u6ca1\u6709GPU\u7528CPU\n    device = \"cuda\" if torch.cuda.is_available() else 'cpu'\n    model = model.to(device)\n    classes = &#91;'\u6234\u53e3\u7f69', '\u4e0d\u5e26\u53e3\u7f69']\n    with torch.no_grad():\n        for b_x, b_y in test_dataloader:\n            b_x = b_x.to(device)\n            b_y = b_y.to(device)\n\n            # \u8bbe\u7f6e\u6a21\u578b\u4e3a\u9a8c\u8bc1\u6a21\u578b\n            model.eval()\n            output = model(b_x)\n            pre_lab = torch.argmax(output, dim=1)\n            result = pre_lab.item()\n            label = b_y.item()\n            print(\"\u9884\u6d4b\u503c\uff1a\",  classes&#91;result], \"------\", \"\u771f\u5b9e\u503c\uff1a\", classes&#91;label])\n\n\n    image = Image.open('no_mask.jfif')\n\n    normalize = transforms.Normalize(&#91;0.17263485, 0.15147247, 0.14267451], &#91;0.0736155,  0.06216329, 0.05930814])\n    # \u5b9a\u4e49\u6570\u636e\u96c6\u5904\u7406\u65b9\u6cd5\u53d8\u91cf\n    test_transform = transforms.Compose(&#91;transforms.Resize((224, 224)), transforms.ToTensor(), normalize])\n    image = test_transform(image)\n\n    # \u6dfb\u52a0\u6279\u6b21\u7ef4\u5ea6\n    image = image.unsqueeze(0)\n\n    with torch.no_grad():\n        model.eval()\n        image = image.to(device)\n        output = model(image)\n        pre_lab = torch.argmax(output, dim=1)\n        result = pre_lab.item()\n    print(\"\u9884\u6d4b\u503c\uff1a\",  classes&#91;result])\n\n\n\n<\/code><\/pre>\n\n\n\n<p>model_train.py<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import copy\nimport time\n\nimport torch\nfrom torchvision.datasets import ImageFolder\nfrom torchvision import transforms\nimport torch.utils.data as Data\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom model import ResNet18, Residual\nimport torch.nn as nn\nimport pandas as pd\n\ndef train_val_data_process():\n    # \u5b9a\u4e49\u6570\u636e\u96c6\u7684\u8def\u5f84\n    ROOT_TRAIN = r'data\\train'\n\n    normalize = transforms.Normalize(&#91;0.17263485, 0.15147247, 0.14267451], &#91;0.0736155,  0.06216329, 0.05930814])\n    # \u5b9a\u4e49\u6570\u636e\u96c6\u5904\u7406\u65b9\u6cd5\u53d8\u91cf\n    train_transform = transforms.Compose(&#91;transforms.Resize((224, 224)), transforms.ToTensor(), normalize])\n    # \u52a0\u8f7d\u6570\u636e\u96c6\n    train_data = ImageFolder(ROOT_TRAIN, transform=train_transform)\n\n    train_data, val_data = Data.random_split(train_data, &#91;round(0.8*len(train_data)), round(0.2*len(train_data))])\n    train_dataloader = Data.DataLoader(dataset=train_data,\n                                       batch_size=32,\n                                       shuffle=True,\n                                       num_workers=2)\n\n    val_dataloader = Data.DataLoader(dataset=val_data,\n                                       batch_size=32,\n                                       shuffle=True,\n                                       num_workers=2)\n\n    return train_dataloader, val_dataloader\n\n\n\ndef train_model_process(model, train_dataloader, val_dataloader, num_epochs):\n    # \u8bbe\u5b9a\u8bad\u7ec3\u6240\u7528\u5230\u7684\u8bbe\u5907\uff0c\u6709GPU\u7528GPU\u6ca1\u6709GPU\u7528CPU\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    # \u4f7f\u7528Adam\u4f18\u5316\u5668\uff0c\u5b66\u4e60\u7387\u4e3a0.001\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n    # \u635f\u5931\u51fd\u6570\u4e3a\u4ea4\u53c9\u71b5\u51fd\u6570\n    criterion = nn.CrossEntropyLoss()\n    # \u5c06\u6a21\u578b\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n    model = model.to(device)\n    # \u590d\u5236\u5f53\u524d\u6a21\u578b\u7684\u53c2\u6570\n    best_model_wts = copy.deepcopy(model.state_dict())\n\n    # \u521d\u59cb\u5316\u53c2\u6570\n    # \u6700\u9ad8\u51c6\u786e\u5ea6\n    best_acc = 0.0\n    # \u8bad\u7ec3\u96c6\u635f\u5931\u5217\u8868\n    train_loss_all = &#91;]\n    # \u9a8c\u8bc1\u96c6\u635f\u5931\u5217\u8868\n    val_loss_all = &#91;]\n    # \u8bad\u7ec3\u96c6\u51c6\u786e\u5ea6\u5217\u8868\n    train_acc_all = &#91;]\n    # \u9a8c\u8bc1\u96c6\u51c6\u786e\u5ea6\u5217\u8868\n    val_acc_all = &#91;]\n    # \u5f53\u524d\u65f6\u95f4\n    since = time.time()\n\n    for epoch in range(num_epochs):\n        print(\"Epoch {}\/{}\".format(epoch, num_epochs-1))\n        print(\"-\"*10)\n\n        # \u521d\u59cb\u5316\u53c2\u6570\n        # \u8bad\u7ec3\u96c6\u635f\u5931\u51fd\u6570\n        train_loss = 0.0\n        # \u8bad\u7ec3\u96c6\u51c6\u786e\u5ea6\n        train_corrects = 0\n        # \u9a8c\u8bc1\u96c6\u635f\u5931\u51fd\u6570\n        val_loss = 0.0\n        # \u9a8c\u8bc1\u96c6\u51c6\u786e\u5ea6\n        val_corrects = 0\n        # \u8bad\u7ec3\u96c6\u6837\u672c\u6570\u91cf\n        train_num = 0\n        # \u9a8c\u8bc1\u96c6\u6837\u672c\u6570\u91cf\n        val_num = 0\n\n        # \u5bf9\u6bcf\u4e00\u4e2amini-batch\u8bad\u7ec3\u548c\u8ba1\u7b97\n        for step, (b_x, b_y) in enumerate(train_dataloader):\n            # \u5c06\u7279\u5f81\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n            b_x = b_x.to(device)\n            # \u5c06\u6807\u7b7e\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n            b_y = b_y.to(device)\n            # \u8bbe\u7f6e\u6a21\u578b\u4e3a\u8bad\u7ec3\u6a21\u5f0f\n            model.train()\n\n            # \u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u8f93\u5165\u4e3a\u4e00\u4e2abatch\uff0c\u8f93\u51fa\u4e3a\u4e00\u4e2abatch\u4e2d\u5bf9\u5e94\u7684\u9884\u6d4b\n            output = model(b_x)\n            # \u67e5\u627e\u6bcf\u4e00\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u884c\u6807\n            pre_lab = torch.argmax(output, dim=1)\n            # \u8ba1\u7b97\u6bcf\u4e00\u4e2abatch\u7684\u635f\u5931\u51fd\u6570\n            loss = criterion(output, b_y)\n\n            # \u5c06\u68af\u5ea6\u521d\u59cb\u5316\u4e3a0\n            optimizer.zero_grad()\n            # \u53cd\u5411\u4f20\u64ad\u8ba1\u7b97\n            loss.backward()\n            # \u6839\u636e\u7f51\u7edc\u53cd\u5411\u4f20\u64ad\u7684\u68af\u5ea6\u4fe1\u606f\u6765\u66f4\u65b0\u7f51\u7edc\u7684\u53c2\u6570\uff0c\u4ee5\u8d77\u5230\u964d\u4f4eloss\u51fd\u6570\u8ba1\u7b97\u503c\u7684\u4f5c\u7528\n            optimizer.step()\n            # \u5bf9\u635f\u5931\u51fd\u6570\u8fdb\u884c\u7d2f\u52a0\n            train_loss += loss.item() * b_x.size(0)\n            # \u5982\u679c\u9884\u6d4b\u6b63\u786e\uff0c\u5219\u51c6\u786e\u5ea6train_corrects\u52a01\n            train_corrects += torch.sum(pre_lab == b_y.data)\n            # \u5f53\u524d\u7528\u4e8e\u8bad\u7ec3\u7684\u6837\u672c\u6570\u91cf\n            train_num += b_x.size(0)\n        for step, (b_x, b_y) in enumerate(val_dataloader):\n            # \u5c06\u7279\u5f81\u653e\u5165\u5230\u9a8c\u8bc1\u8bbe\u5907\u4e2d\n            b_x = b_x.to(device)\n            # \u5c06\u6807\u7b7e\u653e\u5165\u5230\u9a8c\u8bc1\u8bbe\u5907\u4e2d\n            b_y = b_y.to(device)\n            # \u8bbe\u7f6e\u6a21\u578b\u4e3a\u8bc4\u4f30\u6a21\u5f0f\n            model.eval()\n            # \u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u8f93\u5165\u4e3a\u4e00\u4e2abatch\uff0c\u8f93\u51fa\u4e3a\u4e00\u4e2abatch\u4e2d\u5bf9\u5e94\u7684\u9884\u6d4b\n            output = model(b_x)\n            # \u67e5\u627e\u6bcf\u4e00\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u884c\u6807\n            pre_lab = torch.argmax(output, dim=1)\n            # \u8ba1\u7b97\u6bcf\u4e00\u4e2abatch\u7684\u635f\u5931\u51fd\u6570\n            loss = criterion(output, b_y)\n\n\n            # \u5bf9\u635f\u5931\u51fd\u6570\u8fdb\u884c\u7d2f\u52a0\n            val_loss += loss.item() * b_x.size(0)\n            # \u5982\u679c\u9884\u6d4b\u6b63\u786e\uff0c\u5219\u51c6\u786e\u5ea6train_corrects\u52a01\n            val_corrects += torch.sum(pre_lab == b_y.data)\n            # \u5f53\u524d\u7528\u4e8e\u9a8c\u8bc1\u7684\u6837\u672c\u6570\u91cf\n            val_num += b_x.size(0)\n\n        # \u8ba1\u7b97\u5e76\u4fdd\u5b58\u6bcf\u4e00\u6b21\u8fed\u4ee3\u7684loss\u503c\u548c\u51c6\u786e\u7387\n        # \u8ba1\u7b97\u5e76\u4fdd\u5b58\u8bad\u7ec3\u96c6\u7684loss\u503c\n        train_loss_all.append(train_loss \/ train_num)\n        # \u8ba1\u7b97\u5e76\u4fdd\u5b58\u8bad\u7ec3\u96c6\u7684\u51c6\u786e\u7387\n        train_acc_all.append(train_corrects.double().item() \/ train_num)\n\n        # \u8ba1\u7b97\u5e76\u4fdd\u5b58\u9a8c\u8bc1\u96c6\u7684loss\u503c\n        val_loss_all.append(val_loss \/ val_num)\n        # \u8ba1\u7b97\u5e76\u4fdd\u5b58\u9a8c\u8bc1\u96c6\u7684\u51c6\u786e\u7387\n        val_acc_all.append(val_corrects.double().item() \/ val_num)\n\n        print(\"{} train loss:{:.4f} train acc: {:.4f}\".format(epoch, train_loss_all&#91;-1], train_acc_all&#91;-1]))\n        print(\"{} val loss:{:.4f} val acc: {:.4f}\".format(epoch, val_loss_all&#91;-1], val_acc_all&#91;-1]))\n\n        if val_acc_all&#91;-1] > best_acc:\n            # \u4fdd\u5b58\u5f53\u524d\u6700\u9ad8\u51c6\u786e\u5ea6\n            best_acc = val_acc_all&#91;-1]\n            # \u4fdd\u5b58\u5f53\u524d\u6700\u9ad8\u51c6\u786e\u5ea6\u7684\u6a21\u578b\u53c2\u6570\n            best_model_wts = copy.deepcopy(model.state_dict())\n\n        # \u8ba1\u7b97\u8bad\u7ec3\u548c\u9a8c\u8bc1\u7684\u8017\u65f6\n        time_use = time.time() - since\n        print(\"\u8bad\u7ec3\u548c\u9a8c\u8bc1\u8017\u8d39\u7684\u65f6\u95f4{:.0f}m{:.0f}s\".format(time_use\/\/60, time_use%60))\n\n    # \u9009\u62e9\u6700\u4f18\u53c2\u6570\uff0c\u4fdd\u5b58\u6700\u4f18\u53c2\u6570\u7684\u6a21\u578b\n    model.load_state_dict(best_model_wts)\n    # torch.save(model.load_state_dict(best_model_wts), \"C:\/Users\/86159\/Desktop\/LeNet\/best_model.pth\")\n    torch.save(best_model_wts, \"C:\/Users\/86159\/Desktop\/ResNet18-1\/best_model.pth\")\n\n\n    train_process = pd.DataFrame(data={\"epoch\":range(num_epochs),\n                                       \"train_loss_all\":train_loss_all,\n                                       \"val_loss_all\":val_loss_all,\n                                       \"train_acc_all\":train_acc_all,\n                                       \"val_acc_all\":val_acc_all,})\n\n    return train_process\n\n\ndef matplot_acc_loss(train_process):\n    # \u663e\u793a\u6bcf\u4e00\u6b21\u8fed\u4ee3\u540e\u7684\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u7684\u635f\u5931\u51fd\u6570\u548c\u51c6\u786e\u7387\n    plt.figure(figsize=(12, 4))\n    plt.subplot(1, 2, 1)\n    plt.plot(train_process&#91;'epoch'], train_process.train_loss_all, \"ro-\", label=\"Train loss\")\n    plt.plot(train_process&#91;'epoch'], train_process.val_loss_all, \"bs-\", label=\"Val loss\")\n    plt.legend()\n    plt.xlabel(\"epoch\")\n    plt.ylabel(\"Loss\")\n    plt.subplot(1, 2, 2)\n    plt.plot(train_process&#91;'epoch'], train_process.train_acc_all, \"ro-\", label=\"Train acc\")\n    plt.plot(train_process&#91;'epoch'], train_process.val_acc_all, \"bs-\", label=\"Val acc\")\n    plt.xlabel(\"epoch\")\n    plt.ylabel(\"acc\")\n    plt.legend()\n    plt.show()\n\n\nif __name__ == '__main__':\n    # \u52a0\u8f7d\u9700\u8981\u7684\u6a21\u578b\n    ResNet18 = ResNet18(Residual)\n    # \u52a0\u8f7d\u6570\u636e\u96c6\n    train_data, val_data = train_val_data_process()\n    # \u5229\u7528\u73b0\u6709\u7684\u6a21\u578b\u8fdb\u884c\u6a21\u578b\u7684\u8bad\u7ec3\n    train_process = train_model_process(ResNet18, train_data, val_data, num_epochs=50)\n    matplot_acc_loss(train_process)<\/code><\/pre>\n","protected":false},"excerpt":{"rendered":"<p>data_partitioning.py mean_std.py model.py model_test.py model_tra &#8230;<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"emotion":"","emotion_color":"","title_style":"","license":"","footnotes":""},"categories":[3],"tags":[6,9],"class_list":["post-494","post","type-post","status-publish","format-standard","hentry","category-3","tag-cv","tag-9"],"_links":{"self":[{"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/posts\/494","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/eve2333.top\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=494"}],"version-history":[{"count":0,"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/posts\/494\/revisions"}],"wp:attachment":[{"href":"https:\/\/eve2333.top\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=494"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/eve2333.top\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=494"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/eve2333.top\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=494"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}