PyTorch 模型保存与再训练,基于 MNIST 数据集。

导入依赖包

1
2
3
4
5
6
7
8
9
import os

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader

加载数据

1
2
3
4
5
6
7
8
9
n_epochs = 10
batch_size_train = 128
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 20
torch.manual_seed(33)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device
device(type='cuda', index=1)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
)
train_data = torchvision.datasets.MNIST(
"/workspace/disk1/datasets",
train=True,
download=True,
transform=transform,
)
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=batch_size_train,
shuffle=True,
)

test_data = torchvision.datasets.MNIST(
"/workspace/disk1/datasets/",
train=False,
download=True,
transform=transform,
)
test_loader = torch.utils.data.DataLoader(
test_data,
batch_size=batch_size_test,
shuffle=True,
)
1
2
3
4
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
# print(example_targets)
print(example_data.shape)
torch.Size([1000, 1, 28, 28])
1
2
3
4
5
6
7
8
9
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap="gray", interpolation="none")
plt.title(f"Ground Truth: {example_targets[i]}")
plt.xticks([])
plt.yticks([])
plt.show()

png

构建模型和优化算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = F.log_softmax(self.fc2(x), dim=1)
return x
1
2
3
4
5
6
network = Net().to(device)
optimizer = optim.SGD(
network.parameters(),
lr=learning_rate,
momentum=momentum,
)

模型训练与保存

1
2
3
4
train_losses = []
train_counter = []
test_losses = []
test_counter = [i * len(train_loader.dataset) for i in range(n_epochs)]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def train(epoch):
network.train()
for batch_idx, (data, target) in enumerate(train_loader):
data = data.to(device)
target = target.to(device)
optimizer.zero_grad()
output = network(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print(
f"Train Epoch: {epoch} [{str(batch_idx * len(data)).zfill(5)}/{len(train_loader.dataset)} ({100. * batch_idx/ len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
)
train_losses.append(loss.item())
train_counter.append(
(batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset))
)

# 保存当前批次的模型
model_checkpoint = f"/workspace/disk1/datasets/models/mnist/model_epoch{epoch}.pth"
optimizer_checkpoint = (
f"/workspace/disk1/datasets/models/mnist/optimizer_epoch{epoch}.pth"
)
latest_model = "/workspace/disk1/datasets/models/mnist/model_latest.pth"
latest_optimizer = "/workspace/disk1/datasets/models/mnist/optimizer_latest.pth"

torch.save(network.state_dict(), model_checkpoint)
torch.save(optimizer.state_dict(), optimizer_checkpoint)

if os.path.exists(latest_model):
os.remove(latest_model)
if os.path.exists(latest_optimizer):
os.remove(latest_optimizer)
os.symlink(model_checkpoint, latest_model)
os.symlink(optimizer_checkpoint, latest_optimizer)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def test():
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data = data.to(device)
target = target.to(device)
output = network(data)
# test_loss += F.nll_loss(output, target, size_average=False).item()
test_loss += F.nll_loss(output, target, reduction="sum").item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
test_losses.append(test_loss)
print(
f"\nTest set: Avg. loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.*correct/len(test_loader.dataset):.0f}%)\n"
)
1
2
3
for epoch in range(1, n_epochs + 1):
train(epoch)
test()
Train Epoch: 1 [00000/60000 (0%)]	Loss: 2.288714
Train Epoch: 1 [02560/60000 (4%)]	Loss: 2.255195
Train Epoch: 1 [05120/60000 (9%)]	Loss: 2.198278
Train Epoch: 1 [07680/60000 (13%)]	Loss: 2.091046
Train Epoch: 1 [10240/60000 (17%)]	Loss: 1.832017
Train Epoch: 1 [12800/60000 (21%)]	Loss: 1.800682
Train Epoch: 1 [15360/60000 (26%)]	Loss: 1.533700
Train Epoch: 1 [17920/60000 (30%)]	Loss: 1.332488
Train Epoch: 1 [20480/60000 (34%)]	Loss: 1.289864
Train Epoch: 1 [23040/60000 (38%)]	Loss: 1.195423
Train Epoch: 1 [25600/60000 (43%)]	Loss: 1.075095
Train Epoch: 1 [28160/60000 (47%)]	Loss: 1.004397
Train Epoch: 1 [30720/60000 (51%)]	Loss: 0.832168
Train Epoch: 1 [33280/60000 (55%)]	Loss: 0.849826
Train Epoch: 1 [35840/60000 (60%)]	Loss: 0.968393
Train Epoch: 1 [38400/60000 (64%)]	Loss: 0.866699
Train Epoch: 1 [40960/60000 (68%)]	Loss: 0.721661
Train Epoch: 1 [43520/60000 (72%)]	Loss: 0.962982
Train Epoch: 1 [46080/60000 (77%)]	Loss: 0.695050
Train Epoch: 1 [48640/60000 (81%)]	Loss: 0.799360
Train Epoch: 1 [51200/60000 (85%)]	Loss: 0.683803
Train Epoch: 1 [53760/60000 (90%)]	Loss: 0.561625
Train Epoch: 1 [56320/60000 (94%)]	Loss: 0.579419
Train Epoch: 1 [58880/60000 (98%)]	Loss: 0.568427

Test set: Avg. loss: 0.3283, Accuracy: 9096/10000 (91%)

Train Epoch: 2 [00000/60000 (0%)]	Loss: 0.781268
Train Epoch: 2 [02560/60000 (4%)]	Loss: 0.679616
Train Epoch: 2 [05120/60000 (9%)]	Loss: 0.572369
Train Epoch: 2 [07680/60000 (13%)]	Loss: 0.638919
Train Epoch: 2 [10240/60000 (17%)]	Loss: 0.721595
Train Epoch: 2 [12800/60000 (21%)]	Loss: 0.723331
Train Epoch: 2 [15360/60000 (26%)]	Loss: 0.604582
Train Epoch: 2 [17920/60000 (30%)]	Loss: 0.689362
Train Epoch: 2 [20480/60000 (34%)]	Loss: 0.548551
Train Epoch: 2 [23040/60000 (38%)]	Loss: 0.650297
Train Epoch: 2 [25600/60000 (43%)]	Loss: 0.506893
Train Epoch: 2 [28160/60000 (47%)]	Loss: 0.581708
Train Epoch: 2 [30720/60000 (51%)]	Loss: 0.557465
Train Epoch: 2 [33280/60000 (55%)]	Loss: 0.461637
Train Epoch: 2 [35840/60000 (60%)]	Loss: 0.619341
Train Epoch: 2 [38400/60000 (64%)]	Loss: 0.464800
Train Epoch: 2 [40960/60000 (68%)]	Loss: 0.473921
Train Epoch: 2 [43520/60000 (72%)]	Loss: 0.567576
Train Epoch: 2 [46080/60000 (77%)]	Loss: 0.447070
Train Epoch: 2 [48640/60000 (81%)]	Loss: 0.503476
Train Epoch: 2 [51200/60000 (85%)]	Loss: 0.500809
Train Epoch: 2 [53760/60000 (90%)]	Loss: 0.553329
Train Epoch: 2 [56320/60000 (94%)]	Loss: 0.504529
Train Epoch: 2 [58880/60000 (98%)]	Loss: 0.457889

Test set: Avg. loss: 0.2128, Accuracy: 9357/10000 (94%)

Train Epoch: 3 [00000/60000 (0%)]	Loss: 0.369795
Train Epoch: 3 [02560/60000 (4%)]	Loss: 0.414531
Train Epoch: 3 [05120/60000 (9%)]	Loss: 0.604378
Train Epoch: 3 [07680/60000 (13%)]	Loss: 0.426111
Train Epoch: 3 [10240/60000 (17%)]	Loss: 0.492895
Train Epoch: 3 [12800/60000 (21%)]	Loss: 0.393350
Train Epoch: 3 [15360/60000 (26%)]	Loss: 0.555914
Train Epoch: 3 [17920/60000 (30%)]	Loss: 0.476940
Train Epoch: 3 [20480/60000 (34%)]	Loss: 0.430539
Train Epoch: 3 [23040/60000 (38%)]	Loss: 0.571562
Train Epoch: 3 [25600/60000 (43%)]	Loss: 0.488061
Train Epoch: 3 [28160/60000 (47%)]	Loss: 0.598932
Train Epoch: 3 [30720/60000 (51%)]	Loss: 0.417797
Train Epoch: 3 [33280/60000 (55%)]	Loss: 0.486182
Train Epoch: 3 [35840/60000 (60%)]	Loss: 0.375228
Train Epoch: 3 [38400/60000 (64%)]	Loss: 0.384777
Train Epoch: 3 [40960/60000 (68%)]	Loss: 0.428213
Train Epoch: 3 [43520/60000 (72%)]	Loss: 0.443456
Train Epoch: 3 [46080/60000 (77%)]	Loss: 0.308731
Train Epoch: 3 [48640/60000 (81%)]	Loss: 0.347535
Train Epoch: 3 [51200/60000 (85%)]	Loss: 0.439240
Train Epoch: 3 [53760/60000 (90%)]	Loss: 0.460515
Train Epoch: 3 [56320/60000 (94%)]	Loss: 0.521822
Train Epoch: 3 [58880/60000 (98%)]	Loss: 0.451665

Test set: Avg. loss: 0.1591, Accuracy: 9513/10000 (95%)

Train Epoch: 4 [00000/60000 (0%)]	Loss: 0.510825
Train Epoch: 4 [02560/60000 (4%)]	Loss: 0.279729
Train Epoch: 4 [05120/60000 (9%)]	Loss: 0.333447
Train Epoch: 4 [07680/60000 (13%)]	Loss: 0.434208
Train Epoch: 4 [10240/60000 (17%)]	Loss: 0.516646
Train Epoch: 4 [12800/60000 (21%)]	Loss: 0.339009
Train Epoch: 4 [15360/60000 (26%)]	Loss: 0.342047
Train Epoch: 4 [17920/60000 (30%)]	Loss: 0.315687
Train Epoch: 4 [20480/60000 (34%)]	Loss: 0.365422
Train Epoch: 4 [23040/60000 (38%)]	Loss: 0.408456
Train Epoch: 4 [25600/60000 (43%)]	Loss: 0.512156
Train Epoch: 4 [28160/60000 (47%)]	Loss: 0.246768
Train Epoch: 4 [30720/60000 (51%)]	Loss: 0.254370
Train Epoch: 4 [33280/60000 (55%)]	Loss: 0.403202
Train Epoch: 4 [35840/60000 (60%)]	Loss: 0.364687
Train Epoch: 4 [38400/60000 (64%)]	Loss: 0.313392
Train Epoch: 4 [40960/60000 (68%)]	Loss: 0.256359
Train Epoch: 4 [43520/60000 (72%)]	Loss: 0.306669
Train Epoch: 4 [46080/60000 (77%)]	Loss: 0.459862
Train Epoch: 4 [48640/60000 (81%)]	Loss: 0.227380
Train Epoch: 4 [51200/60000 (85%)]	Loss: 0.368363
Train Epoch: 4 [53760/60000 (90%)]	Loss: 0.329823
Train Epoch: 4 [56320/60000 (94%)]	Loss: 0.304764
Train Epoch: 4 [58880/60000 (98%)]	Loss: 0.288302

Test set: Avg. loss: 0.1322, Accuracy: 9596/10000 (96%)

Train Epoch: 5 [00000/60000 (0%)]	Loss: 0.247250
Train Epoch: 5 [02560/60000 (4%)]	Loss: 0.383674
Train Epoch: 5 [05120/60000 (9%)]	Loss: 0.315615
Train Epoch: 5 [07680/60000 (13%)]	Loss: 0.259549
Train Epoch: 5 [10240/60000 (17%)]	Loss: 0.322165
Train Epoch: 5 [12800/60000 (21%)]	Loss: 0.474284
Train Epoch: 5 [15360/60000 (26%)]	Loss: 0.254566
Train Epoch: 5 [17920/60000 (30%)]	Loss: 0.466621
Train Epoch: 5 [20480/60000 (34%)]	Loss: 0.320438
Train Epoch: 5 [23040/60000 (38%)]	Loss: 0.316798
Train Epoch: 5 [25600/60000 (43%)]	Loss: 0.192171
Train Epoch: 5 [28160/60000 (47%)]	Loss: 0.478387
Train Epoch: 5 [30720/60000 (51%)]	Loss: 0.346078
Train Epoch: 5 [33280/60000 (55%)]	Loss: 0.315769
Train Epoch: 5 [35840/60000 (60%)]	Loss: 0.284116
Train Epoch: 5 [38400/60000 (64%)]	Loss: 0.354763
Train Epoch: 5 [40960/60000 (68%)]	Loss: 0.369070
Train Epoch: 5 [43520/60000 (72%)]	Loss: 0.233857
Train Epoch: 5 [46080/60000 (77%)]	Loss: 0.222160
Train Epoch: 5 [48640/60000 (81%)]	Loss: 0.325135
Train Epoch: 5 [51200/60000 (85%)]	Loss: 0.226506
Train Epoch: 5 [53760/60000 (90%)]	Loss: 0.407618
Train Epoch: 5 [56320/60000 (94%)]	Loss: 0.359771
Train Epoch: 5 [58880/60000 (98%)]	Loss: 0.290950

Test set: Avg. loss: 0.1152, Accuracy: 9643/10000 (96%)

Train Epoch: 6 [00000/60000 (0%)]	Loss: 0.371519
Train Epoch: 6 [02560/60000 (4%)]	Loss: 0.294658
Train Epoch: 6 [05120/60000 (9%)]	Loss: 0.195041
Train Epoch: 6 [07680/60000 (13%)]	Loss: 0.247292
Train Epoch: 6 [10240/60000 (17%)]	Loss: 0.307761
Train Epoch: 6 [12800/60000 (21%)]	Loss: 0.183960
Train Epoch: 6 [15360/60000 (26%)]	Loss: 0.255224
Train Epoch: 6 [17920/60000 (30%)]	Loss: 0.564046
Train Epoch: 6 [20480/60000 (34%)]	Loss: 0.217146
Train Epoch: 6 [23040/60000 (38%)]	Loss: 0.364980
Train Epoch: 6 [25600/60000 (43%)]	Loss: 0.237876
Train Epoch: 6 [28160/60000 (47%)]	Loss: 0.344803
Train Epoch: 6 [30720/60000 (51%)]	Loss: 0.347686
Train Epoch: 6 [33280/60000 (55%)]	Loss: 0.197488
Train Epoch: 6 [35840/60000 (60%)]	Loss: 0.346718
Train Epoch: 6 [38400/60000 (64%)]	Loss: 0.256105
Train Epoch: 6 [40960/60000 (68%)]	Loss: 0.211900
Train Epoch: 6 [43520/60000 (72%)]	Loss: 0.264353
Train Epoch: 6 [46080/60000 (77%)]	Loss: 0.339571
Train Epoch: 6 [48640/60000 (81%)]	Loss: 0.198715
Train Epoch: 6 [51200/60000 (85%)]	Loss: 0.335813
Train Epoch: 6 [53760/60000 (90%)]	Loss: 0.244630
Train Epoch: 6 [56320/60000 (94%)]	Loss: 0.260668
Train Epoch: 6 [58880/60000 (98%)]	Loss: 0.281284

Test set: Avg. loss: 0.1029, Accuracy: 9682/10000 (97%)

Train Epoch: 7 [00000/60000 (0%)]	Loss: 0.259941
Train Epoch: 7 [02560/60000 (4%)]	Loss: 0.353288
Train Epoch: 7 [05120/60000 (9%)]	Loss: 0.345746
Train Epoch: 7 [07680/60000 (13%)]	Loss: 0.263094
Train Epoch: 7 [10240/60000 (17%)]	Loss: 0.370562
Train Epoch: 7 [12800/60000 (21%)]	Loss: 0.184917
Train Epoch: 7 [15360/60000 (26%)]	Loss: 0.358648
Train Epoch: 7 [17920/60000 (30%)]	Loss: 0.358313
Train Epoch: 7 [20480/60000 (34%)]	Loss: 0.455060
Train Epoch: 7 [23040/60000 (38%)]	Loss: 0.157829
Train Epoch: 7 [25600/60000 (43%)]	Loss: 0.255777
Train Epoch: 7 [28160/60000 (47%)]	Loss: 0.296378
Train Epoch: 7 [30720/60000 (51%)]	Loss: 0.220109
Train Epoch: 7 [33280/60000 (55%)]	Loss: 0.207805
Train Epoch: 7 [35840/60000 (60%)]	Loss: 0.333757
Train Epoch: 7 [38400/60000 (64%)]	Loss: 0.351853
Train Epoch: 7 [40960/60000 (68%)]	Loss: 0.225360
Train Epoch: 7 [43520/60000 (72%)]	Loss: 0.220420
Train Epoch: 7 [46080/60000 (77%)]	Loss: 0.281292
Train Epoch: 7 [48640/60000 (81%)]	Loss: 0.224555
Train Epoch: 7 [51200/60000 (85%)]	Loss: 0.300659
Train Epoch: 7 [53760/60000 (90%)]	Loss: 0.155560
Train Epoch: 7 [56320/60000 (94%)]	Loss: 0.322972
Train Epoch: 7 [58880/60000 (98%)]	Loss: 0.189427

Test set: Avg. loss: 0.0973, Accuracy: 9693/10000 (97%)

Train Epoch: 8 [00000/60000 (0%)]	Loss: 0.237282
Train Epoch: 8 [02560/60000 (4%)]	Loss: 0.205023
Train Epoch: 8 [05120/60000 (9%)]	Loss: 0.178199
Train Epoch: 8 [07680/60000 (13%)]	Loss: 0.335434
Train Epoch: 8 [10240/60000 (17%)]	Loss: 0.274767
Train Epoch: 8 [12800/60000 (21%)]	Loss: 0.170677
Train Epoch: 8 [15360/60000 (26%)]	Loss: 0.401550
Train Epoch: 8 [17920/60000 (30%)]	Loss: 0.360215
Train Epoch: 8 [20480/60000 (34%)]	Loss: 0.329762
Train Epoch: 8 [23040/60000 (38%)]	Loss: 0.311394
Train Epoch: 8 [25600/60000 (43%)]	Loss: 0.187782
Train Epoch: 8 [28160/60000 (47%)]	Loss: 0.242610
Train Epoch: 8 [30720/60000 (51%)]	Loss: 0.327457
Train Epoch: 8 [33280/60000 (55%)]	Loss: 0.249692
Train Epoch: 8 [35840/60000 (60%)]	Loss: 0.325073
Train Epoch: 8 [38400/60000 (64%)]	Loss: 0.210571
Train Epoch: 8 [40960/60000 (68%)]	Loss: 0.244188
Train Epoch: 8 [43520/60000 (72%)]	Loss: 0.258391
Train Epoch: 8 [46080/60000 (77%)]	Loss: 0.228774
Train Epoch: 8 [48640/60000 (81%)]	Loss: 0.253648
Train Epoch: 8 [51200/60000 (85%)]	Loss: 0.387448
Train Epoch: 8 [53760/60000 (90%)]	Loss: 0.223092
Train Epoch: 8 [56320/60000 (94%)]	Loss: 0.155341
Train Epoch: 8 [58880/60000 (98%)]	Loss: 0.233621

Test set: Avg. loss: 0.0875, Accuracy: 9730/10000 (97%)

Train Epoch: 9 [00000/60000 (0%)]	Loss: 0.356232
Train Epoch: 9 [02560/60000 (4%)]	Loss: 0.228741
Train Epoch: 9 [05120/60000 (9%)]	Loss: 0.237508
Train Epoch: 9 [07680/60000 (13%)]	Loss: 0.200507
Train Epoch: 9 [10240/60000 (17%)]	Loss: 0.180248
Train Epoch: 9 [12800/60000 (21%)]	Loss: 0.316860
Train Epoch: 9 [15360/60000 (26%)]	Loss: 0.251897
Train Epoch: 9 [17920/60000 (30%)]	Loss: 0.341667
Train Epoch: 9 [20480/60000 (34%)]	Loss: 0.256542
Train Epoch: 9 [23040/60000 (38%)]	Loss: 0.272902
Train Epoch: 9 [25600/60000 (43%)]	Loss: 0.207665
Train Epoch: 9 [28160/60000 (47%)]	Loss: 0.134075
Train Epoch: 9 [30720/60000 (51%)]	Loss: 0.198930
Train Epoch: 9 [33280/60000 (55%)]	Loss: 0.218460
Train Epoch: 9 [35840/60000 (60%)]	Loss: 0.372338
Train Epoch: 9 [38400/60000 (64%)]	Loss: 0.207677
Train Epoch: 9 [40960/60000 (68%)]	Loss: 0.274011
Train Epoch: 9 [43520/60000 (72%)]	Loss: 0.158897
Train Epoch: 9 [46080/60000 (77%)]	Loss: 0.236698
Train Epoch: 9 [48640/60000 (81%)]	Loss: 0.238054
Train Epoch: 9 [51200/60000 (85%)]	Loss: 0.264504
Train Epoch: 9 [53760/60000 (90%)]	Loss: 0.246547
Train Epoch: 9 [56320/60000 (94%)]	Loss: 0.181960
Train Epoch: 9 [58880/60000 (98%)]	Loss: 0.182172

