这个例子展示了如何使用数据集来找出是什么激活了深层神经网络的通道。这使您能够理解神经网络是如何工作的,并使用训练数据集诊断潜在问题。
这个例子介绍了许多简单的可视化技术,使用GoogLeNet在食物数据集上学习的转换。通过查看最大限度或最小限度地激活分类器的图像,您可以发现神经网络分类错误的原因。
将图像加载为图像数据存储。这个小数据集包含了9类食物的978个观察结果。
将这些数据分解为训练、验证和测试集,为使用GoogLeNet进行迁移学习做准备。显示从数据集中选择的图像。
rng默认的dataDir = fullfile (tempdir,“食品数据集”);url ="//www.tianjin-qmedu.com/万博1manbetxsupportfiles/nnet/data/ExampleFoodImageDataset.zip";如果~存在(dataDir“dir”)MKDIR(DATADIR);结束DownloadexampleFoodImimagesData(URL,Datadir);
下载MathWorks示例食物图像数据集…这可能需要几分钟下载…下载完成了…将文件解压缩……解压缩完成……完成了。
imds = imagedataStore(Datadir,...“IncludeSubfolders”,真的,“LabelSource”,“foldernames”);[imdsTrain, imdsValidation imdsTest] = splitEachLabel (imd, 0.6, 0.2);rnd = randperm(元素个数(imds.Files), 9);为i = 1:numel(rnd)子图(3,3,i)imshow(imread(imds.files {rnd(i)}))标签= imds.labels(RND(i));标题(标签,“翻译”,“没有”)结束
使用预先训练过的GoogLeNet网络,并再次训练它来分类9种食物。如果你没有深度学习工具箱™模型GoogLeNet网络万博1manbetx支持包安装,然后软件提供下载链接。
为了尝试一个不同的预训练网络,在MATLAB®中打开这个例子,选择一个不同的网络,例如squeezenet
,这个网络甚至比googlenet
.有关所有可用网络的列表,请参见预训练深度神经网络.
网= googlenet;
第一个元素是层
网络的属性是图像的输入层。这一层需要输入尺寸为224 × 224 × 3的图像,其中3是彩色通道的数量。
inputSize = net.Layers (1) .InputSize;
网络的卷积层提取图像特征,最后一个可学习层和最后一个分类层用来对输入图像进行分类。这两个层,“loss3-classifier”
和“输出”
在GoogLeNet中,包含关于如何将网络提取的特征组合成类别概率、损失值和预测标签的信息。为了训练一个预先训练的网络来分类新的图像,将这两层替换为适应新数据集的新层。
从训练好的网络中提取层图。
Lgraph = LayerGraph(网);
在大多数网络中,具有可读权重的最后一层是完全连接的图层。使用新的完全连接的图层替换此完全连接的图层,其中输出数等于新数据集中的类(9,本示例中的9)。
numClasses =元素个数(类别(imdsTrain.Labels));newfclayer = fullyConnectedLayer (numClasses,...“名字”,“new_fc”,...“WeightLearnRateFactor”10...“BiasLearnRateFactor”10);lgraph = replaceLayer (lgraph net.Layers (end-2) . name, newfclayer);
分类层指定网络的输出类。用没有类标签的新一个替换分类层。列车网络
在训练时自动设置层的输出类。
newclasslayer=classificationLayer(“名字”,“new_classoutput”);lgraph = replaceLayer (lgraph net.Layers(结束). name, newclasslayer);
网络需要大小为224 × 224 × 3的输入图像,但图像数据存储中的图像大小不同。使用扩充图像数据存储来自动调整训练图像的大小。指定要对训练图像执行的附加增强操作:沿着垂直轴随机翻转训练图像,随机将其平移至30像素,并将其水平和垂直缩放至10%。数据增强有助于防止网络过度拟合和记忆训练图像的确切细节。
pixelRange = [-30 30];scaleRange = [0.9 1.1];imageAugmenter = imageDataAugmenter (...'randxreflection',真的,...“RandXTranslation”pixelRange,...“兰迪翻译”pixelRange,...“RandXScale”scaleRange,...“RandYScale”,scalerage);augimdsTrain=增强的图像数据存储(inputSize(1:2),imdsTrain,...“DataAugmentation”图像增强仪);
若要自动调整验证图像的大小而不执行进一步的数据扩展,请使用扩展后的图像数据存储,而不指定任何额外的预处理操作。
augimdsValidation = augmentedImageDatastore (inputSize (1:2), imdsValidation);
指定培训选项。集初始学习率
在尚未冷冻的转移层中减慢学习的小值。在上一步中,您增加了最后一次学习层的学习率因子,以加速新的最终层的学习。这种学习率设置的组合导致新层次的快速学习,中间层的学习速度较慢,较早的冰冻层。
指定要训练的纪元数。当进行迁移学习时,你不需要训练那么多的纪元。epoch是整个训练数据集上的一个完整的训练周期。指定小批量大小和验证数据。每个epoch计算一次验证精度。
miniBatchSize=10;valFrequency=floor(numel(augimdsTrain.Files)/miniBatchSize);options=trainingOptions(“个”,...“MiniBatchSize”miniBatchSize,...“MaxEpochs”,4,...“初始学习率”,3e-4,...“洗牌”,“every-epoch”,...“验证数据”augimdsValidation,...“ValidationFrequency”,valFrequency,...“详细”错误的...“阴谋”,“训练进步”);
使用训练数据对网络进行训练。默认情况下,列车网络
如果GPU可用,则使用GPU。这需要并行计算工具箱™和支持的GPU设备。万博1manbetx有关支持的设备的信息,请参见万博1manbetxGPU支万博1manbetx持情况(并行计算工具箱)否则列车网络
使用CPU。您还可以使用使用指定执行环境“ExecutionEnvironment”
的名称-值对参数培训选项
.由于这个数据集很小,训练速度很快。如果您运行这个示例并亲自训练网络,您将得到不同的结果和由于训练过程中涉及的随机性造成的错误分类。
net = trainnetwork(augimdstrain,3.1页,选项);
利用微调网络对测试图像进行分类,并计算分类精度。
AugimdTest=增强的图像数据存储(inputSize(1:2),ImdTest);[PredictedClass,predictedScores]=分类(net,AugimdTest);精度=平均值(PredictedClass==imdTest.Labels)
精度= 0.8418
绘制一个测试集预测的混淆矩阵。这突出显示了哪些特定的类对网络造成了最大的问题。
图;confusionchart (imdsTest。标签,predictedClasses,“归一化”,“row-normalized”);
混乱矩阵表明,网络对某些课程的表现不佳,例如希腊沙拉,生鱼片,热狗和寿司。这些类在数据集中经常代表,这可能正在影响网络性能。调查其中一个课程,以更好地理解为什么网络正在努力。
图();直方图(imdsValidation.Labels);甘氨胆酸ax = ();ax.XAxis.TickLabelInterpreter =“没有”;
调查寿司类的网络分类。
首先,找出哪些寿司图片最强烈地激活了寿司课的网络。这就回答了“该网站认为哪些图片最像寿司?”
绘制最大激活的图像,这些是强烈激活全连接层的“寿司”神经元的输入图像。这张图显示了排名靠前的4张图片,以降序排列。
chosenClass =“寿司”;classIdx=find(net.Layers(end).Classes==chosenClass);numImgsToShow=4;[sortedScores,imgIdx]=findMaxActivatingImages(imdsTest,chosenClass,predictedScores,numImgsToShow);地物绘图图像(imdsTest,imgIdx,sortedScores,predictedClasses,numimmgstoshow)
对于寿司来说,网络是正确的选择吗?网络中最活跃的寿司类图片看起来都很相似——许多圆形紧密地聚集在一起。
该网站在分类这类寿司方面做得很好。然而,为了验证这一点并更好地理解为什么网络会做出它的决定,可以使用像gradc - cam这样的可视化技术。有关使用gradcam的更多信息,请参见gradcam揭示了Deep Learning决策背后的原因.
从增强图像数据存储中读取第一个调整大小的图像,然后使用gradCAM
.
imageNumber=1;观察值=augimdtest.readByIndex(imgIdx(imageNumber));img=观察。输入{1};label=predictedClasses(imgIdx(imageNumber));分数=分类的核心(图像编号);gradcamMap=gradCAM(网络、img、标签);图α=0.5;plotGradCAM(img、gradcamMap、alpha);sgtitle(字符串(标签)+”(分数:+ max(得分)+“)”)
Grad-CAM地图确认网络在图像中的寿司上集中。但是,您还可以看到网络正在查看板块和桌子的部分。
第二张图的左边是一簇寿司,右边是一个单独的寿司。要查看网络关注的内容,请阅读第二张图并绘制Grad CAM。
imageNumber=2;observation=AugimdTest.readByIndex(imgIdx(imageNumber));img=observation.input{1};label=predictedClasses(imgIdx(imageNumber));score=sortedScores(imageNumber);gradcamMap=gradCAM(net,img,label);图形绘图gradCAM(img,gradcamMap,alpha);sgtitle(字符串(label)+”(分数:+ max(得分)+“)”)
该网站将这张图片归类为寿司,因为它看到了一组寿司。然而,它是否能够单独对一种寿司进行分类呢?看一张寿司的图片来测试这一点。
img = imread(strcat(tempdir,“食品数据集/寿司/ sushi_18.jpg”));img = imresize (img net.Layers (1) .InputSize (1:2),“方法”,“双线性”,“抗锯齿”,真正的);(标签,分数)=(净,img)进行分类;gradcamMap = gradCAM(净、img标签);图alpha = 0.5;plotGradCAM (img, gradcamMapα);sgtitle(字符串(标签)+”(分数:+ max(得分)+“)”)
该网络能够正确地对这一种寿司进行分类。然而,GradCAM显示出网络集中在寿司顶部和黄瓜簇上,而不是整块。
在孤独的寿司上运行Grad-Cam可视化技术,该技术不含任何堆积的小块成分。
img = imread (“crop__sushi34-copy.jpg”);img = imresize (img net.Layers (1) .InputSize (1:2),“方法”,“双线性”,“抗锯齿”,真正的);(标签,分数)=(净,img)进行分类;gradcamMap = gradCAM(净、img标签);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(string(标签)+”(分数:+ max(得分)+“)”)
在这种情况下,可视化技术突出了网络表现不佳的原因。它错误地将寿司图像分类为汉堡包。
为了解决这个问题,您必须在训练过程中向网络提供更多的单个寿司的图像。
现在找出哪些寿司图片对寿司课的网络激活最少。这就回答了“网络认为哪些图片不太像寿司?”
这是有用的,因为它可以找到网络性能较差的图像,并提供一些对其决策的了解。
chosenClass =“寿司”;numImgsToShow = 9;[sortedScores, imgIdx] = findMinActivatingImages (imdsTest、chosenClass predictedScores, numImgsToShow);图plotImages (imdsTest imgIdx、sortedScores predictedClasses, numImgsToShow)
为什么网络将寿司归类为生鱼片?该网站将9张图片中的3张分类为生鱼片。其中一些图像,例如图像4和9,实际上包含了生鱼片,这意味着网络实际上并没有错误地对它们进行分类。这些图片被标错了。
要查看网络的焦点是什么,在其中一张图像上运行grado - cam技术。
imageNumber = 4;观察= augimdsTest.readByIndex (imgIdx (imageNumber));img = observation.input {1};标签= predictedClasses (imgIdx (imageNumber));分数= sortedScores (imageNumber);gradcamMap = gradCAM(净、img标签);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(string(标签)+(寿司评分:)+ max(得分)+“)”)
正如所料,该网络关注的是生鱼片而不是寿司。
为什么这个网络把寿司归类为披萨?该网站将其中四张图片分类为披萨而不是寿司。以图像1为例,该图像有一个彩色的顶点,这可能会混淆网络。
要查看网络正在查看的图像的哪一部分,请在其中一个图像上运行Grad CAM技术。
imageNumber = 1;观察= augimdsTest.readByIndex (imgIdx (imageNumber));img = observation.input {1};标签= predictedClasses (imgIdx (imageNumber));分数= sortedScores (imageNumber);gradcamMap = gradCAM(净、img标签);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(string(标签)+(寿司评分:)+ max(得分)+“)”)
这家电视台非常关注浇头。为了帮助网络区分披萨和带配料的寿司,添加更多带配料的寿司训练图像。网络也聚焦于板块。这可能是因为网络已经学会了将特定的食物与特定类型的盘子联系起来,在看寿司图片时也突出显示了这一点。为了提高网络的表现,训练时使用更多不同类型盘子上的食物例子。
为什么这个网站把寿司归类为汉堡?要查看网络专注的内容,请在错误分类的图像上运行渐变凸轮技术。
imageNumber = 2;观察= augimdsTest.readByIndex (imgIdx (imageNumber));img = observation.input {1};标签= predictedClasses (imgIdx (imageNumber));分数= sortedScores (imageNumber);gradcamMap = gradCAM(净、img标签);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(string(标签)+(寿司评分:)+ max(得分)+“)”)
网络聚焦在图像中的花朵上。紫色的花和棕色的茎混淆了网络,把这个图像识别为汉堡。
为什么这个网站把寿司归类为炸薯条?该电视台将第三张图片分类为炸薯条而不是寿司。这种寿司的顶部是黄色的,网络可能会把这种颜色和薯条联系起来。
在该图像上运行Grad CAM。
imageNumber = 3;观察= augimdsTest.readByIndex (imgIdx (imageNumber));img = observation.input {1};标签= predictedClasses (imgIdx (imageNumber));分数= sortedScores (imageNumber);gradcamMap = gradCAM(净、img标签);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(string(标签)+(寿司评分:)+ max(得分)+“)”,“翻译”,“没有”)
网络关注黄色寿司,将其归类为薯条。与汉堡一样,这种不寻常的颜色导致网络对寿司进行错误分类。
为了帮助网络在这个特定的情况下,用更多黄色食物的图像训练它,而不是炸薯条。
调查导致大或小等级分数的数据点,以及网络自信但不正确分类的数据点,是一种简单的技术,可以提供有用的洞察力,了解经过培训的网络是如何运作的。在食品数据集的情况下,本例强调:
测试数据包含了一些带有错误真实标签的图像,比如“生鱼片”实际上是“寿司”。数据还包含不完整的标签,例如同时包含寿司和生鱼片的图像。
该网络认为“寿司”成为“多元,聚集,圆形的东西”。但是,它必须能够区分孤独的寿司。
任何带有配料或不寻常颜色的寿司或生鱼片都会混淆网络。要解决这个问题,数据中必须包含更多种类的寿司和生鱼片。
为了提高性能,网络需要看到更多来自未充分表示的类的图像。
作用下载ExampleFoodImagesData(url,dataDir)%下载Example Food Image数据集,包含978张图片%不同种类的食物分为9类。%版权所有2019 Mathworks,Inc。文件名=“ExampleFoodImageDataset.zip”;fileFullPath = fullfile (dataDir,文件名);%下载。zip文件到临时目录。如果~exist(fileFullPath,“文件”)fprintf("下载MathWorks示例食物图像数据集…\n");流(“这可能需要几分钟下载…\n”); websave(fileFullPath,url);fprintf(“下载完成…\n”);其他的流("跳过下载,文件已经存在…\n");结束%解压文件。%通过检查文件是否存在来检查文件是否已经解压缩一个类目录的%。exampleFolderFullPath = fullfile (dataDir,“披萨”);如果~存在(exampleFolderFullPath“dir”)fprintf(“正在解压缩文件…\n”);解压缩(fileFullPath,dataDir);fprintf(“解压缩完成...... \ n”);其他的流("跳过解压缩,文件已经解压缩…\n");结束流(“完成。\ n”);结束作用[sortedScores, imgIdx] = findMaxActivatingImages (imd,类名,predictedScores numImgsToShow)%在所选班级的所有图片上查找所选班级的预测分数%(例如,寿司在所有寿司图像上的预测得分)[scoresForChosenClass, imgsOfClassIdxs] = findScoresForChosenClass (imd,类名,predictedScores);%按降序排列分数[sortedScores, idx] =排序(scoresForChosenClass,“下”);%只返回前几个索引imgidx = Imgsofclassidxs(IDX(1:Numimgstoshow));结束作用[sortedScores, imgIdx] = findMinActivatingImages (imd,类名,predictedScores numImgsToShow)%在所选班级的所有图片上查找所选班级的预测分数%(例如,寿司在所有寿司图像上的预测得分)[scoresForChosenClass, imgsOfClassIdxs] = findScoresForChosenClass (imd,类名,predictedScores);%将分数按升序排序[sortedScores, idx] =排序(scoresForChosenClass,“提升”);%只返回前几个索引imgidx = Imgsofclassidxs(IDX(1:Numimgstoshow));结束作用[scoresforchosenclass,imgsofclassidxs] = findscoresforchosenclass(IMDS,ClassName,PrediceScores)%查找类名的索引(例如,“sushi”是第9个类)uniqueClasses =独特(imds.Labels);select * from table_name (table_name) = table_name (table_name);%在imageDatastore中查找标签为“className”的图像的索引%(例如查找阶级寿司的所有图像)imgsOfClassIdxs =找到(imd)。标签= =类名);在所有的图像上找到所选班级的预测分数%选择类%(例如,寿司在所有寿司图像上的预测得分)scoresForChosenClass=预测得分(imgsOfClassIdxs,chosenClassIdx);结束作用plotImages (imd, imgIdx sortedScores、predictedClasses numImgsToShow)为i=1:numImgsToShow score = sortedScores(i);sortedImgIdx = imgIdx(我);predClass = predictedClasses (sortedImgIdx);correctClass = imds.Labels (sortedImgIdx);imgPath = imds.Files {sortedImgIdx};如果predClass==correctClass颜色=“{绿}\颜色”;其他的颜色=“\红色}”;结束predClassTitle = strrep (string (predClass),“_”,' ');trickClasstitle = strrep(字符串(trictrusclass),“_”,' ');次要情节(装天花板(numImgsToShow. / 3), i) imshow (imread (imgPath));标题(预测:“+ color + predClassTitle +”{黑}\换行符\颜色得分:“+ + num2str(得分)“\ newlineground truth:”+(课程名称);结束结束作用plotGradCAM (img, gradcamMapα)次要情节(1、2、1)imshow (img);h =情节(1、2、2);imshow (img)在;imagesc(gradcamMap,“字母数据”、α);originalSize2 =得到(h,'位置');colormap飞机colorbar集(h,'位置', originalSize2);持有关;结束
googlenet
|imageageAtastore.
|augmentedImageDatastore
|confusionchart
|dlnetwork
|分类
|咬合敏感度
|gradCAM
|imageLIME