图像分类残差网络训练
这个例子展示了如何创建一个带有剩余连接的深度学习神经网络,并在CIFAR-10数据上进行训练。残差连接是卷积神经网络体系结构中一个很受欢迎的元素。使用剩余连接可以改善网络中的梯度流,并可以训练更深层次的网络。
对于许多应用程序,使用由简单层序列组成的网络就足够了。然而,一些应用程序要求网络具有更复杂的图结构,其中层可以有来自多个层的输入和多个层的输出。这些类型的网络通常被称为有向无环图(DAG)网络。残余网络(ResNet)是一种DAG网络,它具有绕过主网络层的残余(或快捷)连接。剩余连接使参数梯度更容易从输出层传播到网络的早期层,这使得训练更深层次的网络成为可能。这种增加的网络深度可以在更困难的任务中获得更高的准确性。
ResNet体系结构由初始层和栈包含剩余块,然后是最后的层。残块有三种类型:
初始剩余块-该块出现在第一个堆栈的开始。这个例子使用了瓶颈组件;因此,该块包含与下采样块相同的层,只是步幅为
[1]
在第一卷积层。有关更多信息,请参见resnetLayers
.标准剩余块-该块出现在每个堆栈中,在第一个下采样剩余块之后。该块在每个堆栈中出现多次,并保留激活大小。
向下采样剩余块-该块出现在每个堆栈的开始(除了第一个),并且在每个堆栈中只出现一次。下采样块中的第一个卷积单元将空间维度下采样2倍。
每个堆栈的深度可以变化,这个例子训练了一个剩余网络,有三个堆栈的深度递减。第一个堆栈的深度为4,第二个堆栈的深度为3,最后一个堆栈的深度为2。
每个剩余块包含深度学习层。有关每个块中的层的更多信息,请参见resnetLayers
.
要创建和训练一个适合图像分类的残差网络,请遵循以下步骤:
创建残留网络
resnetLayers
函数。训练网络使用
trainNetwork
函数。训练过的网络是一个DAGNetwork
对象。对新数据进行分类和预测
分类
而且预测
功能。
你也可以加载预训练的残差网络进行图像分类。有关更多信息,请参见预训练的深度神经网络.
准备数据
下载CIFAR-10数据集[1]。该数据集包含60,000张图像。每张图像大小为32 * 32像素,有三个颜色通道(RGB)。数据集的大小为175 MB。根据您的互联网连接,下载过程可能需要时间。
Datadir = tempdir;downloadCIFARData (datadir);
将CIFAR-10训练和测试图像加载为4-D数组。训练集包含50,000张图像,测试集包含10,000张图像。使用CIFAR-10测试图像进行网络验证。
[XTrain,TTrain,XValidation,TValidation] = loadCIFARData(datadir);
您可以使用以下代码显示训练图像的随机样本。
图;idx = randperm(size(XTrain,4),20);我= imtile (XTrain (:,:,:, idx) ThumbnailSize =(96、96));imshow (im)
创建一个augmentedImageDatastore
对象用于网络培训。在训练期间,数据存储沿垂直轴随机翻转训练图像,并在水平和垂直上将它们随机转换为4个像素。数据增强有助于防止网络过度拟合和记忆训练图像的确切细节。
imageSize = [32 32 3];pixelRange = [-4 4];imageAugmenter = imageDataAugmenter(...RandXReflection = true,...RandXTranslation = pixelRange,...RandYTranslation = pixelRange);augimdsTrain = augmentedimagedastore (imageSize,XTrain,TTrain,...DataAugmentation = imageAugmenter,...OutputSizeMode =“randcrop”);
定义网络架构
使用resnetLayers
函数创建适合此数据集的残差网络。
CIFAR-10图像是32x32像素,因此,使用一个小的初始滤波器大小为3,初始步幅为1。将初始过滤器的数量设置为16。
网络中的第一个堆栈以初始剩余块开始。每个后续堆栈都从一个下采样剩余块开始。下采样块中的第一个卷积单元对空间维度进行了2倍的下采样。为了保持整个网络中每个卷积层所需的计算量大致相同,每次执行空间下采样时,将过滤器的数量增加两倍。设置堆栈深度为
[4 3 2]
以及过滤器的数量[16 32 64]
.
initialFilterSize = 3;numInitialFilters = 16;initialStride = 1;numFilters = [16 32 64];stackDepth = [4 3 2];lgraph = resnetLayers(imageSize,10,...InitialFilterSize = InitialFilterSize,...InitialNumFilters = numInitialFilters,...InitialStride = InitialStride,...InitialPoolingLayer =“没有”,...StackDepth=[4 3 2],...NumFilters=[16 32 64]);
可视化网络。
情节(lgraph);
培训方案
指定培训选项。训练网络80个epoch。选择与迷你批处理大小成比例的学习率,并在60个epoch后将学习率降低10倍。使用验证数据每个epoch验证一次网络。
miniBatchSize = 128;learnRate = 0.1*miniBatchSize/128;valFrequency = floor(size(XTrain,4)/miniBatchSize);选项= trainingOptions(“个”,...InitialLearnRate = learnRate,...MaxEpochs = 80,...MiniBatchSize = MiniBatchSize,...VerboseFrequency = valFrequency,...洗牌=“every-epoch”,...情节=“训练进步”,...Verbose = false,...ValidationData = {XValidation, TValidation},...ValidationFrequency = valFrequency,...LearnRateSchedule =“分段”,...LearnRateDropFactor = 0.1,...LearnRateDropPeriod = 60);
列车网络的
训练网络使用trainNetwork
,设置doTraining
旗帜真正的
.否则,加载一个预先训练好的网络。在一个好的GPU上训练网络需要两个多小时。如果你没有GPU,那么训练需要更长的时间。
doTraining = false;如果doTraining net = trainNetwork(augimdsTrain,lgraph,options);其他的负载(“trainedResidualNetwork.mat”,“净”);结束
评估训练网络
计算网络在训练集(没有数据增强)和验证集上的最终精度。
[YValPred,probs] = category (net,XValidation);validationError = mean(YValPred ~= TValidation);YTrainPred = category (net,XTrain);trainError = mean(YTrainPred ~= TTrain);disp (“训练错误:”+ trainError*100 +“%”)
训练误差:3.462%
disp ("验证错误:"+ validationError*100 +“%”)
验证错误:9.27%
绘制混淆矩阵。通过使用列和行摘要显示每个类的精度和召回率。该网络最常见的混淆猫和狗。
图(单位=“归一化”,Position=[0.2 0.2 0.4 0.4]);cm = confusionchart(TValidation,YValPred);厘米。Title =验证数据混淆矩阵;厘米。ColumnSummary =“column-normalized”;厘米。RowSummary =“row-normalized”;
您可以使用以下代码显示9个测试图像的随机样本以及它们的预测类和这些类的概率。
figure idx = randperm(size(XValidation,4),9);为i = 1:元素个数(idx)次要情节(3 3 i) imshow (XValidation (:,:,:, idx(我)));Prob = num2str(100*max(probs(idx(i),:)),3);(YValPred(idx(i)));标题([predClass +”、“+问题+“%”])结束
参考文献
[1]克里哲夫斯基,亚历克斯。“从微小的图像中学习多层特征。”(2009)。https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf
[2]何开明,张翔宇,任少卿,孙健。“用于图像识别的深度剩余学习。”在IEEE计算机视觉和模式识别会议论文集,第770-778页。2016.
另请参阅
resnetLayers
|resnet3dLayers
|trainNetwork
|trainingOptions
|layerGraph
|analyzeNetwork