主要内容

查看网络行为使用tsne

此示例显示如何使用tsne命令功能查看经过训练的网络中的激活情况。这个视图可以帮助您理解网络是如何工作的。

这个tsne(统计和机器学习工具箱)统计和机器学习工具箱中的函数™ 实现t-分布随机邻居嵌入(t-SNE)[1]。该技术将高维数据(如层中的网络激活)映射到二维。该技术使用一种非线性映射,试图保持距离。通过使用t-SNE可视化网络激活,您可以了解网络如何响应。

您可以使用t-SNE可视化深度学习网络在输入数据通过网络层时如何改变其表示形式。您还可以使用t-SNE查找输入数据的问题,并了解网络对哪些观察结果分类错误。

例如,t-SNE可以将softmax层的多维激活减少为具有类似结构的二维表示。生成的t-SNE图中的紧密簇与网络通常正确分类的类相对应。可视化允许您找到出现在错误簇中的点,表示观察到t网络分类不正确。观测值可能标记不正确,或者网络可能预测某个观测值是另一类的实例,因为它与该类的其他观测值相似。请注意,softmax激活的t-SNE缩减仅使用这些激活,而不使用基础观测值a是的。

下载数据集

此示例使用示例食品图像数据集,该数据集包含九个类别的978张食品照片,大小约为77MB。通过调用下载示例FoodImagesDatahelper函数;此helper函数的代码显示在本例结束.

dataDir=fullfile(tempdir,“ExampleFoodImageDataset”);网址="//www.tianjin-qmedu.com/万博1manbetxsupportfiles/nnet/data/ExampleFoodImageDataset.zip";如果~exist(dataDir,“目录”)mkdir(dataDir);终止下载ExampleFoodImagesData(url,dataDir);
正在下载MathWorks示例食品图像数据集…下载可能需要几分钟时间…下载完成…解压缩文件…解压缩完成…完成。

训练网络对食品图像进行分类

修改挤压网预训练网络,对数据集中的食品图像进行分类。将最后一个卷积层替换为新的卷积层,新的卷积层只有九个过滤器,每个过滤器对应一种食品类型。

lgraph = layerGraph (squeezenet ());lgraph = lgraph.replaceLayer (“分类层预测”,...分类层(“姓名”,“分类层预测”)); newConv=卷积2dlayer([14],9,“姓名”,“conv”,“填充”,“相同”); lgraph=lgraph.replaceLayer(“conv10”,newConv);

创建一个图像数据存储包含图像数据的路径。将数据存储分割为训练集和验证集,65%的数据用于训练,其余数据用于验证。由于数据集相当小,过拟合是一个重要问题。为了最小化过拟合,使用随机翻转和缩放来增加训练集。

imds=图像数据存储(数据目录,...“包含子文件夹”符合事实的“标签源”,“文件夹名称”);8月= imageDataAugmenter (“随机选择”符合事实的...“随机反射”符合事实的...“RandXScale”, [0.8 1.2],...“RandYScale”[0.8 - 1.2]);trainingFraction = 0.65;[trainImds,valImds] = splitEachLabel(imds, trainingFraction);augImdsTrain = augmentedImageDatastore([227 227], trainImds,...“数据增强”augImdsVal=增强图像数据存储([227 227],valImds);

创建培训选项并培训网络。SqueezeNet是一个小型网络,可以快速培训。您可以在GPU或CPU上进行培训;此示例在CPU上进行培训。

opts=培训选项(“亚当”,...“初始学习率”,1e-4,...“最大时代”, 30,...“验证数据”,augImdsVal,...“冗长”错误的...“情节”,“培训进度”,...“执行环境”,“cpu”,...“最小批量大小”,128);rng违约net=列车网络(augImdsTrain、lgraph、opts);

验证数据进行分类

使用网络对验证集中的图像进行分类。要验证网络在分类新数据时是否合理准确,请绘制真实标签和预测标签的混淆矩阵。

figure();YPred=classify(net,augImdsVal);confusionchart(valImds.Labels,YPred,“专栏摘要”,“column-normalized”)

该网络对多张图像进行了很好的分类。该网络似乎对寿司图像存在问题,许多图像被分类为寿司,但一些图像被分类为比萨饼或汉堡包。该网络没有将任何图像分类为热狗类。

计算多个层的激活

要继续分析网络性能,请在早期最大池层、最终卷积层和最终softmax层计算数据集中每个观测值的激活。将激活输出为NxM矩阵,其中N为观察数,M为激活维度数。M是空间和通道尺寸的乘积。每行是一个观察值,每列是一个维度。在softmax层M=9,因为食品数据集有九个类。矩阵中的每一行包含九个元素,对应于一个观测值属于九类食物的概率。

earlyLayerName =“pool1”;最后一名=“conv”;softmaxLayerName=“问题”;Pool1激活=激活(净,...augImdsVal,earlyLayerName,“输出”,“行”);finalConvActivations =激活(网络,...augImdsVal finalConvLayerName,“输出”,“行”);softmaxActivations=激活(净,...augImdsVal、softmaxLayerName、,“输出”,“行”);

分类的模糊性

您可以使用softmax激活来计算最有可能是不正确的图像分类。定义模棱两可第二大概率与最大概率之比的分类。分类的模糊度介于0(几乎确定的分类)和1(几乎与第二类一样可能被分类为最有可能的分类)之间。模糊度接近1表示网络不确定特定图像所属的类别。这种不确定性可能是由于两个类别的观测值与网络非常相似,以至于无法了解它们之间的差异。或者,由于特定观测值包含的元素超过一个类别,因此网络无法决定哪种分类是正确的。请注意,低模糊度并不一定意味着正确的分类;即使网络对某个类别的概率很高,分类仍然可能是错误的。

[R,RI]=maxk(softmaxActivations,2,2);歧义=R(:,2)。/R(:,1);

找到最模糊的图像。

[歧义,歧义idx]=排序(歧义,“下降”);

查看模糊图像的最可能类别和真实类别。

班级名册=独特(valImds.Labels);top10Idx = ambiguityIdx (1:10);top10Ambiguity =模棱两可(1:10);最有=班级名册(RI (ambiguityIdx, 1));secondLikely =班级名册(RI (ambiguityIdx, 2));表(top10Idx top10Ambiguity,最有(1:10),secondLikely (1:10), valImds.Labels (ambiguityIdx (1:10)),...“变化无常”,[“图像#”,“歧义”,“最有可能”,“第二”,“真正的阶级”])
ans=10×5表UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUU寿司生鱼片生鱼片283 0.80407披萨寿司披萨27 0.80278汉堡包披萨法式炸薯条302 0.79283生鱼片寿司寿司201 0.76034披萨希腊沙拉披萨

该网络预测,图像27最有可能是汉堡包或比萨饼。然而,这张照片实际上是炸薯条。查看图片,看看为什么会发生这种错误分类。

v=27;figure();imshow(valImds.Files{v});title(sprintf(“观察:%i\n”+...“实际:%s。预测:%s”v...字符串(valImds.Labels(v)),字符串(YPred(v)),...“翻译”,“没有”);

图像包含几个不同的区域,其中一些区域可能会混淆网络。

使用t-SNE计算数据的二维表示

计算早期max池层、最终卷积层和最终softmax层的网络数据的低维表示形式。使用tsne函数将激活数据的维数从M降低到2。激活的维数越大,t-SNE计算所需的时间越长。因此,早期最大池层(激活有200704个维度)的计算时间比最终softmax层的计算时间长。为t-SNE结果的再现性设置随机种子。

rng违约pool1tsne=tsne(pool1Activations);finalConvtsne=tsne(finalConvActivations);softmaxtsne=tsne(softmaxActivations);

比较早期和后期层的网络行为

t-SNE技术试图保持距离,以便在高维表示中彼此接近的点在低维表示中也彼此接近。如混淆矩阵所示,该网络能够有效地进行分类。因此,语义相似(或相同类型)的图像,如caesar沙拉和caprese沙拉,在softmax激活空间中彼此接近。t-SNE用一个比九维softmax分数更容易理解和绘制的二维表示来捕捉这种接近性。

早期层倾向于对边缘和颜色等低级特征进行操作。深层层学习了具有更多语义含义的高级特征,例如比萨饼和热狗之间的差异。因此,早期层的激活不会按类别显示任何聚类。两个像素相似的图像(例如,它们都包含大量绿色像素)在激活的高维空间中相互靠近,而不管其语义内容如何。来自更高层次的激活倾向于将来自同一类的点聚集在一起。这种行为在softmax层最为明显,并保留在二维t-SNE表示中。

绘制早期最大池化层、最终卷积层和最终softmax层的t-SNE数据gscatter功能。观察早期最大池激活不会在同一类别的图像之间显示任何聚集。最终卷积层的激活在某种程度上按类别聚集,但比softmax激活少。不同颜色对应不同类别的观察。

多勒根=“关”;markerSize=7;图;子批次(1,3,1);gscatter(pool1tsne(:,1),pool1tsne(:,2),valImds.Labels,...[],'.',markerSize,doLegend);头衔(“最大池激活次数”); 子批次(1,3,2);gscatter(finalConvtsne(:,1),finalConvtsne(:,2),valImds.Labels,...[],'.',markerSize,doLegend);头衔(“最终会议激活”);次要情节(1,3,3);gscatter (softmaxtsne (: 1) softmaxtsne (:, 2), valImds。标签,...[],'.',markerSize,doLegend);头衔(“Softmax激活”);

探索t-SNE图中的观测值

创建一个更大的softmax激活图,包括标记每个类的图例。从t-SNE图中,您可以了解更多关于后验概率分布结构的信息。

例如,该图显示了一个独特的、独立的薯条观测集群,而生鱼片和寿司集群并没有得到很好的解决。与混淆矩阵类似,该图表明网络在预测薯条类方面更为准确。

numClasses=长度(类列表);颜色=线(numclass);h=数字;gscatter(softmaxtsne(:,1),softmaxtsne(:,2),有效标签,颜色);l=图例;l、 翻译=“没有”;l.地点=“bestoutside”;

您还可以使用t-SNE确定哪些图像被网络错误分类以及原因。不正确的观测通常是其周围星团的错误颜色的孤立点。例如,汉堡的错误分类图像非常靠近炸薯条区域(离橙色簇中心最近的绿点)。这个点是观测值99。在t-SNE图上圈出该观察结果,并用显示图像.

obs=99;图(h)保持;hs=散射(softmaxtsne(obs,1),softmaxtsne(obs,2),...“黑色”,“线宽”,1.5);l.String{end}=“汉堡包”; 持有;图();imshow(valImds.Files{obs});title(sprintf(“观察:%i\n”+...“实际:%s。预测:%s”,obs,...字符串(有效标签(obs)),字符串(YPred(obs)),...“翻译”,“没有”);

如果一张图片包含多种食物,网络就会混淆。在这种情况下,网络将图像分类为炸薯条,尽管前景中的食物是汉堡。图片边缘可见的炸薯条造成了混乱。

类似地,模棱两可的图像27(在本例前面所示)有多个区域。检查t-SNE图,突出显示此炸薯条图像的模棱两可方面。

obs=27;图(h)保持;h=散射(softmaxtsne(obs,1),softmaxtsne(obs,2),...“k”,“d”,“线宽”,1.5);l.String{end}=“炸薯条”; 持有;

图像在图中不是一个定义良好的簇,这表明分类可能是不正确的。这张照片离炸薯条区很远,离汉堡区很近。

这个为什么错误分类必须由其他信息提供,典型的是基于图像内容的假设。然后,您可以使用其他数据或工具来测试假设,这些工具表明图像的哪些空间区域对网络分类很重要。有关示例,请参见咬合敏感度Grad CAM揭示了深度学习决策背后的原因.

参考文献

[1] 范德马腾、劳伦斯和杰弗里·辛顿。“使用t-SNE可视化数据。”机器学习研究杂志9,2008年,第2579-2605页。

辅助函数

作用dataDir downloadExampleFoodImagesData (url)%下载Example Food Image数据集,包含978张图片%不同种类的食物分为9类。版权所有2019 The MathWorks, Inc.文件名=“ExampleFoodImageDataset.zip”;fileFullPath=fullfile(dataDir,文件名);%将.zip文件下载到临时目录中。如果~exist(fileFullPath,“文件”)fprintf(“正在下载MathWorks示例食品图像数据集…\n”);fprintf(“这可能需要几分钟才能下载…\n”);websave(文件完整路径,url);fprintf(“下载完成…\n”);其他的fprintf(正在跳过下载,文件已存在…\n);终止%解压缩文件。%%通过检查文件是否存在,检查文件是否已解压缩%类目录之一的。exampleFolderFullPath=fullfile(dataDir,“比萨饼”);如果~exist(例如FolderFullPath,“目录”)fprintf(“正在解压缩文件…\n”);解压缩(fileFullPath,dataDir);fprintf(“解完成…\ n”);其他的fprintf(正在跳过解压缩,文件已解压缩…\n);终止fprintf(“完成。\n”);终止

另见

|||||||(统计和机器学习工具箱)

相关的话题