# LeNet

## 模型

```python
import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 第一层卷积网络，输入图片是 32 * 32，彩色图片有RGB 3个通道，记为：input(3, 32, 32)
        # 设置 16 个 filter 卷积核，那么本层网络输出的图像就有16层
        # 那么传递给下一层网络，下一层网络的 in_channels 就是 16
        # filter 大小是 5*5，那么输出图像的大小就是 32-5 +1 = 28，即 28 * 28 的输出大小
        # 本层网络输出记为：output(16, 28, 28)
        self.conv1 = nn.Conv2d(
            in_channels=3, # 第一层网络的通道数，如果是黑白的那就只有1个通道
            out_channels=16, # 相当于 filter 的个数
            kernel_size=5,  # filter 宽高设置为 5*5
            stride=1,  # filter每次移动的步长
            padding=0  # 扩展图片的边界，这里扩展0层，如果是大于0，那么扩展的像素值是设置为0的，0是黑色，255是白色
            )
        # 2*2 的池化层, 把图片宽高变成一半，channel不变
        # input(16, 28, 28) -> output(16, 14, 14)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 卷积层, in_channels 是上一层的输出 16
        # 设置本层卷积 out_channels 为 32，filter 宽高设置为 5*5
        # 14 - 5 + 1 = 10
        # input(16, 14, 14) -> output(32, 10, 10)
        self.conv2 = nn.Conv2d(16, 32, 5)
        # 池化, input(32, 10, 10) -> output(32, 5, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        # 全连接层, 在 forward 中会把图片拉平成一维向量再传给全连接层
        # 因此，全连接层的 in_channels 要用上一层的 out_channels 乘上图片的宽高，即：32*5*5
        # input(32*5*5) -> output(120)
        self.fc1 = nn.Linear(in_channels=32*5*5, out_channels=120)
        # 全连接层，input(120) -> output(84)
        self.fc2 = nn.Linear(120, 84)
        # 最后一层网络，要做分类了，out_channels 要设置为类别总数，input(84) -> output(10)
        self.fc3 = nn.Linear(84, 10)

    # 这个就是数据集在网络中的逐层前向传播触发函数, 框架会自动调用这个函数
    # x 的 shape 是 [batch, channel, height, width]
    def forward(self, x):
        # 先走 conv1，然后 relu 处理 conv1 的结果
        x = F.relu(self.conv1(x))    
        x = self.pool1(x)            
        x = F.relu(self.conv2(x))    
        x = self.pool2(x)           
        # 把图像拉平成一维向量再传递给全连接层处理
        x = x.view(-1, 32*5*5)       
        x = F.relu(self.fc1(x))      
        x = F.relu(self.fc2(x))      
        x = self.fc3(x)              
        return x
```

## 训练

```python
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms


def main():
    # 数据集的预处理
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # 50000张训练图片
    train_set = torchvision.datasets.CIFAR10(
        root='./data', # 本数据集的保存路径 
        train=True,  # 是需要训练的训练集数据
        download=False,  # 第一次使用时要将download设置为True才会自动去下载数据集
        transform=transform  # 设置数据预处理
        )
    train_loader = torch.utils.data.DataLoader(
        train_set, 
        batch_size=36,  # 分批次加载，设置每一批是36个样本, 在机器内存允许范围内，batch_size设置的越大，训练效果越好
        shuffle=True,  # 是否随机提取样本
        num_workers=0
        )

    # 10000张验证图片
    val_set = torchvision.datasets.CIFAR10(
        root='./data', 
        train=False,  # 设置为False，表示是不需要训练的测试数据集
        download=False, 
        transform=transform)
    val_loader = torch.utils.data.DataLoader(
        val_set, 
        batch_size=5000,
        shuffle=False, 
        num_workers=0)
    val_data_iter = iter(val_loader)
    val_image, val_label = val_data_iter.next()
    
    # classes = ('plane', 'car', 'bird', 'cat',
    #            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    net = LeNet()
    # 进 CrossEntropyLoss 看说明可以看出这个损失函数已经包含了 softmax
    loss_function = nn.CrossEntropyLoss()

    # 优化器
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    # 对训练数据集迭代5轮
    for epoch in range(5):  
        running_loss = 0.0
        # 按批次遍历数据集
        for step, data in enumerate(train_loader, start=0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = net(inputs)  # 对定义的网络正向传播得到的输出
            loss = loss_function(outputs, labels)  # 对训练集的模型预测值和实际标签值计算损失函数
            loss.backward()  # 反向传播
            optimizer.step()  # 使用优化器的参数更新

            # print statistics
            running_loss += loss.item()
            if step % 500 == 499:    # print every 500 mini-batches
                # 这个步骤只是打印输出，指定在这个代码块里面不要去计算误差损失梯度，在验证测试和预测的环节都需要带上这个函数
                with torch.no_grad():  
                    # 对 val_image 这一批样本做预测
                    outputs = net(val_image)  # [batch, 10]
                    # 得到最有可能的类别
                    predict_y = torch.max(outputs, dim=1)[1]
                    # 与真实类别做比较，得到预测的正确率
                    accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)

                    print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %
                          (epoch + 1, step + 1, running_loss / 500, accuracy))
                    # 把参数清零，进入下个批次的训练
                    running_loss = 0.0

    print('Finished Training')

    # 保存模型
    save_path = './Lenet.pth'
    torch.save(net.state_dict(), save_path)


if __name__ == '__main__':
    main()
```

## 预测

```python
import torch
import torchvision.transforms as transforms
from PIL import Image

from model import LeNet


def main():
    transform = transforms.Compose(
        [transforms.Resize((32, 32)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    net = LeNet()
    # 加载模型权重文件
    net.load_state_dict(torch.load('Lenet.pth'))

    im = Image.open('1.jpg')  # 读取的图片是 [高度，宽度，通道数]
    im = transform(im)  # 预处理，比训练的预处理多加了一个Resize，使像素跟训练集一样 -> [C, H, W]
    im = torch.unsqueeze(im, dim=0)  # 再最前面增加一个维度，赋值为0 -> [N, C, H, W]

    with torch.no_grad():  # 指定不计算损失函数
        outputs = net(im)
        # 得到可能性最大的类别的索引
        predict = torch.max(outputs, dim=1)[1].data.numpy()
        # 也可以使用softmax处理：torch.softmax(outputs, dim=1)
    # 根据类别的索引号得到类别
    print(classes[int(predict)])


if __name__ == '__main__':
    main()
```


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://www.1024cx.top/ai/deep_learning/lenet.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
