import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
def main():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 50000张训练图片
# 第一次使用时要将download设置为True才会自动去下载数据集
train_set = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36, # 在机器内存允许范围内,batch_size设置的越大,训练效果越好
shuffle=True, num_workers=0)
# 10000张验证图片
# 第一次使用时要将download设置为True才会自动去下载数据集
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,
shuffle=False, num_workers=0)
val_data_iter = iter(val_loader)
val_image, val_label = val_data_iter.next()
# classes = ('plane', 'car', 'bird', 'cat',
# 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net = LeNet()
# 进 CrossEntropyLoss 看说明可以看出它已经包含了 softmax
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
# 对训练数据集迭代5轮
for epoch in range(5): # loop over the dataset multiple times
running_loss = 0.0
# 按批次遍历数据集
for step, data in enumerate(train_loader, start=0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs) # 对定义的网络正向传播得到的输出
loss = loss_function(outputs, labels) # 对训练集的模型预测值和实际标签值计算损失函数
loss.backward() # 反向传播
optimizer.step() # 使用优化器的参数更新
# print statistics
running_loss += loss.item()
if step % 500 == 499: # print every 500 mini-batches
# 这个步骤只是打印输出,指定在这个代码块里面不要去计算误差损失梯度,在验证测试和预测的环节都需要带上这个函数
with torch.no_grad():
# 对 val_image 这一批样本做预测
outputs = net(val_image) # [batch, 10]
# 得到最有可能的类别
predict_y = torch.max(outputs, dim=1)[1]
# 与真实类别做比较,得到预测的正确率
accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)
print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
(epoch + 1, step + 1, running_loss / 500, accuracy))
# 把参数清零,进入下个批次的训练
running_loss = 0.0
print('Finished Training')
# 保存模型
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)
if __name__ == '__main__':
main()