Test set: Avg. loss: 0.0810, Accuracy: 9751/10000 (98%)

Train Epoch: 10 [00000/60000 (0%)]	Loss: 0.223450
Train Epoch: 10 [02560/60000 (4%)]	Loss: 0.146772
Train Epoch: 10 [05120/60000 (9%)]	Loss: 0.448847
Train Epoch: 10 [07680/60000 (13%)]	Loss: 0.192779
Train Epoch: 10 [10240/60000 (17%)]	Loss: 0.180810
Train Epoch: 10 [12800/60000 (21%)]	Loss: 0.256196
Train Epoch: 10 [15360/60000 (26%)]	Loss: 0.248770
Train Epoch: 10 [17920/60000 (30%)]	Loss: 0.243000
Train Epoch: 10 [20480/60000 (34%)]	Loss: 0.274350
Train Epoch: 10 [23040/60000 (38%)]	Loss: 0.256088
Train Epoch: 10 [25600/60000 (43%)]	Loss: 0.326276
Train Epoch: 10 [28160/60000 (47%)]	Loss: 0.192222
Train Epoch: 10 [30720/60000 (51%)]	Loss: 0.163986
Train Epoch: 10 [33280/60000 (55%)]	Loss: 0.261450
Train Epoch: 10 [35840/60000 (60%)]	Loss: 0.263229
Train Epoch: 10 [38400/60000 (64%)]	Loss: 0.278172
Train Epoch: 10 [40960/60000 (68%)]	Loss: 0.245064
Train Epoch: 10 [43520/60000 (72%)]	Loss: 0.298991
Train Epoch: 10 [46080/60000 (77%)]	Loss: 0.334879
Train Epoch: 10 [48640/60000 (81%)]	Loss: 0.306450
Train Epoch: 10 [51200/60000 (85%)]	Loss: 0.208277
Train Epoch: 10 [53760/60000 (90%)]	Loss: 0.211731
Train Epoch: 10 [56320/60000 (94%)]	Loss: 0.186170
Train Epoch: 10 [58880/60000 (98%)]	Loss: 0.219588

Test set: Avg. loss: 0.0760, Accuracy: 9754/10000 (98%)
1
2
3
4
5
6
7
fig = plt.figure()
plt.plot(train_counter, train_losses, color="blue")
plt.scatter(test_counter, test_losses, color="red")
plt.legend(["Train Loss", "Test Loss"], loc="upper right")
plt.xlabel("number of training examples seen")
plt.ylabel("negative log likelihood loss")
plt.show()

png

模型测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
with torch.no_grad():
example_data = example_data.to(device)
output = network(example_data)
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(
example_data.detach().to("cpu").numpy()[i][0], cmap="gray", interpolation="none"
)
plt.title(f"prediction: {output.data.max(1, keepdim=True)[1][i].item()}")
plt.xticks([])
plt.yticks([])
plt.show()

png

模型再训练

1
2
3
4
5
6
continued_network = Net()
continued_optimizer = optim.SGD(
network.parameters(),
lr=learning_rate,
momentum=momentum,
)
1
2
3
4
5
6
7
latest_model = "/workspace/disk1/datasets/models/mnist/model_latest.pth"
latest_optimizer = "/workspace/disk1/datasets/models/mnist/optimizer_latest.pth"

