{"id":444,"date":"2024-10-25T14:17:55","date_gmt":"2024-10-25T06:17:55","guid":{"rendered":"https:\/\/eve2333.top\/?p=444"},"modified":"2024-10-25T14:17:55","modified_gmt":"2024-10-25T06:17:55","slug":"lenet%e4%b8%8ealexnet%e5%ae%9e%e6%88%98","status":"publish","type":"post","link":"https:\/\/eve2333.top\/?p=444","title":{"rendered":"LeNet\u4e0eAlexNet\u5b9e\u6218"},"content":{"rendered":"\n<p>\u642d\u5efa\u795e\u7ecf\u7f51\u7edc\u6a21\u578b\u7684\u8fc7\u7a0b\u53ef\u4ee5\u603b\u7ed3\u4e3a\u4ee5\u4e0b\u6b65\u9aa4\uff1a<\/p>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>\u521d\u59cb\u5316\u7f51\u7edc\u5c42\u4e0e\u53c2\u6570<\/strong>\uff1a\u5728\u6a21\u578b\u7684\u521d\u59cb\u5316\u9636\u6bb5\uff0c\u9700\u8981\u5b9a\u4e49\u5404\u79cd\u7f51\u7edc\u5c42\uff08\u5982\u5377\u79ef\u5c42\u3001\u5168\u8fde\u63a5\u5c42\uff09\u4ee5\u53ca\u6240\u9700\u7684\u53c2\u6570\u3002\u8fd9\u76f8\u5f53\u4e8e\u4e3a\u642d\u5efa\u7f51\u7edc\u51c6\u5907\u57fa\u7840\u7684\u7ec4\u4ef6\uff0c\u5982\u7816\u5934\u3001\u6c34\u6ce5\u7b49\u8d44\u6e90\u3002<\/li>\n\n\n\n<li><strong>\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b<\/strong>\uff1a\u5b9a\u4e49\u524d\u5411\u4f20\u64ad\u51fd\u6570\uff0c\u7528\u4e8e\u5c06\u8f93\u5165\u6570\u636e\u901a\u8fc7\u5404\u4e2a\u7f51\u7edc\u5c42\u9010\u6b65\u8fdb\u884c\u8ba1\u7b97\uff0c\u4ece\u800c\u751f\u6210\u8f93\u51fa\u3002\u6bcf\u4e00\u5c42\u7684\u8f93\u51fa\u4f5c\u4e3a\u4e0b\u4e00\u5c42\u7684\u8f93\u5165\uff0c\u76f4\u5230\u6700\u7ec8\u83b7\u5f97\u6a21\u578b\u7684\u8f93\u51fa\u3002\u8be5\u8fc7\u7a0b\u901a\u8fc7\u8c03\u7528\u5df2\u521d\u59cb\u5316\u7684\u5c42\u6765\u5b9e\u73b0\u5404\u5c42\u95f4\u7684\u8fde\u63a5\u3002<\/li>\n\n\n\n<li><strong>\u6fc0\u6d3b\u51fd\u6570\u4e0e\u5176\u4ed6\u64cd\u4f5c<\/strong>\uff1a\u5728\u524d\u5411\u4f20\u64ad\u7684\u8fc7\u7a0b\u4e2d\uff0c\u901a\u5e38\u9700\u8981\u4f7f\u7528\u6fc0\u6d3b\u51fd\u6570\uff08\u5982 ReLU\uff09\u5bf9\u4e2d\u95f4\u7ed3\u679c\u8fdb\u884c\u975e\u7ebf\u6027\u53d8\u6362\uff0c\u63d0\u5347\u6a21\u578b\u7684\u8868\u8fbe\u80fd\u529b\u3002\u6b64\u5916\uff0c\u53ef\u80fd\u8fd8\u4f1a\u4f7f\u7528\u6c60\u5316\u5c42\u7b49\u64cd\u4f5c\u6765\u7f29\u5c0f\u7279\u5f81\u56fe\u7684\u5c3a\u5bf8\uff0c\u4ece\u800c\u51cf\u5c11\u8ba1\u7b97\u91cf\u5e76\u63d0\u53d6\u66f4\u6709\u610f\u4e49\u7684\u7279\u5f81\u3002<\/li>\n\n\n\n<li><strong>\u53cd\u5411\u4f20\u64ad\u4e0e\u8bad\u7ec3<\/strong>\uff1a\u6a21\u578b\u642d\u5efa\u5b8c\u6210\u540e\uff0c\u914d\u5408\u8bad\u7ec3\u8fc7\u7a0b\u7684\u53cd\u5411\u4f20\u64ad\u7b97\u6cd5\uff0c\u901a\u8fc7\u8ba1\u7b97\u68af\u5ea6\u548c\u66f4\u65b0\u53c2\u6570\u6765\u4f18\u5316\u6a21\u578b\u3002\u8bad\u7ec3\u8fc7\u7a0b\u901a\u5e38\u5728\u72ec\u7acb\u7684\u8bad\u7ec3\u4ee3\u7801\u4e2d\u5b8c\u6210\uff0c\u800c\u524d\u5411\u4f20\u64ad\u5219\u662f\u7f51\u7edc\u6a21\u578b\u7684\u4e00\u90e8\u5206\u3002<\/li>\n<\/ol>\n\n\n\n<p>\u8fd9\u4e2a\u6d41\u7a0b\u7c7b\u4f3c\u4e8e\u642d\u5efa\u4e00\u5ea7\u5efa\u7b51\u7269\uff0c\u5148\u51c6\u5907\u597d\u5404\u7c7b\u6750\u6599\u548c\u5de5\u5177\uff08\u5373\u521d\u59cb\u5316\u7f51\u7edc\u5c42\u548c\u53c2\u6570\uff09\uff0c\u7136\u540e\u6839\u636e\u8bbe\u8ba1\u9010\u6b65\u6784\u5efa\uff0c\u76f4\u5230\u6a21\u578b\u80fd\u591f\u901a\u8fc7\u8f93\u5165\u751f\u6210\u6709\u6548\u8f93\u51fa\u3002<\/p>\n\n\n\n<h1 class=\"wp-block-heading\"><strong>\u521d\u59cb\u5316<\/strong><\/h1>\n\n\n\n<p>\u65b0\u5efa\u4e00\u4e2aLeNet\u6587\u4ef6\u5939\uff0c\u5176\u4e2dmodel.py\u5982\u4e0b<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\nfrom torch import nn\nfrom torchsummary import summary\n\nclass LeNet(nn. Module):\n    def __init__(self):\n        super(LeNet, self).__init__()\n        self.c1=nn. Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)\n        self.sig=nn.Sigmoid()\n        self.s2 =nn. AvgPool2d(kernel_size=2, stride=2)\n        self.c3 =nn. Conv2d(in_channels=6, out_channels=16, kernel_size=5)\n        self.s4=nn.AvgPool2d(kernel_size=2, stride=2)\n\n        self. flatten =nn. Flatten()\n        self. f5 =nn. Linear( 400,  120)\n        self. f6=nn. Linear( 120,84)\n        self. f7=nn. Linear(84,10)<\/code><\/pre>\n\n\n\n<p>\u5728\u6784\u5efa\u795e\u7ecf\u7f51\u7edc\u65f6\uff0c\u9996\u5148\u9700\u8981\u5b9a\u4e49\u7f51\u7edc\u7684\u5404\u5c42\u53ca\u5176\u53c2\u6570\u3002\u5728\u521d\u59cb\u5316\u9636\u6bb5\uff0c\u6211\u4eec\u6839\u636e LeNet \u7684\u7ed3\u6784\u5b9a\u4e49\u5377\u79ef\u5c42\u3001\u6fc0\u6d3b\u51fd\u6570\u3001\u6c60\u5316\u5c42\u3001\u5c55\u5e73\u5c42\u4ee5\u53ca\u5168\u8fde\u63a5\u5c42\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li><strong>\u5377\u79ef\u5c42 (Convolutional Layer)<\/strong>\uff1a\u901a\u8fc7 <code>nn.Conv2d<\/code> \u5b9a\u4e49\u5377\u79ef\u5c42\uff0c\u6307\u5b9a\u8f93\u5165\u901a\u9053\u3001\u8f93\u51fa\u901a\u9053\u3001\u5377\u79ef\u6838\u7684\u5927\u5c0f\uff0c\u4ee5\u53ca\u586b\u5145\uff08padding\uff09\u7b49\u53c2\u6570\u3002LeNet \u4e2d\u7684\u7b2c\u4e00\u5c42\u5377\u79ef\u63a5\u6536\u7070\u5ea6\u56fe\uff08\u901a\u9053\u6570\u4e3a 1\uff09\uff0c\u8f93\u51fa 6 \u4e2a\u7279\u5f81\u56fe\uff0c\u5377\u79ef\u6838\u5927\u5c0f\u4e3a 5x5\uff0c\u4f7f\u7528 padding \u4e3a 2\u3002<\/li>\n\n\n\n<li><strong>\u6fc0\u6d3b\u51fd\u6570 (Activation Function)<\/strong>\uff1a\u4f7f\u7528 <code>nn.Sigmoid()<\/code> \u6765\u589e\u52a0\u7f51\u7edc\u7684\u975e\u7ebf\u6027\u8868\u793a\u80fd\u529b\u3002<\/li>\n\n\n\n<li><strong>\u6c60\u5316\u5c42 (Pooling Layer)<\/strong>\uff1a\u901a\u8fc7 <code>nn.AvgPool2d<\/code> \u5b9a\u4e49\u6c60\u5316\u5c42\uff0c\u7528\u4e8e\u4e0b\u91c7\u6837\u7279\u5f81\u56fe\uff0c\u51cf\u5c11\u8ba1\u7b97\u91cf\u3002\u6c60\u5316\u6838\u5927\u5c0f\u4e3a 2x2\uff0c\u6b65\u5e45\u4e3a 2\u3002<\/li>\n\n\n\n<li><strong>\u5168\u8fde\u63a5\u5c42 (Fully Connected Layer)<\/strong>\uff1a\u901a\u8fc7 <code>nn.Linear<\/code> \u5b9a\u4e49\u5168\u8fde\u63a5\u5c42\uff0c\u7528\u4e8e\u5c06\u5377\u79ef\u5c42\u7684\u8f93\u51fa\u6620\u5c04\u5230\u5177\u4f53\u7684\u7c7b\u522b\u6807\u7b7e\u3002LeNet \u4e2d\u6709\u4e09\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u7b2c\u4e00\u5c42\u8f93\u5165 400 \u4e2a\u7279\u5f81\uff0c\u8f93\u51fa 120 \u4e2a\u795e\u7ecf\u5143\uff0c\u63a5\u7740\u662f 84 \u4e2a\u795e\u7ecf\u5143\uff0c\u6700\u540e\u8f93\u51fa\u4e3a 10 \u7c7b\u3002<\/li>\n<\/ul>\n\n\n\n<h1 class=\"wp-block-heading\">\u524d\u5411\u4f20\u64ad<\/h1>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\nfrom torch import nn\nfrom torchsummary import summary\n\nclass LeNet(nn. Module):\n    def __init__(self):\n        super(LeNet, self).__init__()\n        self.c1=nn. Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)\n        self.sig=nn.Sigmoid()\n        self.s2 =nn. AvgPool2d(kernel_size=2, stride=2)\n        self.c3 =nn. Conv2d(in_channels=6, out_channels=16, kernel_size=5)\n        self.s4=nn.AvgPool2d(kernel_size=2, stride=2)\n\n        self. flatten =nn. Flatten()\n        self. f5 =nn. Linear( 400,  120)\n        self. f6=nn. Linear( 120,84)\n        self. f7=nn. Linear(84,10)\n    def forward(self,x):\n    x = self.sig(self.c1(x))   # Pass input through the first convolution layer (c1), followed by a sigmoid activation function (sig).\n    x = self.s2(x)             # Pass the result through the first pooling layer (s2).\n    x = self.sig(self.c3(x))   # Pass the output through the second convolution layer (c3) and apply sigmoid activation again.\n    x = self.s4(x)             # Pass the result through the second pooling layer (s4).\n    x = self.flatten(x)        # Flatten the output to prepare it for fully connected layers.\n    x = self.f5(x)             # Pass the flattened output through the first fully connected layer (f5).\n    x = self.f6(x)             # Pass through the second fully connected layer (f6).\n    x = self.f7(x)             # Pass through the final fully connected layer (f7) to produce the output.\n    return x                   # Return the final result as the output of the forward pass.<\/code><\/pre>\n\n\n\n<p>\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u5df2\u7ecf\u786e\u5b9a\u4e86\u6a21\u578b\u7684\u67b6\u6784\uff0c\u5e76\u4e14\u521d\u59cb\u5316\u4e86\u6240\u6709\u5fc5\u8981\u7684\u4fe1\u606f\u3002\u73b0\u5728\uff0c\u6211\u4eec\u9700\u8981\u5229\u7528\u8fd9\u4e9b\u521d\u59cb\u5316\u597d\u7684\u4fe1\u606f\u6765\u642d\u5efa\u795e\u7ecf\u7f51\u7edc\u7684\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\u3002\u4e00\u65e6\u6211\u4eec\u5c06\u6570\u636e\u8f93\u5165\u5230\u8fd9\u4e2a\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u7684\u6a21\u578b\u7ed3\u6784\u5c31\u5b8c\u6574\u4e86\u3002\u63a5\u4e0b\u6765\uff0c\u6211\u4eec\u9700\u8981\u5b9a\u4e49\u4e00\u4e2a\u540d\u4e3a<code>forward<\/code>\u7684\u51fd\u6570\uff0c\u8fd9\u4e2a\u51fd\u6570\u5c06\u8d1f\u8d23\u6574\u4e2a\u524d\u5411\u4f20\u64ad\u7684\u8fc7\u7a0b\u3002<\/p>\n\n\n\n<p>\u5f53\u4f60\u5728\u7c7b\u4e2d\u5b9a\u4e49\u4e00\u4e2a\u51fd\u6570\u65f6\uff0c\u6bd4\u5982\u6211\u4eec\u7684<code>forward<\/code>\u51fd\u6570\uff0cPython\u4f1a\u81ea\u52a8\u9075\u5faa\u4e00\u5b9a\u7684\u8bed\u6cd5\u89c4\u5219\uff0c\u8fd9\u4e9b\u89c4\u5219\u662f\u6211\u4eec\u5728\u5b66\u4e60\u8fc7\u7a0b\u4e2d\u9700\u8981\u638c\u63e1\u7684\u3002\u5373\u4f7f\u4f60\u6ca1\u6709Python\u57fa\u7840\uff0c\u53ea\u8981\u8ddf\u7740\u6211\u7684\u6b65\u9aa4\u6765\uff0c\u4e5f\u80fd\u6210\u529f\u6784\u5efa\u51fa\u795e\u7ecf\u7f51\u7edc\u3002\u6211\u4f1a\u8be6\u7ec6\u89e3\u91ca\u6bcf\u4e00\u6b65\uff0c\u56e0\u4e3a\u6211\u7684\u8bfe\u7a0b\u662f\u9762\u5411\u521d\u5b66\u8005\u7684\uff0c\u76ee\u7684\u662f\u786e\u4fdd\u5927\u5bb6\u90fd\u80fd\u7406\u89e3\u5e76\u5b66\u5230\u4e1c\u897f\u3002\u6709\u65f6\u5019\u53ef\u80fd\u4f1a\u663e\u5f97\u6709\u4e9b\u5570\u55e6\uff0c\u4f46\u8fd9\u662f\u5fc5\u8981\u7684\uff0c\u4ee5\u786e\u4fdd\u6bcf\u4e2a\u4eba\u90fd\u80fd\u8ddf\u4e0a\u3002<\/p>\n\n\n\n<p>\u5728\u8fd9\u4e2a\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u9996\u5148\u5b9a\u4e49\u8f93\u5165\u63a5\u53e3<code>X<\/code>\u3002\u73b0\u5728\uff0c\u6211\u4eec\u5df2\u7ecf\u62e5\u6709\u4e86\u6240\u6709\u7684\u795e\u7ecf\u7f51\u7edc\u5c42\uff0c\u6211\u4eec\u9700\u8981\u642d\u5efa\u7f51\u7edc\u6a21\u578b\uff0c\u6a21\u578b\u9700\u8981\u8f93\u5165<code>X<\/code>\u5e76\u5f97\u51fa\u6211\u4eec\u7684\u8f93\u51fa<code>Y<\/code>\u3002\u5728\u8fd9\u4e2a\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u53ef\u80fd\u4f1a\u5b9a\u4e49\u4e00\u4e2a<code>X2<\/code>\u4f5c\u4e3a\u8f93\u5165\u63a5\u53e3\u3002\u6211\u4eec\u628a<code>X<\/code>\u8f93\u5165\u5230\u7b2c\u4e00\u5c42\uff0c\u4e5f\u5c31\u662f\u6211\u4eec\u7684\u5377\u79ef\u5c42<code>c1<\/code>\u3002\u7136\u540e\uff0c\u6570\u636e\u4f1a\u901a\u8fc7\u5377\u79ef\u5c42\uff0c\u63a5\u7740\u901a\u8fc7\u6fc0\u6d3b\u51fd\u6570\uff0c\u518d\u901a\u8fc7\u6c60\u5316\u5c42\uff0c\u7136\u540e\u518d\u6b21\u901a\u8fc7\u5377\u79ef\u5c42\u3002\u8fd9\u91cc\u7684<code>X<\/code>\u4f1a\u518d\u6b21\u8f93\u5165\u5230\u6211\u4eec\u7684\u6fc0\u6d3b\u51fd\u6570\u4e2d\uff0c\u7136\u540e\u8f93\u51fa\u5230\u6c60\u5316\u5c42\uff0c\u6700\u540e\u901a\u8fc7\u5168\u8fde\u63a5\u5c42\u3002<\/p>\n\n\n\n<p>\u5728\u8fd9\u4e2a\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u53ef\u4ee5\u5c06\u6bcf\u4e00\u5c42\u7684\u8f93\u51fa\u547d\u540d\u4e3a<code>A1<\/code>\u3001<code>A2<\/code>\u7b49\uff0c\u8fd9\u53ea\u662f\u6211\u4eec\u5b9a\u4e49\u7684\u4e00\u4e2a\u540d\u5b57\uff0c\u4f60\u4e5f\u53ef\u4ee5\u7528<code>A<\/code>\u6765\u8868\u793a\u3002\u5728\u642d\u5efa\u7f51\u7edc\u6a21\u578b\u65f6\uff0c\u6bcf\u4e00\u5c42\u7684\u8f93\u51fa\u90fd\u4f1a\u4f5c\u4e3a\u4e0b\u4e00\u5c42\u7684\u8f93\u5165\u3002\u4f8b\u5982\uff0c\u5377\u79ef\u5c42\u7684\u8f93\u51fa\u4f1a\u518d\u6b21\u8f93\u5165\u5230\u6fc0\u6d3b\u51fd\u6570\u4e2d\uff0c\u7136\u540e\u901a\u8fc7\u6c60\u5316\u5c42\uff0c\u6700\u540e\u901a\u8fc7\u5168\u8fde\u63a5\u5c42\u3002\u8fd9\u4e2a\u8fc7\u7a0b\u975e\u5e38\u6d41\u7545\uff0c\u6bcf\u4e00\u5c42\u7684\u64cd\u4f5c\u90fd\u662f\u6709\u5e8f\u7684\u3002<\/p>\n\n\n\n<p>\u6700\u540e\uff0c\u6211\u4eec\u901a\u8fc7\u5e73\u5c55\u5c42\u5c06\u6570\u636e\u5c55\u5e73\uff0c\u7136\u540e\u8f93\u5165\u5230\u5168\u8fde\u63a5\u5c42\u3002\u5728\u8fd9\u4e2a\u8fc7\u7a0b\u4e2d\uff0c\u6211\u4eec\u53ef\u80fd\u4f1a\u6709\u591a\u4e2a\u5168\u8fde\u63a5\u5c42\uff0c\u6bd4\u5982<code>f5<\/code>\u3001<code>f6<\/code>\u548c<code>f7<\/code>\u3002\u6700\u7ec8\uff0c\u6211\u4eec\u5f97\u5230\u8f93\u51fa<code>Y<\/code>\uff0c\u8fd9\u662f\u6211\u4eec\u7684\u524d\u5411\u4f20\u64ad\u7ed3\u679c\uff0c\u4e5f\u662f\u6211\u4eec\u8fd4\u56de\u7684\u503c\u3002\u8fd9\u6837\uff0c\u6211\u4eec\u7684\u6a21\u578b\u5c31\u642d\u5efa\u597d\u4e86\uff0c\u8f93\u5165<code>X<\/code>\u5e94\u8be5\u5f97\u51fa\u6211\u4eec\u7684<code>Y<\/code>\uff0c\u8fd9\u4e2a<code>Y<\/code>\u662f\u6211\u4eec\u7684\u8fd4\u56de\u503c\uff0c\u4e5f\u5c31\u662f\u795e\u7ecf\u7f51\u7edc\u7684\u6700\u7ec8\u7ed3\u679c\u3002<\/p>\n\n\n\n<p>\u6240\u4ee5\uff0c\u4f60\u4f1a\u53d1\u73b0\u795e\u7ecf\u7f51\u7edc\u7684\u5b9a\u4e49\u548c\u6211\u4eec\u4e4b\u524d\u8bb2\u7684\u539f\u7406\u662f\u76f8\u6263\u7684\u3002\u6211\u4eec\u4e4b\u524d\u82b1\u4e86\u5f88\u591a\u65f6\u95f4\u53bb\u8bb2\u89e3\u795e\u7ecf\u7f51\u7edc\u7684\u8f93\u5165\u7279\u5f81\u56fe\u548c\u8f93\u51fa\u7279\u5f81\u56fe\uff0c\u8fd9\u662f\u4e3a\u4e86\u8ba9\u5927\u5bb6\u66f4\u6e05\u6670\u5730\u7406\u89e3\u524d\u5411\u4f20\u64ad\u7684\u8fc7\u7a0b\u548c\u5177\u4f53\u7684\u53c2\u6570\u3002\u5728\u642d\u5efa\u795e\u7ecf\u7f51\u7edc\u65f6\uff0c\u8fd9\u4e2a\u8fc7\u7a0b\u662f\u975e\u5e38\u6e05\u6670\u7684\u3002\u65e0\u8bba\u662f\u4ec0\u4e48\u7c7b\u578b\u7684\u795e\u7ecf\u7f51\u7edc\uff0c\u57fa\u672c\u4e0a\u90fd\u662f\u8fd9\u6837\u7684\u6d41\u7a0b\uff0c\u53ea\u662f\u53ef\u80fd\u4f1a\u66f4\u590d\u6742\u4e00\u4e9b\u3002\u6240\u4ee5\uff0c\u6211\u4eec\u53ea\u9700\u8981\u5b9a\u4e49\u597d\u524d\u5411\u4f20\u64ad\uff0c\u7136\u540e\u6309\u7167\u8fd9\u4e2a\u903b\u8f91\u4e00\u5c42\u4e00\u5c42\u5730\u4f20\u9012\u6570\u636e\uff0c\u76f4\u5230\u6700\u540e\u4e00\u5c42\uff0c\u7136\u540e\u8fd4\u56de\u795e\u7ecf\u7f51\u7edc\u7684\u524d\u5411\u4f20\u64ad\u7ed3\u679c\u3002<\/p>\n\n\n\n<p>\u8fd9\u4e2a\u8fc7\u7a0b\u5c31\u662f\u6211\u4eec\u6240\u8bf4\u7684\u524d\u5411\u4f20\u64ad\uff0c\u5b83\u662f\u795e\u7ecf\u7f51\u7edc\u5728\u7ed9\u5b9a\u8f93\u5165\u6570\u636e\u65f6\u8fdb\u884c\u9884\u6d4b\u6216\u63a8\u65ad\u7684\u6b65\u9aa4\u3002\u5728\u8fd9\u4e2a\u8fc7\u7a0b\u4e2d\uff0c\u6bcf\u4e00\u5c42\u7684\u8f93\u51fa\u90fd\u6210\u4e3a\u4e0b\u4e00\u5c42\u7684\u8f93\u5165\uff0c\u76f4\u5230\u6700\u540e\u4e00\u5c42\u4ea7\u751f\u6700\u7ec8\u7684\u8f93\u51fa\u7ed3\u679c\u3002\u8fd9\u4e2a\u7ed3\u679c\u53ef\u4ee5\u662f\u5206\u7c7b\u6807\u7b7e\u3001\u56de\u5f52\u503c\u6216\u5176\u4ed6\u7c7b\u578b\u7684\u9884\u6d4b\u8f93\u51fa\uff0c\u5177\u4f53\u53d6\u51b3\u4e8e\u7f51\u7edc\u7684\u8bbe\u8ba1\u548c\u4efb\u52a1\u76ee\u6807\u3002<\/p>\n\n\n\n<p>\u901a\u8fc7\u8fd9\u79cd\u65b9\u5f0f\uff0cLeNet\u7c7b\u5b9a\u4e49\u4e86\u4e00\u4e2a\u5b8c\u6574\u7684\u524d\u5411\u4f20\u64ad\u6d41\u7a0b\uff0c\u4f7f\u5f97\u6211\u4eec\u53ef\u4ee5\u5c06\u8f93\u5165\u6570\u636e\u901a\u8fc7\u7f51\u7edc\u7ed3\u6784\uff0c\u6700\u7ec8\u5f97\u5230\u9884\u6d4b\u7ed3\u679c\u3002\u8fd9\u4e2a\u6d41\u7a0b\u662f\u6784\u5efa\u4efb\u4f55\u795e\u7ecf\u7f51\u7edc\u6a21\u578b\u7684\u57fa\u7840\uff0c\u65e0\u8bba\u662f\u7528\u4e8e\u56fe\u50cf\u8bc6\u522b\u3001\u81ea\u7136\u8bed\u8a00\u5904\u7406\u8fd8\u662f\u5176\u4ed6\u673a\u5668\u5b66\u4e60\u4efb\u52a1\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\nfrom torch import nn\nfrom torchsummary import summary\n\nclass LeNet(nn. Module):\n    def __init__(self):\n        super(LeNet, self).__init__()\n        self.c1=nn. Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)\n        self.sig=nn.Sigmoid()\n        self.s2 =nn. AvgPool2d(kernel_size=2, stride=2)\n        self.c3 =nn. Conv2d(in_channels=6, out_channels=16, kernel_size=5)\n        self.s4=nn.AvgPool2d(kernel_size=2, stride=2)\n\n        self. flatten =nn. Flatten()\n        self. f5 =nn. Linear( 400,  120)\n        self. f6=nn. Linear( 120,84)\n        self. f7=nn. Linear(84,10)\n    def forward(self,x):\n        x = self.sig(self.c1(x))\n        x = self.s2(x)\n        x=self.sig(self.c3(x))\n        x=self.s4(x)\n        x=self.flatten(x)\n        x=self.f5(x)\n        x= self.f6(x)\n        x=self.f7(x)\n        return x\nif __name__==\"__main__\":\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    model = LeNet().to(device)\n    print(summary(model, (1,28,28)))\n<\/code><\/pre>\n\n\n\n<h1 class=\"wp-block-heading\">\u6570\u636e\u96c6<\/h1>\n\n\n\n<p>\u7531\u4e8e\u624b\u5199\u6570\u5b57\u8bc6\u522b\u5df2\u7ecf\u505a\u70c2\u4e86,\u6240\u4ee5\u4f7f\u7528\u8863\u670d\u5206\u7c7b\u7684\u6570\u636e\u96c6,\u65b0\u5efaplot.py<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>from torchvision.datasets import FashionMNIST\nfrom torchvision import transforms\nimport torch.utils.data as Data\nimport numpy as np\nimport matplotlib.pyplot as plt\n\ntrain_data = FashionMNIST(root='.\/data',\n                          train=True,\n                          transform=transforms.Compose(&#91;transforms.Resize(size=224), transforms.ToTensor()]),\n                          download=True)\n\ntrain_loader = Data.DataLoader(dataset=train_data,\n                               batch_size=64,\n                               shuffle=True,\n                               num_workers=0)\n\n# \u83b7\u5f97\u4e00\u4e2aBatch\u7684\u6570\u636e\nfor step, (b_x, b_y) in enumerate(train_loader):\n    if step > 0:\n        break\nbatch_x = b_x.squeeze().numpy()  # \u5c06\u56db\u7ef4\u5f20\u91cf\u79fb\u9664\u7b2c1\u7ef4\uff0c\u5e76\u8f6c\u6362\u6210Numpy\u6570\u7ec4\nbatch_y = b_y.numpy()  # \u5c06\u5f20\u91cf\u8f6c\u6362\u6210Numpy\u6570\u7ec4\nclass_Label = train_data.classes  # \u8bad\u7ec3\u96c6\u7684\u6807\u7b7e\n# print(class_label)\nprint(\"The size of batch in train data:\", batch_x.shape)  # \u6bcf\u4e2amini-batch\u7684\u7ef4\u5ea6\u662f64*224*224\n\nplt.figure(figsize=(12, 5))\nfor ii in np.arange(len(batch_y)):\n    plt.subplot(4, 16, ii + 1)<\/code><\/pre>\n\n\n\n<h1 class=\"wp-block-heading\">\u6570\u636e\u52a0\u8f7d\u51fd\u6570<\/h1>\n\n\n\n<p>model_train.py<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>from torchvision.datasets import FashionMNIST\nfrom torchvision import transforms\nimport torch.utils.data as Data\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom model import LeNet\n\n\ndef train_val_data_process():\n    train_data = FashionMNIST(root='.\/data',\n                              train=True,\n                              transform=transforms.Compose(&#91;transforms.Resize(size=28), transforms.ToTensor()]),\n                              download=True)\n    train_data, val_data = Data.random_split(train_data, &#91;round(0.8 * len(train_data)), round(0.2 * len(train_data))])\n    train_dataloader = Data.DataLoader(dataset=train_data,\n                                       batch_size=128,\n                                       shuffle=True,\n                                       num_workers=8)\n\n    val_dataloader = Data.DataLoader(dataset=val_data,\n                                     batch_size=128,\n                                     shuffle=True,\n                                     num_workers=8)\n\n    return train_dataloader, val_dataloader<\/code><\/pre>\n\n\n\n<h1 class=\"wp-block-heading\">\u8bad\u7ec3\u6a21\u677f\u4ee3\u7801<\/h1>\n\n\n\n<p>model_train.py<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import time\n\nimport torch.optim\nfrom torch import nn\nfrom torchvision.datasets import FashionMNIST\nfrom torchvision import transforms\nimport torch.utils.data as Data\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom model import LeNet\n\n\ndef train_val_data_process():\n    train_data = FashionMNIST(root='.\/data',\n                              train=True,\n                              transform=transforms.Compose(&#91;transforms.Resize(size=28), transforms.ToTensor()]),\n                              download=True)\n    train_data, val_data = Data.random_split(train_data, &#91;round(0.8 * len(train_data)), round(0.2 * len(train_data))])\n    train_dataloader = Data.DataLoader(dataset=train_data,\n                                       batch_size=128,\n                                       shuffle=True,\n                                       num_workers=8)\n\n    val_dataloader = Data.DataLoader(dataset=val_data,\n                                     batch_size=128,\n                                     shuffle=True,\n                                     num_workers=8)\n\n    return train_dataloader, val_dataloader\n\n\ndef train_model(model, train_dataloader, val_dataloader, num_epochs):\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # \u4f18\u5316\u5668,\u68af\u5ea6\u4e0b\u964d\u6cd5\u8fdb\u4e00\u6b65\u62d3\u5c55\u662fSGD,Adam\u548cSGDM\u7b49\u7b49\u7b49\u7b49\u7684\n\n    criterion = nn.CrossEntropyLoss()\n    # \u5747\u65b9\u635f\u5931,\u5206\u7c7b\u4e2d\u6211\u4eec\u4e00\u822c\u7528\u4ea4\u53c9\u71b5\u635f\u5931\u6765\u66f4\u65b0\u635f\u5931\u503c,\u7136\u540e\u662fw,b\n    # \u5c06\u6a21\u578b\u653e\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n    model = model.to(device)\n    # \u590d\u5236\u5f53\u524d\u6a21\u578b\u7684\u53c2\u6570\n    best_model_wts = copy.deepcopy(model.state_dict())\n\n    # \u521d\u59cb\u5316\u53c2\u6570\n    # \u6700\u9ad8\u51c6\u786e\u5ea6\n    best_acc = 0.0\n    # \u8bad\u7ec3\u96c6\u635f\u5931\u5217\u8868\n    train_loss_all = &#91;]\n    # \u9a8c\u8bc1\u96c6\u635f\u5931\u5217\u8868\n    val_loss_all = &#91;]\n    # \u8bad\u7ec3\u96c6\u51c6\u786e\u5ea6\u5217\u8868\n    train_acc_all = &#91;]\n    # \u9a8c\u8bc1\u96c6\u51c6\u786e\u5ea6\u5217\u8868\n    val_acc_all = &#91;]\n    # \u5f53\u524d\u65f6\u95f4\n    since = time.time()<\/code><\/pre>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>\u53c2\u6570\u521d\u59cb\u5316<\/strong>\uff1a\u6211\u4eec\u5c06 <code>train_loss<\/code>, <code>train_accuracy<\/code>, <code>val_loss<\/code>, <code>val_accuracy<\/code> \u521d\u59cb\u5316\u4e3a 0\uff0c\u5e76\u5728\u6bcf\u4e2a epoch \u4e2d\u7d2f\u79ef\u8fd9\u4e9b\u503c\u3002<\/li>\n\n\n\n<li><strong>\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d<\/strong>\uff1a\u6bcf\u6b21\u4ece <code>train_loader<\/code> \u4e2d\u52a0\u8f7d\u4e00\u6279\u6570\u636e\uff0c\u5e76\u5c06\u8f93\u5165 <code>inputs<\/code> \u548c\u6807\u7b7e <code>labels<\/code> \u653e\u5165\u8bbe\u5907\uff08\u6bd4\u5982 GPU\uff09\u3002<\/li>\n\n\n\n<li><strong>\u68af\u5ea6\u6e05\u96f6<\/strong>\uff1a\u4e3a\u4e86\u9632\u6b62\u68af\u5ea6\u7d2f\u79ef\uff0c\u6bcf\u4e2a batch \u90fd\u9700\u8981\u8c03\u7528 <code>optimizer.zero_grad()<\/code> \u6765\u5c06\u68af\u5ea6\u91cd\u7f6e\u3002<\/li>\n\n\n\n<li><strong>\u524d\u5411\u4f20\u64ad<\/strong>\uff1a\u901a\u8fc7\u6a21\u578b\u7684\u524d\u5411\u4f20\u64ad\u8ba1\u7b97\u8f93\u51fa <code>outputs<\/code>\u3002<\/li>\n\n\n\n<li><strong>\u635f\u5931\u8ba1\u7b97<\/strong>\uff1a\u901a\u8fc7\u5b9a\u4e49\u7684\u635f\u5931\u51fd\u6570 <code>loss_fn<\/code> \u8ba1\u7b97\u635f\u5931\u3002<\/li>\n\n\n\n<li><strong>\u53cd\u5411\u4f20\u64ad<\/strong>\uff1a\u901a\u8fc7 <code>loss.backward()<\/code> \u8ba1\u7b97\u68af\u5ea6\u3002<\/li>\n\n\n\n<li><strong>\u53c2\u6570\u66f4\u65b0<\/strong>\uff1a\u4f7f\u7528\u4f18\u5316\u5668 <code>optimizer.step()<\/code> \u8fdb\u884c\u53c2\u6570\u66f4\u65b0\u3002<\/li>\n\n\n\n<li><strong>\u8ba1\u7b97\u8bad\u7ec3\u51c6\u786e\u5ea6<\/strong>\uff1a\u901a\u8fc7\u6bd4\u8f83\u9884\u6d4b\u7ed3\u679c <code>preds<\/code> \u4e0e\u771f\u5b9e\u6807\u7b7e <code>labels<\/code>\uff0c\u8ba1\u7b97\u51c6\u786e\u9884\u6d4b\u7684\u6570\u91cf\uff0c\u5e76\u66f4\u65b0\u7d2f\u8ba1\u51c6\u786e\u7387\u3002<\/li>\n<\/ol>\n\n\n\n<p>\u4f60\u53ef\u4ee5\u5728\u6bcf\u4e2a epoch \u540e\u9762\u6dfb\u52a0\u9a8c\u8bc1\u7684\u90e8\u5206\u548c\u76f8\u5e94\u7684\u7edf\u8ba1\u3002<\/p>\n\n\n\n<p>\u8fd9\u6837\u6211\u4eec\u5df2\u7ecf\u628a\u4e00\u4e2a\u5b8c\u6574\u7684\u8bad\u7ec3\u6d41\u7a0b\u6846\u67b6\u6574\u7406\u51fa\u6765\uff0c\u63a5\u4e0b\u6765\u4f60\u53ef\u4ee5\u6839\u636e\u9700\u8981\u7ee7\u7eed\u5b8c\u5584\u9a8c\u8bc1\u90e8\u5206\u6216\u8005\u6dfb\u52a0\u7ec6\u8282\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\">1. <strong>\u6a21\u578b\u8bad\u7ec3\u7684\u521d\u59cb\u5316<\/strong><\/h3>\n\n\n\n<p>\u6211\u4eec\u5df2\u7ecf\u5b9a\u4e49\u4e86\u6a21\u578b\u52a0\u8f7d\u5230\u5408\u9002\u7684\u8bbe\u5907\uff08\u5982GPU\u6216CPU\uff09\u3001\u8bbe\u7f6e\u4e86\u4f18\u5316\u5668\uff08\u5982Adam\uff09\u4ee5\u53ca\u635f\u5931\u51fd\u6570\uff08\u5982\u4ea4\u53c9\u71b5\uff09\u3002\u63a5\u4e0b\u6765\u6211\u4eec\u8fdb\u5165\u5b9e\u9645\u7684\u8bad\u7ec3\u6b65\u9aa4\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\">2. <strong>\u8bad\u7ec3\u8f6e\u6b21\uff08Epoch\uff09<\/strong><\/h3>\n\n\n\n<p>\u9996\u5148\uff0c\u6211\u4eec\u5728\u8bad\u7ec3\u51fd\u6570\u91cc\u4f1a\u6709\u4e00\u4e2a <code>for<\/code> \u5faa\u73af\uff0c\u5faa\u73af\u7684\u6b21\u6570\u5c31\u662f\u8bad\u7ec3\u7684\u8f6e\u6b21\uff08<code>epochs<\/code>\uff09\u3002\u6bcf\u8f6e\u8bad\u7ec3\u7684\u76ee\u7684\u662f\u901a\u8fc7\u53cd\u5411\u4f20\u64ad\u66f4\u65b0\u6a21\u578b\u7684\u53c2\u6570\uff08<code>W<\/code> \u548c <code>B<\/code>\uff09\uff0c\u4ee5\u4fbf\u4f7f\u635f\u5931\u51fd\u6570\u7684\u503c\u9010\u6e10\u4e0b\u964d\uff0c\u6a21\u578b\u7684\u6027\u80fd\u4e0d\u65ad\u63d0\u9ad8\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\">3. <strong>\u6279\u6b21\u8bad\u7ec3\uff08Mini-batch Training\uff09<\/strong><\/h3>\n\n\n\n<p>\u5bf9\u4e8e\u6bcf\u8f6e\u8bad\u7ec3\uff0c\u6211\u4eec\u9700\u8981\u5c06\u8bad\u7ec3\u6570\u636e\u5206\u6210\u5c0f\u6279\u6b21\uff08mini-batches\uff09\uff0c\u8fd9\u6837\u53ef\u4ee5\u8282\u7701\u5185\u5b58\u5e76\u63d0\u9ad8\u8bad\u7ec3\u901f\u5ea6\u3002\u5e38\u7528\u7684\u65b9\u5f0f\u662f\u7528 <code>DataLoader<\/code> \u6765\u8fed\u4ee3\u6279\u6b21\u6570\u636e\u3002\u5728\u6bcf\u4e2a\u6279\u6b21\u4e2d\uff0c\u6211\u4eec\u4f1a\uff1a<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u5c06\u6570\u636e\u548c\u6807\u7b7e\u4f20\u5165\u6a21\u578b\u3002<\/li>\n\n\n\n<li>\u524d\u5411\u4f20\u64ad\uff1a\u6a21\u578b\u8ba1\u7b97\u8f93\u51fa\u3002<\/li>\n\n\n\n<li>\u8ba1\u7b97\u635f\u5931\u503c\uff1a\u5229\u7528\u635f\u5931\u51fd\u6570\u8ba1\u7b97\u6a21\u578b\u8f93\u51fa\u548c\u771f\u5b9e\u6807\u7b7e\u4e4b\u95f4\u7684\u5dee\u5f02\u3002<\/li>\n\n\n\n<li>\u53cd\u5411\u4f20\u64ad\uff1a\u901a\u8fc7\u635f\u5931\u503c\u7684\u68af\u5ea6\u66f4\u65b0\u6a21\u578b\u53c2\u6570\u3002<\/li>\n\n\n\n<li>\u4f18\u5316\u5668 <code>step()<\/code>\uff1a\u66f4\u65b0\u6a21\u578b\u7684\u6743\u91cd\u3002<\/li>\n<\/ul>\n\n\n\n<h3 class=\"wp-block-heading\">4. <strong>\u9a8c\u8bc1\u96c6\u8bc4\u4f30<\/strong><\/h3>\n\n\n\n<p>\u5728\u6bcf\u4e2aepoch\u7ed3\u675f\u65f6\uff0c\u6211\u4eec\u901a\u5e38\u4f1a\u4f7f\u7528\u9a8c\u8bc1\u96c6\u6765\u8bc4\u4f30\u6a21\u578b\u7684\u6027\u80fd\u3002\u9a8c\u8bc1\u96c6\u4e0d\u53c2\u4e0e\u6a21\u578b\u53c2\u6570\u7684\u66f4\u65b0\uff0c\u53ea\u7528\u4e8e\u8861\u91cf\u6a21\u578b\u7684\u6cdb\u5316\u80fd\u529b\u3002\u8fd9\u91cc\u9700\u8981\u5173\u95ed\u68af\u5ea6\u8ba1\u7b97\uff08<code>torch.no_grad()<\/code>\uff09\uff0c\u4ee5\u51cf\u5c11\u5185\u5b58\u6d88\u8017\u548c\u52a0\u901f\u8bc4\u4f30\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\">5. <strong>\u4fdd\u5b58\u6700\u4f73\u6a21\u578b<\/strong><\/h3>\n\n\n\n<p>\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u4f1a\u4fdd\u5b58\u5f53\u524d\u6700\u4f73\u6a21\u578b\u7684\u53c2\u6570\uff08\u5373\u4f7f\u9a8c\u8bc1\u96c6\u4e0a\u7684\u635f\u5931\u6700\u5c0f\uff09\uff0c\u7528\u4e8e\u540e\u7eed\u7684\u6a21\u578b\u6d4b\u8bd5\u6216\u90e8\u7f72\u3002\u4fdd\u5b58\u7684\u6a21\u578b\u901a\u5e38\u5305\u62ec\u7f51\u7edc\u7ed3\u6784\u548c\u6743\u91cd\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\">6. <strong>\u8bb0\u5f55\u8bad\u7ec3\u8fc7\u7a0b<\/strong><\/h3>\n\n\n\n<p>\u4e3a\u4e86\u53ef\u89c6\u5316\u8bad\u7ec3\u8fdb\u5ea6\uff0c\u6211\u4eec\u4f1a\u4fdd\u5b58\u6bcf\u4e2aepoch\u7684\u8bad\u7ec3\u635f\u5931\u3001\u9a8c\u8bc1\u635f\u5931\u3001\u8bad\u7ec3\u7cbe\u5ea6\u3001\u9a8c\u8bc1\u7cbe\u5ea6\u7b49\u4fe1\u606f\uff0c\u4f9b\u540e\u7eed\u5206\u6790\u548c\u7ed8\u56fe\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\">\u603b\u7ed3\u6d41\u7a0b\uff1a<\/h3>\n\n\n\n<ol class=\"wp-block-list\">\n<li><strong>\u521d\u59cb\u5316<\/strong>\uff1a\n<ul class=\"wp-block-list\">\n<li>\u786e\u5b9a\u8bbe\u5907\u3001\u5b9a\u4e49\u4f18\u5316\u5668\u548c\u635f\u5931\u51fd\u6570\u3002<\/li>\n\n\n\n<li>\u5c06\u6a21\u578b\u548c\u6570\u636e\u52a0\u8f7d\u5230\u8bbe\u5907\u4e0a\u3002<\/li>\n<\/ul>\n<\/li>\n\n\n\n<li><strong>\u8bad\u7ec3\u5faa\u73af<\/strong>\uff1a\n<ul class=\"wp-block-list\">\n<li>\u5bf9\u4e8e\u6bcf\u4e2aepoch\uff1a\n<ul class=\"wp-block-list\">\n<li>\u5bf9\u6bcf\u4e2amini-batch\u8fdb\u884c\u524d\u5411\u4f20\u64ad\u3001\u8ba1\u7b97\u635f\u5931\u3001\u53cd\u5411\u4f20\u64ad\u3001\u66f4\u65b0\u53c2\u6570\u3002<\/li>\n<\/ul>\n<\/li>\n\n\n\n<li>\u4f7f\u7528\u9a8c\u8bc1\u96c6\u8bc4\u4f30\u6a21\u578b\u6027\u80fd\u3002<\/li>\n\n\n\n<li>\u4fdd\u5b58\u6700\u4f73\u6a21\u578b\u3002<\/li>\n<\/ul>\n<\/li>\n\n\n\n<li><strong>\u8bb0\u5f55\u548c\u8f93\u51fa<\/strong>\uff1a\n<ul class=\"wp-block-list\">\n<li>\u4fdd\u5b58\u6bcf\u8f6e\u8bad\u7ec3\u548c\u9a8c\u8bc1\u7684\u635f\u5931\u3001\u7cbe\u5ea6\uff0c\u8bb0\u5f55\u8bad\u7ec3\u65f6\u95f4\u3002<\/li>\n<\/ul>\n<\/li>\n<\/ol>\n\n\n\n<p>\u8fd9\u6837\uff0c\u6574\u4e2a\u8bad\u7ec3\u51fd\u6570\u7684\u7ed3\u6784\u548c\u601d\u8def\u5c31\u6e05\u6670\u4e86\uff0c\u4f60\u53ef\u4ee5\u901a\u8fc7\u5b9e\u73b0\u5b83\u6765\u8fdb\u884c\u6a21\u578b\u8bad\u7ec3\u3002<\/p>\n\n\n\n<h1 class=\"wp-block-heading\">&nbsp;\u8bad\u7ec3\u53cd\u5411\u4f20\u64ad<\/h1>\n\n\n\n<pre class=\"wp-block-code\"><code>import time\n\nimport torch.optim\nfrom torch import nn\nfrom torchvision.datasets import FashionMNIST\nfrom torchvision import transforms\nimport torch.utils.data as Data\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom model import LeNet\n\n\ndef train_val_data_process():\n    train_data = FashionMNIST(root='.\/data',\n                              train=True,\n                              transform=transforms.Compose(&#91;transforms.Resize(size=28), transforms.ToTensor()]),\n                              download=True)\n    train_data, val_data = Data.random_split(train_data, &#91;round(0.8 * len(train_data)), round(0.2 * len(train_data))])\n    train_dataloader = Data.DataLoader(dataset=train_data,\n                                       batch_size=128,\n                                       shuffle=True,\n                                       num_workers=8)\n\n    val_dataloader = Data.DataLoader(dataset=val_data,\n                                     batch_size=128,\n                                     shuffle=True,\n                                     num_workers=8)\n\n    return train_dataloader, val_dataloader\n\n\ndef train_model(model, train_dataloader, val_dataloader, num_epochs):\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # \u4f18\u5316\u5668,\u68af\u5ea6\u4e0b\u964d\u6cd5\u8fdb\u4e00\u6b65\u62d3\u5c55\u662fSGD,Adam\u548cSGDM\u7b49\u7b49\u7b49\u7b49\u7684\n\n    criterion = nn.CrossEntropyLoss()\n    # \u5747\u65b9\u635f\u5931,\u5206\u7c7b\u4e2d\u6211\u4eec\u4e00\u822c\u7528\u4ea4\u53c9\u71b5\u635f\u5931\u6765\u66f4\u65b0\u635f\u5931\u503c,\u7136\u540e\u662fw,b\n    # \u5c06\u6a21\u578b\u653e\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n    model = model.to(device)\n    # \u590d\u5236\u5f53\u524d\u6a21\u578b\u7684\u53c2\u6570\n    best_model_wts = copy.deepcopy(model.state_dict())\n\n    # \u521d\u59cb\u5316\u53c2\u6570\n    # \u6700\u9ad8\u51c6\u786e\u5ea6\n    best_acc = 0.0\n    # \u8bad\u7ec3\u96c6\u635f\u5931\u5217\u8868\n    train_loss_all = &#91;]\n    # \u9a8c\u8bc1\u96c6\u635f\u5931\u5217\u8868\n    val_loss_all = &#91;]\n    # \u8bad\u7ec3\u96c6\u51c6\u786e\u5ea6\u5217\u8868\n    train_acc_all = &#91;]\n    # \u9a8c\u8bc1\u96c6\u51c6\u786e\u5ea6\u5217\u8868\n    val_acc_all = &#91;]\n    # \u5f53\u524d\u65f6\u95f4\n    since = time.time()\n\n    for epoch in range(num_epochs):\n        print('Epoch {}\/{}'.format(epoch + 1, num_epochs))\n        print(\"-\" * 10)\n\n        train_loss = 0.0\n        train_correct = 0.0\n        val_loss = 0.0\n        val_correct = 0.0\n\n        train_num = 0\n        val_num = 0\n\n        # \u5bf9\u6bcf\u4e00\u4e2amini-batch\u8bad\u7ec3\u548c\u8ba1\u7b97\n        for step, (b_x, b_y) in enumerate(train_dataloader):\n            # \u5c06\u7279\u5f81\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n            b_x = b_x.to(device)\n            # \u5c06\u6807\u7b7e\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n            b_y = b_y.to(device)\n            # \u8bbe\u7f6e\u6a21\u578b\u4e3a\u8bad\u7ec3\u6984\u5f0f\n            model.train()\n            \n            # \u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u8f93\u5165\u4e3a\u4e00\u4e2abatch,\u8f93\u51fa\u4e3a\u4e00\u4e2abatch\u4e2d\u5bf9\u5e94\u7684\u9884\u6d4b\n            output = model(b_x)\n<\/code><\/pre>\n\n\n\n<p>\u9996\u5148train_dataloader\u662f\u8bad\u7ec3\u7684\u6240\u6709\u6570\u636e, \u52a0\u516560000*0.8\u7684\u6570\u636e,\u5176\u4e2d\u6bcf\u4e2a\u6279\u6b21\u6709128\u4e2a\u610f\u5473\u7740,\u90a3\u53ef\u80fd\u6709N\u4e2a\u6279\u6b21\u7684\u6570\u636e\u4e00\u6279\u6b21\u4e00\u6279\u6b21\u5c31\u53d6\u561b,\u76f4\u5230\u628a\u8fd9\u7b2cN\u4e2a\u6570\u636e\u53d6\u5b8c,\u5047\u8bbe\u8fd9\u91cc\u662f\u7b2c\u4e00\u6b21\u5faa\u73af,\u90a3\u4e48\u5c31\u53d6\u7b2c\u4e00\u4e2a\u6279\u6b21\u7684\u6570\u636e,\u5982\u679c\u7b2c\u4e8c\u6b21\u5faa\u73af\u7684\u8bdd,\u662f\u7b2c\u4e8c\u6279\u6b21\u7684\u6570\u636e,128\u4e2a\u5411\u91cf\u56fe,128\u4e58\u4ee5\u4ec0\u4e48\u5427,28\u00d728\u53ef\u80fd\u8fd8\u6709\u4e2a\u901a\u9053\u561b,\u8fd9\u91cc\u7684\u4ec0\u4e48BX\u7b49\u4e8e\u8fd9\u4e2a\u7684,\u90a3\u4e48\u5f88\u663e\u7136BY\u662f\u4ec0\u4e48\u90a3BY\u5c31\u662f128\u4e58\u4ee5\u5b83\u7684\u6807\u7b7e\u561b,\u5c31128\u4e2alabel,\u653e\u5230\u6211\u4eec\u7684\u8bbe\u5907\u5f53\u4e2d\u8fd9\u91cc\u8fdb\u884c\u4e00\u4e2a\u8bad\u7ec3,\u8fdb\u884c\u524d\u5411\u4f20\u64ad,\u5f97\u51fa\u8fd9\u6837\u7684\u4e00\u4e2a\u7ed3\u679c<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code># \u5bf9\u6bcf\u4e00\u4e2amini-batch\u8bad\u7ec3\u548c\u8ba1\u7b97\n        for step, (b_x, b_y) in enumerate(train_dataloader):\n            # \u5c06\u7279\u5f81\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n            b_x = b_x.to(device)\n            # \u5c06\u6807\u7b7e\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n            b_y = b_y.to(device)\n            # \u8bbe\u7f6e\u6a21\u578b\u4e3a\u8bad\u7ec3\u6984\u5f0f\n            model.train()\n\n            # \u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u8f93\u5165\u4e3a\u4e00\u4e2abatch,\u8f93\u51fa\u4e3a\u4e00\u4e2abatch\u4e2d\u5bf9\u5e94\u7684\u9884\u6d4b\n            output = model(b_x)\n\n            pre_lab=torch.argmax(output,dim=1)\n            #model\u5df2\u7ecf\u8f93\u51fa\u6765\u4e00\u4e2a\u503c\u4e86\u8f93\u51fa\u6765\u7684\u503c\u8981\u7ecf\u8fc7\u4ec0\u4e48\u7ecf\u8fc7\u8fd9\u4e2a\u4ee3\u7801\u554atorch.arg\n            #\u8fd8\u8bb0\u5f97\u524d\u9762\u6211\u4eec\u8bf4\u6211\u4eec\u6700\u540e\u7684\u8f93\u51fa\u662f\u4e2a\u795e\u7ecf\u5143\u662f10\u4e2a\u503c,\u628a\u8fd9\u4e2a\u8f93\u51fa\u5341\u4e2a\u503c\u8f93\u5165\u5230\u8fd9\u4e2asoft max\u91cc\u9762\n            #\u53d6\u6982\u7387\u6700\u5927\u7684\u4e00\u4e2a\u503c\u4f5c\u4e3a\u6807\u7b7e,\u56e0\u4e3a\u5b83\u7684\u8f93\u51fa\u7684\u503c\u662f\u5341\u4e2a\u503c,\u6240\u4ee5\u67e5\u627e\u6bcf\u4e00\u884c\u5f53\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u884c\u6807\n            #\u6700\u5927\u5bf9\u5e94\u4e0b\u6807\u7684\u4e00\u4e2a\u6570\u503c\n\n            loss = criterion(output, b_y)#b_y\u662f\u6807\u7b7e\u554a\uff0c\u770b\u4f60\u7684\u8f93\u51fa\u548c\u6807\u7b7e\u5229\u7528\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\u505a\u5bf9\u6bd4,output\u5c31\u662f\u4f60\u8bad\u7ec3\u7684\u6a21\u578b\u8bad\u7ec3\u51fa\u6765\u7684\u503c\uff0c\u7136\u540e\u548c\u6807\u7b7e\u503c\u53bb\u7b97\u635f\u5931\n            #\u8fd9\u91cc\u4e0d\u4e86\u89e3\u95ee\u4ec0\u4e48\u7528output\u548cb_y\u505a\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\u7684\u53ef\u4ee5\u53bb\u770b\u4e00\u4e0b\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\u7684\u8ba1\u7b97\u516c\u5f0f\uff0c\u8ba1\u7b97\u8fc7\u7a0b\u9700\u8981\u7528\u5230\u6bcf\u4e00\u4e2a\u6807\u7b7e\u7684\u9884\u6d4b\u6982\u7387\n\n            #\u68af\u5ea6\u503c\u5316\u4e3a0\n            optimizer.zero_grad()\n            #\u53cd\u5411\u4f20\u64ad\u8ba1\u7b97\n            loss.backward()\n            #\u6839\u636e\u7f51\u7edc\u53cd\u5411\u4f20\u64ad\u7684\u68af\u5ea6\u4fe1\u606f\u66f4\u65b0\u7f51\u8def\u53c2\u6570,\u8d77\u5230\u964d\u4f4eloss\u51fd\u6570\u8ba1\u7b97\u503c\u7684\u4f5c\u7528\n            optimizer.step()\n            #\u5bf9\u635f\u5931\u51fd\u6570\u8fdb\u884c\u7d2f\u52a0\n            train_loss += loss.item()*b_x.size()\n            #\u5982\u679c\u9884\u6d4b\u6b63\u786e,\u51c6\u786e\u5ea6\u52a01\n            train_correct +=torch.sum(pre_lab==b_y.data)\n            #\u5f53\u524d\u7528\u4e8e\u8bad\u7ec3\u7684\u6837\u672c\u6570\u91cf\n            train_num += b_x.size<\/code><\/pre>\n\n\n\n<h1 class=\"wp-block-heading\">\u6a21\u578b\u9a8c\u8bc1<\/h1>\n\n\n\n<pre class=\"wp-block-code\"><code>train_loss += loss.item() * b_x.size(0)\n            # \u5982\u679c\u9884\u6d4b\u6b63\u786e,\u51c6\u786e\u5ea6\u52a01\n            train_correct += torch.sum(pre_lab == b_y.data)\n            # \u5f53\u524d\u7528\u4e8e\u8bad\u7ec3\u7684\u6837\u672c\u6570\u91cf\n            train_num += b_x.size(0)<\/code><\/pre>\n\n\n\n<p>\u8fd8\u6709\u6211\u4eec\u7684\u6837\u672c\u6570\u91cf,\u5bf9\u4e0d\u5bf9,\u6211\u4eec\u6765\u89e3\u91ca\u4e00\u4e0b\u8fd9\u91cc\u7684loss\u503c,\u5b83\u83b7\u53d6\u6211\u4eecloss\u503c,\u8fd9\u4e2a\u503c\u662f\u4ec0\u4e48,\u662f\u6bcf\u4e2a\u6837\u672c\u7684,\u5e73\u5747\u503c,\u5e73\u5747loss\u503c,\u7136\u540e\u8fd9\u91cc\u4e3a\u4ec0\u4e48\u8981\u4e58\u4ec0\u4e48,\u4e58\u4ee5\u6211\u4eec\u8be5\u6837\u672c\u7684\u6570\u91cf,\u4e58\u4ee5\u6211\u4eec\u8be5\u6837\u672c\u7684\u6570\u91cf\u5462,\u56e0\u4e3a\u4f60\u8be5\u4f60\u4eec\u4e00\u6279\u8f6e\u6b21\u53c8\u4e00\u6279,\u5c31\u5047\u8bbe\u662f100\u4e2a\u6837\u672c,\u5b83\u7684\u4e00\u4e2a\u5e73\u5747\u503c,\u5047\u8bbe\u5e73\u5747\u5e73\u5747loss\u4e58\u4ee5\u4e58\u4ee5100,\u662f\u4e0d\u662f\u8be5\u6279\u6b21\u7684\u7d2f\u7d2f\u52a0,\u78be\u538bloss\u5bf9\u4e0d\u5bf9,\u7136\u540e\u4f60\u8be5\u6279\u6b21\u7684\u4e00\u4e2a\u7d2f\u52a0loss,\u662f\u4e0d\u662f\u7b49\u4e8e\u6211\u4eec\u90a3\u4e2atrain loss,\u5bf9\u4e0d\u5bf9,\u6ca1\u6709\u95ee\u9898\u5427,\u56e0\u4e3a\u5e94\u8be5\u662f\u521d\u59cb\u5316\u5b83\u662f\u96f6\u561b,\u4f60\u73b0\u5728\u628a\u5b83\u7d2f\u52a0\u5728\u4e00\u8d77\u561b,\u7136\u540e\u5728\u7b2c\u4e8c\u4e2aP4\u7684\u65f6\u5019,\u4f60\u518d\u628a\u4ec0\u4e48,\u628a\u628a\u4ec0\u4e48\u8fd9\u4e2a\u5e73\u5747loss\u4e58\u4ee5\u4ec0\u4e48\u4e58\u4ee5,\u6211\u4eec\u5bf9\u90a3\u4e2a\u6279\u6b21\u518d\u7d2f\u52a0\u5230\u8fd9\u4e2a\u503c\u4e0a\u9762,\u6240\u4ee5\u4f60\u4e0d\u65ad\u7684\u7d2f\u52a0\u5b8c\u4e4b\u540e,\u4f60\u4f1a\u53d1\u73b0\u4f60\u8be5\u5e72\u561b,\u8fd9\u4e2atrain loss\u8fd9\u6837\u7684\u4e00\u4e2a\u4ec0\u4e48\u4e00\u4e2a\u503c,\u5b83\u662f\u6211\u4eec\u6211\u4eec\u5047\u8bbe5\u4e07\u4e2a\u6837\u672c\u7684loss\u503c\u7684\u7d2f\u52a0,\u8fd9\u91cc\u6ca1\u6709\u95ee\u9898\u5427,\u4f60\u53d1\u73b0\u8fd9\u91cc\u662f\u4ec0\u4e48,\u8fd9\u91cc\u662f\u6211\u4eec\u5f53\u524d\u8bad\u7ec3\u6837\u672c\u7684\u6570\u91cf\u5417,\u5c31\u662f\u6bcf\u8fd9\u662f\u4ec0\u4e48,\u8fd9\u662f\u6bcf\u6279\u6b21\u7684\u6570\u503c\u5417,\u6211\u4eec\u5047\u8bbe\u4e00\u6279\u6b21\u662f100,\u5047\u8bbe\u4e94\u4e2a\u6570\u5c31500\u4eba\u4e4b\u540e\u662f\u5427,\u4e94\u4e94\u767e\u4eba\u4e4b\u540e\u4ed6\u5c31\u662f\u4ec0\u4e48,\u4ed6\u5c31\u76f8\u5f53\u4e8e\u5c31\u662f\u6211\u4eec\u6240\u6709\u7684\u6837\u672c\u6570\u91cf,\u4e0d\u5c31\u5df2\u7ecf\u83b7\u5f97\u4e86\u5417,\u8fd9\u91cc\u80af\u5b9a\u662f\u6ca1\u6709\u95ee\u9898,\u5bf9\u4e0d\u5bf9,\u7136\u540e\u7684\u8bdd\u8fd9\u91cc\u662f\u4ec0\u4e48,\u662f\u6211\u4eec\u6240\u6709\u6837\u672c\u7684loss\u503c\u7684\u7d2f\u52a0,\u5bf9\u4e0d\u5bf9,\u6211\u628a\u90a3\u6837\u7684loss\u7d2f\u52a0\u518d\u9664\u4ee5\u4ec0\u4e48,\u9664\u4ee5\u6211\u4eec\u7684level\u662f\u4e0d\u662f\u5c31\u662f\u4ec0\u4e48,\u6211\u4eec\u8be5\u8f6e\u6b21\u7684\u6839\u561b,\u5e73\u5747\u90a3\u4e2aloss\u503c\u80fd\u660e\u767d\u6211\u7684\u610f\u601d\u5417,\u6240\u4ee5\u6240\u4ee5\u6211\u4eec\u90a3\u4e2a\u4ec0\u4e48,\u6211\u4eec\u9a8c\u8bc1\u673a\u4e0d\u4e5f\u662f\u4e00\u6837\u5417,\u6240\u4ee5\u4f60\u4f1a\u53d1\u73b0\u5c31\u662f\u6211\u4eecloss\u503c\u9664\u4ee5\u4ec0\u4e48,\u9664\u4ee5\u6211\u4eec\u7684\u6837\u672c\u6570\u91cf,\u5c31\u662f\u8be5\u8f6e\u6b21\u7684\u4e00\u4e2a\u5e73\u5747\u7684\u4e00\u4e2aloss\u503c,\u52a0\u5230\u6211\u4eec\u8fd9\u6837\u7684\u4e00\u4e2a\u5217\u8868\u91cc\u9762\u554a,\u6211\u4eec\u6240\u4ee5\u8be5\u8be5\u5217\u8868\u91cc\u9762\u5c31\u83b7\u53d6\u4e86\u4ec0\u4e48,\u8be5\u5217\u8868\u4e00\u5f00\u59cb\u662f\u7a7a\u7684\u561b,\u6211\u4eec\u5b9a\u4e49\u8fd9\u4e2a\u662f\u7a7a\u7684\u561b,\u83b7\u53d6\u7b2c\u4e00\u4e2a\u4ec0\u4e48\u7b2c\u4e00\u6b21\u8bad\u7ec3\u7684\u4e00\u4e2a\u4ec0\u4e48,\u7b2c\u4e00\u6b21\u8bad\u7ec3\u8f6e\u6b21\u7684\u4e00\u4e2a\u4ec0\u4e48loss\u503c,\u80fd\u660e\u767d\u6211\u7684\u610f\u601d\u5417,\u8fd9\u91cc\u518d\u8bb2\u4e00\u4e0b\u554a,\u8fd9\u91cc\u6709\u70b9\u6666\u6da9\u96be\u61c2\u554a,\u8fd9\u91cc\u7684\u4e00\u4e2a\u4ec0\u4e48loss\u503c,\u662f\u6211\u4eec\u8be5\u6279\u6b21\u6837\u672c\u7684\u4e00\u4e2a\u4ec0\u4e48,\u8be5\u6279\u6b21\u6bd4\u5982\u8bf4100\u4e2a100\u4e2a100\u4e2a\u6837\u672c,\u5b83\u7684\u5e73\u5747loss,\u7136\u540e\u4e58\u4ee5\u5bf9\u5e94\u7684\u561b,\u8be5\u6279\u6b21\u7684\u4e00\u4e2a\u6837\u672c\u6570\u91cf\u4e58\u4ee5\u5bf9\u5e94P4,\u8be5\u6837\u672c\u7684\u6570\u91cf\u5c31\u662f\u4ec0\u4e48,\u5c31\u8be5\u8f6e\u6b21\u7684\u5c31\u662f\u4e0d\u8be5\u5c31\u8be5\u6279\u6b21\u6837\u672c\u7684\u6570\u91cf,\u5c31\u662ftrain loss,\u662f\u4e0d\u662f\u662f\u5427,Train loss,\u7136\u540e\u6211\u4eec\u4e00\u76f4\u5728\u5faa\u73af\u561b\u662f\u5427,\u628a\u76f4\u5230\u628a\u6240\u6709\u7684\u6570\u636e\u662f\u5427,\u5443\u8bad\u7ec3\u5b8c\u4e4b\u540e,\u662f\u4e0d\u662f\u5230\u6700\u540e\u6700\u540e\u4e00\u4e2a\u8865\u7684,\u5c31\u662f\u6211\u4eec\u5c31\u662f\u6700\u540e\u4e00\u4e2a\u6279\u6b21,\u5bf9\u4e0d\u5bf9,\u6700\u540e\u4e00\u4e2a\u6279\u6b21\u7684\u65f6\u5019\u554a,\u8ba1\u7b97\u5b8c\u4e4b\u540e,\u8fd9\u91cc\u7684\u503c\u662f\u4e0d\u662f\u5305\u542b\u4e86\u6211\u4eec\u6240\u6709\u6837\u672c\u7684,\u8be5\u6837\u672c\u7684\u8be5\u8f6e\u8be5\u8f6e\u6b21,\u8be5\u6837\u672c\u7684\u6240\u6709\u7684loss\u503c\u4e86,\u6240\u6709loss\u503c,\u5bf9\u4e0d\u5bf9\u554a,\u6240\u6709\u7684loss\u503c,\u7136\u540e\u8fd9\u91cc\u83b7\u53d6\u7684\u662f\u4ec0\u4e48\u5443,\u8be5\u8f6e\u6b21\u7684\u6240\u6709\u7684\u6837\u672c\u7684\u6570\u91cf,\u4e5f\u5c31\u662f\u6211\u4eec\u8bad\u7ec3\u6837\u672c\u7684\u603b\u6570\u91cf\u561b,\u90fd\u61c2\u6211\u610f\u601d\u5427,\u5b9e\u9645\u4e0a\u4f60\u8fd9\u91cc\u4e0d\u8fd9\u4e48\u5199,\u4f60\u4f60\u5728\u524d\u9762\u5176\u5b9e\u4e5f\u80fd\u7b97\u51fa\u6765\u5417,\u90a3\u6211\u4eec\u524d\u9762\u8fd9\u91cc\u5df2\u7ecf\u5df2\u7ecf\u5df2\u7ecf\u7b97\u8fc7\u4e86,\u5bf9\u4e0d\u5bf9,\u8fd9\u4e2a\u5df2\u7ecf\u7b97\u8fc7\u4e86,\u53ea\u662f\u4e3a\u4e86\u65b9\u4fbf\u561b,\u8ddf\u5927\u5bb6\u7edf\u4e00\u561b,\u4f60\u8fd9\u91cc\u90fd\u521d\u59cb\u5316\u4e86,\u6211\u8fd9\u91cc\u521d\u59cb\u5316\u4e00\u4e0b,\u5bf9\u4e0d\u5bf9\u662f\u5427,\u7136\u7136\u540e\u7684\u8bdd,\u4f60\u4f60\u4f60\u8fd9\u4e2a\u8fd9\u91cc\u9762\u90a3\u4e2a\u503c,\u5c31\u662f\u6211\u4eec\u6240\u6709\u6837\u672c\u7684\u4e00\u4e2aloss\u503c\u7684\u603b\u548c,\u9664\u4ee5\u4ec0\u4e48\u6211\u4eec\u6837\u672c\u6570\u91cf,\u90a3\u4e48\u5c31\u8be5\u8f6e\u6b21\u5e73\u5747\u7684\u4e00\u4e2aloss\u503c\u4e86\u5417\u662f\u5427<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import time\n\nimport torch.optim\nfrom torch import nn\nfrom torchvision.datasets import FashionMNIST\nfrom torchvision import transforms\nimport torch.utils.data as Data\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom model import LeNet\n\n\ndef train_val_data_process():\n    train_data = FashionMNIST(root='.\/data',\n                              train=True,\n                              transform=transforms.Compose(&#91;transforms.Resize(size=28), transforms.ToTensor()]),\n                              download=True)\n    train_data, val_data = Data.random_split(train_data, &#91;round(0.8 * len(train_data)), round(0.2 * len(train_data))])\n    train_dataloader = Data.DataLoader(dataset=train_data,\n                                       batch_size=128,\n                                       shuffle=True,\n                                       num_workers=8)\n\n    val_dataloader = Data.DataLoader(dataset=val_data,\n                                     batch_size=128,\n                                     shuffle=True,\n                                     num_workers=8)\n\n    return train_dataloader, val_dataloader\n\n\ndef train_model(model, train_dataloader, val_dataloader, num_epochs):\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # \u4f18\u5316\u5668,\u68af\u5ea6\u4e0b\u964d\u6cd5\u8fdb\u4e00\u6b65\u62d3\u5c55\u662fSGD,Adam\u548cSGDM\u7b49\u7b49\u7b49\u7b49\u7684\n\n    criterion = nn.CrossEntropyLoss()\n    # \u5747\u65b9\u635f\u5931,\u5206\u7c7b\u4e2d\u6211\u4eec\u4e00\u822c\u7528\u4ea4\u53c9\u71b5\u635f\u5931\u6765\u66f4\u65b0\u635f\u5931\u503c,\u7136\u540e\u662fw,b\n    # \u5c06\u6a21\u578b\u653e\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n    model = model.to(device)\n    # \u590d\u5236\u5f53\u524d\u6a21\u578b\u7684\u53c2\u6570\n    best_model_wts = copy.deepcopy(model.state_dict())\n\n    # \u521d\u59cb\u5316\u53c2\u6570\n    # \u6700\u9ad8\u51c6\u786e\u5ea6\n    best_acc = 0.0\n    # \u8bad\u7ec3\u96c6\u635f\u5931\u5217\u8868\n    train_loss_all = &#91;]\n    # \u9a8c\u8bc1\u96c6\u635f\u5931\u5217\u8868\n    val_loss_all = &#91;]\n    # \u8bad\u7ec3\u96c6\u51c6\u786e\u5ea6\u5217\u8868\n    train_acc_all = &#91;]\n    # \u9a8c\u8bc1\u96c6\u51c6\u786e\u5ea6\u5217\u8868\n    val_acc_all = &#91;]\n    # \u5f53\u524d\u65f6\u95f4\n    since = time.time()\n\n    for epoch in range(num_epochs):\n        print('Epoch {}\/{}'.format(epoch + 1, num_epochs))\n        print(\"-\" * 10)\n\n        train_loss = 0.0\n        train_correct = 0.0\n        val_loss = 0.0\n        val_correct = 0.0\n\n        train_num = 0\n        val_num = 0\n\n        # \u5bf9\u6bcf\u4e00\u4e2amini-batch\u8bad\u7ec3\u548c\u8ba1\u7b97\n        for step, (b_x, b_y) in enumerate(train_dataloader):\n            # \u5c06\u7279\u5f81\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n            b_x = b_x.to(device)\n            # \u5c06\u6807\u7b7e\u653e\u5165\u5230\u8bad\u7ec3\u8bbe\u5907\u4e2d\n            b_y = b_y.to(device)\n            # \u8bbe\u7f6e\u6a21\u578b\u4e3a\u8bad\u7ec3\u6984\u5f0f\n            model.train()\n\n            # \u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u8f93\u5165\u4e3a\u4e00\u4e2abatch,\u8f93\u51fa\u4e3a\u4e00\u4e2abatch\u4e2d\u5bf9\u5e94\u7684\u9884\u6d4b\n            output = model(b_x)\n\n            pre_lab = torch.argmax(output, dim=1)\n            # model\u5df2\u7ecf\u8f93\u51fa\u6765\u4e00\u4e2a\u503c\u4e86\u8f93\u51fa\u6765\u7684\u503c\u8981\u7ecf\u8fc7\u4ec0\u4e48\u7ecf\u8fc7\u8fd9\u4e2a\u4ee3\u7801\u554atorch.arg\n            # \u8fd8\u8bb0\u5f97\u524d\u9762\u6211\u4eec\u8bf4\u6211\u4eec\u6700\u540e\u7684\u8f93\u51fa\u662f\u4e2a\u795e\u7ecf\u5143\u662f10\u4e2a\u503c,\u628a\u8fd9\u4e2a\u8f93\u51fa\u5341\u4e2a\u503c\u8f93\u5165\u5230\u8fd9\u4e2asoft max\u91cc\u9762\n            # \u53d6\u6982\u7387\u6700\u5927\u7684\u4e00\u4e2a\u503c\u4f5c\u4e3a\u6807\u7b7e,\u56e0\u4e3a\u5b83\u7684\u8f93\u51fa\u7684\u503c\u662f\u5341\u4e2a\u503c,\u6240\u4ee5\u67e5\u627e\u6bcf\u4e00\u884c\u5f53\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u884c\u6807\n            # \u6700\u5927\u5bf9\u5e94\u4e0b\u6807\u7684\u4e00\u4e2a\u6570\u503c\n\n            loss = criterion(output, b_y)  # b_y\u662f\u6807\u7b7e\u554a\uff0c\u770b\u4f60\u7684\u8f93\u51fa\u548c\u6807\u7b7e\u5229\u7528\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\u505a\u5bf9\u6bd4,output\u5c31\u662f\u4f60\u8bad\u7ec3\u7684\u6a21\u578b\u8bad\u7ec3\u51fa\u6765\u7684\u503c\uff0c\u7136\u540e\u548c\u6807\u7b7e\u503c\u53bb\u7b97\u635f\u5931\n            # \u8fd9\u91cc\u4e0d\u4e86\u89e3\u95ee\u4ec0\u4e48\u7528output\u548cb_y\u505a\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\u7684\u53ef\u4ee5\u53bb\u770b\u4e00\u4e0b\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\u7684\u8ba1\u7b97\u516c\u5f0f\uff0c\u8ba1\u7b97\u8fc7\u7a0b\u9700\u8981\u7528\u5230\u6bcf\u4e00\u4e2a\u6807\u7b7e\u7684\u9884\u6d4b\u6982\u7387\n\n            # \u68af\u5ea6\u503c\u5316\u4e3a0\n            optimizer.zero_grad()\n            # \u53cd\u5411\u4f20\u64ad\u8ba1\u7b97\n            loss.backward()\n            # \u6839\u636e\u7f51\u7edc\u53cd\u5411\u4f20\u64ad\u7684\u68af\u5ea6\u4fe1\u606f\u66f4\u65b0\u7f51\u8def\u53c2\u6570,\u8d77\u5230\u964d\u4f4eloss\u51fd\u6570\u8ba1\u7b97\u503c\u7684\u4f5c\u7528\n            optimizer.step()\n            # \u5bf9\u635f\u5931\u51fd\u6570\u8fdb\u884c\u7d2f\u52a0\n            train_loss += loss.item() * b_x.size(0)\n            # \u5982\u679c\u9884\u6d4b\u6b63\u786e,\u51c6\u786e\u5ea6\u52a01\n            train_correct += torch.sum(pre_lab == b_y.data)\n            # \u5f53\u524d\u7528\u4e8e\u8bad\u7ec3\u7684\u6837\u672c\u6570\u91cf\n            train_num += b_x.size(0)\n\n            for step, (b_x, b_y) in enumerate(val_dataloader):\n                # \u5c06\u7279\u5f81\u653e\u5165\u5230\u9a8c\u8bc1\u8bbe\u5907\u4e2d\n                b_x = b_x.to(device)\n                # \u5c06\u6807\u7b7e\u653e\u5165\u5230\u9a8c\u8bc1\u8bbe\u5907\u4e2d\n                b_y = b_y.to(device)\n                # \u8bbe\u7f6e\u6a21\u578b\u4e3a\u8bc4\u4f30\u6a21\u5f0f\n                model.eval()\n                # \u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\uff0c\u8f93\u5165\u4e3a\u4e00\u4e2apatch,\u8f93\u51fa\u4e3a\u4e00\u4e2abatch\u4e2d\u5bf9\u5e94\u7684\u9884\u6d4b\n                output = model(b_x)\n                # \u67e5\u627e\u6bcf\u4e00\u884c\u4e2d\u6700\u5927\u503c\u5bf9\u5e94\u7684\u884c\u6807\n                pre_lab = torch.argmax(output, dim=1)\n                # \u8ba1\u7b97\u6bcf\u4e00\u4e2abatch\u7684\u635f\u5931\u51fd\u6570\n                loss = criterion(output, b_y)\n                # \u5bf9\u635f\u5931\u51fd\u6570\u8fdb\u884c\u7d2f\u52a0\n                val_loss + loss.item() * b_x.size(0)\n                # \u5982\u679c\u9884\u6d4b\u6b63\u786e\uff0c\u5219\u51c6\u786e\u5ea6train.-corrects\u52a01\n                val_correct + torch.sum(pre_lab == b_y.data)\n                # \u5f53\u524d\u7528\u4e8e\u9a8c\u8bc1\u7684\u6837\u672c\u6570\u91cf\n                val_num + b_x.size(0)\n\n        #\u8ba1\u7b97\u5e76\u4fdd\u5b58\u6bcf\u4e00\u6b21\u8fed\u4ee3\u7684loss\u503c\u548c\u51c6\u786e\u7387:\u5176\u5b9e\u5c31\u662f\u6bcf\u6b21\u6c42\u5b8closs\u4e4b\u540e\u53e0\u52a0\uff0c\u6ca1\u5565\u96be\u7406\u89e3\u7684\uff0c\u65e0\u975e\u591a\u4e86\u4e2a\u5e73\u5747\u64cd\u4f5c\n        train_loss_all.append(train_loss\/train_num)\n        train_acc_all.append(train_correct.double().item()\/train_num)\n\n        #\n        val_loss_all.append(val_loss\/val_num)\n        val_acc_all.append(val_correct.double().item()\/val_num)\n\n        print('{} Train Loss:{:.4f} Train Accuracy:{:4f}'.format(epoch, train_loss, train_acc_all&#91;-1]))\n        print('{} Val Loss:{:.4f} Val Accuracy:{:4f}'.format(epoch, val_loss_all&#91;-1], val_acc_all&#91;-1]))\n<\/code><\/pre>\n\n\n\n<p>P50-P60\u8df3\u4e86 \u200b<\/p>\n","protected":false},"excerpt":{"rendered":"<p>\u642d\u5efa\u795e\u7ecf\u7f51\u7edc\u6a21\u578b\u7684\u8fc7\u7a0b\u53ef\u4ee5\u603b\u7ed3\u4e3a\u4ee5\u4e0b\u6b65\u9aa4\uff1a \u8fd9\u4e2a\u6d41\u7a0b\u7c7b\u4f3c\u4e8e\u642d\u5efa\u4e00\u5ea7\u5efa\u7b51\u7269\uff0c\u5148\u51c6\u5907\u597d\u5404\u7c7b\u6750\u6599\u548c\u5de5\u5177\uff08\u5373\u521d\u59cb\u5316\u7f51\u7edc\u5c42\u548c\u53c2\u6570\uff09\uff0c\u7136\u540e\u6839\u636e &#8230;<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"emotion":"","emotion_color":"","title_style":"","license":"","footnotes":""},"categories":[3],"tags":[6,9],"class_list":["post-444","post","type-post","status-publish","format-standard","hentry","category-3","tag-cv","tag-9"],"_links":{"self":[{"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/posts\/444","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/eve2333.top\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=444"}],"version-history":[{"count":0,"href":"https:\/\/eve2333.top\/index.php?rest_route=\/wp\/v2\/posts\/444\/revisions"}],"wp:attachment":[{"href":"https:\/\/eve2333.top\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=444"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/eve2333.top\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=444"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/eve2333.top\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=444"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}