主要内容

モデル関数を使用したネットワ,クの学習

この例では,層グラフまたはdlnetworkではなく関数を使用して深層学習ネットワ,クを作成し,学習させる方法を説明します。関数の使用には、幅広いネットワークを記述する柔軟性が得られるという利点があります。欠点は、より多くのステップを実行し、データを慎重に準備しなければならないことです。この例では、手書き数字のイメージと、数字を分類して垂直位置からの各数字の角度を判定する双対目的関数を使用します。

学習デ,タの読み込み

関数digitTrain4DArrayDataは。イメージ,ラベル,角度についてarrayDatastoreオブジェクトを作成してから,関数结合を使用して,すべての学習デ,タを含む単一のデ,タストアを作成します。クラス名と,離散でない応答の数を抽出します。

[XTrain,T1Train,T2Train] = digitTrain4DArrayData;dsXTrain = arrayDatastore(XTrain,IterationDimension=4);dsT1Train = arrayDatastore(T1Train);dsT2Train = arrayDatastore(T2Train);dsTrain = combine(dsXTrain,dsT1Train,dsT2Train);classNames =类别(T1Train);numClasses = numel(classNames);numResponses = size(T2Train,2);numObservations = numel(T1Train);

学習デタからの一部のメジを表示します。

idx = randperm(numObservations,64);I = imtile(XTrain(:,:,:,idx));图imshow(我)

深層学習モデルの定義

ラベルと回転角度の両方を予測する次のネットワ,クを定義します。

  • 16個の5 x 5フィルターをもつconvolution-batchnorm-ReLUブロック

  • 各ブロックに32個の3 x 3フィルターがあり,間にReLU演算をもつ,2個のconvolution-batchnormブロックの分岐

  • 32個の1 x 1の畳み込みをもつconvolution-batchnormブロックのあるスキップ接続

  • 加算とそれに続くReLU演算を使用した両方の分岐の組み合わせ

  • 回帰出力用に、サ▪▪ズが1(応答数)▪▪の全結合演算をも▪▪分岐

  • 分類出力用に、サ▪▪ズが10(クラス数)の全結合演算とソフトマックス演算をも▪▪分岐

layerGraph.png

モデルのパラメ,タ,と状態の定義および初期化

各演算のパラメ,タ,を定義して構造体に含めます。parameters.OperationName.ParameterNameの形式を使用します。ここで,参数は構造体,operationNameは演算名(“conv1”など),ParameterNameはパラメ,タ,名(“Weights”など)です。

モデルパラメ,タ,を含む構造体参数を作成します。サンプル関数initializeGlorotおよびinitializeZerosを使用して,学習可能な重みとバ。サンプル関数initializeZerosおよびinitializeOnesを使用して,バッチ正規化オフセットとスケ,ルパラメ,タ,をそれぞれ初期化します。

バッチ正規化演算を使用して学習や推論を実行するには,ネットワ,クの状態も管理しなければなりません。予測の前に,学習デ,タから派生するデ,タセットの平均と分散を指定しなければなりません。状態パラメ,タ,を含む構造体状态を作成します。バッチ正規化の統計値は,dlarrayオブジェクトにしないでください。関数0およびを使用して,バッチ正規化の学習済み平均と学習済み分散の状態をそれぞれ初期化します。

この初期化サンプル関数は,この例にサポ,トファ,ルとして添付されています。

最初の畳み込み演算"conv1"のパラメ,タ,を初期化します。

filterSize = [5 5];numChannels = 1;numFilters = 16;sz = [filterSize numChannels numFilters];numOut = prod(filterSize) * numFilters;numIn = prod(filterSize) * numFilters;parameters.conv1。Weights = initializeGlorot(sz,numOut,numIn);parameters.conv1。偏差= initializeZeros([numFilters 1]);

最初のバッチ正規化演算"batchnorm1"のパラメ,タ,と状態を初期化します。