network_state_dict = torch.load(latest_model)
continued_network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load(latest_optimizer)
continued_optimizer.load_state_dict(optimizer_state_dict)
1
2
3
4
5
m = 5
for i in range(n_epochs + 1, n_epochs + m):
test_counter.append(i * len(train_loader.dataset))
train(i)
test()
Train Epoch: 11 [00000/60000 (0%)]	Loss: 0.208258
Train Epoch: 11 [02560/60000 (4%)]	Loss: 0.274365
Train Epoch: 11 [05120/60000 (9%)]	Loss: 0.438731
Train Epoch: 11 [07680/60000 (13%)]	Loss: 0.192449
Train Epoch: 11 [10240/60000 (17%)]	Loss: 0.333653
Train Epoch: 11 [12800/60000 (21%)]	Loss: 0.243886
Train Epoch: 11 [15360/60000 (26%)]	Loss: 0.403167
Train Epoch: 11 [17920/60000 (30%)]	Loss: 0.287619
Train Epoch: 11 [20480/60000 (34%)]	Loss: 0.169019
Train Epoch: 11 [23040/60000 (38%)]	Loss: 0.212516
Train Epoch: 11 [25600/60000 (43%)]	Loss: 0.202806
Train Epoch: 11 [28160/60000 (47%)]	Loss: 0.217581
Train Epoch: 11 [30720/60000 (51%)]	Loss: 0.230557
Train Epoch: 11 [33280/60000 (55%)]	Loss: 0.378703
Train Epoch: 11 [35840/60000 (60%)]	Loss: 0.189265
Train Epoch: 11 [38400/60000 (64%)]	Loss: 0.302335
Train Epoch: 11 [40960/60000 (68%)]	Loss: 0.132773
Train Epoch: 11 [43520/60000 (72%)]	Loss: 0.222291
Train Epoch: 11 [46080/60000 (77%)]	Loss: 0.321381
Train Epoch: 11 [48640/60000 (81%)]	Loss: 0.150813
Train Epoch: 11 [51200/60000 (85%)]	Loss: 0.249176
Train Epoch: 11 [53760/60000 (90%)]	Loss: 0.245743
Train Epoch: 11 [56320/60000 (94%)]	Loss: 0.212701
Train Epoch: 11 [58880/60000 (98%)]	Loss: 0.363840

Test set: Avg. loss: 0.0722, Accuracy: 9776/10000 (98%)

Train Epoch: 12 [00000/60000 (0%)]	Loss: 0.189078
Train Epoch: 12 [02560/60000 (4%)]	Loss: 0.184256
Train Epoch: 12 [05120/60000 (9%)]	Loss: 0.143418
Train Epoch: 12 [07680/60000 (13%)]	Loss: 0.161733
Train Epoch: 12 [10240/60000 (17%)]	Loss: 0.238975
Train Epoch: 12 [12800/60000 (21%)]	Loss: 0.200676
Train Epoch: 12 [15360/60000 (26%)]	Loss: 0.163866
Train Epoch: 12 [17920/60000 (30%)]	Loss: 0.337974
Train Epoch: 12 [20480/60000 (34%)]	Loss: 0.160897
Train Epoch: 12 [23040/60000 (38%)]	Loss: 0.156017
Train Epoch: 12 [25600/60000 (43%)]	Loss: 0.188498
Train Epoch: 12 [28160/60000 (47%)]	Loss: 0.272446
Train Epoch: 12 [30720/60000 (51%)]	Loss: 0.124439
Train Epoch: 12 [33280/60000 (55%)]	Loss: 0.131949
Train Epoch: 12 [35840/60000 (60%)]	Loss: 0.293010
Train Epoch: 12 [38400/60000 (64%)]	Loss: 0.187551
Train Epoch: 12 [40960/60000 (68%)]	Loss: 0.181151
Train Epoch: 12 [43520/60000 (72%)]	Loss: 0.270526
Train Epoch: 12 [46080/60000 (77%)]	Loss: 0.131309
Train Epoch: 12 [48640/60000 (81%)]	Loss: 0.261624
Train Epoch: 12 [51200/60000 (85%)]	Loss: 0.239715
Train Epoch: 12 [53760/60000 (90%)]	Loss: 0.163549
Train Epoch: 12 [56320/60000 (94%)]	Loss: 0.160421
Train Epoch: 12 [58880/60000 (98%)]	Loss: 0.160318

Test set: Avg. loss: 0.0687, Accuracy: 9787/10000 (98%)

