主要内容

Train Network Using Model Function

此示例显示了如何通过使用函数而不是图层图或一个图层图或训练深度学习网络dlnetwork。这advantage of using functions is the flexibility to describe a wide variety of networks. The disadvantage is that you must complete more steps and prepare your data carefully. This example uses images of handwritten digits, with the dual objectives of classifying the digits and determining the angle of each digit from the vertical.

Load Training Data

digitTrain4DArrayDatafunction loads the images, their digit labels, and their angles of rotation from the vertical. CreatearrayDatastore图像,标签和角度的对象,然后使用combine函数制作一个包含所有培训数据的单个数据存储。提取类名称和非差异响应的数量。

[Xtrain,T1Train,T2Train] = DigitTrain4DarrayData;dsxtrain = arraydatastore(xtrain,iterationDimension = 4);dst1tra​​in = arraydatastore(t1tra​​in);dst2train = arraydatastore(t2train);dstrain = combine(dsxtrain,dst1tra​​in,dst2train);classNames =类别(t1tra​​in);numClasses = numel(classNames);numRespons = size(t2train,2);numObservations = numel(t1tra​​in);

从培训数据中查看一些图像。

idx = randperm(numObservations,64); I = imtile(XTrain(:,:,:,idx)); figure imshow(I)

定义深度学习模型

定义以下网络,以预测旋转的标签和角度。

  • A convolution-batchnorm-ReLU block with 16 5-by-5 filters.

  • 两个卷积束块的分支每个分支每个分支,每个分支都有32 3 x 3滤波器

  • 带有32 1 x-1卷积的卷积束块的跳过连接。

  • Combine both branches using addition followed by a ReLU operation

  • 对于回归输出,一个具有完全连接的操作的分支(响应次数)。

  • For classification output, a branch with a fully connected operation of size 10 (the number of classes) and a softmax operation.

定义和初始化模型参数和状态

定义每个操作的参数,并将其包括在结构中。使用格式parameters.OperationName.ParameterNamewhereparametersis the struct, OPEATIONNAME是操作的名称(例如“ Conv1”)和ParameterNameis the name of the parameter (for example, "Weights").

创建一个结构parameters包含模型参数。初始化可学习的权重和偏见初始尺寸and初始Zezeros示例函数分别。初始化批处理的归一化偏移量和比例参数初始ZezerosandinitializeOnes示例函数分别。

要使用批归一化操作进行培训和推理,您还必须管理网络状态。在预测之前,您必须指定从培训数据得出的数据集均值和差异。创建一个结构statecontaining the state parameters. The batch normalization statistics must not bedlarrayobjects. Initialize the batch normalization trained mean and trained variance states using theand那些函数分别。

这initialization example functions are attached to this example as supporting files.

初始化第一个卷积操作“ Conv1”的参数。