parameters.batchnorm1。Offset = initializeZeros([numFilters 1]);parameters.batchnorm1。Scale = initializeOnes([numFilters 1]);state.batchnorm1。TrainedMean = initializezero ([numFilters 1]);state.batchnorm1。TrainedVariance = initializeOnes([numFilters 1]);

2番目の畳み込み演算“conv2”のパラメタを初期化します。

filterSize = [3 3];numChannels = 16;numFilters = 32;sz = [filterSize numChannels numFilters];numOut = prod(filterSize) * numFilters;numIn = prod(filterSize) * numFilters;parameters.conv2。Weights = initializeGlorot(sz,numOut,numIn);parameters.conv2。偏差= initializeZeros([numFilters 1]);

2番目のバッチ正規化演算"batchnorm2"のパラメ,タ,と状態を初期化します。

parameters.batchnorm2。Offset = initializeZeros([numFilters 1]);parameters.batchnorm2。Scale = initializeOnes([numFilters 1]);state.batchnorm2。TrainedMean = initializezero ([numFilters 1]);state.batchnorm2。TrainedVariance = initializeOnes([numFilters 1]);

3番目の畳み込み演算"conv3"のパラメタを初期化します。

filterSize = [3 3];numChannels = 32;numFilters = 32;sz = [filterSize numChannels numFilters];numOut = prod(filterSize) * numFilters;numIn = prod(filterSize) * numFilters;parameters.conv3。Weights = initializeGlorot(sz,numOut,numIn);parameters.conv3。偏差= initializeZeros([numFilters 1]);

3番目のバッチ正規化演算"batchnorm3"のパラメタと状態を初期化します。

parameters.batchnorm3。Offset = initializeZeros([numFilters 1]);parameters.batchnorm3。Scale = initializeOnes([numFilters 1]);state.batchnorm3。TrainedMean = initializezero ([numFilters 1]);state.batchnorm3。TrainedVariance = initializeOnes([numFilters 1]);

スキップ接続“convSkip”における畳み込み演算のパラメ,タ,を初期化します。

filterSize = [1 1];numChannels = 16;numFilters = 32;sz = [filterSize numChannels numFilters];numOut = prod(filterSize) * numFilters;numIn = prod(filterSize) * numFilters;parameters.convSkip.Weights = initializeGlorot(sz,numOut,numIn);parameters.convSkip.Bias = initializeZeros([numFilters 1]);

スキップ接続“batchnormSkipにおけるバッチ正規化演算のパラメーターと状態を初期化します。

parameters.batchnormSkip.Offset = initializeZeros([numFilters 1]);parameters.batchnormSkip.Scale = initializeOnes([numFilters 1]);state.batchnormSkip.TrainedMean = initializeZeros([numFilters 1]);state.batchnormSkip.TrainedVariance = initializeOnes([numFilters 1]);

分類出力“fc1”に対応する全結合演算のパラメ,タ,を初期化します。

sz = [numClasses 6272];numOut = numClasses;numIn = 6272;parameters.fc1。Weights = initializeGlorot(sz,numOut,numIn);parameters.fc1。偏差= initializeZeros([numClasses 1]);

回帰出力“fc2”に対応する全結合演算のパラメ,タ,を初期化します。

sz = [numResponses 6272];numOut = numResponses;numIn = 6272;parameters.fc2。Weights = initializeGlorot(sz,numOut,numIn);parameters.fc2。偏差= initializeZeros([numResponses 1]);

パラメ,タ,の構造を表示します。

参数
参数=带字段的结构:conv1:(1×1结构)batchnorm1:[1×1 struct] conv2:[1×1 struct] batchnorm2:[1×1 struct] conv3:[1×1 struct] batchnorm3:[1×1 struct] convSkip:[1×1 struct] batchnormSkip:[1×1 struct] fc1:[1×1 struct] fc2:[1×1 struct]

演算"conv1"のパラメ,タ,を表示します。

parameters.conv1
ans =带字段的结构:权重:[5×5×1×16 dlarray]偏差:[16×1 dlarray]

状態パラメ,タ,の構造を表示します。

