原理

背景与动机

一个基本事实:

随着网络加深,原始输入会被扭曲,信息会丢失,模型的训练误差会增加。(“退化问题”)

抽象描述:

损失了信息的图像 + 损失的信息 = 信息完整的图像

想法:

损失的信息 = 信息完整的图像 - 损失了信息的图像

如果我们能够将损失的信息(称为残差)以某种方式保存起来,然后在后续的网络中加回,则可以一定程度遏制这种扭曲,从而遏制误差变大

于是就有了 ResNet 的概念

残差块(Residual Block)

在传统的深层神经网络中,假设某一层的输入为 $\mathbf{x}$,希望学习一个映射函数 $H(\mathbf{x})$。但是,直接优化这一映射会变得困难,尤其是在层数很深的情况下。

ResNet提出,引入一个恒等映射,令每个残差块学习一个“残差”函数 $F(\mathbf{x}) = H(\mathbf{x}) - \mathbf{x}$,即:

$$

H(\mathbf{x}) = F(\mathbf{x}) + \mathbf{x}

$$

这种结构称为“跳跃连接”(Skip Connection),它将输入直接加入到输出,使得梯度可以更顺畅地在网络中传递,缓解了梯度消失和梯度爆炸的问题。

系统分析

从数据流的角度来看,输入数据在残差块中经历以下几步:

卷积层1:输入 $\mathbf{x}$ 通过第一个卷积层进行特征提取,得到特征图。

批归一化(Batch Normalization):对卷积输出进行归一化,加快训练速度并稳定训练过程。

激活函数(ReLU):引入非线性,使得模型能够拟合更复杂的函数。

卷积层2:再次进行卷积操作,提取更加复杂的特征。

批归一化:再次对卷积输出进行归一化。

跳跃连接:将输入 $\mathbf{x}$ 直接加到经过两次卷积和归一化后的输出上,形成最终的输出 $\mathbf{y} = F(\mathbf{x}) + \mathbf{x}$。

激活函数(ReLU):对跳跃连接后的输出进行非线性处理。

整个数据流如图所示:

x -----> [Conv1] -> [BN1] -> [ReLU] -> [Conv2] -> [BN2] ----+

| |

+----------------------------------------------------+

|

[ReLU]

4. 网络架构

ResNet的典型架构由多个残差块堆叠而成。以ResNet-18为例:

初始层:一个7x7的卷积层,步幅为2,后接一个3x3的最大池化层,步幅为2。

残差块层:4个阶段,每个阶段包含若干个基本残差块。每个阶段的特征图数量通常为64、128、256、512。

全局平均池化:将特征图压缩为全局特征。

全连接层:输出最终的分类结果。

实现 ResNet

残差块

1import torch

2import torch.nn as nn

3import torch.nn.functional as F

4

5class BasicBlock(nn.Module):

6 expansion = 1 # 对于BasicBlock,通道数扩展倍率为1

7

8 def __init__(self, in_channels, out_channels, stride=1, downsample=None):

9 super(BasicBlock, self).__init__()

10 # 第一个卷积层

11 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,

12 stride=stride, padding=1, bias=False)

13 self.bn1 = nn.BatchNorm2d(out_channels)

14 # 第二个卷积层

15 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,

16 stride=1, padding=1, bias=False)

17 self.bn2 = nn.BatchNorm2d(out_channels)

18 # 下采样

19 self.downsample = downsample

20 self.relu = nn.ReLU(inplace=True)

21

22 def forward(self, x):

23 identity = x

24

25 out = self.conv1(x) # 卷积层1

26 out = self.bn1(out) # 批归一化1

27 out = self.relu(out) # ReLU

28

29 out = self.conv2(out) # 卷积层2

30 out = self.bn2(out) # 批归一化2

31

32 if self.downsample is not None:

33 identity = self.downsample(x) # 下采样调整维度

34

35 out += identity # 跳跃连接

36 out = self.relu(out) # ReLU

37

38 return out

核心就是在激活前进行 out += identity

ResNet

1class ResNet(nn.Module):

2 def __init__(self, block, layers, num_classes=10): # 以CIFAR-10为例

3 super(ResNet, self).__init__()

4 self.in_channels = 64

5 # 初始卷积层

6 self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3,

7 stride=1, padding=1, bias=False) # CIFAR-10不需要7x7卷积和池化

8 self.bn1 = nn.BatchNorm2d(self.in_channels)

9 self.relu = nn.ReLU(inplace=True)

10 # 残差块层

11 self.layer1 = self._make_layer(block, 64, layers[0]) # 64通道

12 self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # 128通道,步幅2

13 self.layer3 = self._make_layer(block, 256, layers[2], stride=2) # 256通道,步幅2

14 self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # 512通道,步幅2

15 # 全局平均池化和全连接层

16 self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

17 self.fc = nn.Linear(512 * block.expansion, num_classes)

18

19 # 权重初始化

20 for m in self.modules():

21 if isinstance(m, nn.Conv2d):

22 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

23 elif isinstance(m, nn.BatchNorm2d):

24 nn.init.constant_(m.weight, 1)

25 nn.init.constant_(m.bias, 0)

26

27 def _make_layer(self, block, out_channels, blocks, stride=1):

28 downsample = None

29 # 如果输入和输出维度不一致,或步幅不为1,需要下采样

30 if stride != 1 or self.in_channels != out_channels * block.expansion:

31 downsample = nn.Sequential(

32 nn.Conv2d(self.in_channels, out_channels * block.expansion,

33 kernel_size=1, stride=stride, bias=False),

34 nn.BatchNorm2d(out_channels * block.expansion),

35 )

36 layers = []

37 layers.append(block(self.in_channels, out_channels, stride, downsample))

38 self.in_channels = out_channels * block.expansion

39 for _ in range(1, blocks):

40 layers.append(block(self.in_channels, out_channels)) # stride=1

41 return nn.Sequential(*layers)

42

43 def forward(self, x):

44 out = self.conv1(x) # 初始卷积

45 out = self.bn1(out)

46 out = self.relu(out)

47

48 out = self.layer1(out) # 残差块1

49 out = self.layer2(out) # 残差块2

50 out = self.layer3(out) # 残差块3

51 out = self.layer4(out) # 残差块4

52

53 out = self.avg_pool(out) # 全局平均池化

54 out = torch.flatten(out, 1)

55 out = self.fc(out) # 全连接层

56

57 return out

构建 ResNet-18

ResNet-18由4个阶段,每个阶段包含2个BasicBlock。

1def ResNet18(num_classes=10):

2 return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

数据集准备

使用CIFAR-10数据集进行训练和测试。

1import torch

2import torch.optim as optim

3import torchvision

4import torchvision.transforms as transforms

5

6# 数据预处理

7transform_train = transforms.Compose([

8 transforms.RandomHorizontalFlip(), # 随机水平翻转

9 transforms.RandomCrop(32, padding=4), # 随机裁剪

10 transforms.ToTensor(),

11 transforms.Normalize((0.4914, 0.4822, 0.4465),

12 (0.2023, 0.1994, 0.2010)),

13])

14

15transform_test = transforms.Compose([

16 transforms.ToTensor(),

17 transforms.Normalize((0.4914, 0.4822, 0.4465),

18 (0.2023, 0.1994, 0.2010)),

19])

20

21# 加载数据

22trainset = torchvision.datasets.CIFAR10(root='./data', train=True,

23 download=True, transform=transform_train)

24trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,

25 shuffle=True, num_workers=2)

26

27testset = torchvision.datasets.CIFAR10(root='./data', train=False,

28 download=True, transform=transform_test)

29testloader = torch.utils.data.DataLoader(testset, batch_size=100,

30 shuffle=False, num_workers=2)

31

32classes = ('plane', 'car', 'bird', 'cat', 'deer',

33 'dog', 'frog', 'horse', 'ship', 'truck')

训练与测试

1def train(model, device, trainloader, optimizer, criterion, epoch):

2 model.train()

3 running_loss = 0.0

4 total = 0

5 correct = 0

6 for batch_idx, (inputs, targets) in enumerate(trainloader):

7 inputs, targets = inputs.to(device), targets.to(device)

8

9 optimizer.zero_grad() # 梯度清零

10 outputs = model(inputs) # 前向传播

11 loss = criterion(outputs, targets) # 计算损失

12 loss.backward() # 反向传播

13 optimizer.step() # 更新参数

14

15 running_loss += loss.item()

16 _, predicted = outputs.max(1)

17 total += targets.size(0)

18 correct += predicted.eq(targets).sum().item()

19

20 if batch_idx % 100 == 99: # 每100个batch打印一次

21 print(f'Epoch [{epoch}], Batch [{batch_idx+1}], Loss: {running_loss / 100:.3f}, '

22 f'Accuracy: {100. * correct / total:.2f}%')

23 running_loss = 0.0

24

25def test(model, device, testloader, criterion):

26 model.eval()

27 test_loss = 0

28 correct = 0

29 total = 0

30 with torch.no_grad():

31 for inputs, targets in testloader:

32 inputs, targets = inputs.to(device), targets.to(device)

33 outputs = model(inputs)

34 loss = criterion(outputs, targets)

35

36 test_loss += loss.item()

37 _, predicted = outputs.max(1)

38 total += targets.size(0)

39 correct += predicted.eq(targets).sum().item()

40 print(f'Test Loss: {test_loss / len(testloader):.3f}, '

41 f'Test Accuracy: {100. * correct / total:.2f}%')

42 return 100. * correct / total

主程序

1import time

2

3device = 'cuda' if torch.cuda.is_available() else 'cpu'

4print(f'device: {device}')

5

6model = ResNet18(num_classes=10).to(device)

7

8criterion = nn.CrossEntropyLoss()

9optimizer = optim.SGD(model.parameters(), lr=0.1,

10 momentum=0.9, weight_decay=5e-4)

11scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

12

13num_epochs = 100

14best_acc = 0

15

16for epoch in range(1, num_epochs + 1):

17 start_time = time.time()

18 train(model, device, trainloader, optimizer, criterion, epoch)

19 acc = test(model, device, testloader, criterion)

20 if acc > best_acc:

21 best_acc = acc

22 # 保存最佳模型

23 torch.save(model.state_dict(), 'best_resnet18.pth')

24 scheduler.step()

25 end_time = time.time()

26 print(f'Epoch [{epoch}] Done in {end_time - start_time:.2f}s\n')

27

28print(f'Best accuracy {best_acc:.2f}%')

检验性能

1# 加载最佳模型

2model.load_state_dict(torch.load('best_resnet18.pth'))

3test_acc = test(model, device, testloader, criterion)

4print(f'Test accuracy: {test_acc:.2f}%')

Test Acc 能到 93 左右,还行。

$ python impl.py

Files already downloaded and verified

Files already downloaded and verified

使用设备: cuda

Epoch [1], Batch [100], Loss: 2.422, Accuracy: 20.01%

Epoch [1], Batch [200], Loss: 1.868, Accuracy: 24.83%

Epoch [1], Batch [300], Loss: 1.707, Accuracy: 28.58%

Test Loss: 1.544, Test Accuracy: 43.36%

Epoch [1] Done in 20.87s

Epoch [2], Batch [100], Loss: 1.518, Accuracy: 44.02%

Epoch [2], Batch [200], Loss: 1.406, Accuracy: 46.39%

Epoch [2], Batch [300], Loss: 1.322, Accuracy: 48.25%

Test Loss: 1.230, Test Accuracy: 57.20%

Epoch [2] Done in 20.73s

Epoch [3], Batch [100], Loss: 1.142, Accuracy: 59.45%

Epoch [3], Batch [200], Loss: 1.070, Accuracy: 60.82%

Epoch [3], Batch [300], Loss: 1.013, Accuracy: 61.90%

Test Loss: 1.114, Test Accuracy: 62.81%

Epoch [3] Done in 20.72s

Epoch [4], Batch [100], Loss: 0.906, Accuracy: 67.95%

Epoch [4], Batch [200], Loss: 0.838, Accuracy: 69.22%

Epoch [4], Batch [300], Loss: 0.808, Accuracy: 70.00%

Test Loss: 0.874, Test Accuracy: 69.42%

Epoch [4] Done in 20.72s

Epoch [5], Batch [100], Loss: 0.713, Accuracy: 75.28%

Epoch [5], Batch [200], Loss: 0.706, Accuracy: 75.52%

Epoch [5], Batch [300], Loss: 0.681, Accuracy: 75.67%

Test Loss: 0.752, Test Accuracy: 74.31%

Epoch [5] Done in 20.71s

Epoch [6], Batch [100], Loss: 0.617, Accuracy: 78.60%

Epoch [6], Batch [200], Loss: 0.620, Accuracy: 78.71%

Epoch [6], Batch [300], Loss: 0.597, Accuracy: 78.91%

Test Loss: 0.779, Test Accuracy: 73.72%

Epoch [6] Done in 20.66s

Epoch [7], Batch [100], Loss: 0.556, Accuracy: 80.91%

Epoch [7], Batch [200], Loss: 0.547, Accuracy: 81.07%

Epoch [7], Batch [300], Loss: 0.566, Accuracy: 80.84%

Test Loss: 0.697, Test Accuracy: 75.82%

Epoch [7] Done in 20.75s

Epoch [8], Batch [100], Loss: 0.520, Accuracy: 82.01%

Epoch [8], Batch [200], Loss: 0.528, Accuracy: 82.11%

Epoch [8], Batch [300], Loss: 0.530, Accuracy: 81.93%

Test Loss: 0.625, Test Accuracy: 78.78%

Epoch [8] Done in 20.72s

Epoch [9], Batch [100], Loss: 0.514, Accuracy: 82.43%

Epoch [9], Batch [200], Loss: 0.498, Accuracy: 82.74%

Epoch [9], Batch [300], Loss: 0.492, Accuracy: 82.85%

Test Loss: 0.528, Test Accuracy: 82.16%

Epoch [9] Done in 20.72s

Epoch [10], Batch [100], Loss: 0.469, Accuracy: 83.93%

Epoch [10], Batch [200], Loss: 0.474, Accuracy: 83.75%

Epoch [10], Batch [300], Loss: 0.497, Accuracy: 83.43%

Test Loss: 0.637, Test Accuracy: 78.31%

Epoch [10] Done in 20.69s

Epoch [11], Batch [100], Loss: 0.446, Accuracy: 84.35%

Epoch [11], Batch [200], Loss: 0.469, Accuracy: 84.33%

Epoch [11], Batch [300], Loss: 0.477, Accuracy: 84.14%

Test Loss: 0.527, Test Accuracy: 82.31%

