LeNet

模型

import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 第一层卷积网络,输入图片是 32 * 32,彩色图片有RGB 3个通道,记为:input(3, 32, 32)
        # 设置 16 个 filter 卷积核,那么本层网络输出的图像就有16层
        # 那么传递给下一层网络,下一层网络的 in_channels 就是 16
        # filter 大小是 5*5,那么输出图像的大小就是 32-5 +1 = 28,即 28 * 28 的输出大小
        # 本层网络输出记为:output(16, 28, 28)
        self.conv1 = nn.Conv2d(
            in_channels=3, # 第一层网络的通道数,如果是黑白的那就只有1个通道
            out_channels=16, # 相当于 filter 的个数
            kernel_size=5,  # filter 宽高设置为 5*5
            stride=1,  # filter每次移动的步长
            padding=0  # 扩展图片的边界,这里扩展0层,如果是大于0,那么扩展的像素值是设置为0的,0是黑色,255是白色
            )
        # 2*2 的池化层, 把图片宽高变成一半,channel不变
        # input(16, 28, 28) -> output(16, 14, 14)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 卷积层, in_channels 是上一层的输出 16
        # 设置本层卷积 out_channels 为 32,filter 宽高设置为 5*5
        # 14 - 5 + 1 = 10
        # input(16, 14, 14) -> output(32, 10, 10)
        self.conv2 = nn.Conv2d(16, 32, 5)
        # 池化, input(32, 10, 10) -> output(32, 5, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        # 全连接层, 在 forward 中会把图片拉平成一维向量再传给全连接层
        # 因此,全连接层的 in_channels 要用上一层的 out_channels 乘上图片的宽高,即:32*5*5
        # input(32*5*5) -> output(120)
        self.fc1 = nn.Linear(in_channels=32*5*5, out_channels=120)
        # 全连接层,input(120) -> output(84)
        self.fc2 = nn.Linear(120, 84)
        # 最后一层网络,要做分类了,out_channels 要设置为类别总数,input(84) -> output(10)
        self.fc3 = nn.Linear(84, 10)

    # 这个就是数据集在网络中的逐层前向传播触发函数, 框架会自动调用这个函数
    # x 的 shape 是 [batch, channel, height, width]
    def forward(self, x):
        # 先走 conv1,然后 relu 处理 conv1 的结果
        x = F.relu(self.conv1(x))    
        x = self.pool1(x)            
        x = F.relu(self.conv2(x))    
        x = self.pool2(x)           
        # 把图像拉平成一维向量再传递给全连接层处理
        x = x.view(-1, 32*5*5)       
        x = F.relu(self.fc1(x))      
        x = F.relu(self.fc2(x))      
        x = self.fc3(x)              
        return x

训练

预测

Last updated

Was this helpful?