状态
状态=带字段的结构:batchnorm1: [1×1 struct] batchnorm2: [1×1 struct] batchnorm3: [1×1 struct] batchnormSkip: [1×1 struct]

演算"batchnorm1"の状態パラメ,タ,を表示します。

state.batchnorm1
ans =带字段的结构:TrainedMean: [16×1 dlarray] TrainedVariance: [16×1 dlarray]

モデルの関数の定義

この例の最後にリストされている,前に説明した深層学習モデルの出力を計算する関数模型を作成します。

関数模型は,モデルパラメ,タ,参数,入力デ,タ,モデルが学習と予測のど,らの出力を返すべきかを指定するフラグdoTraining,およびネットワ,クの状態を受け取ります。ネットワ,クはラベルの予測,角度の予測,および更新されたネットワ,クの状態を出力します。

モデル損失関数の定義

例の最後にリストされている関数modelLossを作成します。この関数は,モデルパラメーター,ならびに入力データのミニバッチとそれに対応するターゲット(ラベルと角度を含む)を受け取り,学習可能なパラメーターについての損失と損失の勾配,および更新されたネットワークの状態を返します。

学習オプションの指定

学習オプションを指定します。ミニバッチサesc escズを128として20エポック学習させます。

numEpochs = 20;miniBatchSize = 128;

モデルの学習

minibatchqueueを使用して,。各ミニバッチで次を行います。

  • カスタムミニバッチ前処理関数preprocessMiniBatch(この例の最後に定義)を使用して,クラスラベルをone-hot符号化します。

  • イメージデータを次元ラベル“SSCB”(空间,空间,通道,批次)で書式設定します。既定では,minibatchqueueオブジェクトは,基となる型がdlarrayオブジェクトにデ,タを変換します。書式をクラスラベルまたは角度に追加しないでください。

  • Gpuが利用できる場合,Gpuで学習を行います。既定では,minibatchqueueオブジェクトは,gpuが利用可能な場合,各出力をgpuArrayに変換します。GPUを使用するには,并行计算工具箱™とサポートされているGPUデバイスが必要です。サポトされているデバスにいては,Gpu計算の要件(并行计算工具箱)を参照してください。