Epoch [11] Done in 20.77s

Epoch [12], Batch [100], Loss: 0.427, Accuracy: 85.59%

Epoch [12], Batch [200], Loss: 0.472, Accuracy: 84.71%

Epoch [12], Batch [300], Loss: 0.472, Accuracy: 84.40%

Test Loss: 0.813, Test Accuracy: 73.67%

Epoch [12] Done in 20.71s

Epoch [13], Batch [100], Loss: 0.430, Accuracy: 85.02%

Epoch [13], Batch [200], Loss: 0.431, Accuracy: 85.09%

Epoch [13], Batch [300], Loss: 0.444, Accuracy: 84.98%

Test Loss: 0.821, Test Accuracy: 75.01%

Epoch [13] Done in 20.74s

Epoch [14], Batch [100], Loss: 0.410, Accuracy: 85.82%

Epoch [14], Batch [200], Loss: 0.431, Accuracy: 85.51%

Epoch [14], Batch [300], Loss: 0.425, Accuracy: 85.48%

Test Loss: 0.635, Test Accuracy: 79.43%

Epoch [14] Done in 20.73s

Epoch [15], Batch [100], Loss: 0.409, Accuracy: 86.09%

Epoch [15], Batch [200], Loss: 0.421, Accuracy: 85.82%

Epoch [15], Batch [300], Loss: 0.417, Accuracy: 85.93%

Test Loss: 0.450, Test Accuracy: 84.87%

Epoch [15] Done in 20.79s

Epoch [16], Batch [100], Loss: 0.385, Accuracy: 86.80%

Epoch [16], Batch [200], Loss: 0.422, Accuracy: 86.24%

Epoch [16], Batch [300], Loss: 0.417, Accuracy: 86.03%

Test Loss: 0.562, Test Accuracy: 81.20%

Epoch [16] Done in 20.70s

Epoch [17], Batch [100], Loss: 0.388, Accuracy: 86.75%

Epoch [17], Batch [200], Loss: 0.420, Accuracy: 86.22%

Epoch [17], Batch [300], Loss: 0.401, Accuracy: 86.27%

Test Loss: 0.840, Test Accuracy: 73.49%

Epoch [17] Done in 20.73s

Epoch [18], Batch [100], Loss: 0.400, Accuracy: 86.34%

Epoch [18], Batch [200], Loss: 0.387, Accuracy: 86.52%

Epoch [18], Batch [300], Loss: 0.399, Accuracy: 86.43%

Test Loss: 0.584, Test Accuracy: 81.25%

Epoch [18] Done in 20.70s

Epoch [19], Batch [100], Loss: 0.386, Accuracy: 87.20%

Epoch [19], Batch [200], Loss: 0.387, Accuracy: 86.97%

Epoch [19], Batch [300], Loss: 0.393, Accuracy: 86.82%

Test Loss: 0.570, Test Accuracy: 81.27%

Epoch [19] Done in 20.70s

Epoch [20], Batch [100], Loss: 0.386, Accuracy: 87.02%

Epoch [20], Batch [200], Loss: 0.391, Accuracy: 86.88%

Epoch [20], Batch [300], Loss: 0.398, Accuracy: 86.80%

Test Loss: 0.781, Test Accuracy: 76.02%

Epoch [20] Done in 20.72s

Epoch [21], Batch [100], Loss: 0.384, Accuracy: 86.77%

Epoch [21], Batch [200], Loss: 0.388, Accuracy: 86.75%

Epoch [21], Batch [300], Loss: 0.401, Accuracy: 86.61%

Test Loss: 0.619, Test Accuracy: 79.30%

Epoch [21] Done in 20.70s

Epoch [22], Batch [100], Loss: 0.372, Accuracy: 87.52%

Epoch [22], Batch [200], Loss: 0.362, Accuracy: 87.67%

Epoch [22], Batch [300], Loss: 0.378, Accuracy: 87.44%

Test Loss: 0.472, Test Accuracy: 84.23%

Epoch [22] Done in 20.71s

Epoch [23], Batch [100], Loss: 0.351, Accuracy: 88.22%

Epoch [23], Batch [200], Loss: 0.368, Accuracy: 87.98%

Epoch [23], Batch [300], Loss: 0.385, Accuracy: 87.66%

Test Loss: 0.591, Test Accuracy: 81.47%

Epoch [23] Done in 20.71s

Epoch [24], Batch [100], Loss: 0.343, Accuracy: 88.38%

Epoch [24], Batch [200], Loss: 0.362, Accuracy: 87.91%

Epoch [24], Batch [300], Loss: 0.380, Accuracy: 87.58%

Test Loss: 0.649, Test Accuracy: 79.41%

Epoch [24] Done in 20.71s

Epoch [25], Batch [100], Loss: 0.358, Accuracy: 87.67%

Epoch [25], Batch [200], Loss: 0.358, Accuracy: 87.72%

Epoch [25], Batch [300], Loss: 0.377, Accuracy: 87.50%

