主要内容

从检查点网络恢复培训

这个示例演示了如何在训练深度学习网络时保存检查点网络,并从先前保存的网络中恢复训练。

加载示例数据

以4-D数组的形式加载示例数据。digitTrain4DArrayData将数字训练集加载为4-D阵列数据。XTrain是一个28 × 28 × 1 × 5000的数组,其中28是图像的高度,28是图像的宽度。1为通道数,5000为手写数字合成图像数。YTrain是包含每个观察的标签的分类向量。

[XTrain, YTrain] = digitTrain4DArrayData;大小(XTrain)
ans =1×428 28 1 5000

显示一些图像XTrain

图;烫= randperm(大小(XTrain, 4), 20);I = 1:20 subplot(4,5, I);imshow (XTrain(:,:,:,烫(我)));结束

定义网络体系结构

定义神经网络结构。

layers = [imageInputLayer([28 28 1])]“填充”“相同”maxPooling2dLayer(2,“步”2) convolution2dLayer(16日“填充”“相同”maxPooling2dLayer(2,“步”32岁的,2)convolution2dLayer (3“填充”“相同”) batchNormalizationLayer relullayer averageepooling2dlayer (7) fulllyconnectedlayer (10) softmaxLayer classificationLayer];

指定培训方案和培训网络

指定带动量随机梯度下降(SGDM)的训练选项,并指定保存检查点网络的路径。

checkpointPath = pwd;选择= trainingOptions (“个”...“InitialLearnRate”, 0.1,...“MaxEpochs”, 20岁,...“详细”假的,...“阴谋”“训练进步”...“洗牌”“every-epoch”...“CheckpointPath”, checkpointPath);

培训网络。trainNetwork如果有可用的GPU,则使用GPU。如果没有可用的GPU,则使用CPU。trainNetwork每个纪元保存一个检查点网络,并自动为检查点文件分配唯一的名称。

net1 = trainNetwork (XTrain、YTrain层,选择);

负载检查点网络和恢复训练

假设训练被中断,没有完成。您可以加载最后一个检查点网络并从该点恢复训练,而不是从头开始重新开始训练。trainNetwork用表单上的文件名保存检查点文件net_checkpoint__195__2018_07_13__11_59_10.mat,其中195为迭代数,2018年_07_13是日期,和11 _59_10是时候trainNetwork保存网络。检查点网络具有变量名

将检查点网络加载到工作区中。

负载(“net_checkpoint__195__2018_07_13__11_59_10.mat”“净”

指定训练选项并减少最大纪元数。您还可以调整其他培训选项,如初始学习率。

选择= trainingOptions (“个”...“InitialLearnRate”, 0.1,...“MaxEpochs”15岁的...“详细”假的,...“阴谋”“训练进步”...“洗牌”“every-epoch”...“CheckpointPath”, checkpointPath);

使用加载了新的训练选项的检查点网络层继续训练。如果检查点网络是DAG网络,则使用layerGraph(净)作为论据而不是网。层

net2 = trainNetwork (XTrain YTrain、net.Layers选项);

另请参阅

|

相关的例子

更多关于