mbq = minibatchqueue(dsTrain,...MiniBatchSize = MiniBatchSize,...MiniBatchFcn = @preprocessMiniBatch,...MiniBatchFormat = [“SSCB”""""]);

各エポックにいて,デタをシャッフルしてデタのミニバッチをルプで回します。反復が終了するたびに,学習の進行状況を表示します。各ミニバッチで次を行います。

  • 関数dlfevalおよびmodelLossを使用してモデルの損失と勾配を評価。

  • 関数adamupdateを使用してネットワ,クパラメ,タ,を更新。

亚当用にパラメ,タ,を初期化します。

trailingAvg = [];trailingAvgSq = [];

学習の進行状況プロットを初期化します。

图C = colororder;lineLossTrain = animatedline(Color=C(2,:));Ylim ([0 inf]) xlabel(“迭代”) ylabel (“损失”网格)

モデルに学習させます。

迭代= 0;开始= tic;%遍历epoch。epoch = 1:numEpochs% Shuffle数据。洗牌(兆贝可)在小批上循环Hasdata (mbq)迭代=迭代+ 1;[X,T1,T2] = next(mbq);计算模型损失、梯度和状态,使用dlfeval和% modelLoss函数。[loss,gradients,state] = dlfeval(@modelLoss,parameters,X,T1,T2,state);使用Adam优化器更新网络参数。。[parameters,trailingAvg,trailingAvgSq] = adamupdate(参数,梯度,...trailingAvg trailingAvgSq,迭代);%显示培训进度。D = duration(0,0,toc(start),Format=“hh: mm: ss”);损失=双倍(损失);addpoints (lineLossTrain、迭代、失去)标题(”时代:“+ epoch +,消失:"+字符串(D))现在绘制结束结束

モデルのテスト

真のラベルと角度をもテストセットで予測を比較して,モデルの分類精度をテストします。学習デ,タと同じ設定のminibatchqueueオブジェクトを使用して,テストデ,タセットを管理します。

[XTest,T1Test,T2Test] = digitTest4DArrayData;dsXTest = arrayDatastore(XTest,IterationDimension=4);dsT1Test = arrayDatastore(T1Test);dsT2Test = arrayDatastore(T2Test);dsTest = combine(dsXTest,dsT1Test,dsT2Test);mbqTest = minibatchqueue(dsTest...MiniBatchSize = MiniBatchSize,...MiniBatchFcn = @preprocessMiniBatch,...MiniBatchFormat = [“SSCB”""""]);

検証デ,タのラベルと角度を予測するために,ミニバッチをル,プ処理し,doTrainingオプションをに設定したモデル関数を使用します。予測されたクラスと角度を保存します。予測されたクラスおよび角度を真のクラスおよび角度と比較し,その結果を保存します。

doTraining = false;classesforecasts = [];anglesforecasts = [];classCorr = [];angleDiff = [];在小批上循环。hasdata (mbqTest)读取小批数据。[X,T1,T2] = next(mbqTest);使用预测函数进行预测。[Y1,Y2] = model(parameters,X,doTraining,state);确定预测的类。Y1 = onehotdecode(Y1,classNames,1);classesforecasts = [classesforecasts Y1];% Dermine预测角度Y2 = extractdata(Y2);anglesforecasts = [anglesforecasts Y2];比较预测的和真实的类。Y1Test = onehotdecode(T1,classNames,1);classCorr = [classCorr Y1 == Y1Test];比较预测角度和真实角度。angleDiffBatch = Y2 - T2;angleDiff = [angleDiff extractdata(gather(angleDiffBatch))];结束

分類精度を評価します。

精确度=平均值(classCorr)
准确度= 0.9712

回帰精度を評価します。

angleRMSE =√(mean(angleDiff.^2))
angleRMSE =6.7999

一部の。予測角度を赤,正解ラベルを緑で表示します。

idx = randperm(size(XTest,4),9);数字i = 1:9 subplot(3,3,i) i = XTest(:,:,:,idx(i));imshow (I)sz = size(I,1);Offset = sz/2;thetaPred =角预测(idx(i));plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],“r——”) thetaValidation = T2Test(idx(i));plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],,“g——”)举行label = string(classesforecasts (idx(i)));标题(”的标签:“+标签)结束

モデル関数

関数模型は,モデルパラメ,タ,参数,入力デ,タX,モデルが学習と予測のどらの出力を返すべきかを指定するフラグdoTraining,およびネットワ,クの状態状态を受け取ります。ネットワ,クはラベルの予測,角度の予測,および更新されたネットワ,クの状態を出力します。

layerGraph.png