Train Epoch: 13 [00000/60000 (0%)]	Loss: 0.140241
Train Epoch: 13 [02560/60000 (4%)]	Loss: 0.144069
Train Epoch: 13 [05120/60000 (9%)]	Loss: 0.323135
Train Epoch: 13 [07680/60000 (13%)]	Loss: 0.336287
Train Epoch: 13 [10240/60000 (17%)]	Loss: 0.107315
Train Epoch: 13 [12800/60000 (21%)]	Loss: 0.169032
Train Epoch: 13 [15360/60000 (26%)]	Loss: 0.162337
Train Epoch: 13 [17920/60000 (30%)]	Loss: 0.253107
Train Epoch: 13 [20480/60000 (34%)]	Loss: 0.166370
Train Epoch: 13 [23040/60000 (38%)]	Loss: 0.243374
Train Epoch: 13 [25600/60000 (43%)]	Loss: 0.160263
Train Epoch: 13 [28160/60000 (47%)]	Loss: 0.187129
Train Epoch: 13 [30720/60000 (51%)]	Loss: 0.348670
Train Epoch: 13 [33280/60000 (55%)]	Loss: 0.166424
Train Epoch: 13 [35840/60000 (60%)]	Loss: 0.184487
Train Epoch: 13 [38400/60000 (64%)]	Loss: 0.159097
Train Epoch: 13 [40960/60000 (68%)]	Loss: 0.110388
Train Epoch: 13 [43520/60000 (72%)]	Loss: 0.114675
Train Epoch: 13 [46080/60000 (77%)]	Loss: 0.193499
Train Epoch: 13 [48640/60000 (81%)]	Loss: 0.256665
Train Epoch: 13 [51200/60000 (85%)]	Loss: 0.204359
Train Epoch: 13 [53760/60000 (90%)]	Loss: 0.228794
Train Epoch: 13 [56320/60000 (94%)]	Loss: 0.229143
Train Epoch: 13 [58880/60000 (98%)]	Loss: 0.198778

Test set: Avg. loss: 0.0659, Accuracy: 9792/10000 (98%)

Train Epoch: 14 [00000/60000 (0%)]	Loss: 0.154132
Train Epoch: 14 [02560/60000 (4%)]	Loss: 0.174841
Train Epoch: 14 [05120/60000 (9%)]	Loss: 0.131765
Train Epoch: 14 [07680/60000 (13%)]	Loss: 0.163187
Train Epoch: 14 [10240/60000 (17%)]	Loss: 0.130205
Train Epoch: 14 [12800/60000 (21%)]	Loss: 0.230511
Train Epoch: 14 [15360/60000 (26%)]	Loss: 0.206032
Train Epoch: 14 [17920/60000 (30%)]	Loss: 0.209682
Train Epoch: 14 [20480/60000 (34%)]	Loss: 0.143732
Train Epoch: 14 [23040/60000 (38%)]	Loss: 0.247467
Train Epoch: 14 [25600/60000 (43%)]	Loss: 0.141316
Train Epoch: 14 [28160/60000 (47%)]	Loss: 0.156982
Train Epoch: 14 [30720/60000 (51%)]	Loss: 0.249250
Train Epoch: 14 [33280/60000 (55%)]	Loss: 0.252457
Train Epoch: 14 [35840/60000 (60%)]	Loss: 0.137284
Train Epoch: 14 [38400/60000 (64%)]	Loss: 0.212023
Train Epoch: 14 [40960/60000 (68%)]	Loss: 0.227320
Train Epoch: 14 [43520/60000 (72%)]	Loss: 0.200754
Train Epoch: 14 [46080/60000 (77%)]	Loss: 0.197454
Train Epoch: 14 [48640/60000 (81%)]	Loss: 0.200271
Train Epoch: 14 [51200/60000 (85%)]	Loss: 0.135254
Train Epoch: 14 [53760/60000 (90%)]	Loss: 0.137874
Train Epoch: 14 [56320/60000 (94%)]	Loss: 0.213711
Train Epoch: 14 [58880/60000 (98%)]	Loss: 0.286362

Test set: Avg. loss: 0.0631, Accuracy: 9794/10000 (98%)

参考文献

  1. 用PyTorch实现MNIST手写数字识别(非常详细)