Test Loss: 0.614, Test Accuracy: 79.93%

Epoch [25] Done in 20.70s

Epoch [26], Batch [100], Loss: 0.355, Accuracy: 87.79%

Epoch [26], Batch [200], Loss: 0.351, Accuracy: 87.75%

Epoch [26], Batch [300], Loss: 0.375, Accuracy: 87.57%

Test Loss: 0.623, Test Accuracy: 79.35%

Epoch [26] Done in 20.71s

Epoch [27], Batch [100], Loss: 0.351, Accuracy: 88.19%

Epoch [27], Batch [200], Loss: 0.348, Accuracy: 88.20%

Epoch [27], Batch [300], Loss: 0.366, Accuracy: 87.99%

Test Loss: 0.673, Test Accuracy: 80.00%

Epoch [27] Done in 20.70s

Epoch [28], Batch [100], Loss: 0.349, Accuracy: 88.23%

Epoch [28], Batch [200], Loss: 0.361, Accuracy: 88.01%

Epoch [28], Batch [300], Loss: 0.369, Accuracy: 87.87%

Test Loss: 0.569, Test Accuracy: 82.28%

Epoch [28] Done in 20.71s

Epoch [29], Batch [100], Loss: 0.348, Accuracy: 88.22%

Epoch [29], Batch [200], Loss: 0.347, Accuracy: 88.32%

Epoch [29], Batch [300], Loss: 0.359, Accuracy: 88.14%

Test Loss: 0.478, Test Accuracy: 84.06%

Epoch [29] Done in 20.72s

Epoch [30], Batch [100], Loss: 0.359, Accuracy: 87.91%

Epoch [30], Batch [200], Loss: 0.346, Accuracy: 88.15%

Epoch [30], Batch [300], Loss: 0.360, Accuracy: 87.99%

Test Loss: 0.822, Test Accuracy: 74.72%

Epoch [30] Done in 20.73s

Epoch [31], Batch [100], Loss: 0.322, Accuracy: 88.74%

Epoch [31], Batch [200], Loss: 0.365, Accuracy: 88.21%

Epoch [31], Batch [300], Loss: 0.365, Accuracy: 88.01%

Test Loss: 0.517, Test Accuracy: 82.56%

Epoch [31] Done in 20.70s

Epoch [32], Batch [100], Loss: 0.334, Accuracy: 88.77%

Epoch [32], Batch [200], Loss: 0.325, Accuracy: 88.77%

Epoch [32], Batch [300], Loss: 0.353, Accuracy: 88.43%

Test Loss: 0.591, Test Accuracy: 80.54%

Epoch [32] Done in 20.70s

Epoch [33], Batch [100], Loss: 0.345, Accuracy: 88.23%

Epoch [33], Batch [200], Loss: 0.352, Accuracy: 88.08%

Epoch [33], Batch [300], Loss: 0.367, Accuracy: 87.92%

Test Loss: 0.494, Test Accuracy: 83.82%

Epoch [33] Done in 20.71s

Epoch [34], Batch [100], Loss: 0.339, Accuracy: 88.21%

Epoch [34], Batch [200], Loss: 0.345, Accuracy: 88.20%

Epoch [34], Batch [300], Loss: 0.351, Accuracy: 88.12%

Test Loss: 0.424, Test Accuracy: 85.62%

Epoch [34] Done in 20.76s

Epoch [35], Batch [100], Loss: 0.320, Accuracy: 89.38%

Epoch [35], Batch [200], Loss: 0.349, Accuracy: 88.67%

Epoch [35], Batch [300], Loss: 0.343, Accuracy: 88.48%

Test Loss: 0.448, Test Accuracy: 85.36%

Epoch [35] Done in 20.71s

Epoch [36], Batch [100], Loss: 0.328, Accuracy: 88.77%

Epoch [36], Batch [200], Loss: 0.339, Accuracy: 88.77%

Epoch [36], Batch [300], Loss: 0.374, Accuracy: 88.25%

Test Loss: 0.487, Test Accuracy: 84.15%

Epoch [36] Done in 20.71s

Epoch [37], Batch [100], Loss: 0.329, Accuracy: 88.97%

Epoch [37], Batch [200], Loss: 0.336, Accuracy: 88.75%

Epoch [37], Batch [300], Loss: 0.334, Accuracy: 88.63%

Test Loss: 0.484, Test Accuracy: 84.09%

Epoch [37] Done in 20.70s

Epoch [38], Batch [100], Loss: 0.340, Accuracy: 88.69%

Epoch [38], Batch [200], Loss: 0.329, Accuracy: 88.73%

Epoch [38], Batch [300], Loss: 0.344, Accuracy: 88.58%

Test Loss: 0.497, Test Accuracy: 83.75%

Epoch [38] Done in 20.72s

Epoch [39], Batch [100], Loss: 0.317, Accuracy: 89.23%

Epoch [39], Batch [200], Loss: 0.350, Accuracy: 88.67%

