PyTorch 模型保存与再训练,基于 MNIST 数据集。
导入依赖包 1 2 3 4 5 6 7 8 9 import osimport matplotlib.pyplot as pltimport torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimimport torchvisionfrom torch.utils.data import DataLoader
加载数据 1 2 3 4 5 6 7 8 9 n_epochs = 10 batch_size_train = 128 batch_size_test = 1000 learning_rate = 0.01 momentum = 0.5 log_interval = 20 torch.manual_seed(33 ) device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu" ) device
device(type='cuda', index=1)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 transform = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307 ,), (0.3081 ,)), ] ) train_data = torchvision.datasets.MNIST( "/workspace/disk1/datasets" , train=True , download=True , transform=transform, ) train_loader = torch.utils.data.DataLoader( train_data, batch_size=batch_size_train, shuffle=True , ) test_data = torchvision.datasets.MNIST( "/workspace/disk1/datasets/" , train=False , download=True , transform=transform, ) test_loader = torch.utils.data.DataLoader( test_data, batch_size=batch_size_test, shuffle=True , )
1 2 3 4 examples = enumerate (test_loader) batch_idx, (example_data, example_targets) = next (examples) print (example_data.shape)
torch.Size([1000, 1, 28, 28])
1 2 3 4 5 6 7 8 9 fig = plt.figure() for i in range (6 ): plt.subplot(2 , 3 , i + 1 ) plt.tight_layout() plt.imshow(example_data[i][0 ], cmap="gray" , interpolation="none" ) plt.title(f"Ground Truth: {example_targets[i]} " ) plt.xticks([]) plt.yticks([]) plt.show()
构建模型和优化算法 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 class Net (nn.Module): def __init__ (self ): super (Net, self).__init__() self.conv1 = nn.Conv2d(1 , 10 , kernel_size=5 ) self.conv2 = nn.Conv2d(10 , 20 , kernel_size=5 ) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320 , 50 ) self.fc2 = nn.Linear(50 , 10 ) def forward (self, x ): x = F.relu(F.max_pool2d(self.conv1(x), 2 )) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2 )) x = x.view(-1 , 320 ) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = F.log_softmax(self.fc2(x), dim=1 ) return x
1 2 3 4 5 6 network = Net().to(device) optimizer = optim.SGD( network.parameters(), lr=learning_rate, momentum=momentum, )
模型训练与保存 1 2 3 4 train_losses = [] train_counter = [] test_losses = [] test_counter = [i * len (train_loader.dataset) for i in range (n_epochs)]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 def train (epoch ): network.train() for batch_idx, (data, target) in enumerate (train_loader): data = data.to(device) target = target.to(device) optimizer.zero_grad() output = network(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % log_interval == 0 : print ( f"Train Epoch: {epoch} [{str (batch_idx * len (data)).zfill(5 )} /{len (train_loader.dataset)} ({100. * batch_idx/ len (train_loader):.0 f} %)]\tLoss: {loss.item():.6 f} " ) train_losses.append(loss.item()) train_counter.append( (batch_idx * 64 ) + ((epoch - 1 ) * len (train_loader.dataset)) ) model_checkpoint = f"/workspace/disk1/datasets/models/mnist/model_epoch{epoch} .pth" optimizer_checkpoint = ( f"/workspace/disk1/datasets/models/mnist/optimizer_epoch{epoch} .pth" ) latest_model = "/workspace/disk1/datasets/models/mnist/model_latest.pth" latest_optimizer = "/workspace/disk1/datasets/models/mnist/optimizer_latest.pth" torch.save(network.state_dict(), model_checkpoint) torch.save(optimizer.state_dict(), optimizer_checkpoint) if os.path.exists(latest_model): os.remove(latest_model) if os.path.exists(latest_optimizer): os.remove(latest_optimizer) os.symlink(model_checkpoint, latest_model) os.symlink(optimizer_checkpoint, latest_optimizer)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def test (): network.eval () test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data = data.to(device) target = target.to(device) output = network(data) test_loss += F.nll_loss(output, target, reduction="sum" ).item() pred = output.data.max (1 , keepdim=True )[1 ] correct += pred.eq(target.data.view_as(pred)).sum () test_loss /= len (test_loader.dataset) test_losses.append(test_loss) print ( f"\nTest set: Avg. loss: {test_loss:.4 f} , Accuracy: {correct} /{len (test_loader.dataset)} ({100. *correct/len (test_loader.dataset):.0 f} %)\n" )
1 2 3 for epoch in range (1 , n_epochs + 1 ): train(epoch) test()
Train Epoch: 1 [00000/60000 (0%)] Loss: 2.288714
Train Epoch: 1 [02560/60000 (4%)] Loss: 2.255195
Train Epoch: 1 [05120/60000 (9%)] Loss: 2.198278
Train Epoch: 1 [07680/60000 (13%)] Loss: 2.091046
Train Epoch: 1 [10240/60000 (17%)] Loss: 1.832017
Train Epoch: 1 [12800/60000 (21%)] Loss: 1.800682
Train Epoch: 1 [15360/60000 (26%)] Loss: 1.533700
Train Epoch: 1 [17920/60000 (30%)] Loss: 1.332488
Train Epoch: 1 [20480/60000 (34%)] Loss: 1.289864
Train Epoch: 1 [23040/60000 (38%)] Loss: 1.195423
Train Epoch: 1 [25600/60000 (43%)] Loss: 1.075095
Train Epoch: 1 [28160/60000 (47%)] Loss: 1.004397
Train Epoch: 1 [30720/60000 (51%)] Loss: 0.832168
Train Epoch: 1 [33280/60000 (55%)] Loss: 0.849826
Train Epoch: 1 [35840/60000 (60%)] Loss: 0.968393
Train Epoch: 1 [38400/60000 (64%)] Loss: 0.866699
Train Epoch: 1 [40960/60000 (68%)] Loss: 0.721661
Train Epoch: 1 [43520/60000 (72%)] Loss: 0.962982
Train Epoch: 1 [46080/60000 (77%)] Loss: 0.695050
Train Epoch: 1 [48640/60000 (81%)] Loss: 0.799360
Train Epoch: 1 [51200/60000 (85%)] Loss: 0.683803
Train Epoch: 1 [53760/60000 (90%)] Loss: 0.561625
Train Epoch: 1 [56320/60000 (94%)] Loss: 0.579419
Train Epoch: 1 [58880/60000 (98%)] Loss: 0.568427
Test set: Avg. loss: 0.3283, Accuracy: 9096/10000 (91%)
Train Epoch: 2 [00000/60000 (0%)] Loss: 0.781268
Train Epoch: 2 [02560/60000 (4%)] Loss: 0.679616
Train Epoch: 2 [05120/60000 (9%)] Loss: 0.572369
Train Epoch: 2 [07680/60000 (13%)] Loss: 0.638919
Train Epoch: 2 [10240/60000 (17%)] Loss: 0.721595
Train Epoch: 2 [12800/60000 (21%)] Loss: 0.723331
Train Epoch: 2 [15360/60000 (26%)] Loss: 0.604582
Train Epoch: 2 [17920/60000 (30%)] Loss: 0.689362
Train Epoch: 2 [20480/60000 (34%)] Loss: 0.548551
Train Epoch: 2 [23040/60000 (38%)] Loss: 0.650297
Train Epoch: 2 [25600/60000 (43%)] Loss: 0.506893
Train Epoch: 2 [28160/60000 (47%)] Loss: 0.581708
Train Epoch: 2 [30720/60000 (51%)] Loss: 0.557465
Train Epoch: 2 [33280/60000 (55%)] Loss: 0.461637
Train Epoch: 2 [35840/60000 (60%)] Loss: 0.619341
Train Epoch: 2 [38400/60000 (64%)] Loss: 0.464800
Train Epoch: 2 [40960/60000 (68%)] Loss: 0.473921
Train Epoch: 2 [43520/60000 (72%)] Loss: 0.567576
Train Epoch: 2 [46080/60000 (77%)] Loss: 0.447070
Train Epoch: 2 [48640/60000 (81%)] Loss: 0.503476
Train Epoch: 2 [51200/60000 (85%)] Loss: 0.500809
Train Epoch: 2 [53760/60000 (90%)] Loss: 0.553329
Train Epoch: 2 [56320/60000 (94%)] Loss: 0.504529
Train Epoch: 2 [58880/60000 (98%)] Loss: 0.457889
Test set: Avg. loss: 0.2128, Accuracy: 9357/10000 (94%)
Train Epoch: 3 [00000/60000 (0%)] Loss: 0.369795
Train Epoch: 3 [02560/60000 (4%)] Loss: 0.414531
Train Epoch: 3 [05120/60000 (9%)] Loss: 0.604378
Train Epoch: 3 [07680/60000 (13%)] Loss: 0.426111
Train Epoch: 3 [10240/60000 (17%)] Loss: 0.492895
Train Epoch: 3 [12800/60000 (21%)] Loss: 0.393350
Train Epoch: 3 [15360/60000 (26%)] Loss: 0.555914
Train Epoch: 3 [17920/60000 (30%)] Loss: 0.476940
Train Epoch: 3 [20480/60000 (34%)] Loss: 0.430539
Train Epoch: 3 [23040/60000 (38%)] Loss: 0.571562
Train Epoch: 3 [25600/60000 (43%)] Loss: 0.488061
Train Epoch: 3 [28160/60000 (47%)] Loss: 0.598932
Train Epoch: 3 [30720/60000 (51%)] Loss: 0.417797
Train Epoch: 3 [33280/60000 (55%)] Loss: 0.486182
Train Epoch: 3 [35840/60000 (60%)] Loss: 0.375228
Train Epoch: 3 [38400/60000 (64%)] Loss: 0.384777
Train Epoch: 3 [40960/60000 (68%)] Loss: 0.428213
Train Epoch: 3 [43520/60000 (72%)] Loss: 0.443456
Train Epoch: 3 [46080/60000 (77%)] Loss: 0.308731
Train Epoch: 3 [48640/60000 (81%)] Loss: 0.347535
Train Epoch: 3 [51200/60000 (85%)] Loss: 0.439240
Train Epoch: 3 [53760/60000 (90%)] Loss: 0.460515
Train Epoch: 3 [56320/60000 (94%)] Loss: 0.521822
Train Epoch: 3 [58880/60000 (98%)] Loss: 0.451665
Test set: Avg. loss: 0.1591, Accuracy: 9513/10000 (95%)
Train Epoch: 4 [00000/60000 (0%)] Loss: 0.510825
Train Epoch: 4 [02560/60000 (4%)] Loss: 0.279729
Train Epoch: 4 [05120/60000 (9%)] Loss: 0.333447
Train Epoch: 4 [07680/60000 (13%)] Loss: 0.434208
Train Epoch: 4 [10240/60000 (17%)] Loss: 0.516646
Train Epoch: 4 [12800/60000 (21%)] Loss: 0.339009
Train Epoch: 4 [15360/60000 (26%)] Loss: 0.342047
Train Epoch: 4 [17920/60000 (30%)] Loss: 0.315687
Train Epoch: 4 [20480/60000 (34%)] Loss: 0.365422
Train Epoch: 4 [23040/60000 (38%)] Loss: 0.408456
Train Epoch: 4 [25600/60000 (43%)] Loss: 0.512156
Train Epoch: 4 [28160/60000 (47%)] Loss: 0.246768
Train Epoch: 4 [30720/60000 (51%)] Loss: 0.254370
Train Epoch: 4 [33280/60000 (55%)] Loss: 0.403202
Train Epoch: 4 [35840/60000 (60%)] Loss: 0.364687
Train Epoch: 4 [38400/60000 (64%)] Loss: 0.313392
Train Epoch: 4 [40960/60000 (68%)] Loss: 0.256359
Train Epoch: 4 [43520/60000 (72%)] Loss: 0.306669
Train Epoch: 4 [46080/60000 (77%)] Loss: 0.459862
Train Epoch: 4 [48640/60000 (81%)] Loss: 0.227380
Train Epoch: 4 [51200/60000 (85%)] Loss: 0.368363
Train Epoch: 4 [53760/60000 (90%)] Loss: 0.329823
Train Epoch: 4 [56320/60000 (94%)] Loss: 0.304764
Train Epoch: 4 [58880/60000 (98%)] Loss: 0.288302
Test set: Avg. loss: 0.1322, Accuracy: 9596/10000 (96%)
Train Epoch: 5 [00000/60000 (0%)] Loss: 0.247250
Train Epoch: 5 [02560/60000 (4%)] Loss: 0.383674
Train Epoch: 5 [05120/60000 (9%)] Loss: 0.315615
Train Epoch: 5 [07680/60000 (13%)] Loss: 0.259549
Train Epoch: 5 [10240/60000 (17%)] Loss: 0.322165
Train Epoch: 5 [12800/60000 (21%)] Loss: 0.474284
Train Epoch: 5 [15360/60000 (26%)] Loss: 0.254566
Train Epoch: 5 [17920/60000 (30%)] Loss: 0.466621
Train Epoch: 5 [20480/60000 (34%)] Loss: 0.320438
Train Epoch: 5 [23040/60000 (38%)] Loss: 0.316798
Train Epoch: 5 [25600/60000 (43%)] Loss: 0.192171
Train Epoch: 5 [28160/60000 (47%)] Loss: 0.478387
Train Epoch: 5 [30720/60000 (51%)] Loss: 0.346078
Train Epoch: 5 [33280/60000 (55%)] Loss: 0.315769
Train Epoch: 5 [35840/60000 (60%)] Loss: 0.284116
Train Epoch: 5 [38400/60000 (64%)] Loss: 0.354763
Train Epoch: 5 [40960/60000 (68%)] Loss: 0.369070
Train Epoch: 5 [43520/60000 (72%)] Loss: 0.233857
Train Epoch: 5 [46080/60000 (77%)] Loss: 0.222160
Train Epoch: 5 [48640/60000 (81%)] Loss: 0.325135
Train Epoch: 5 [51200/60000 (85%)] Loss: 0.226506
Train Epoch: 5 [53760/60000 (90%)] Loss: 0.407618
Train Epoch: 5 [56320/60000 (94%)] Loss: 0.359771
Train Epoch: 5 [58880/60000 (98%)] Loss: 0.290950
Test set: Avg. loss: 0.1152, Accuracy: 9643/10000 (96%)
Train Epoch: 6 [00000/60000 (0%)] Loss: 0.371519
Train Epoch: 6 [02560/60000 (4%)] Loss: 0.294658
Train Epoch: 6 [05120/60000 (9%)] Loss: 0.195041
Train Epoch: 6 [07680/60000 (13%)] Loss: 0.247292
Train Epoch: 6 [10240/60000 (17%)] Loss: 0.307761
Train Epoch: 6 [12800/60000 (21%)] Loss: 0.183960
Train Epoch: 6 [15360/60000 (26%)] Loss: 0.255224
Train Epoch: 6 [17920/60000 (30%)] Loss: 0.564046
Train Epoch: 6 [20480/60000 (34%)] Loss: 0.217146
Train Epoch: 6 [23040/60000 (38%)] Loss: 0.364980
Train Epoch: 6 [25600/60000 (43%)] Loss: 0.237876
Train Epoch: 6 [28160/60000 (47%)] Loss: 0.344803
Train Epoch: 6 [30720/60000 (51%)] Loss: 0.347686
Train Epoch: 6 [33280/60000 (55%)] Loss: 0.197488
Train Epoch: 6 [35840/60000 (60%)] Loss: 0.346718
Train Epoch: 6 [38400/60000 (64%)] Loss: 0.256105
Train Epoch: 6 [40960/60000 (68%)] Loss: 0.211900
Train Epoch: 6 [43520/60000 (72%)] Loss: 0.264353
Train Epoch: 6 [46080/60000 (77%)] Loss: 0.339571
Train Epoch: 6 [48640/60000 (81%)] Loss: 0.198715
Train Epoch: 6 [51200/60000 (85%)] Loss: 0.335813
Train Epoch: 6 [53760/60000 (90%)] Loss: 0.244630
Train Epoch: 6 [56320/60000 (94%)] Loss: 0.260668
Train Epoch: 6 [58880/60000 (98%)] Loss: 0.281284
Test set: Avg. loss: 0.1029, Accuracy: 9682/10000 (97%)
Train Epoch: 7 [00000/60000 (0%)] Loss: 0.259941
Train Epoch: 7 [02560/60000 (4%)] Loss: 0.353288
Train Epoch: 7 [05120/60000 (9%)] Loss: 0.345746
Train Epoch: 7 [07680/60000 (13%)] Loss: 0.263094
Train Epoch: 7 [10240/60000 (17%)] Loss: 0.370562
Train Epoch: 7 [12800/60000 (21%)] Loss: 0.184917
Train Epoch: 7 [15360/60000 (26%)] Loss: 0.358648
Train Epoch: 7 [17920/60000 (30%)] Loss: 0.358313
Train Epoch: 7 [20480/60000 (34%)] Loss: 0.455060
Train Epoch: 7 [23040/60000 (38%)] Loss: 0.157829
Train Epoch: 7 [25600/60000 (43%)] Loss: 0.255777
Train Epoch: 7 [28160/60000 (47%)] Loss: 0.296378
Train Epoch: 7 [30720/60000 (51%)] Loss: 0.220109
Train Epoch: 7 [33280/60000 (55%)] Loss: 0.207805
Train Epoch: 7 [35840/60000 (60%)] Loss: 0.333757
Train Epoch: 7 [38400/60000 (64%)] Loss: 0.351853
Train Epoch: 7 [40960/60000 (68%)] Loss: 0.225360
Train Epoch: 7 [43520/60000 (72%)] Loss: 0.220420
Train Epoch: 7 [46080/60000 (77%)] Loss: 0.281292
Train Epoch: 7 [48640/60000 (81%)] Loss: 0.224555
Train Epoch: 7 [51200/60000 (85%)] Loss: 0.300659
Train Epoch: 7 [53760/60000 (90%)] Loss: 0.155560
Train Epoch: 7 [56320/60000 (94%)] Loss: 0.322972
Train Epoch: 7 [58880/60000 (98%)] Loss: 0.189427
Test set: Avg. loss: 0.0973, Accuracy: 9693/10000 (97%)
Train Epoch: 8 [00000/60000 (0%)] Loss: 0.237282
Train Epoch: 8 [02560/60000 (4%)] Loss: 0.205023
Train Epoch: 8 [05120/60000 (9%)] Loss: 0.178199
Train Epoch: 8 [07680/60000 (13%)] Loss: 0.335434
Train Epoch: 8 [10240/60000 (17%)] Loss: 0.274767
Train Epoch: 8 [12800/60000 (21%)] Loss: 0.170677
Train Epoch: 8 [15360/60000 (26%)] Loss: 0.401550
Train Epoch: 8 [17920/60000 (30%)] Loss: 0.360215
Train Epoch: 8 [20480/60000 (34%)] Loss: 0.329762
Train Epoch: 8 [23040/60000 (38%)] Loss: 0.311394
Train Epoch: 8 [25600/60000 (43%)] Loss: 0.187782
Train Epoch: 8 [28160/60000 (47%)] Loss: 0.242610
Train Epoch: 8 [30720/60000 (51%)] Loss: 0.327457
Train Epoch: 8 [33280/60000 (55%)] Loss: 0.249692
Train Epoch: 8 [35840/60000 (60%)] Loss: 0.325073
Train Epoch: 8 [38400/60000 (64%)] Loss: 0.210571
Train Epoch: 8 [40960/60000 (68%)] Loss: 0.244188
Train Epoch: 8 [43520/60000 (72%)] Loss: 0.258391
Train Epoch: 8 [46080/60000 (77%)] Loss: 0.228774
Train Epoch: 8 [48640/60000 (81%)] Loss: 0.253648
Train Epoch: 8 [51200/60000 (85%)] Loss: 0.387448
Train Epoch: 8 [53760/60000 (90%)] Loss: 0.223092
Train Epoch: 8 [56320/60000 (94%)] Loss: 0.155341
Train Epoch: 8 [58880/60000 (98%)] Loss: 0.233621
Test set: Avg. loss: 0.0875, Accuracy: 9730/10000 (97%)
Train Epoch: 9 [00000/60000 (0%)] Loss: 0.356232
Train Epoch: 9 [02560/60000 (4%)] Loss: 0.228741
Train Epoch: 9 [05120/60000 (9%)] Loss: 0.237508
Train Epoch: 9 [07680/60000 (13%)] Loss: 0.200507
Train Epoch: 9 [10240/60000 (17%)] Loss: 0.180248
Train Epoch: 9 [12800/60000 (21%)] Loss: 0.316860
Train Epoch: 9 [15360/60000 (26%)] Loss: 0.251897
Train Epoch: 9 [17920/60000 (30%)] Loss: 0.341667
Train Epoch: 9 [20480/60000 (34%)] Loss: 0.256542
Train Epoch: 9 [23040/60000 (38%)] Loss: 0.272902
Train Epoch: 9 [25600/60000 (43%)] Loss: 0.207665
Train Epoch: 9 [28160/60000 (47%)] Loss: 0.134075
Train Epoch: 9 [30720/60000 (51%)] Loss: 0.198930
Train Epoch: 9 [33280/60000 (55%)] Loss: 0.218460
Train Epoch: 9 [35840/60000 (60%)] Loss: 0.372338
Train Epoch: 9 [38400/60000 (64%)] Loss: 0.207677
Train Epoch: 9 [40960/60000 (68%)] Loss: 0.274011
Train Epoch: 9 [43520/60000 (72%)] Loss: 0.158897
Train Epoch: 9 [46080/60000 (77%)] Loss: 0.236698
Train Epoch: 9 [48640/60000 (81%)] Loss: 0.238054
Train Epoch: 9 [51200/60000 (85%)] Loss: 0.264504
Train Epoch: 9 [53760/60000 (90%)] Loss: 0.246547
Train Epoch: 9 [56320/60000 (94%)] Loss: 0.181960
Train Epoch: 9 [58880/60000 (98%)] Loss: 0.182172
Test set: Avg. loss: 0.0810, Accuracy: 9751/10000 (98%)
Train Epoch: 10 [00000/60000 (0%)] Loss: 0.223450
Train Epoch: 10 [02560/60000 (4%)] Loss: 0.146772
Train Epoch: 10 [05120/60000 (9%)] Loss: 0.448847
Train Epoch: 10 [07680/60000 (13%)] Loss: 0.192779
Train Epoch: 10 [10240/60000 (17%)] Loss: 0.180810
Train Epoch: 10 [12800/60000 (21%)] Loss: 0.256196
Train Epoch: 10 [15360/60000 (26%)] Loss: 0.248770
Train Epoch: 10 [17920/60000 (30%)] Loss: 0.243000
Train Epoch: 10 [20480/60000 (34%)] Loss: 0.274350
Train Epoch: 10 [23040/60000 (38%)] Loss: 0.256088
Train Epoch: 10 [25600/60000 (43%)] Loss: 0.326276
Train Epoch: 10 [28160/60000 (47%)] Loss: 0.192222
Train Epoch: 10 [30720/60000 (51%)] Loss: 0.163986
Train Epoch: 10 [33280/60000 (55%)] Loss: 0.261450
Train Epoch: 10 [35840/60000 (60%)] Loss: 0.263229
Train Epoch: 10 [38400/60000 (64%)] Loss: 0.278172
Train Epoch: 10 [40960/60000 (68%)] Loss: 0.245064
Train Epoch: 10 [43520/60000 (72%)] Loss: 0.298991
Train Epoch: 10 [46080/60000 (77%)] Loss: 0.334879
Train Epoch: 10 [48640/60000 (81%)] Loss: 0.306450
Train Epoch: 10 [51200/60000 (85%)] Loss: 0.208277
Train Epoch: 10 [53760/60000 (90%)] Loss: 0.211731
Train Epoch: 10 [56320/60000 (94%)] Loss: 0.186170
Train Epoch: 10 [58880/60000 (98%)] Loss: 0.219588
Test set: Avg. loss: 0.0760, Accuracy: 9754/10000 (98%)
1 2 3 4 5 6 7 fig = plt.figure() plt.plot(train_counter, train_losses, color="blue" ) plt.scatter(test_counter, test_losses, color="red" ) plt.legend(["Train Loss" , "Test Loss" ], loc="upper right" ) plt.xlabel("number of training examples seen" ) plt.ylabel("negative log likelihood loss" ) plt.show()
模型测试 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 examples = enumerate (test_loader) batch_idx, (example_data, example_targets) = next (examples) with torch.no_grad(): example_data = example_data.to(device) output = network(example_data) fig = plt.figure() for i in range (6 ): plt.subplot(2 , 3 , i + 1 ) plt.tight_layout() plt.imshow( example_data.detach().to("cpu" ).numpy()[i][0 ], cmap="gray" , interpolation="none" ) plt.title(f"prediction: {output.data.max (1 , keepdim=True )[1 ][i].item()} " ) plt.xticks([]) plt.yticks([]) plt.show()
模型再训练 1 2 3 4 5 6 continued_network = Net() continued_optimizer = optim.SGD( network.parameters(), lr=learning_rate, momentum=momentum, )
1 2 3 4 5 6 7 latest_model = "/workspace/disk1/datasets/models/mnist/model_latest.pth" latest_optimizer = "/workspace/disk1/datasets/models/mnist/optimizer_latest.pth" network_state_dict = torch.load(latest_model) continued_network.load_state_dict(network_state_dict) optimizer_state_dict = torch.load(latest_optimizer) continued_optimizer.load_state_dict(optimizer_state_dict)
1 2 3 4 5 m = 5 for i in range (n_epochs + 1 , n_epochs + m): test_counter.append(i * len (train_loader.dataset)) train(i) test()
Train Epoch: 11 [00000/60000 (0%)] Loss: 0.208258
Train Epoch: 11 [02560/60000 (4%)] Loss: 0.274365
Train Epoch: 11 [05120/60000 (9%)] Loss: 0.438731
Train Epoch: 11 [07680/60000 (13%)] Loss: 0.192449
Train Epoch: 11 [10240/60000 (17%)] Loss: 0.333653
Train Epoch: 11 [12800/60000 (21%)] Loss: 0.243886
Train Epoch: 11 [15360/60000 (26%)] Loss: 0.403167
Train Epoch: 11 [17920/60000 (30%)] Loss: 0.287619
Train Epoch: 11 [20480/60000 (34%)] Loss: 0.169019
Train Epoch: 11 [23040/60000 (38%)] Loss: 0.212516
Train Epoch: 11 [25600/60000 (43%)] Loss: 0.202806
Train Epoch: 11 [28160/60000 (47%)] Loss: 0.217581
Train Epoch: 11 [30720/60000 (51%)] Loss: 0.230557
Train Epoch: 11 [33280/60000 (55%)] Loss: 0.378703
Train Epoch: 11 [35840/60000 (60%)] Loss: 0.189265
Train Epoch: 11 [38400/60000 (64%)] Loss: 0.302335
Train Epoch: 11 [40960/60000 (68%)] Loss: 0.132773
Train Epoch: 11 [43520/60000 (72%)] Loss: 0.222291
Train Epoch: 11 [46080/60000 (77%)] Loss: 0.321381
Train Epoch: 11 [48640/60000 (81%)] Loss: 0.150813
Train Epoch: 11 [51200/60000 (85%)] Loss: 0.249176
Train Epoch: 11 [53760/60000 (90%)] Loss: 0.245743
Train Epoch: 11 [56320/60000 (94%)] Loss: 0.212701
Train Epoch: 11 [58880/60000 (98%)] Loss: 0.363840
Test set: Avg. loss: 0.0722, Accuracy: 9776/10000 (98%)
Train Epoch: 12 [00000/60000 (0%)] Loss: 0.189078
Train Epoch: 12 [02560/60000 (4%)] Loss: 0.184256
Train Epoch: 12 [05120/60000 (9%)] Loss: 0.143418
Train Epoch: 12 [07680/60000 (13%)] Loss: 0.161733
Train Epoch: 12 [10240/60000 (17%)] Loss: 0.238975
Train Epoch: 12 [12800/60000 (21%)] Loss: 0.200676
Train Epoch: 12 [15360/60000 (26%)] Loss: 0.163866
Train Epoch: 12 [17920/60000 (30%)] Loss: 0.337974
Train Epoch: 12 [20480/60000 (34%)] Loss: 0.160897
Train Epoch: 12 [23040/60000 (38%)] Loss: 0.156017
Train Epoch: 12 [25600/60000 (43%)] Loss: 0.188498
Train Epoch: 12 [28160/60000 (47%)] Loss: 0.272446
Train Epoch: 12 [30720/60000 (51%)] Loss: 0.124439
Train Epoch: 12 [33280/60000 (55%)] Loss: 0.131949
Train Epoch: 12 [35840/60000 (60%)] Loss: 0.293010
Train Epoch: 12 [38400/60000 (64%)] Loss: 0.187551
Train Epoch: 12 [40960/60000 (68%)] Loss: 0.181151
Train Epoch: 12 [43520/60000 (72%)] Loss: 0.270526
Train Epoch: 12 [46080/60000 (77%)] Loss: 0.131309
Train Epoch: 12 [48640/60000 (81%)] Loss: 0.261624
Train Epoch: 12 [51200/60000 (85%)] Loss: 0.239715
Train Epoch: 12 [53760/60000 (90%)] Loss: 0.163549
Train Epoch: 12 [56320/60000 (94%)] Loss: 0.160421
Train Epoch: 12 [58880/60000 (98%)] Loss: 0.160318
Test set: Avg. loss: 0.0687, Accuracy: 9787/10000 (98%)
Train Epoch: 13 [00000/60000 (0%)] Loss: 0.140241
Train Epoch: 13 [02560/60000 (4%)] Loss: 0.144069
Train Epoch: 13 [05120/60000 (9%)] Loss: 0.323135
Train Epoch: 13 [07680/60000 (13%)] Loss: 0.336287
Train Epoch: 13 [10240/60000 (17%)] Loss: 0.107315
Train Epoch: 13 [12800/60000 (21%)] Loss: 0.169032
Train Epoch: 13 [15360/60000 (26%)] Loss: 0.162337
Train Epoch: 13 [17920/60000 (30%)] Loss: 0.253107
Train Epoch: 13 [20480/60000 (34%)] Loss: 0.166370
Train Epoch: 13 [23040/60000 (38%)] Loss: 0.243374
Train Epoch: 13 [25600/60000 (43%)] Loss: 0.160263
Train Epoch: 13 [28160/60000 (47%)] Loss: 0.187129
Train Epoch: 13 [30720/60000 (51%)] Loss: 0.348670
Train Epoch: 13 [33280/60000 (55%)] Loss: 0.166424
Train Epoch: 13 [35840/60000 (60%)] Loss: 0.184487
Train Epoch: 13 [38400/60000 (64%)] Loss: 0.159097
Train Epoch: 13 [40960/60000 (68%)] Loss: 0.110388
Train Epoch: 13 [43520/60000 (72%)] Loss: 0.114675
Train Epoch: 13 [46080/60000 (77%)] Loss: 0.193499
Train Epoch: 13 [48640/60000 (81%)] Loss: 0.256665
Train Epoch: 13 [51200/60000 (85%)] Loss: 0.204359
Train Epoch: 13 [53760/60000 (90%)] Loss: 0.228794
Train Epoch: 13 [56320/60000 (94%)] Loss: 0.229143
Train Epoch: 13 [58880/60000 (98%)] Loss: 0.198778
Test set: Avg. loss: 0.0659, Accuracy: 9792/10000 (98%)
Train Epoch: 14 [00000/60000 (0%)] Loss: 0.154132
Train Epoch: 14 [02560/60000 (4%)] Loss: 0.174841
Train Epoch: 14 [05120/60000 (9%)] Loss: 0.131765
Train Epoch: 14 [07680/60000 (13%)] Loss: 0.163187
Train Epoch: 14 [10240/60000 (17%)] Loss: 0.130205
Train Epoch: 14 [12800/60000 (21%)] Loss: 0.230511
Train Epoch: 14 [15360/60000 (26%)] Loss: 0.206032
Train Epoch: 14 [17920/60000 (30%)] Loss: 0.209682
Train Epoch: 14 [20480/60000 (34%)] Loss: 0.143732
Train Epoch: 14 [23040/60000 (38%)] Loss: 0.247467
Train Epoch: 14 [25600/60000 (43%)] Loss: 0.141316
Train Epoch: 14 [28160/60000 (47%)] Loss: 0.156982
Train Epoch: 14 [30720/60000 (51%)] Loss: 0.249250
Train Epoch: 14 [33280/60000 (55%)] Loss: 0.252457
Train Epoch: 14 [35840/60000 (60%)] Loss: 0.137284
Train Epoch: 14 [38400/60000 (64%)] Loss: 0.212023
Train Epoch: 14 [40960/60000 (68%)] Loss: 0.227320
Train Epoch: 14 [43520/60000 (72%)] Loss: 0.200754
Train Epoch: 14 [46080/60000 (77%)] Loss: 0.197454
Train Epoch: 14 [48640/60000 (81%)] Loss: 0.200271
Train Epoch: 14 [51200/60000 (85%)] Loss: 0.135254
Train Epoch: 14 [53760/60000 (90%)] Loss: 0.137874
Train Epoch: 14 [56320/60000 (94%)] Loss: 0.213711
Train Epoch: 14 [58880/60000 (98%)] Loss: 0.286362
Test set: Avg. loss: 0.0631, Accuracy: 9794/10000 (98%)
参考文献
用PyTorch实现MNIST手写数字识别(非常详细)