函数[Y1,Y2,state] = model(parameters,X,doTraining,state)%初始操作%卷积- conv1weights = parameters.conv1.Weights;bias = parameters.conv1.Bias;Y = dlconv(X,权重,偏差,填充=“相同”);%批规范化,ReLU - batchnorm1, relu1offset = parameters.batchnorm1.Offset;scale = parameters.batchnorm1.Scale;trainedMean = state.batchnorm1.TrainedMean;trainedVariance = state.batchnorm1.TrainedVariance;如果doTraining [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance);%更新状态state.batchnorm1。受过训练的人;state.batchnorm1。trained方差= trained方差;其他的Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance);结束Y = relu(Y);%主要分支机构操作%卷积- conv2weights = parameters.conv2.Weights;bias = parameters.conv2.Bias;YnoSkip = dlconv(Y,权重,偏差,填充=“相同”,步= 2);%批归一化,ReLU - batchnorm2, relu2offset = parameters.batchnorm2.Offset;scale = parameters.batchnorm2.Scale;trainedMean = state.batchnorm2.TrainedMean;trainedVariance = state.batchnorm2.TrainedVariance;如果doTraining [noskip,trainedMean,trainedVariance] = batchnorm(noskip,offset,scale,trainedMean,trainedVariance);%更新状态state.batchnorm2。受过训练的人;state.batchnorm2。trained方差= trained方差;其他的noskip = batchnorm(noskip,offset,scale,trainedMean,trainedVariance);结束noskip = relu(noskip);%卷积- conv3weights = parameters.conv3.Weights;bias = parameters.conv3.Bias;YnoSkip = dlconv(YnoSkip,权重,偏差,填充=“相同”);批处理归一化- batchnorm3offset = parameters.batchnorm3.Offset;scale = parameters.batchnorm3.Scale;trainedMean = state.batchnorm3.TrainedMean;trainedVariance = state.batchnorm3.TrainedVariance;如果doTraining [noskip,trainedMean,trainedVariance] = batchnorm(noskip,offset,scale,trainedMean,trainedVariance);%更新状态state.batchnorm3。受过训练的人;state.batchnorm3。trained方差= trained方差;其他的noskip = batchnorm(noskip,offset,scale,trainedMean,trainedVariance);结束%跳过连接操作%卷积,批归一化(跳过连接)- convSkip, batchnormSkipweights = parameters.convSkip.Weights;bias = parameters.convSkip.Bias;YSkip = dlconv(Y,权重,偏差,Stride=2);offset = parameters.batchnormSkip.Offset;scale = parameters.batchnormSkip.Scale;trainedMean = state.batchnormSkip.TrainedMean;trainedVariance = state.batchnormSkip.TrainedVariance;如果doTraining [YSkip,trainedMean,trainedVariance] = batchnorm(YSkip,offset,scale,trainedMean,trainedVariance);%更新状态state.batchnormSkip.TrainedMean = trainedMean;state.batchnormSkip.TrainedVariance = trainedVariance;其他的YSkip = batchnorm(YSkip,offset,scale,trainedMean,trainedVariance);结束最终操作%加法,ReLU -加法,relu4Y = YSkip + noskip;Y = relu(Y);完全连接,softmax(标签)- fc1, softmaxweights = parameters.fc1.Weights;bias = parameters.fc1.Bias;Y1 =完全连接(Y,权重,偏差);Y1 = softmax(Y1);%完全连接(角度)- fc2weights = parameters.fc2.Weights;bias = parameters.fc2.Bias;Y2 =完全连接(Y,权重,偏差);结束

モデル損失関数

関数modelLossは,モデルパラメ,タ,入力デ,タのミニバッチXとそれに対応するタ,ゲットT1およびT2(それぞれラベルと角度を含む)を受け取り,学習可能なパラメーターについての損失と損失の勾配,および更新されたネットワークの状態を返します。

函数[loss,gradients,state] = modelLoss(parameters,X,T1,T2,state) doTraining = true;[Y1,Y2,state] = model(parameters,X,doTraining,state);lossLabels = crossentropy(Y1,T1);lossAngles = mse(Y2,T2);loss = lossLabels + 0.1*lossAngles;Gradients = dlgradient(损失,参数);结束

ミニバッチ前処理関数

関数preprocessMiniBatchは,次の手順でデ,タを前処理します。

  1. 入力cell配列からメジデタを抽出して数値配列に連結します。4番目の次元でメジデタを連結することにより;3番目の次元が各メジに追加されます。この次元は,シングルトンチャネル次元として使用されます。

  2. 入力电池配列からラベルと角度データを抽出して,それを2番目の次元と共に,绝对配列および数値配列にそれぞれ連結します。

  3. カテゴリカルラベルを数値配列に一热符号化します。最初の次元への符号化は,ネットワ,ク出力の形状と一致する符号化された配列を生成します。

函数[X,T1,T2] = preprocessMiniBatch(dataX,dataT1,dataT2)从单元格和拼接中提取图像数据X = cat(4,dataX{:});从单元格和级联中提取标签数据T1 = cat(2,dataT1{:});从单元格和拼接中提取角度数据T2 = cat(2,dataT2{:});单热编码标签T1 = onehotencode(T1,1);结束

参考

||||||||||||

関連するトピック