Epoch [39], Batch [300], Loss: 0.345, Accuracy: 88.57%

Test Loss: 0.410, Test Accuracy: 86.56%

Epoch [39] Done in 20.77s

Epoch [40], Batch [100], Loss: 0.331, Accuracy: 88.46%

Epoch [40], Batch [200], Loss: 0.335, Accuracy: 88.48%

Epoch [40], Batch [300], Loss: 0.326, Accuracy: 88.62%

Test Loss: 0.459, Test Accuracy: 85.06%

Epoch [40] Done in 20.71s

Epoch [41], Batch [100], Loss: 0.329, Accuracy: 88.83%

Epoch [41], Batch [200], Loss: 0.335, Accuracy: 88.59%

Epoch [41], Batch [300], Loss: 0.349, Accuracy: 88.40%

Test Loss: 0.532, Test Accuracy: 83.28%

Epoch [41] Done in 20.70s

Epoch [42], Batch [100], Loss: 0.301, Accuracy: 89.88%

Epoch [42], Batch [200], Loss: 0.340, Accuracy: 89.09%

Epoch [42], Batch [300], Loss: 0.349, Accuracy: 88.74%

Test Loss: 0.680, Test Accuracy: 78.78%

Epoch [42] Done in 20.72s

Epoch [43], Batch [100], Loss: 0.320, Accuracy: 89.06%

Epoch [43], Batch [200], Loss: 0.333, Accuracy: 88.81%

Epoch [43], Batch [300], Loss: 0.344, Accuracy: 88.68%

Test Loss: 0.401, Test Accuracy: 86.55%

Epoch [43] Done in 20.73s

Epoch [44], Batch [100], Loss: 0.343, Accuracy: 88.29%

Epoch [44], Batch [200], Loss: 0.335, Accuracy: 88.35%

Epoch [44], Batch [300], Loss: 0.333, Accuracy: 88.47%

Test Loss: 0.475, Test Accuracy: 84.56%

Epoch [44] Done in 20.70s

Epoch [45], Batch [100], Loss: 0.325, Accuracy: 89.16%

Epoch [45], Batch [200], Loss: 0.332, Accuracy: 88.90%

Epoch [45], Batch [300], Loss: 0.335, Accuracy: 88.85%

Test Loss: 0.491, Test Accuracy: 83.45%

Epoch [45] Done in 20.72s

Epoch [46], Batch [100], Loss: 0.328, Accuracy: 88.64%

Epoch [46], Batch [200], Loss: 0.335, Accuracy: 88.62%

Epoch [46], Batch [300], Loss: 0.332, Accuracy: 88.69%

Test Loss: 0.563, Test Accuracy: 81.67%

Epoch [46] Done in 20.70s

Epoch [47], Batch [100], Loss: 0.310, Accuracy: 89.43%

Epoch [47], Batch [200], Loss: 0.333, Accuracy: 88.90%

Epoch [47], Batch [300], Loss: 0.342, Accuracy: 88.62%

Test Loss: 0.556, Test Accuracy: 82.48%

Epoch [47] Done in 20.72s

Epoch [48], Batch [100], Loss: 0.316, Accuracy: 89.09%

Epoch [48], Batch [200], Loss: 0.321, Accuracy: 89.13%

Epoch [48], Batch [300], Loss: 0.333, Accuracy: 89.02%

Test Loss: 0.419, Test Accuracy: 86.28%

Epoch [48] Done in 20.72s

Epoch [49], Batch [100], Loss: 0.320, Accuracy: 88.99%

Epoch [49], Batch [200], Loss: 0.328, Accuracy: 88.95%

Epoch [49], Batch [300], Loss: 0.322, Accuracy: 88.91%

Test Loss: 0.474, Test Accuracy: 84.37%

Epoch [49] Done in 20.71s

Epoch [50], Batch [100], Loss: 0.308, Accuracy: 89.34%

Epoch [50], Batch [200], Loss: 0.333, Accuracy: 89.02%

Epoch [50], Batch [300], Loss: 0.327, Accuracy: 88.96%

Test Loss: 0.503, Test Accuracy: 84.12%

Epoch [50] Done in 20.70s

Epoch [51], Batch [100], Loss: 0.233, Accuracy: 92.05%

Epoch [51], Batch [200], Loss: 0.171, Accuracy: 93.20%

Epoch [51], Batch [300], Loss: 0.167, Accuracy: 93.58%

Test Loss: 0.219, Test Accuracy: 92.78%

Epoch [51] Done in 20.77s

Epoch [52], Batch [100], Loss: 0.139, Accuracy: 95.30%

Epoch [52], Batch [200], Loss: 0.125, Accuracy: 95.55%

Epoch [52], Batch [300], Loss: 0.133, Accuracy: 95.57%

Test Loss: 0.212, Test Accuracy: 92.96%

Epoch [52] Done in 20.89s

Epoch [53], Batch [100], Loss: 0.107, Accuracy: 96.63%

