主要内容

深度学习网络训练中的自定义输出

这个例子展示了如何定义在深度学习神经网络训练过程中的每次迭代中运行的输出函数。方法指定输出函数“OutputFcn”的名称-值对参数trainingOptions,然后trainNetwork在训练开始之前调用这些函数一次,在每次训练迭代之后调用一次,在训练结束之后调用一次。每次调用输出函数时,trainNetwork传递一个包含当前迭代次数、损失和精度等信息的结构。您可以使用输出函数来显示或绘制进度信息,或者停止训练。为了尽早停止训练,让输出函数返回真正的.如果任何输出函数返回真正的,然后训练结束trainNetwork返回最新的网络。

方法指定验证数据和验证等待时间,即可在验证集上的损失停止减少时停止训练“ValidationData”“ValidationPatience”的名称-值对参数trainingOptions,分别。验证耐心是在网络训练停止之前,验证集上的损失可以大于或等于之前最小损失的次数。您可以使用输出函数添加额外的停止条件。这个例子展示了如何创建一个输出函数,当验证数据上的分类精度停止提高时停止训练。输出函数在脚本的末尾定义。

加载训练数据,其中包含5000张数字图像。留出1000个图像用于网络验证。

[XTrain,YTrain] = digitTrain4DArrayData;idx = randperm(size(XTrain,4),1000);XValidation = XTrain(:,:,:,idx);XTrain(:,:,:,idx) = [];YValidation = YTrain(idx);YTrain(idx) = [];

构造网络对数字图像数据进行分类。

图层= [imageInputLayer([28 28 1])卷积2dlayer (3,8,“填充”“相同”maxPooling2dLayer(2,“步”2) convolution2dLayer(16日“填充”“相同”maxPooling2dLayer(2,“步”32岁的,2)convolution2dLayer (3“填充”“相同”batchNormalizationLayer reluLayer fullyConnectedLayer(10) softmaxLayer classificationLayer];

指定网络培训选项。若要在训练期间定期验证网络,请指定验证数据。选择“ValidationFrequency”值,以便每个纪元验证一次网络。

若要在验证集上的分类准确率停止提高时停止训练,请指定stopIfAccuracyNotImproving作为输出函数。的第二个输入参数stopIfAccuracyNotImproving在网络训练停止之前,验证集上的精度可以小于或等于先前最高精度的次数。为要训练的最大epoch数选择任意大的值。训练不应该到达最后一个纪元,因为训练会自动停止。

miniBatchSize = 128;validationFrequency = floor(编号(YTrain)/miniBatchSize);选项= trainingOptions(“个”...“InitialLearnRate”, 0.01,...“MaxEpochs”, 100,...“MiniBatchSize”miniBatchSize,...“VerboseFrequency”validationFrequency,...“ValidationData”{XValidation, YValidation},...“ValidationFrequency”validationFrequency,...“阴谋”“训练进步”...“OutputFcn”@(信息)stopIfAccuracyNotImproving(信息,3));

培训网络。当验证精度停止增加时,训练停止。

net = trainNetwork(XTrain,YTrain,图层,选项);
单CPU训练。初始化输入数据规范化。|======================================================================================================================| | 时代| |迭代时间| Mini-batch | |验证Mini-batch | |验证基地学习  | | | | ( hh: mm: ss) | | | | |损失损失精度精度  | |======================================================================================================================| | 1 | 1 | 00:00:06 | | 7.81% 12.70% | 2.7155 | 2.5169 | 0.0100 | | 1 | | 31日00:00:14 |71.88% | 74.50% | 0.8805 | 0.8131 | 0.0100 | | 62 | | 00:00:21 | | 86.72% 88.20% | 0.3876 | 0.4409 | 0.0100 | | 3 | 93 | 00:00:28 | | 95.31% 94.10% | 0.2188 | 0.2539 | 0.0100 | | 124 | | 00:00:34 | | 96.09% 96.90% | 0.1457 | 0.1740 | 0.0100 | | 155 | | 00:00:40 | | 99.22% 97.60% | 0.0988 | 0.1301 | 0.0100 | | 6 | 186 | 00:00:47 | | 99.22% 98.00% | 0.0777 | 0.1133 | 0.0100 | | 217 | | 00:00:54 | | 100.00% 98.20% | 0.0561 | 0.0940 | 0.0100 | | 248 | | 00:01:00 | | 98.00% | 0.0444 100.00%|0.0870 | 0.0100 | | 9 | 279 | 00:01:07 | 100.00% | 98.20% | 0.0343 | 0.0770 | 0.0100 | | 10 | 310 | 00:01:12 | 100.00% | 98.30% | 0.0274 | 0.0677 | 0.0100 | | 11 | 341 | 00:01:18 | 100.00% | 98.60% | 0.0241 | 0.0617 | 0.0100 | | 12 | 372 | 00:01:28 | 100.00% | 98.80% | 0.0213 | 0.0566 | 0.0100 | | 13 | 403 | 00:01:35 | 100.00% | 98.90% | 0.0187 | 0.0533 | 0.0100 | | 14 | 434 | 00:01:40 | 100.00% | 98.80% | 0.0165 | 0.0508 | 0.0100 | | 15 | 465 | 00:01:46 | 100.00% | 98.80% | 0.0142 | 0.0481 | 0.0100 | | 16 | 496 | 00:01:52 | 100.00% | 98.90% | 0.0126 | 0.0456 | 0.0100 | |======================================================================================================================| Training finished: Stopped by OutputFcn.

{

定义输出函数

定义输出函数stopIfAccuracyNotImproving(信息,N),如果验证数据上的最佳分类精度没有提高,则停止网络训练N一行中的网络验证。该判据与使用验证损失的内置停止判据相似,不同之处在于它适用于分类精度而不是损失。

函数stop = stopifaccuracynotimprove (info,N) stop = false;跟踪最佳验证精度及其验证次数。%准确度并没有提高。持续的bestValAccuracy持续的valLag培训开始时清除变量。%如果信息。状态= =“开始”bestValAccuracy = 0;valLag = 0;elseif~ isempty (info.ValidationLoss)%将当前验证精度与目前最佳精度进行比较,%%,或将最佳精度设置为当前精度,或增加%未得到改善的验证次数。如果信息。ValidationAccuracy > bestValAccuracy valLag = 0;bestValAccuracy = info.ValidationAccuracy;其他的valLag = valLag + 1;结束%如果验证延迟至少为N,即验证精度%在至少N次验证中没有改善,则返回true和停止训练。如果valLag >= N stop = true;结束结束结束

另请参阅

|

相关的话题