adamupdate
使用自适应矩估计更新参数(Adam)
语法
描述
使用自适应矩估计(Adam)算法更新自定义训练循环中的网络可学习参数。
请注意
该函数应用Adam优化算法来更新自定义训练循环中的网络参数,该循环使用定义为的网络dlnetwork
对象或模型函数。如果你想训练一个定义为a的网络层
数组或作为LayerGraph
,使用以下函数:
创建一个
TrainingOptionsADAM
对象使用trainingOptions
函数。使用
TrainingOptionsADAM
对象的trainNetwork
函数。
[
更新网络的可学习参数netUpdated
,averageGrad
,averageSqGrad
= adamupdate(网
,研究生
,averageGrad
,averageSqGrad
,迭代
)网
使用亚当算法。在训练循环中使用此语法可迭代更新定义为的网络dlnetwork
对象。
[
中更新可学习参数参数个数
,averageGrad
,averageSqGrad
= adamupdate(参数个数
,研究生
,averageGrad
,averageSqGrad
,迭代
)参数个数
使用亚当算法。在训练循环中使用此语法迭代更新使用函数定义的网络的可学习参数。
[___= adamupdate(___
还指定用于全局学习率、梯度衰减、平方梯度衰减和小常数epsilon的值,以及前面语法中的输入参数。learnRate
,gradDecay
,sqGradDecay
,ε
)
例子
使用更新可学习参数adamupdate
执行单个自适应时刻估计更新步骤,全局学习率为0.05
的梯度衰减因子0.75
的梯度衰减因子平方0.95
.
将参数和参数梯度创建为数值数组。
Params = rand(3,3,4);Grad = ones(3,3,4);
初始化第一次迭代的迭代计数器、平均梯度和平均梯度平方。
迭代= 1;averageGrad = [];averageSqGrad = [];
指定全局学习率、梯度衰减因子和平方梯度衰减因子的自定义值。
learnRate = 0.05;gradDecay = 0.75;sqGradDecay = 0.95;
使用更新可学习参数adamupdate
.
[params,averageGrad,averageSqGrad] = adamupdate(params,grad,averageGrad,averageSqGrad,iteration,learnRate,gradDecay,sqGradDecay);
更新迭代计数器。
迭代=迭代+ 1;
列车网络使用adamupdate
使用adamupdate
用亚当算法训练一个网络。
负荷训练数据
加载数字训练数据。
[XTrain,TTrain] = digitTrain4DArrayData;类=类别(TTrain);numClasses = nummel(类);
定义网络
属性定义网络并指定平均图像值的意思是
选项在图像输入层。
图层= [imageInputLayer([28 28 1],“的意思是”reluLayer卷积2dlayer (3,20),“填充”,1) relullayer卷积2dlayer (3,20,“填充”,1) reluLayer fullyConnectedLayer(numClasses) softmaxLayer];
创建一个dlnetwork
对象。
Net = dlnetwork(layers);
定义模型损失函数
创建helper函数modelLoss
,在示例的末尾列出。函数的参数为dlnetwork
对象和带有相应标签的小批输入数据,并返回损失和损失相对于可学习参数的梯度。
指定培训项目
指定在培训期间使用的选项。
miniBatchSize = 128;numEpochs = 20;numObservations = numel(TTrain);numIterationsPerEpoch = floor(numObservations./miniBatchSize);
列车网络的
初始化平均梯度和平均梯度的平方。
averageGrad = [];averageSqGrad = [];
计算训练进度监控器的总迭代次数。
numIterations = nummepochs * numIterationsPerEpoch;
初始化TrainingProgressMonitor
对象。因为计时器在创建监视器对象时开始,所以请确保创建的对象接近训练循环。
monitor = trainingProgressMonitor(指标=“损失”信息=“时代”包含=“迭代”);
使用自定义训练循环训练模型。对于每个纪元,洗牌数据并在小批量数据上循环。方法更新网络参数adamupdate
函数。在每次迭代结束时,显示训练进度。
在GPU上训练(如果有的话)。使用GPU需要并行计算工具箱™和受支持的GPU设备。万博1manbetx有关受支持设备的信息,请参见万博1manbetxGPU计算要求(并行计算工具箱).
迭代= 0;Epoch = 0;而epoch < numEpochs && ~monitor。停止epoch = epoch + 1;% Shuffle数据。idx = randperm(数字(TTrain));XTrain = XTrain(:,:,:,idx);TTrain = TTrain(idx);I = 0;而i < numIterationsPerEpoch && ~monitor。Stop i = i + 1;迭代=迭代+ 1;读取小批数据并将标签转换为虚拟标签%变量。idx = (i-1)*miniBatchSize+1:i*miniBatchSize;X = XTrain(:,:,:,idx);T = 0 (numClasses, miniBatchSize,“单身”);为T(c,TTrain(idx)==classes(c)) = 1;结束将小批数据转换为大数组。X = dlarray(single(X),“SSCB”);如果在GPU上训练,则将数据转换为gpuArray。如果canUseGPU X = gpuArray(X);结束计算模型损失和梯度使用dlfeval和% modelLoss函数。[loss,gradients] = dlfeval(@modelLoss,net,X,T);使用Adam优化器更新网络参数。。[net,averageGrad,averageSqGrad] = adamupdate(net,gradients,averageGrad,averageSqGrad,iteration);更新培训进度监视器。recordMetrics(监控、迭代损失=损失);updateInfo(监视、时代=时代+“的”+ numEpochs);班长。进度= 100 * iteration/numIterations;结束结束
测试网络
通过比较测试集上的预测与真实标签来测试模型的分类准确性。
[XTest,TTest] = digitTest4DArrayData;
将数据转换为adlarray
使用维度格式“SSCB”
(空间,空间,通道,批次)。对于GPU预测,也将数据转换为agpuArray
.
XTest = dlarray(XTest,“SSCB”);如果canUseGPU XTest = gpuArray(XTest);结束
对图像进行分类dlnetwork
对象时,使用预测
计算并找出得分最高的课程。
YTest =预测(net,XTest);[~,idx] = max(extractdata(YTest),[],1);YTest = classes(idx);
评估分类准确率。
精度=平均值(YTest==TTest)
准确度= 0.9896
模型损失函数
的modelLoss
Helper函数接受dlnetwork
对象网
和一小批输入数据X
有相应的标签T
,并返回损失以及损失相对于中可学习参数的梯度网
.要自动计算梯度,请使用dlgradient
函数。
函数[loss,gradients] = modelLoss(net,X,T) Y = forward(net,X);损失=交叉熵(Y,T);gradients = dlgradient(loss,net.Learnables);结束
输入参数
网
- - - - - -网络
dlnetwork
对象
网络,指定为dlnetwork
对象。
函数更新可学的
的属性dlnetwork
对象。网可学的
是一个包含三个变量的表:
层
-层名,指定为字符串标量。参数
—参数名称,指定为字符串标量。价值
参数的值,指定为包含dlarray
.
输入参数研究生
一定是和?一样形式的表网可学的
.
参数个数
- - - - - -网络可学习参数
dlarray
|数字数组|单元阵列|结构|表格
网络可学习参数,指定为dlarray
、数字数组、单元格数组、结构体或表。
如果你指定参数个数
作为一个表,它必须包含以下三个变量:
层
-层名,指定为字符串标量。参数
—参数名称,指定为字符串标量。价值
参数的值,指定为包含dlarray
.
你可以指定参数个数
作为使用单元格数组、结构、表或嵌套单元格数组或结构的网络可学习参数的容器。单元格数组、结构或表中的可学习参数必须为dlarray
或数据类型的数值双
或单
.
输入参数研究生
必须提供与?完全相同的数据类型、顺序和字段(用于结构)或变量(用于表)参数个数
.
数据类型:单
|双
|结构体
|表格
|细胞
研究生
- - - - - -损失的梯度
dlarray
|数字数组|单元阵列|结构|表格
损耗的梯度,指定为adlarray
、数字数组、单元格数组、结构体或表。
确切的形式研究生
取决于输入网络或可学习参数。下表显示了所需的格式研究生
可能的输入adamupdate
.
输入 | 可学的参数 | 梯度 |
---|---|---|
网 |
表格网可学的 包含层 ,参数 ,价值 变量。的价值 变量由单元格数组组成,单元格数组包含每个可学习参数dlarray . |
表具有相同的数据类型、变量和排序网可学的 .研究生 必须有一个价值 由包含每个可学习参数梯度的单元格数组组成的变量。 |
参数个数 |
dlarray |
dlarray 使用相同的数据类型和顺序参数个数 |
数字数组 | 具有相同数据类型和顺序的数值数组参数个数 |
|
单元阵列 | 单元格数组,具有相同的数据类型、结构和顺序参数个数 |
|
结构 | 结构,具有相同的数据类型、字段和排序参数个数 |
|
表层 ,参数 ,价值 变量。的价值 变量必须由单元格数组组成,其中包含每个可学习参数dlarray . |
表具有相同的数据类型、变量和排序参数个数 .研究生 必须有一个价值 由包含每个可学习参数梯度的单元格数组组成的变量。 |
你可以获得研究生
从电话到dlfeval
对包含调用的函数求值dlgradient
.有关更多信息,请参见在深度学习工具箱中使用自动区分.
averageGrad
- - - - - -参数梯度的移动平均
[]
|dlarray
|数字数组|单元阵列|结构|表格
参数梯度的移动平均,指定为空数组,adlarray
、数字数组、单元格数组、结构体或表。
确切的形式averageGrad
取决于输入网络或可学习参数。下表显示了所需的格式averageGrad
可能的输入adamupdate
.
输入 | 可学的参数 | 平均梯度 |
---|---|---|
网 |
表格网可学的 包含层 ,参数 ,价值 变量。的价值 变量由单元格数组组成,单元格数组包含每个可学习参数dlarray . |
表具有相同的数据类型、变量和排序网可学的 .averageGrad 必须有一个价值 由包含每个可学习参数的平均梯度的单元格数组组成的变量。 |
参数个数 |
dlarray |
dlarray 使用相同的数据类型和顺序参数个数 |
数字数组 | 具有相同数据类型和顺序的数值数组参数个数 |
|
单元阵列 | 单元格数组,具有相同的数据类型、结构和顺序参数个数 |
|
结构 | 结构,具有相同的数据类型、字段和排序参数个数 |
|
表层 ,参数 ,价值 变量。的价值 变量必须由单元格数组组成,其中包含每个可学习参数dlarray . |
表具有相同的数据类型、变量和排序参数个数 .averageGrad 必须有一个价值 由包含每个可学习参数的平均梯度的单元格数组组成的变量。 |
如果你指定averageGrad
而且averageSqGrad
作为空数组,函数假设没有之前的渐变,并以与一系列迭代中的第一次更新相同的方式运行。要迭代地更新可学习参数,请使用averageGrad
的前一次调用的输出adamupdate
随着averageGrad
输入。
averageSqGrad
- - - - - -参数梯度平方的移动平均
[]
|dlarray
|数字数组|单元阵列|结构|表格
参数梯度平方的移动平均,指定为空数组,adlarray
、数字数组、单元格数组、结构体或表。
确切的形式averageSqGrad
取决于输入网络或可学习参数。下表显示了所需的格式averageSqGrad
可能的输入adamupdate
.
输入 | 可学的参数 | 平均梯度平方 |
---|---|---|
网 |
表格网可学的 包含层 ,参数 ,价值 变量。的价值 变量由单元格数组组成,单元格数组包含每个可学习参数dlarray . |
表具有相同的数据类型、变量和排序网可学的 .averageSqGrad 必须有一个价值 由包含每个可学习参数的平均梯度平方的单元格数组组成的变量。 |
参数个数 |
dlarray |
dlarray 使用相同的数据类型和顺序参数个数 |
数字数组 | 具有相同数据类型和顺序的数值数组参数个数 |
|
单元阵列 | 单元格数组,具有相同的数据类型、结构和顺序参数个数 |
|
结构 | 结构,具有相同的数据类型、字段和排序参数个数 |
|
表层 ,参数 ,价值 变量。的价值 变量必须由单元格数组组成,其中包含每个可学习参数dlarray . |
表具有相同的数据类型、变量和排序方式参数个数 .averageSqGrad 必须有一个价值 由包含每个可学习参数的平均梯度平方的单元格数组组成的变量。 |
如果你指定averageGrad
而且averageSqGrad
作为空数组,函数假设没有之前的渐变,并以与一系列迭代中的第一次更新相同的方式运行。要迭代地更新可学习参数,请使用averageSqGrad
的前一次调用的输出adamupdate
随着averageSqGrad
输入。
迭代
- - - - - -迭代数
正整数
迭代数,指定为正整数。第一次打电话给adamupdate
,使用值1
.你必须增加迭代
通过1
的一系列调用中的每个后续调用adamupdate
.Adam算法使用这个值来校正一组迭代开始时移动平均线中的偏差。
learnRate
- - - - - -全球学习率
0.001
(默认)|积极的标量
全局学习率,指定为正标量。的默认值learnRate
是0.001
.
如果指定网络参数为adlnetwork
,每个参数的学习率为全局学习率乘以网络层中定义的相应学习率因子属性。
gradDecay
- - - - - -梯度衰减因子
0.9
(默认)|之间的正标量0
而且1
梯度衰减因子,指定为之间的正标量0
而且1
.的默认值gradDecay
是0.9
.
sqGradDecay
- - - - - -平方梯度衰减因子
0.999
(默认)|之间的正标量0
而且1
梯度衰减因子的平方,指定为之间的正标量0
而且1
.的默认值sqGradDecay
是0.999
.
ε
- - - - - -小的常数
1 e-8
(默认)|积极的标量
用于防止被零除错误的小常数,指定为正标量。的默认值ε
是1 e-8
.
输出参数
netUpdated
-更新网络
dlnetwork
对象
更新后的网络,返回为dlnetwork
对象。
函数更新可学的
的属性dlnetwork
对象。
参数个数
—更新网络可学习参数
dlarray
|数字数组|单元数组|结构|表
更新网络可学习参数,返回为dlarray
类型的数字数组、单元格数组、结构体或表价值
变量,包含网络更新后的可学习参数。
averageGrad
-更新参数梯度的移动平均
dlarray
|数字数组|单元数组|结构|表
参数梯度的更新移动平均值,返回为dlarray
、数字数组、单元格数组、结构体或表。
averageSqGrad
-更新了参数梯度平方的移动平均值
dlarray
|数字数组|单元数组|结构|表
更新的参数梯度平方的移动平均值,返回为adlarray
、数字数组、单元格数组、结构体或表。
更多关于
亚当
该函数使用自适应矩估计(Adam)算法来更新可学习参数。有关更多信息,请参阅下面的Adam算法的定义随机梯度下降在trainingOptions
参考页面。
扩展功能
GPU数组
通过使用并行计算工具箱™在图形处理单元(GPU)上运行来加速代码。
使用注意事项和限制:
当以下输入参数中至少有一个是
gpuArray
或者一个dlarray
类型的底层数据gpuArray
,该函数运行在GPU上。研究生
averageGrad
averageSqGrad
参数个数
有关更多信息,请参见在图形处理器上运行MATLAB函数(并行计算工具箱).
版本历史
R2019b引入
Abrir比如
Tiene una versión modificada de este ejemplo。¿Desea abrir este ejemplo con sus modificaciones?
MATLAB突击队
Ha hecho clic en unenlace que对应一个este commando de MATLAB:
弹射突击队introduciéndolo en la ventana de commandos de MATLAB。Los navegadores web no permission comandos de MATLAB。
您也可以从以下列表中选择一个网站:
如何获得最佳的网站性能
选择中国站点(中文或英文)以获得最佳站点性能。其他MathWorks国家站点没有针对您所在位置的访问进行优化。