Epoch [53], Batch [200], Loss: 0.109, Accuracy: 96.41%

Epoch [53], Batch [300], Loss: 0.108, Accuracy: 96.40%

Test Loss: 0.199, Test Accuracy: 93.29%

Epoch [53] Done in 20.79s

Epoch [54], Batch [100], Loss: 0.097, Accuracy: 96.86%

Epoch [54], Batch [200], Loss: 0.098, Accuracy: 96.75%

Epoch [54], Batch [300], Loss: 0.092, Accuracy: 96.85%

Test Loss: 0.199, Test Accuracy: 93.65%

Epoch [54] Done in 20.78s

Epoch [55], Batch [100], Loss: 0.084, Accuracy: 97.21%

Epoch [55], Batch [200], Loss: 0.080, Accuracy: 97.30%

Epoch [55], Batch [300], Loss: 0.089, Accuracy: 97.20%

Test Loss: 0.200, Test Accuracy: 93.69%

Epoch [55] Done in 20.77s

Epoch [56], Batch [100], Loss: 0.073, Accuracy: 97.63%

Epoch [56], Batch [200], Loss: 0.082, Accuracy: 97.50%

Epoch [56], Batch [300], Loss: 0.071, Accuracy: 97.54%

Test Loss: 0.212, Test Accuracy: 93.24%

Epoch [56] Done in 20.71s

Epoch [57], Batch [100], Loss: 0.062, Accuracy: 98.00%

Epoch [57], Batch [200], Loss: 0.065, Accuracy: 98.00%

Epoch [57], Batch [300], Loss: 0.070, Accuracy: 97.87%

Test Loss: 0.212, Test Accuracy: 93.45%

Epoch [57] Done in 20.70s

Epoch [58], Batch [100], Loss: 0.059, Accuracy: 98.16%

Epoch [58], Batch [200], Loss: 0.057, Accuracy: 98.19%

Epoch [58], Batch [300], Loss: 0.069, Accuracy: 97.97%

Test Loss: 0.219, Test Accuracy: 93.21%

Epoch [58] Done in 20.71s

Epoch [59], Batch [100], Loss: 0.061, Accuracy: 97.95%

Epoch [59], Batch [200], Loss: 0.052, Accuracy: 98.14%

Epoch [59], Batch [300], Loss: 0.059, Accuracy: 98.12%

Test Loss: 0.211, Test Accuracy: 93.61%

Epoch [59] Done in 20.73s

Epoch [60], Batch [100], Loss: 0.046, Accuracy: 98.57%

Epoch [60], Batch [200], Loss: 0.049, Accuracy: 98.49%

Epoch [60], Batch [300], Loss: 0.057, Accuracy: 98.35%

Test Loss: 0.226, Test Accuracy: 93.26%

Epoch [60] Done in 20.69s

Epoch [61], Batch [100], Loss: 0.046, Accuracy: 98.53%

Epoch [61], Batch [200], Loss: 0.045, Accuracy: 98.51%

Epoch [61], Batch [300], Loss: 0.051, Accuracy: 98.45%

Test Loss: 0.214, Test Accuracy: 93.73%

Epoch [61] Done in 20.77s

Epoch [62], Batch [100], Loss: 0.040, Accuracy: 98.73%

Epoch [62], Batch [200], Loss: 0.046, Accuracy: 98.64%

Epoch [62], Batch [300], Loss: 0.043, Accuracy: 98.62%

Test Loss: 0.227, Test Accuracy: 93.34%

Epoch [62] Done in 20.72s

Epoch [63], Batch [100], Loss: 0.039, Accuracy: 98.73%

Epoch [63], Batch [200], Loss: 0.043, Accuracy: 98.67%

Epoch [63], Batch [300], Loss: 0.043, Accuracy: 98.63%

Test Loss: 0.221, Test Accuracy: 93.51%

Epoch [63] Done in 20.71s

Epoch [64], Batch [100], Loss: 0.039, Accuracy: 98.82%

Epoch [64], Batch [200], Loss: 0.039, Accuracy: 98.77%

Epoch [64], Batch [300], Loss: 0.036, Accuracy: 98.82%

Test Loss: 0.228, Test Accuracy: 93.55%

Epoch [64] Done in 20.71s

Epoch [65], Batch [100], Loss: 0.033, Accuracy: 98.98%

Epoch [65], Batch [200], Loss: 0.038, Accuracy: 98.85%

Epoch [65], Batch [300], Loss: 0.037, Accuracy: 98.81%

Test Loss: 0.230, Test Accuracy: 93.47%

Epoch [65] Done in 20.72s

Epoch [66], Batch [100], Loss: 0.035, Accuracy: 98.84%

Epoch [66], Batch [200], Loss: 0.036, Accuracy: 98.82%

Epoch [66], Batch [300], Loss: 0.037, Accuracy: 98.81%

Test Loss: 0.242, Test Accuracy: 93.31%

Epoch [66] Done in 20.73s