过滤= [5 5];numchannels = 1;numfilters = 16;sz = [过滤numchannels numfilters];numout = prod(filtersize) * numFilters;numin = prod(过滤) * numFilters;parameters.conv1.Weaters = initializeGlorot(sz,numout,numin);parameters.conv1.bias = initializezeros([[numFilters 1]);

初始化第一个批归一化操作“ batchnorm1”的参数和状态。

parameters.batchnorm1.Offset = initializeZeros([numFilters 1]); parameters.batchnorm1.Scale = initializeOnes([numFilters 1]); state.batchnorm1.TrainedMean = initializeZeros([numFilters 1]); state.batchnorm1.TrainedVariance = initializeOnes([numFilters 1]);

Initialize the parameters for the second convolution operation, "conv2".

过滤= [3 3];numchannels = 16;numfilters = 32;sz = [过滤numchannels numfilters];numout = prod(filtersize) * numFilters;numin = prod(过滤) * numFilters;parameters.conv2.Weights = initializeGlorot(sz,numOut,numIn); parameters.conv2.Bias = initializeZeros([numFilters 1]);

Initialize the parameters and state for the second batch normalization operation, "batchnorm2".

parameters.batchnorm2.offset = initializezeros([[numFilters 1]);parameters.batchnorm2.scale = initializeOnes([[numFilters 1]);state.batchnorm2.trainedmean = initializezeros([[numFilters 1]);state.batchnorm2.trainedVariance = initializeOnes([[numFilters 1]);

Initialize the parameters for the third convolution operation, "conv3".

过滤= [3 3];numchannels = 32;numfilters = 32;sz = [过滤numchannels numfilters];numout = prod(filtersize) * numFilters;numin = prod(过滤) * numFilters;parameters.conv3.weights = initializeglorot(sz,numout,numin);parameters.conv3.bias = initializezeros([[numFilters 1]);

初始化第三批归一化操作“ batchnorm3”的参数和状态。

parameters.batchnorm3.Offset = initializeZeros([numFilters 1]); parameters.batchnorm3.Scale = initializeOnes([numFilters 1]); state.batchnorm3.TrainedMean = initializeZeros([numFilters 1]); state.batchnorm3.TrainedVariance = initializeOnes([numFilters 1]);

初始化跳过连接中卷积操作的参数“ convskip”。

过滤= [1 1];numchannels = 16;numfilters = 32;sz = [过滤numchannels numfilters];numout = prod(filtersize) * numFilters;numin = prod(过滤) * numFilters;parameters.convskip.Wewights = initializeglorot(sz,numout,numin);parameters.convskip.bias = initializezeros([[numFilters 1]);

Initialize the parameters and state for the batch normalization operation in the skip connection, "batchnormSkip".

parameters.batchnormSkip.Offset = initializeZeros([numFilters 1]); parameters.batchnormSkip.Scale = initializeOnes([numFilters 1]); state.batchnormSkip.TrainedMean = initializeZeros([numFilters 1]); state.batchnormSkip.TrainedVariance = initializeOnes([numFilters 1]);

初始化与分类输出“ FC1”相对应的完全连接操作的参数。

sz = [numClasses 6272]; numOut = numClasses; numIn = 6272; parameters.fc1.Weights = initializeGlorot(sz,numOut,numIn); parameters.fc1.Bias = initializeZeros([numClasses 1]);

Initialize the parameters for the fully connected operation corresponding to the regression output, "fc2".

sz = [numRespons 6272];numOut = numRespons;numin = 6272;parameters.fc2.apters = initializeglorot(sz,numout,numin);parameters.fc2.bias = initializezeros([[numRespons 1]);

View the structure of the parameters.

parameters
参数=struct with fields:conv1: [1×1 struct] batchnorm1: [1×1 struct] conv2: [1×1 struct] batchnorm2: [1×1 struct] conv3: [1×1 struct] batchnorm3: [1×1 struct] convSkip: [1×1 struct] batchnormSkip: [1×1 struct] fc1: [1×1 struct] fc2: [1×1 struct]

View the parameters for the "conv1" operation.

parameters.conv1
ans =struct with fields:Weights: [5×5×1×16 dlarray] Bias: [16×1 dlarray]

View the structure of the state parameters.

state
state =struct with fields:batchnorm1:[1×1 struct] batchnorm2:[1×1 struct] batchnorm3:[1×1 struct] batchnormskip:[1×1 struct]

查看“ batchnorm1”操作的状态参数。

state.batchnorm1
ans =struct with fields:TrainedMean: [16×1 dlarray] TrainedVariance: [16×1 dlarray]

定义模型函数

Create the functionmodel,在示例的末尾列出,计算了前面描述的深度学习模型的输出。

这functionmodel采用模型参数parameters,输入数据,旗帜dotrainingwhich specifies whether to model should return outputs for training or prediction, and the network state. The network outputs the predictions for the labels, the predictions for the angles, and the updated network state.

定义模型损失函数

Create the functionmodelLoss, listed at the end of the example, that takes the model parameters, a mini-batch of input data with corresponding targets containing the labels and angles, and returns the loss, the gradients of the loss with respect to the learnable parameters, and the updated network state.

指定培训选项

指定培训选项。训练20个时期,小批量大小为128。

numEpochs = 20; miniBatchSize = 128;

Train Model

Useminibatchqueue处理和管理迷你图像。对于每个迷你批次:

  • 使用自定义迷你批次预处理功能预处理(defined at the end of this example) to one-hot encode the class labels.

  • Format the image data with the dimension labels“ SSCB”(spatial, spatial, channel, batch). By default, theminibatchqueue对象将数据转换为dlarrayobjects with underlying typesingle。请勿在类标签或角度添加格式。

  • 如果有可用的话,请训练GPU。默认情况下,minibatchqueue对象将每个输出转换为gpuArray如果有GPU可用。使用GPU需要并行计算工具箱™和支持的GPU设备。万博1manbetx有关支持设备的信息,请参阅万博1manbetx释放的G万博1manbetxPU支持(Parallel Computing Toolbox)

mbq = minibatchqueue(dstrain,。。。MiniBatchSize=miniBatchSize,。。。MiniBatchFcn=@preprocessMiniBatch,。。。minibatchformat = [“ SSCB”“”“”]);

For each epoch, shuffle the data and loop over mini-batches of data. At the end of each iteration, display the training progress. For each mini-batch:

  • 使用模型丢失和梯度评估dlfevaland themodelLoss功能。

  • Update the network parameters using theAdamupdate功能。

初始化亚当的参数。

trailingavg = [];trailingavgsq = [];

Initialize the training progress plot.

图C =颜色订单;LineLostrain = AnimatedLine(color = c(2,:));Ylim([0 Inf])Xlabel("Iteration") ylabel(“失利”) 网格on

训练模型。

迭代= 0;start = tic;%循环在时期。为了epoch = 1:numEpochs% Shuffle data.洗牌(MBQ)小额批次的循环whilehasdata(MBQ)迭代=迭代 + 1;[X,T1,T2] = Next(MBQ);% Evaluate the model loss, gradients, and state, using dlfeval and the%Modelloss功能。[损失,渐变,状态] = dlfeval(@modelloss,参数,x,t1,t2,state);%使用ADAM优化器更新网络参数。[参数,trailingavg,trailingavgsq] = adamupdate(参数,渐变,,。。。trailingAvg,trailingAvgSq,iteration);%显示培训进度。D = duration(0,0,toc(start),Format=“ HH:MM:SS”); loss = double(loss); addpoints(lineLossTrain,iteration,loss) title("Epoch: "+ epoch +“,经过:”+ string(D)) drawnow结尾结尾

Test Model

通过将测试集上的预测与真实的标签和角度进行比较,测试模型的分类精度。使用一个minibatchqueue对象具有与培训数据相同的设置。

[xtest,t1test,t2test] = digittest4darraydata;dsxtest = arraydatastore(xtest,iterationDimension = 4);dst1test = arraydatastore(t1test);dst2test = arraydatastore(t2test);dstest = combine(dsxtest,dst1test,dst2test);mbqtest = minibatchqueue(dstest,。。。MiniBatchSize=miniBatchSize,。。。MiniBatchFcn=@preprocessMiniBatch,。。。minibatchformat = [“ SSCB”“”“”]);

为了预测验证数据的标签和角度,请循环在微型批次上,然后将模型函数与dotraining选项设置为false。存储预测的类和角度。比较预测的和真实的类和角度,并存储结果。

dotraining = false;classSpredictions = [];Anglespredictions = [];classCorr = [];anglediff = [];%循环在迷你批次上。whilehasdata(mbqtest)% Read mini-batch of data.[X,T1,T2] = next(mbqTest);% Make predictions using the predict function.[y1,y2] =模型(参数,x,dotraining,state);% Determine predicted classes.Y1 = OneHotDecode(Y1,ClassNames,1);classPredictions = [classphictions y1];%皮肤预测角度Y2 = ExtractData(Y2);Anglespredictions = [anglespredictions y2];% Compare predicted and true classesY1Test = onehotdecode(T1,classNames,1); classCorr = [classCorr Y1 == Y1Test];% Compare predicted and true anglesAngleDiffBatch = Y2 -T2;Anglediff = [Anglediff ExtractData(gather(anglediffbatch))];结尾

评估分类精度。

精度=平均值(classCorr)
accuracy = 0.9712

Evaluate the regression accuracy.

anglermse = sqrt(平均(anglediff。^2)))
angleRMSE =single6.7999

View some of the images with their predictions. Display the predicted angles in red and the correct labels in green.

idx = randperm(size(xtest,4),9);数字为了i = 1:9 subplot(3,3,i) I = XTest(:,:,:,idx(i)); imshow(I) holdonsz = size(I,1); offset = sz/2; thetaPred = anglesPredictions(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],"r--") thetaValidation = T2Test(idx(i)); plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],"g--") holdofflabel = string(classSpredictions(idx(i)));标题(“标签: ”+标签)结尾

Model Function

这functionmodel采用模型参数parameters,输入数据X,旗帜dotrainingwhich specifies whether to model should return outputs for training or prediction, and the network statestate。网络输出标签的预测,角度的预测以及更新的网络状态。

function[Y1,Y2,state] = model(parameters,X,doTraining,state)初始操作%% Convolution - conv1重量= parameters.conv1.Weights;偏见= parameters.conv1.Bias; Y = dlconv(X,weights,bias,Padding="same");% Batch normalization, ReLU - batchnorm1, relu1offset = parameters.batchnorm1.offset;scale = parameters.batchnorm1.scale;训练的量= state.batchnorm1.trainedmean;训练的变量= state.batchnorm1.trainedVariance;ifdotraining [y,训练的训练,训练变化] = batchnorm(y,偏移,比例,训练的卑鄙,训练变量);% Update statestate.batchnorm1.trainedmean =训练的态度;state.batchnorm1.traindingvariance =训练变量;elsey = batchnorm(y,偏移,比例,训练的态度,训练变量);结尾y = relu(y);%主分支运营% Convolution - conv2weights = parameters.conv2.Weights; bias = parameters.conv2.Bias; YnoSkip = dlconv(Y,weights,bias,Padding="same",步幅= 2);% Batch normalization, ReLU - batchnorm2, relu2offset = parameters.batchnorm2.Offset; scale = parameters.batchnorm2.Scale; trainedMean = state.batchnorm2.TrainedMean; trainedVariance = state.batchnorm2.TrainedVariance;ifdoTraining [YnoSkip、trainedMean trainedVariance] =batchnorm(YnoSkip,offset,scale,trainedMean,trainedVariance);% Update statestate.batchnorm2.trainedmean =训练的态度;state.batchnorm2.TraindingVariance =训练的变化;elseynoskip = batchnorm(ynoskip,偏移,比例,训练的卑鄙,训练变量);结尾YnoSkip = relu(YnoSkip);% Convolution - conv3权重=参数。bias = parameters.conv3.bias;ynoskip = dlconv(ynoskip,重量,偏见,填充="same");%批归一化-BatchNorm3offset = parameters.batchnorm3.Offset; scale = parameters.batchnorm3.Scale; trainedMean = state.batchnorm3.TrainedMean; trainedVariance = state.batchnorm3.TrainedVariance;ifdoTraining [YnoSkip、trainedMean trainedVariance] =batchnorm(YnoSkip,offset,scale,trainedMean,trainedVariance);% Update statestate.batchnorm3.trainedmean =训练的卑鄙;state.batchnorm3.traindingvariance =训练变化;elseynoskip = batchnorm(ynoskip,偏移,比例,训练的卑鄙,训练变量);结尾% Skip connection operations% Convolution, batch normalization (Skip connection) - convSkip, batchnormSkipweights = parameters.convSkip.Weights; bias = parameters.convSkip.Bias; YSkip = dlconv(Y,weights,bias,Stride=2); offset = parameters.batchnormSkip.Offset; scale = parameters.batchnormSkip.Scale; trainedMean = state.batchnormSkip.TrainedMean; trainedVariance = state.batchnormSkip.TrainedVariance;ifdotraining[YSkip,trainedMean,trainedVariance] = batchnorm(YSkip,offset,scale,trainedMean,trainedVariance);% Update statestate.batchnormskip.trainedmean =训练的态度;state.batchnormskip.traindingvariance =训练的变化;elseyskip = batchnorm(yskip,偏移,比例尺,训练的卑鄙,训练变量);结尾最终操作%%添加,relu-加法,relu4y = yskip + ynoskip;y = relu(y);%完全连接,SoftMax(标签)-FC1,SoftMax权重=参数。bias = parameters.fc1.bias;y1 =完全连接(y,权重,偏见);y1 = softmax(y1);% Fully connect (angles) - fc2weights = parameters.fc2.Weights; bias = parameters.fc2.Bias; Y2 = fullyconnect(Y,weights,bias);结尾

模型损失函数

modelLoss函数,采用模型参数,一个小批次输入数据X具有相应的目标T1andT2containing the labels and angles, respectively, and returns the loss, the gradients of the loss with respect to the learnable parameters, and the updated network state.

function[损失,梯度,状态] = modelloss(参数,x,t1,t2,state)dotraining = true;[y1,y2,state] =模型(参数,x,dotraining,state);LossLabels = CrossentRopy(Y1,T1);损耗= MSE(Y2,T2);损失=损失标签 + 0.1*损失型;渐变= dlgradient(丢失,参数);结尾

迷你批次预处理功能

预处理函数使用以下步骤预处理数据:

  1. Extract the image data from the incoming cell array and concatenate into a numeric array. Concatenating the image data over the fourth dimension adds a third dimension to each image, to be used as a singleton channel dimension.

  2. Extract the label and angle data from the incoming cell arrays and concatenate along the second dimension into a categorical array and a numeric array, respectively.

  3. One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.

function[x,t1,t2] = preprocessminibatch(datax,datat1,datat2)%从细胞和连接酸盐中提取图像数据X= cat(4,dataX{:});% Extract label data from cell and concatenatet1 = cat(2,datat1 {:});%从细胞和连接酸盐提取角度数据t2 = cat(2,datat2 {:});%单速编码标签T1 = OneHotencode(T1,1);结尾

也可以看看

||||||||||||

Related Topics