PyTorch 图像分类,数据集采用内置的 MNIST.
加载数据集 1 2 3 4 5 6 7 8 9 10 11 12 13 import matplotlib.pyplot as pltimport numpy as npimport torchimport torchvisionimport visdomfrom torch import nnfrom torch.utils.data import DataLoaderfrom torchvision import transformstorch.manual_seed(33 ) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu" ) device
device(type='cuda', index=0)
1 2 3 4 5 6 7 8 9 10 11 12 13 train_ds = torchvision.datasets.MNIST( "/workspace/disk1/datasets/" , train=True , transform=transforms.ToTensor(), download=True , ) test_ds = torchvision.datasets.MNIST( "/workspace/disk1/datasets/" , train=False , transform=transforms.ToTensor(), download=True , )
1 2 3 batch_size = 64 train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True ) test_dl = DataLoader(test_ds, batch_size=batch_size)
1 imgs, labels = next (iter (train_dl))
torch.Size([64, 1, 28, 28])
torch.Size([64])
1 2 3 4 5 6 plt.figure(figsize=(batch_size, 1 )) for i, img in enumerate (imgs): img_np = img.numpy().squeeze() plt.subplot(1 , batch_size, i + 1 ) plt.imshow(img_np, cmap="gray" ) plt.axis("off" )
tensor([3, 0, 5, 2, 3, 4, 5, 9, 1, 7, 4, 7, 8, 4, 2, 1, 7, 9, 8, 3, 4, 9, 7, 5,
0, 2, 4, 2, 5, 7, 6, 4, 2, 8, 8, 5, 6, 0, 6, 4, 9, 5, 9, 9, 9, 4, 9, 8,
8, 6, 9, 3, 2, 2, 2, 5, 0, 4, 9, 3, 0, 8, 3, 2])
创建模型 1 2 3 4 5 6 7 8 9 10 11 12 13 class MLPModel (nn.Module): def __init__ (self ): super (MLPModel, self).__init__() self.linear1 = nn.Linear(28 * 28 , 128 ) self.linear2 = nn.Linear(128 , 64 ) self.linear3 = nn.Linear(64 , 10 ) def forward (self, inputs ): x = inputs.view(-1 , 1 * 28 * 28 ) x = torch.relu(self.linear1(x)) x = torch.relu(self.linear2(x)) logits = self.linear3(x) return logits
1 2 model = MLPModel().to(device) model
MLPModel(
(linear1): Linear(in_features=784, out_features=128, bias=True)
(linear2): Linear(in_features=128, out_features=64, bias=True)
(linear3): Linear(in_features=64, out_features=10, bias=True)
)
1 loss_fn = torch.nn.CrossEntropyLoss()
1 2 optimizer = torch.optim.Adam(model.parameters(), lr=0.0001 ) optimizer
Adam (
Parameter Group 0
amsgrad: False
betas: (0.9, 0.999)
capturable: False
eps: 1e-08
foreach: None
lr: 0.0001
maximize: False
weight_decay: 0
)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 def train (dl, model, loss_fn, optimizer ): size = len (dl.dataset) num_batches = len (dl) train_loss, correct = 0 , 0 for x, y in dl: x, y = x.to(device), y.to(device) pred = model(x) loss = loss_fn(pred, y) optimizer.zero_grad() loss.backward() optimizer.step() with torch.no_grad(): correct += (pred.argmax(1 ) == y).type (torch.float ).sum ().item() train_loss += loss.item() correct /= size train_loss /= num_batches return correct, train_loss
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 def test (dl, model, loss_fn ): size = len (dl.dataset) num_batches = len (dl) test_loss, correct = 0 , 0 with torch.no_grad(): for x, y in dl: x, y = x.to(device), y.to(device) pred = model(x) loss = loss_fn(pred, y) test_loss += loss.item() correct += (pred.argmax(1 ) == y).type (torch.float ).sum ().item() correct /= size test_loss /= num_batches return correct, test_loss
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 viz = visdom.Visdom( server="http://localhost" , port=8097 , base_url="/visdom" , username="jinzhongxu" , password="123123" , ) win = "mnist" opts = dict ( title="MNIST" , xlabel="epoch" , ylabel="loss and acc" , markers=True , legend=["train_loss" , "train_acc" , "test_loss" , "test_acc" ], ) viz.line( [[0.0 , 0.0 , 0.0 , 0.0 ]], [0.0 ], win=win, opts=opts, )
Setting up a new session...
'mnist'
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 epochs = 50 train_loss = [] train_acc = [] test_loss = [] test_acc = [] for epoch in range (epochs): epoch_acc, epoch_loss = train( dl=train_dl, model=model, loss_fn=loss_fn, optimizer=optimizer ) train_loss.append(epoch_loss) train_acc.append(epoch_acc) epoch_test_acc, epoch_test_loss = test(dl=test_dl, model=model, loss_fn=loss_fn) test_loss.append(epoch_test_loss) test_acc.append(epoch_test_acc) print ( f"epoch={epoch:2d} , train_loss={epoch_loss:.5 f} , train_acc={epoch_acc:.5 f} , test_loss={epoch_test_loss:.5 f} , test_acc={epoch_test_acc:.5 f} " ) viz.line( [[epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc]], [epoch], win=win, update="append" , ) print ("done!" )
epoch= 0, train_loss=0.87553, train_acc=0.78492, test_loss=0.37957, test_acc=0.89600
epoch= 1, train_loss=0.34057, train_acc=0.90542, test_loss=0.29167, test_acc=0.91640
epoch= 2, train_loss=0.28640, train_acc=0.91793, test_loss=0.26019, test_acc=0.92450
epoch= 3, train_loss=0.25685, train_acc=0.92650, test_loss=0.24007, test_acc=0.92930
epoch= 4, train_loss=0.23372, train_acc=0.93333, test_loss=0.21881, test_acc=0.93560
epoch= 5, train_loss=0.21370, train_acc=0.93912, test_loss=0.20176, test_acc=0.93990
epoch= 6, train_loss=0.19669, train_acc=0.94297, test_loss=0.18557, test_acc=0.94430
epoch= 7, train_loss=0.18122, train_acc=0.94835, test_loss=0.17462, test_acc=0.94620
epoch= 8, train_loss=0.16796, train_acc=0.95127, test_loss=0.16615, test_acc=0.94940
epoch= 9, train_loss=0.15605, train_acc=0.95473, test_loss=0.15155, test_acc=0.95460
epoch=10, train_loss=0.14516, train_acc=0.95775, test_loss=0.14506, test_acc=0.95670
epoch=11, train_loss=0.13557, train_acc=0.96103, test_loss=0.13445, test_acc=0.95970
epoch=12, train_loss=0.12738, train_acc=0.96342, test_loss=0.13094, test_acc=0.96010
epoch=13, train_loss=0.11912, train_acc=0.96610, test_loss=0.12227, test_acc=0.96210
epoch=14, train_loss=0.11219, train_acc=0.96753, test_loss=0.11731, test_acc=0.96430
epoch=15, train_loss=0.10571, train_acc=0.96947, test_loss=0.11181, test_acc=0.96500
epoch=16, train_loss=0.09996, train_acc=0.97147, test_loss=0.10745, test_acc=0.96670
epoch=17, train_loss=0.09438, train_acc=0.97308, test_loss=0.10555, test_acc=0.96800
epoch=18, train_loss=0.08965, train_acc=0.97438, test_loss=0.10191, test_acc=0.96840
epoch=19, train_loss=0.08477, train_acc=0.97557, test_loss=0.09853, test_acc=0.96930
epoch=20, train_loss=0.08065, train_acc=0.97690, test_loss=0.09546, test_acc=0.96970
epoch=21, train_loss=0.07642, train_acc=0.97827, test_loss=0.09460, test_acc=0.97060
epoch=22, train_loss=0.07243, train_acc=0.97918, test_loss=0.09040, test_acc=0.97170
epoch=23, train_loss=0.06898, train_acc=0.98013, test_loss=0.08840, test_acc=0.97270
epoch=24, train_loss=0.06559, train_acc=0.98123, test_loss=0.08831, test_acc=0.97240
epoch=25, train_loss=0.06238, train_acc=0.98242, test_loss=0.08451, test_acc=0.97450
epoch=26, train_loss=0.05947, train_acc=0.98308, test_loss=0.08525, test_acc=0.97340
epoch=27, train_loss=0.05665, train_acc=0.98370, test_loss=0.08331, test_acc=0.97420
epoch=28, train_loss=0.05389, train_acc=0.98510, test_loss=0.08325, test_acc=0.97480
epoch=29, train_loss=0.05153, train_acc=0.98535, test_loss=0.08162, test_acc=0.97450
epoch=30, train_loss=0.04908, train_acc=0.98628, test_loss=0.07992, test_acc=0.97540
epoch=31, train_loss=0.04710, train_acc=0.98658, test_loss=0.07741, test_acc=0.97600
epoch=32, train_loss=0.04476, train_acc=0.98773, test_loss=0.07945, test_acc=0.97460
epoch=33, train_loss=0.04273, train_acc=0.98813, test_loss=0.07803, test_acc=0.97500
epoch=34, train_loss=0.04049, train_acc=0.98873, test_loss=0.07625, test_acc=0.97520
epoch=35, train_loss=0.03883, train_acc=0.98968, test_loss=0.07546, test_acc=0.97660
epoch=36, train_loss=0.03686, train_acc=0.99037, test_loss=0.07731, test_acc=0.97510
epoch=37, train_loss=0.03529, train_acc=0.99060, test_loss=0.07601, test_acc=0.97570
epoch=38, train_loss=0.03339, train_acc=0.99118, test_loss=0.07800, test_acc=0.97490
epoch=39, train_loss=0.03212, train_acc=0.99150, test_loss=0.07530, test_acc=0.97650
epoch=40, train_loss=0.03038, train_acc=0.99222, test_loss=0.07336, test_acc=0.97610
epoch=41, train_loss=0.02889, train_acc=0.99262, test_loss=0.07662, test_acc=0.97680
epoch=42, train_loss=0.02742, train_acc=0.99350, test_loss=0.07404, test_acc=0.97700
epoch=43, train_loss=0.02625, train_acc=0.99347, test_loss=0.07493, test_acc=0.97660
epoch=44, train_loss=0.02480, train_acc=0.99420, test_loss=0.07400, test_acc=0.97710
epoch=45, train_loss=0.02360, train_acc=0.99417, test_loss=0.07704, test_acc=0.97700
epoch=46, train_loss=0.02243, train_acc=0.99490, test_loss=0.07595, test_acc=0.97790
epoch=47, train_loss=0.02117, train_acc=0.99525, test_loss=0.07470, test_acc=0.97700
epoch=48, train_loss=0.02004, train_acc=0.99557, test_loss=0.07563, test_acc=0.97650
epoch=49, train_loss=0.01895, train_acc=0.99598, test_loss=0.07576, test_acc=0.97750
done!
损失和测试准确率曲线: