这个例子展示了如何使用转移学习对经过预处理的卷积神经网络SqueezeNet进行再训练,从而对一组新的图像进行分类。尝试这个例子,看看它是多么简单,开始在MATLAB®深入学习。
在深度学习应用中,迁移学习是一种常用的学习方法。你可以把一个预先训练好的网络作为学习新任务的起点。与从零开始使用随机初始化的权值训练网络相比,使用转移学习对网络进行微调通常要快得多,也容易得多。您可以使用较少的训练图像将学到的特性快速转移到新任务中。
在工作区中,提取MathWorks合并数据集。这是一个小数据集,包含MathWorks商品的75个图像,属于5个不同的类(帽,多维数据集,打牌,螺丝刀,火炬)。
解压缩(“MerchData.zip”);
开深网设计师。
deepNetworkDesigner
选择SqueezeNet从列表中预先训练的网络和点击开放.
深度网络设计器显示整个网络的缩小视图。
探索网络情节。要使用鼠标放大Ctrl+滚轮。要平移,可以使用箭头键,或者按住滚轮并拖动鼠标。选择一个层来查看它的属性。取消选中所有层以查看“。”中的网络摘要属性窗格。
将数据加载到深度网络设计器中数据选项卡上,单击导入数据.“导入数据”对话框打开。
在数据源列表中,选择文件夹.点击浏览并选择已提取的商业数据文件夹。
将数据分为70%的培训数据和30%的验证数据。
指定要在训练图像上执行的增强操作。数据扩充有助于防止网络过度拟合和记忆训练图像的准确细节。对于本例,在x轴上应用随机反射,从范围[-90,90]角度随机旋转,从范围[1,2]随机重新缩放。
点击进口将数据导入深度网络设计器。
为了重新训练SqueezeNet对新图像进行分类,需要替换掉网络的最后一个二维卷积层和最后一个分类层。在SqueezeNet中,这些层有名称“conv10”
和“ClassificationLayer_predictions”
,分别。
在设计师窗格,拖动一个新的convolutional2dLayer
到画布上。为了匹配原始的卷积层,设置FilterSize
来1,- 1
.编辑NumFilters
表示新数据中的类数,在本例中,5
.
通过设置set,改变学习速率,使新层的学习速度快于转移层的学习速度WeightLearnRateFactor
和BiasLearnRateFactor
到10。
删除最后一个2d卷积层,然后连接新层。
替换输出层。滚动到的结尾层的图书馆拖拽一个新的classificationLayer
到画布上。删除原来的输出层,然后在原来的地方连接新层。
若要选择培训选项,请选择培训选项卡并单击培训方案.集InitialLearnRate
到一个小的值,以减缓学习的层次转移。在前面的步骤中,您增加了2-D卷积层的学习速率因子,以加快在新最终层中的学习。这种学习速率设置的组合只在新层中产生快速学习,而在其他层中产生较慢的学习。
对于本例,setInitialLearnRate
来0.0001
,ValidationFrequency
来5
,MaxEpochs
来8
.由于有55个观测值,集合MiniBatchSize
将训练数据平均分配,确保每个epoch使用整个数据集。
要使用指定的训练选项来训练网络,单击关闭然后点击火车.
深度网络设计器允许您可视化和监控培训进度。然后,如果需要,您可以编辑培训选项并重新培训网络。
输出培训的结果,就可以了培训选项卡上,选择输出>输出训练网络和结果.深度网络设计者将训练好的网络作为变量导出trainedNetwork_1
并将训练信息作为变量trainInfoStruct_1
.
您还可以生成MATLAB代码,它将重新创建所使用的网络和培训选项。在培训选项卡上,选择导出>生成培训代码.研究MATLAB代码,了解如何以编程方式准备培训数据、创建网络体系结构和培训网络。
利用训练后的网络加载新图像进行分类。
我= imread (“MerchDataTest.jpg”);
调整测试图像的大小以匹配网络输入的大小。
I = imresize(I, [227 227]);
利用训练后的网络对测试图像进行分类。
(YPred,聚合氯化铝)= (trainedNetwork_1, I)进行分类;imshow(I) label = YPred;标题(string(标签)+", "+ num2str(100 *马克斯(聚合氯化铝),3)+“%”);
[1] Krizhevsky, Alex, Ilya Sutskever,和Geoffrey E. Hinton。深度卷积神经网络的ImageNet分类。神经信息处理系统的进展.2012.
[2]BVLC AlexNet模型.https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet
深层网络设计师|squeezenet
|trainNetwork
|trainingOptions