Epoch [67], Batch [100], Loss: 0.036, Accuracy: 98.88%

Epoch [67], Batch [200], Loss: 0.035, Accuracy: 98.88%

Epoch [67], Batch [300], Loss: 0.035, Accuracy: 98.88%

Test Loss: 0.232, Test Accuracy: 93.47%

Epoch [67] Done in 20.73s

Epoch [68], Batch [100], Loss: 0.036, Accuracy: 98.82%

Epoch [68], Batch [200], Loss: 0.034, Accuracy: 98.88%

Epoch [68], Batch [300], Loss: 0.035, Accuracy: 98.90%

Test Loss: 0.237, Test Accuracy: 93.47%

Epoch [68] Done in 20.70s

Epoch [69], Batch [100], Loss: 0.032, Accuracy: 99.02%

Epoch [69], Batch [200], Loss: 0.033, Accuracy: 99.00%

Epoch [69], Batch [300], Loss: 0.037, Accuracy: 98.94%

Test Loss: 0.240, Test Accuracy: 93.52%

Epoch [69] Done in 20.71s

Epoch [70], Batch [100], Loss: 0.032, Accuracy: 99.05%

Epoch [70], Batch [200], Loss: 0.038, Accuracy: 98.90%

Epoch [70], Batch [300], Loss: 0.039, Accuracy: 98.85%

Test Loss: 0.250, Test Accuracy: 93.27%

Epoch [70] Done in 20.72s

Epoch [71], Batch [100], Loss: 0.028, Accuracy: 99.14%

Epoch [71], Batch [200], Loss: 0.034, Accuracy: 99.06%

Epoch [71], Batch [300], Loss: 0.035, Accuracy: 99.03%

Test Loss: 0.262, Test Accuracy: 92.96%

Epoch [71] Done in 20.70s

Epoch [72], Batch [100], Loss: 0.034, Accuracy: 98.91%

Epoch [72], Batch [200], Loss: 0.031, Accuracy: 98.93%

Epoch [72], Batch [300], Loss: 0.039, Accuracy: 98.84%

Test Loss: 0.251, Test Accuracy: 93.20%

Epoch [72] Done in 20.71s

Epoch [73], Batch [100], Loss: 0.032, Accuracy: 98.99%

Epoch [73], Batch [200], Loss: 0.033, Accuracy: 98.98%

Epoch [73], Batch [300], Loss: 0.040, Accuracy: 98.90%

Test Loss: 0.247, Test Accuracy: 93.18%

Epoch [73] Done in 20.72s

Epoch [74], Batch [100], Loss: 0.033, Accuracy: 98.93%

Epoch [74], Batch [200], Loss: 0.033, Accuracy: 98.97%

Epoch [74], Batch [300], Loss: 0.039, Accuracy: 98.86%

Test Loss: 0.234, Test Accuracy: 93.52%

Epoch [74] Done in 20.73s

Epoch [75], Batch [100], Loss: 0.027, Accuracy: 99.13%

Epoch [75], Batch [200], Loss: 0.035, Accuracy: 99.00%

Epoch [75], Batch [300], Loss: 0.034, Accuracy: 98.98%

Test Loss: 0.256, Test Accuracy: 93.11%

Epoch [75] Done in 20.72s

Epoch [76], Batch [100], Loss: 0.041, Accuracy: 98.64%

Epoch [76], Batch [200], Loss: 0.036, Accuracy: 98.77%

Epoch [76], Batch [300], Loss: 0.031, Accuracy: 98.85%

Test Loss: 0.254, Test Accuracy: 93.17%

Epoch [76] Done in 20.72s

Epoch [77], Batch [100], Loss: 0.033, Accuracy: 98.95%

Epoch [77], Batch [200], Loss: 0.035, Accuracy: 98.89%

Epoch [77], Batch [300], Loss: 0.041, Accuracy: 98.79%

Test Loss: 0.287, Test Accuracy: 92.33%

Epoch [77] Done in 20.72s

Epoch [78], Batch [100], Loss: 0.040, Accuracy: 98.78%

Epoch [78], Batch [200], Loss: 0.038, Accuracy: 98.75%

Epoch [78], Batch [300], Loss: 0.038, Accuracy: 98.74%

Test Loss: 0.270, Test Accuracy: 92.80%

Epoch [78] Done in 20.71s

Epoch [79], Batch [100], Loss: 0.043, Accuracy: 98.56%

Epoch [79], Batch [200], Loss: 0.043, Accuracy: 98.52%

Epoch [79], Batch [300], Loss: 0.041, Accuracy: 98.53%

Test Loss: 0.243, Test Accuracy: 93.37%

Epoch [79] Done in 20.72s

Epoch [80], Batch [100], Loss: 0.036, Accuracy: 98.91%

Epoch [80], Batch [200], Loss: 0.043, Accuracy: 98.77%

Epoch [80], Batch [300], Loss: 0.041, Accuracy: 98.70%

Test Loss: 0.255, Test Accuracy: 93.22%